테스트
This commit is contained in:
124
util/summarize_query.py
Normal file
124
util/summarize_query.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import Generator
|
||||
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from org_offline import get_qwen_model
|
||||
|
||||
|
||||
def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator:
|
||||
"""
|
||||
사용자 질문에 대해 벡터 DB를 검색하고, LLM을 통해 답변을 스트리밍으로 생성합니다.
|
||||
|
||||
Args:
|
||||
sessionId (str): 세션 ID (사용자 구분)
|
||||
query (str): 사용자 질문
|
||||
top_k (int): 검색할 문서 개수
|
||||
min_similarity (float): 최소 유사도 임계값
|
||||
|
||||
Returns:
|
||||
Generator: 스트리밍 응답 제너레이터
|
||||
"""
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
|
||||
# 관련 문서 검색 (top_k보다 여유 있게 가져옴)
|
||||
from org_offline import collection
|
||||
results = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=top_k + 2, # 여유분 확보
|
||||
where={"sessionId": sessionId}
|
||||
)
|
||||
|
||||
print(f"검색 결과: {results}")
|
||||
|
||||
if not results['documents'] or not results['documents'][0]:
|
||||
def generate_empty():
|
||||
yield json.dumps({"kind": "text", "text": "관련 문서를 찾을 수 없습니다."}) + "\n"
|
||||
return generate_empty
|
||||
|
||||
# 유사도 계산 및 필터링
|
||||
filtered_docs = []
|
||||
if results['distances'] and results['distances'][0]:
|
||||
for doc, dist in zip(results['documents'][0], results['distances'][0]):
|
||||
similarity = 1 - dist
|
||||
if similarity >= min_similarity:
|
||||
filtered_docs.append((doc, similarity))
|
||||
if len(filtered_docs) >= top_k:
|
||||
break
|
||||
print(f"필터링된 문서: {filtered_docs}")
|
||||
|
||||
if not filtered_docs:
|
||||
def generate_low_sim():
|
||||
yield json.dumps({"kind": "text", "text": "유사도 기준을 만족하는 문서가 없습니다."}) + "\n"
|
||||
return generate_low_sim
|
||||
|
||||
# 컨텍스트 생성
|
||||
context_parts = []
|
||||
for i, (doc, sim) in enumerate(filtered_docs):
|
||||
context_parts.append(f"[청크 {i+1} | 유사도: {sim:.3f}]\n{doc}")
|
||||
context = "\n\n".join(context_parts)
|
||||
|
||||
# 모델 로드
|
||||
model, tokenizer = get_qwen_model()
|
||||
|
||||
sub_query = ''
|
||||
if query.find('부') > -1:
|
||||
sub_query = '***부로 끝나는 단어는 부서 맵핑'
|
||||
|
||||
# 프롬프트 구성
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
'''
|
||||
당신은 인사 담당 어시스던트 입니다. 인사 이동, 승진, 적정 부서 이동 등 전반적으로 모든 인사 정보에 대해 답변해야합니다.
|
||||
hint) {}
|
||||
'''.format(sub_query)
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"다음 데이터를 참고하세요:\n\n{context}\n\n질문: {query}"
|
||||
}
|
||||
]
|
||||
|
||||
print(f'Messages: {messages}')
|
||||
|
||||
# 토큰화
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
# 스트리머 설정
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
# 생성 인자 설정
|
||||
generation_kwargs = dict(
|
||||
**model_inputs,
|
||||
streamer=streamer,
|
||||
max_new_tokens=150,
|
||||
do_sample=True,
|
||||
temperature=0.3,
|
||||
top_p=0.9,
|
||||
pad_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# 별도 스레드에서 생성 실행
|
||||
from threading import Thread
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
# 제너레이터 함수 정의
|
||||
def generate():
|
||||
for new_text in streamer:
|
||||
if new_text:
|
||||
print(f'new Text: {new_text}')
|
||||
yield json.dumps({"kind": "text", "text": new_text}) + "\n"
|
||||
print(f'End time: {datetime.now()}')
|
||||
|
||||
return generate
|
||||
Reference in New Issue
Block a user