테스트

This commit is contained in:
2026-02-07 15:35:08 +09:00
parent 58decec7f5
commit 99df2fde77
4 changed files with 21 additions and 9 deletions

View File

@@ -1,5 +1,7 @@
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sympy.printing.pytorch import torch
import torch_directml
device = torch_directml.device()
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
@@ -31,19 +33,22 @@ def get_qwen_model():
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True, # 추가 압축
# bnb_4bit_quant_type="nf4", # 성능이 좋은 양자화 방식
# bnb_4bit_compute_dtype=torch.bfloat16 # 연산 속도 유지
# bnb_4bit_compute_dtype=torch.float16 # 연산 속도 유지
# )
# 모델 로드
_model = AutoModelForCausalLM.from_pretrained(
QWEN_MODEL_PATH,
# quantization_config=bnb_config, #gpu
dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장
dtype=torch.float16, # CPU: bfloat16, GPU: float16 권장
device_map="auto",
trust_remote_code=True,
local_files_only=True
local_files_only=True,
low_cpu_mem_usage=True
)
_model.to(device)
# ✅ torch.compile() 적용 (PyTorch 2.0+)
if hasattr(torch, 'compile'):
try:

View File

@@ -5,6 +5,9 @@ from transformers import TextIteratorStreamer
from config.ai.call_llm_model import get_qwen_model
from config.db.chroma import collection
import torch_directml
device = torch_directml.device()
def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator:
"""
@@ -92,7 +95,7 @@ def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_s
add_generation_prompt=True,
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
# 스트리머 설정
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)