diff --git a/README.md b/README.md index dcb1296..0e79ac0 100644 --- a/README.md +++ b/README.md @@ -28,4 +28,7 @@ model.save('./models/all-MiniLM-L6-v2') # 실행방법 uvicorn manual:app --host 0.0.0.0 --port 8040 --reload - \ No newline at end of file + + +# 윈도우 환경 (라데온 사용법) +pip install torch-directml \ No newline at end of file diff --git a/app.py b/app.py index 8a4707e..71d6ce3 100644 --- a/app.py +++ b/app.py @@ -14,8 +14,10 @@ from config.util.org_filter import extract_keywords_simple from repository.usersRepository import findAll from config.ai.call_llm_model import get_qwen_model from config.db.chroma import collection - +import torch_directml +# DirectML 디바이스 선언 store_data = {} +device = torch_directml.device() def search_employees(data: List[dict], query: str) -> List[dict]: """ @@ -79,7 +81,6 @@ def query_select_summarize_stream(results, query, ai, min_similarity: float = 0. Returns: Generator: 스트리밍 응답 제너레이터 """ - if ai : if not results['documents'] or not results['documents'][0]: def generate_empty(): @@ -135,7 +136,7 @@ def query_select_summarize_stream(results, query, ai, min_similarity: float = 0. enable_thinking=False, # Qwen 모델 버전에 따라 지원 여부 확인 필요 return_tensors="pt", return_dict=True - ).to(model.device) + ).to(device) # 스트리머 설정 streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) @@ -205,7 +206,7 @@ def query_summarize_simple(query: str) : add_generation_prompt=True, enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요 ) - model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + model_inputs = tokenizer([text], return_tensors="pt").to(device) # conduct text completion generated_ids = model.generate( diff --git a/config/ai/call_llm_model.py b/config/ai/call_llm_model.py index d09aa41..0c1b32e 100644 --- a/config/ai/call_llm_model.py +++ b/config/ai/call_llm_model.py @@ -1,5 +1,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from sympy.printing.pytorch import torch +import torch_directml +device = torch_directml.device() QWEN_MODEL_PATH = "./models/Qwen3-0.6B" @@ -31,19 +33,22 @@ def get_qwen_model(): # load_in_4bit=True, # bnb_4bit_use_double_quant=True, # 추가 압축 # bnb_4bit_quant_type="nf4", # 성능이 좋은 양자화 방식 - # bnb_4bit_compute_dtype=torch.bfloat16 # 연산 속도 유지 + # bnb_4bit_compute_dtype=torch.float16 # 연산 속도 유지 # ) # 모델 로드 _model = AutoModelForCausalLM.from_pretrained( QWEN_MODEL_PATH, # quantization_config=bnb_config, #gpu - dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장 + dtype=torch.float16, # CPU: bfloat16, GPU: float16 권장 device_map="auto", trust_remote_code=True, - local_files_only=True + local_files_only=True, + low_cpu_mem_usage=True ) + _model.to(device) + # ✅ torch.compile() 적용 (PyTorch 2.0+) if hasattr(torch, 'compile'): try: diff --git a/config/ai/summarize_query.py b/config/ai/summarize_query.py index 971fcd9..c2f9579 100644 --- a/config/ai/summarize_query.py +++ b/config/ai/summarize_query.py @@ -5,6 +5,9 @@ from transformers import TextIteratorStreamer from config.ai.call_llm_model import get_qwen_model from config.db.chroma import collection +import torch_directml +device = torch_directml.device() + def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator: """ @@ -92,7 +95,7 @@ def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_s add_generation_prompt=True, enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요 ) - model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + model_inputs = tokenizer([text], return_tensors="pt").to(device) # 스트리머 설정 streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)