테스트

This commit is contained in:
2026-02-07 15:35:08 +09:00
parent 58decec7f5
commit 99df2fde77
4 changed files with 21 additions and 9 deletions

View File

@@ -28,4 +28,7 @@ model.save('./models/all-MiniLM-L6-v2')
# 실행방법 # 실행방법
uvicorn manual:app --host 0.0.0.0 --port 8040 --reload uvicorn manual:app --host 0.0.0.0 --port 8040 --reload
# 윈도우 환경 (라데온 사용법)
pip install torch-directml

9
app.py
View File

@@ -14,8 +14,10 @@ from config.util.org_filter import extract_keywords_simple
from repository.usersRepository import findAll from repository.usersRepository import findAll
from config.ai.call_llm_model import get_qwen_model from config.ai.call_llm_model import get_qwen_model
from config.db.chroma import collection from config.db.chroma import collection
import torch_directml
# DirectML 디바이스 선언
store_data = {} store_data = {}
device = torch_directml.device()
def search_employees(data: List[dict], query: str) -> List[dict]: def search_employees(data: List[dict], query: str) -> List[dict]:
""" """
@@ -79,7 +81,6 @@ def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.
Returns: Returns:
Generator: 스트리밍 응답 제너레이터 Generator: 스트리밍 응답 제너레이터
""" """
if ai : if ai :
if not results['documents'] or not results['documents'][0]: if not results['documents'] or not results['documents'][0]:
def generate_empty(): def generate_empty():
@@ -135,7 +136,7 @@ def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.
enable_thinking=False, # Qwen 모델 버전에 따라 지원 여부 확인 필요 enable_thinking=False, # Qwen 모델 버전에 따라 지원 여부 확인 필요
return_tensors="pt", return_tensors="pt",
return_dict=True return_dict=True
).to(model.device) ).to(device)
# 스트리머 설정 # 스트리머 설정
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
@@ -205,7 +206,7 @@ def query_summarize_simple(query: str) :
add_generation_prompt=True, add_generation_prompt=True,
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요 enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
) )
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) model_inputs = tokenizer([text], return_tensors="pt").to(device)
# conduct text completion # conduct text completion
generated_ids = model.generate( generated_ids = model.generate(

View File

@@ -1,5 +1,7 @@
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sympy.printing.pytorch import torch from sympy.printing.pytorch import torch
import torch_directml
device = torch_directml.device()
QWEN_MODEL_PATH = "./models/Qwen3-0.6B" QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
@@ -31,19 +33,22 @@ def get_qwen_model():
# load_in_4bit=True, # load_in_4bit=True,
# bnb_4bit_use_double_quant=True, # 추가 압축 # bnb_4bit_use_double_quant=True, # 추가 압축
# bnb_4bit_quant_type="nf4", # 성능이 좋은 양자화 방식 # bnb_4bit_quant_type="nf4", # 성능이 좋은 양자화 방식
# bnb_4bit_compute_dtype=torch.bfloat16 # 연산 속도 유지 # bnb_4bit_compute_dtype=torch.float16 # 연산 속도 유지
# ) # )
# 모델 로드 # 모델 로드
_model = AutoModelForCausalLM.from_pretrained( _model = AutoModelForCausalLM.from_pretrained(
QWEN_MODEL_PATH, QWEN_MODEL_PATH,
# quantization_config=bnb_config, #gpu # quantization_config=bnb_config, #gpu
dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장 dtype=torch.float16, # CPU: bfloat16, GPU: float16 권장
device_map="auto", device_map="auto",
trust_remote_code=True, trust_remote_code=True,
local_files_only=True local_files_only=True,
low_cpu_mem_usage=True
) )
_model.to(device)
# ✅ torch.compile() 적용 (PyTorch 2.0+) # ✅ torch.compile() 적용 (PyTorch 2.0+)
if hasattr(torch, 'compile'): if hasattr(torch, 'compile'):
try: try:

View File

@@ -5,6 +5,9 @@ from transformers import TextIteratorStreamer
from config.ai.call_llm_model import get_qwen_model from config.ai.call_llm_model import get_qwen_model
from config.db.chroma import collection from config.db.chroma import collection
import torch_directml
device = torch_directml.device()
def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator: def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator:
""" """
@@ -92,7 +95,7 @@ def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_s
add_generation_prompt=True, add_generation_prompt=True,
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요 enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
) )
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) model_inputs = tokenizer([text], return_tensors="pt").to(device)
# 스트리머 설정 # 스트리머 설정
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)