테스트
This commit is contained in:
@@ -29,3 +29,6 @@ model.save('./models/all-MiniLM-L6-v2')
|
|||||||
uvicorn manual:app --host 0.0.0.0 --port 8040 --reload
|
uvicorn manual:app --host 0.0.0.0 --port 8040 --reload
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 윈도우 환경 (라데온 사용법)
|
||||||
|
pip install torch-directml
|
||||||
9
app.py
9
app.py
@@ -14,8 +14,10 @@ from config.util.org_filter import extract_keywords_simple
|
|||||||
from repository.usersRepository import findAll
|
from repository.usersRepository import findAll
|
||||||
from config.ai.call_llm_model import get_qwen_model
|
from config.ai.call_llm_model import get_qwen_model
|
||||||
from config.db.chroma import collection
|
from config.db.chroma import collection
|
||||||
|
import torch_directml
|
||||||
|
# DirectML 디바이스 선언
|
||||||
store_data = {}
|
store_data = {}
|
||||||
|
device = torch_directml.device()
|
||||||
|
|
||||||
def search_employees(data: List[dict], query: str) -> List[dict]:
|
def search_employees(data: List[dict], query: str) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
@@ -79,7 +81,6 @@ def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.
|
|||||||
Returns:
|
Returns:
|
||||||
Generator: 스트리밍 응답 제너레이터
|
Generator: 스트리밍 응답 제너레이터
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if ai :
|
if ai :
|
||||||
if not results['documents'] or not results['documents'][0]:
|
if not results['documents'] or not results['documents'][0]:
|
||||||
def generate_empty():
|
def generate_empty():
|
||||||
@@ -135,7 +136,7 @@ def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.
|
|||||||
enable_thinking=False, # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
enable_thinking=False, # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
return_dict=True
|
return_dict=True
|
||||||
).to(model.device)
|
).to(device)
|
||||||
|
|
||||||
# 스트리머 설정
|
# 스트리머 설정
|
||||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
@@ -205,7 +206,7 @@ def query_summarize_simple(query: str) :
|
|||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
||||||
)
|
)
|
||||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
model_inputs = tokenizer([text], return_tensors="pt").to(device)
|
||||||
|
|
||||||
# conduct text completion
|
# conduct text completion
|
||||||
generated_ids = model.generate(
|
generated_ids = model.generate(
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||||
from sympy.printing.pytorch import torch
|
from sympy.printing.pytorch import torch
|
||||||
|
import torch_directml
|
||||||
|
device = torch_directml.device()
|
||||||
|
|
||||||
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
|
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
|
||||||
|
|
||||||
@@ -31,19 +33,22 @@ def get_qwen_model():
|
|||||||
# load_in_4bit=True,
|
# load_in_4bit=True,
|
||||||
# bnb_4bit_use_double_quant=True, # 추가 압축
|
# bnb_4bit_use_double_quant=True, # 추가 압축
|
||||||
# bnb_4bit_quant_type="nf4", # 성능이 좋은 양자화 방식
|
# bnb_4bit_quant_type="nf4", # 성능이 좋은 양자화 방식
|
||||||
# bnb_4bit_compute_dtype=torch.bfloat16 # 연산 속도 유지
|
# bnb_4bit_compute_dtype=torch.float16 # 연산 속도 유지
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# 모델 로드
|
# 모델 로드
|
||||||
_model = AutoModelForCausalLM.from_pretrained(
|
_model = AutoModelForCausalLM.from_pretrained(
|
||||||
QWEN_MODEL_PATH,
|
QWEN_MODEL_PATH,
|
||||||
# quantization_config=bnb_config, #gpu
|
# quantization_config=bnb_config, #gpu
|
||||||
dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장
|
dtype=torch.float16, # CPU: bfloat16, GPU: float16 권장
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
local_files_only=True
|
local_files_only=True,
|
||||||
|
low_cpu_mem_usage=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_model.to(device)
|
||||||
|
|
||||||
# ✅ torch.compile() 적용 (PyTorch 2.0+)
|
# ✅ torch.compile() 적용 (PyTorch 2.0+)
|
||||||
if hasattr(torch, 'compile'):
|
if hasattr(torch, 'compile'):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -5,6 +5,9 @@ from transformers import TextIteratorStreamer
|
|||||||
from config.ai.call_llm_model import get_qwen_model
|
from config.ai.call_llm_model import get_qwen_model
|
||||||
from config.db.chroma import collection
|
from config.db.chroma import collection
|
||||||
|
|
||||||
|
import torch_directml
|
||||||
|
device = torch_directml.device()
|
||||||
|
|
||||||
|
|
||||||
def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator:
|
def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator:
|
||||||
"""
|
"""
|
||||||
@@ -92,7 +95,7 @@ def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_s
|
|||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
||||||
)
|
)
|
||||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
model_inputs = tokenizer([text], return_tensors="pt").to(device)
|
||||||
|
|
||||||
# 스트리머 설정
|
# 스트리머 설정
|
||||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user