테스트

This commit is contained in:
2026-02-01 18:15:59 +09:00
parent b7efaf0542
commit 58decec7f5
2 changed files with 34 additions and 10 deletions

31
app.py
View File

@@ -113,26 +113,29 @@ def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.
{ {
"role": "system", "role": "system",
"content": ( "content": (
''' f'''
당신은 인사 담당 어시스던트 입니다. 인사 이동, 승진, 적정 부서 이동 등 전반적으로 모든 인사 정보에 대해 답변해야합니다. 당신은 인사 담당 어시스던트 입니다. 인사 이동, 승진, 적정 부서 이동 등 전반적으로 모든 인사 정보에 대해 답변해야합니다.
다음 데이터를 참고하세요
{context}
''' '''
) )
}, },
{ {
"role": "user", "role": "user",
"content": f"다음 데이터를 참고하세요:\n\n{context}\n\n질문: {query}" "content": f"질문: {query}"
} }
] ]
# 토큰화 # 토큰화
text = tokenizer.apply_chat_template( model_inputs = tokenizer.apply_chat_template(
messages, messages,
tokenize=False, tokenize=True,
add_generation_prompt=True, add_generation_prompt=True,
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요 enable_thinking=False, # Qwen 모델 버전에 따라 지원 여부 확인 필요
) return_tensors="pt",
model_inputs = tokenizer([text], return_tensors="pt").to(model.device) return_dict=True
).to(model.device)
# 스트리머 설정 # 스트리머 설정
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
@@ -145,6 +148,7 @@ def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.
do_sample=True, do_sample=True,
temperature=0.3, temperature=0.3,
top_p=0.9, top_p=0.9,
use_cache=True, # 속도 향상 핵심
pad_token_id=tokenizer.eos_token_id pad_token_id=tokenizer.eos_token_id
) )
@@ -248,6 +252,17 @@ class Item(BaseModel):
context: list context: list
sessionId: str sessionId: str
import gc
import torch
def clear_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# API 호출 사이사이나 요약 작업 직후에 실행
# clear_memory()
@app.post("/set-data") @app.post("/set-data")
async def set_data(query: Item): async def set_data(query: Item):
""" """
@@ -311,7 +326,7 @@ def question(sessionId: str, query: str):
print(sessionId, query) print(sessionId, query)
generate = query_select_summarize_stream(results, query, ai=False) generate = query_select_summarize_stream(results, query, ai=False)
clear_memory()
return StreamingResponse(generate(), media_type="application/x-ndjson") return StreamingResponse(generate(), media_type="application/x-ndjson")

View File

@@ -1,4 +1,5 @@
from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sympy.printing.pytorch import torch
QWEN_MODEL_PATH = "./models/Qwen3-0.6B" QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
@@ -25,10 +26,18 @@ def get_qwen_model():
local_files_only=True local_files_only=True
) )
# # 4-bit 양자화 설정 gpu
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_use_double_quant=True, # 추가 압축
# bnb_4bit_quant_type="nf4", # 성능이 좋은 양자화 방식
# bnb_4bit_compute_dtype=torch.bfloat16 # 연산 속도 유지
# )
# 모델 로드 # 모델 로드
from sympy.printing.pytorch import torch
_model = AutoModelForCausalLM.from_pretrained( _model = AutoModelForCausalLM.from_pretrained(
QWEN_MODEL_PATH, QWEN_MODEL_PATH,
# quantization_config=bnb_config, #gpu
dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장 dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장
device_map="auto", device_map="auto",
trust_remote_code=True, trust_remote_code=True,