테스트
This commit is contained in:
31
app.py
31
app.py
@@ -113,26 +113,29 @@ def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.
|
|||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": (
|
"content": (
|
||||||
'''
|
f'''
|
||||||
당신은 인사 담당 어시스던트 입니다. 인사 이동, 승진, 적정 부서 이동 등 전반적으로 모든 인사 정보에 대해 답변해야합니다.
|
당신은 인사 담당 어시스던트 입니다. 인사 이동, 승진, 적정 부서 이동 등 전반적으로 모든 인사 정보에 대해 답변해야합니다.
|
||||||
|
다음 데이터를 참고하세요
|
||||||
|
{context}
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": f"다음 데이터를 참고하세요:\n\n{context}\n\n질문: {query}"
|
"content": f"질문: {query}"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# 토큰화
|
# 토큰화
|
||||||
text = tokenizer.apply_chat_template(
|
model_inputs = tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
tokenize=False,
|
tokenize=True,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
enable_thinking=False, # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
||||||
)
|
return_tensors="pt",
|
||||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
return_dict=True
|
||||||
|
).to(model.device)
|
||||||
|
|
||||||
# 스트리머 설정
|
# 스트리머 설정
|
||||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
@@ -145,6 +148,7 @@ def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.
|
|||||||
do_sample=True,
|
do_sample=True,
|
||||||
temperature=0.3,
|
temperature=0.3,
|
||||||
top_p=0.9,
|
top_p=0.9,
|
||||||
|
use_cache=True, # 속도 향상 핵심
|
||||||
pad_token_id=tokenizer.eos_token_id
|
pad_token_id=tokenizer.eos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -248,6 +252,17 @@ class Item(BaseModel):
|
|||||||
context: list
|
context: list
|
||||||
sessionId: str
|
sessionId: str
|
||||||
|
|
||||||
|
import gc
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def clear_memory():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
# API 호출 사이사이나 요약 작업 직후에 실행
|
||||||
|
# clear_memory()
|
||||||
|
|
||||||
@app.post("/set-data")
|
@app.post("/set-data")
|
||||||
async def set_data(query: Item):
|
async def set_data(query: Item):
|
||||||
"""
|
"""
|
||||||
@@ -311,7 +326,7 @@ def question(sessionId: str, query: str):
|
|||||||
print(sessionId, query)
|
print(sessionId, query)
|
||||||
generate = query_select_summarize_stream(results, query, ai=False)
|
generate = query_select_summarize_stream(results, query, ai=False)
|
||||||
|
|
||||||
|
clear_memory()
|
||||||
return StreamingResponse(generate(), media_type="application/x-ndjson")
|
return StreamingResponse(generate(), media_type="application/x-ndjson")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
||||||
|
from sympy.printing.pytorch import torch
|
||||||
|
|
||||||
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
|
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
|
||||||
|
|
||||||
@@ -25,10 +26,18 @@ def get_qwen_model():
|
|||||||
local_files_only=True
|
local_files_only=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# # 4-bit 양자화 설정 gpu
|
||||||
|
# bnb_config = BitsAndBytesConfig(
|
||||||
|
# load_in_4bit=True,
|
||||||
|
# bnb_4bit_use_double_quant=True, # 추가 압축
|
||||||
|
# bnb_4bit_quant_type="nf4", # 성능이 좋은 양자화 방식
|
||||||
|
# bnb_4bit_compute_dtype=torch.bfloat16 # 연산 속도 유지
|
||||||
|
# )
|
||||||
|
|
||||||
# 모델 로드
|
# 모델 로드
|
||||||
from sympy.printing.pytorch import torch
|
|
||||||
_model = AutoModelForCausalLM.from_pretrained(
|
_model = AutoModelForCausalLM.from_pretrained(
|
||||||
QWEN_MODEL_PATH,
|
QWEN_MODEL_PATH,
|
||||||
|
# quantization_config=bnb_config, #gpu
|
||||||
dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장
|
dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장
|
||||||
device_map="auto",
|
device_map="auto",
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
|
|||||||
Reference in New Issue
Block a user