테스트
This commit is contained in:
174
org_offline.py
174
org_offline.py
@@ -10,7 +10,8 @@ from starlette.responses import StreamingResponse
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
||||||
|
|
||||||
from org_transformer_offline import init
|
from util.org_transformer_offline import init
|
||||||
|
from util.org_filter import extract_keywords_simple
|
||||||
|
|
||||||
# === 경로 설정 (모두 로컬) ===
|
# === 경로 설정 (모두 로컬) ===
|
||||||
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
|
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
|
||||||
@@ -99,29 +100,30 @@ def get_qwen_model():
|
|||||||
return _model, _tokenizer
|
return _model, _tokenizer
|
||||||
|
|
||||||
|
|
||||||
def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator:
|
def query_select(sessionId: str, query: str) :
|
||||||
|
keywords = extract_keywords_simple(query)
|
||||||
|
print(keywords)
|
||||||
|
filter_list = [{"sessionId": sessionId}]
|
||||||
|
if keywords['dept'] != '' :
|
||||||
|
filter_list.append({"deptCd": keywords['dept']})
|
||||||
|
if keywords['rank'] != '' :
|
||||||
|
filter_list.append({"gradeCd": keywords['rank']})
|
||||||
|
print(filter_list)
|
||||||
|
results = collection.query(
|
||||||
|
query_texts=[keywords['keyword']],
|
||||||
|
where={"$and": filter_list},
|
||||||
|
)
|
||||||
|
return results, keywords['keyword']
|
||||||
|
|
||||||
|
|
||||||
|
def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.2) -> Generator:
|
||||||
"""
|
"""
|
||||||
사용자 질문에 대해 벡터 DB를 검색하고, LLM을 통해 답변을 스트리밍으로 생성합니다.
|
사용자 질문에 대해 벡터 DB를 검색하고, LLM을 통해 답변을 스트리밍으로 생성합니다.
|
||||||
|
|
||||||
Args:
|
|
||||||
sessionId (str): 세션 ID (사용자 구분)
|
|
||||||
query (str): 사용자 질문
|
|
||||||
top_k (int): 검색할 문서 개수
|
|
||||||
min_similarity (float): 최소 유사도 임계값
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Generator: 스트리밍 응답 제너레이터
|
Generator: 스트리밍 응답 제너레이터
|
||||||
"""
|
"""
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
# 관련 문서 검색 (top_k보다 여유 있게 가져옴)
|
|
||||||
results = collection.query(
|
|
||||||
query_texts=[query],
|
|
||||||
n_results=top_k + 2, # 여유분 확보
|
|
||||||
where={"sessionId": sessionId}
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"검색 결과: {results}")
|
|
||||||
|
|
||||||
if not results['documents'] or not results['documents'][0]:
|
if not results['documents'] or not results['documents'][0]:
|
||||||
def generate_empty():
|
def generate_empty():
|
||||||
@@ -135,15 +137,10 @@ def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_s
|
|||||||
similarity = 1 - dist
|
similarity = 1 - dist
|
||||||
if similarity >= min_similarity:
|
if similarity >= min_similarity:
|
||||||
filtered_docs.append((doc, similarity))
|
filtered_docs.append((doc, similarity))
|
||||||
if len(filtered_docs) >= top_k:
|
if ai and len(filtered_docs) >= 5:
|
||||||
break
|
break
|
||||||
print(f"필터링된 문서: {filtered_docs}")
|
print(f"필터링된 문서: {filtered_docs}")
|
||||||
|
|
||||||
if not filtered_docs:
|
|
||||||
def generate_low_sim():
|
|
||||||
yield json.dumps({"kind": "text", "text": "유사도 기준을 만족하는 문서가 없습니다."}) + "\n"
|
|
||||||
return generate_low_sim
|
|
||||||
|
|
||||||
# 컨텍스트 생성
|
# 컨텍스트 생성
|
||||||
context_parts = []
|
context_parts = []
|
||||||
for i, (doc, sim) in enumerate(filtered_docs):
|
for i, (doc, sim) in enumerate(filtered_docs):
|
||||||
@@ -152,20 +149,15 @@ def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_s
|
|||||||
|
|
||||||
# 모델 로드
|
# 모델 로드
|
||||||
model, tokenizer = get_qwen_model()
|
model, tokenizer = get_qwen_model()
|
||||||
|
|
||||||
sub_query = ''
|
|
||||||
if query.find('부') > -1:
|
|
||||||
sub_query = '***부로 끝나는 단어는 부서 맵핑'
|
|
||||||
|
|
||||||
# 프롬프트 구성
|
# 프롬프트 구성
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": (
|
"content": (
|
||||||
'''
|
'''
|
||||||
당신은 인사 담당 어시스던트 입니다. 인사 이동, 승진, 적정 부서 이동 등 전반적으로 모든 인사 정보에 대해 답변해야합니다.
|
당신은 인사 담당 어시스던트 입니다. 인사 이동, 승진, 적정 부서 이동 등 전반적으로 모든 인사 정보에 대해 답변해야합니다.
|
||||||
hint) {}
|
'''
|
||||||
'''.format(sub_query)
|
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -174,8 +166,7 @@ def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_s
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
print(f'Messages: {messages}')
|
|
||||||
|
|
||||||
# 토큰화
|
# 토큰화
|
||||||
text = tokenizer.apply_chat_template(
|
text = tokenizer.apply_chat_template(
|
||||||
messages,
|
messages,
|
||||||
@@ -207,9 +198,7 @@ def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_s
|
|||||||
def generate():
|
def generate():
|
||||||
for new_text in streamer:
|
for new_text in streamer:
|
||||||
if new_text:
|
if new_text:
|
||||||
print(f'new Text: {new_text}')
|
|
||||||
yield json.dumps({"kind": "text", "text": new_text}) + "\n"
|
yield json.dumps({"kind": "text", "text": new_text}) + "\n"
|
||||||
print(f'End time: {datetime.now()}')
|
|
||||||
|
|
||||||
return generate
|
return generate
|
||||||
|
|
||||||
@@ -225,6 +214,78 @@ app.add_middleware(
|
|||||||
allow_headers=["*"], # 모든 헤더 허용
|
allow_headers=["*"], # 모든 헤더 허용
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def query_summarize_simple(query: str) :
|
||||||
|
model, tokenizer = get_qwen_model()
|
||||||
|
# 프롬프트 구성
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
'''
|
||||||
|
당신은 데이터 진단 전문가 입니다. 아래 질의 내용을 보고 해당 내용이 단순 질문 인지 아니면 통계, 총개수, 카운트, 직원수 질문 인지 확인해야합니다.
|
||||||
|
단순질문 일시 0 통계질문 또는 통계, 총개수, 카운트, 직원수 질문 일시 1 로 대답해주세요. 부가 내용은 필요 없습니다.
|
||||||
|
'''
|
||||||
|
)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"질문: {query}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f'Messages: {messages}')
|
||||||
|
|
||||||
|
# 토큰화
|
||||||
|
text = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
||||||
|
)
|
||||||
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
# conduct text completion
|
||||||
|
generated_ids = model.generate(
|
||||||
|
**model_inputs,
|
||||||
|
max_new_tokens=300,
|
||||||
|
do_sample=True, # ✅ 샘플링 활성화
|
||||||
|
temperature=0.3,
|
||||||
|
top_p=0.9,
|
||||||
|
pad_token_id=tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
||||||
|
|
||||||
|
# parsing thinking content
|
||||||
|
try:
|
||||||
|
# rindex finding 151668 (</think>)
|
||||||
|
index = len(output_ids) - output_ids[::-1].index(151668)
|
||||||
|
except ValueError:
|
||||||
|
index = 0
|
||||||
|
|
||||||
|
thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
|
||||||
|
end_think_id = tokenizer.convert_tokens_to_ids("</think>")
|
||||||
|
if end_think_id in output_ids:
|
||||||
|
idx = len(output_ids) - output_ids[::-1].index(end_think_id)
|
||||||
|
else:
|
||||||
|
idx = 0
|
||||||
|
content = tokenizer.decode(output_ids[idx:], skip_special_tokens=True).strip()
|
||||||
|
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# CORS 설정 추가
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # 모든 출처 허용
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"], # 모든 HTTP 메서드 허용
|
||||||
|
allow_headers=["*"], # 모든 헤더 허용
|
||||||
|
)
|
||||||
|
|
||||||
# === 1. 기초 데이터 주입 API ===
|
# === 1. 기초 데이터 주입 API ===
|
||||||
class Item(BaseModel):
|
class Item(BaseModel):
|
||||||
context: list
|
context: list
|
||||||
@@ -245,38 +306,9 @@ async def set_data(query: Item):
|
|||||||
remaining_count = collection.get(where={"sessionId": query.sessionId})
|
remaining_count = collection.get(where={"sessionId": query.sessionId})
|
||||||
print(f"남은 데이터 수: {len(remaining_count['ids'])}")
|
print(f"남은 데이터 수: {len(remaining_count['ids'])}")
|
||||||
|
|
||||||
doc_list = []
|
|
||||||
for q in query.context:
|
|
||||||
# 각 필드를 안전하게 추출 (None 방어)
|
|
||||||
name = q.get('name') or ""
|
|
||||||
dept = q.get('deptNm') or ""
|
|
||||||
grade = q.get('gradeNm') or ""
|
|
||||||
position = q.get('ptsnNm') or ""
|
|
||||||
office_phone = q.get('ofcePhn') or ""
|
|
||||||
mobile_phone = q.get('mblPhn') or ""
|
|
||||||
chief_name = q.get('chiefNm') or ""
|
|
||||||
state_code = q.get('state') or ""
|
|
||||||
|
|
||||||
# 상태 코드 한글화
|
|
||||||
state_map = {'C': '재직', 'T': '퇴사', 'H': '휴직'}
|
|
||||||
status = state_map.get(state_code, "정보없음")
|
|
||||||
|
|
||||||
# [핵심] 검색 엔진이 좋아할만한 서술형 문장 생성
|
init(query.sessionId, query.context)
|
||||||
# 부서명과 이름을 앞부분에 배치하여 가중치 유도
|
|
||||||
if name == '':
|
|
||||||
doc = (
|
|
||||||
f"부서: {dept}. "
|
|
||||||
f"해당 {dept}의 부서장은 {chief_name}입니다."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
doc = (
|
|
||||||
f"부서: {dept}. 이름: {name}. {dept} 소속의 {name} {grade}입니다. "
|
|
||||||
f"직위는 {position}이며 현재 {status} 중입니다. "
|
|
||||||
f"사내 전화번호(사선)는 {office_phone}입니다."
|
|
||||||
)
|
|
||||||
doc_list.append(doc)
|
|
||||||
|
|
||||||
init(query.sessionId, doc_list)
|
|
||||||
|
|
||||||
return {"status": "success", "message": f"{len(query.context)}건의 데이터가 로드되었습니다."}
|
return {"status": "success", "message": f"{len(query.context)}건의 데이터가 로드되었습니다."}
|
||||||
|
|
||||||
@@ -286,7 +318,17 @@ def question(sessionId: str, query: str):
|
|||||||
"""
|
"""
|
||||||
질의응답 API 엔드포인트
|
질의응답 API 엔드포인트
|
||||||
"""
|
"""
|
||||||
generate = query_and_summarize_stream(sessionId=sessionId, query=query)
|
type = query_summarize_simple(query=query)
|
||||||
|
if(type == '0') :
|
||||||
|
results, keyword = query_select(sessionId, query)
|
||||||
|
print('단순질문 AI')
|
||||||
|
generate = query_select_summarize_stream(results, query=keyword, ai=True)
|
||||||
|
else :
|
||||||
|
results, keyword = query_select(sessionId, query)
|
||||||
|
print('단순질문 데이터베이스조회')
|
||||||
|
generate = query_select_summarize_stream(results, query=keyword, ai=False)
|
||||||
|
|
||||||
|
|
||||||
return StreamingResponse(generate(), media_type="application/x-ndjson")
|
return StreamingResponse(generate(), media_type="application/x-ndjson")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
0
util/__init__.py
Normal file
0
util/__init__.py
Normal file
41
util/org_filter.py
Normal file
41
util/org_filter.py
Normal file
File diff suppressed because one or more lines are too long
@@ -41,17 +41,43 @@ def init(sessionId: str, data: List[Union[str, dict]]):
|
|||||||
|
|
||||||
# 데이터 처리: 문자열이면 그대로, 객체면 JSON 문자열로 변환
|
# 데이터 처리: 문자열이면 그대로, 객체면 JSON 문자열로 변환
|
||||||
documents = []
|
documents = []
|
||||||
for item in data:
|
|
||||||
if isinstance(item, str):
|
doc_list = []
|
||||||
documents.append(item)
|
for q in data:
|
||||||
|
# 각 필드를 안전하게 추출 (None 방어)
|
||||||
|
name = q.get('name') or ""
|
||||||
|
dept = q.get('deptNm') or ""
|
||||||
|
grade = q.get('gradeNm') or ""
|
||||||
|
position = q.get('ptsnNm') or ""
|
||||||
|
office_phone = q.get('ofcePhn') or ""
|
||||||
|
mobile_phone = q.get('mblPhn') or ""
|
||||||
|
chief_name = q.get('chiefNm') or ""
|
||||||
|
state_code = q.get('state') or ""
|
||||||
|
|
||||||
|
# 상태 코드 한글화
|
||||||
|
state_map = {'C': '재직', 'T': '퇴사', 'H': '휴직'}
|
||||||
|
status = state_map.get(state_code, "정보없음")
|
||||||
|
|
||||||
|
# [핵심] 검색 엔진이 좋아할만한 서술형 문장 생성
|
||||||
|
# 부서명과 이름을 앞부분에 배치하여 가중치 유도
|
||||||
|
if name == '':
|
||||||
|
doc = (
|
||||||
|
f"부서: {dept}. "
|
||||||
|
f"해당 {dept}의 부서장은 {chief_name}입니다."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
documents.append(json.dumps(item, ensure_ascii=False))
|
doc = (
|
||||||
|
f"부서: {dept}. 이름: {name}. {dept} 소속의 {name} {grade}입니다. "
|
||||||
|
f"직위는 {position}이며 현재 {status} 중입니다. "
|
||||||
|
f"사내 전화번호(사선)는 {office_phone}입니다."
|
||||||
|
)
|
||||||
|
doc_list.append(doc)
|
||||||
|
|
||||||
# 벡터 DB에 추가
|
# 벡터 DB에 추가
|
||||||
collection.add(
|
collection.add(
|
||||||
documents=documents,
|
documents=doc_list,
|
||||||
ids=doc_ids,
|
ids=doc_ids,
|
||||||
metadatas=[{"sessionId": sessionId} for _ in doc_ids]
|
metadatas=[{"sessionId": sessionId, 'deptCd': d.get('deptCd') or "", 'gradeCd': d.get('gradeCd') or ""} for d in data]
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f'{sessionId} init end')
|
print(f'{sessionId} init end')
|
||||||
124
util/summarize_query.py
Normal file
124
util/summarize_query.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
|
from org_offline import get_qwen_model
|
||||||
|
|
||||||
|
|
||||||
|
def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator:
|
||||||
|
"""
|
||||||
|
사용자 질문에 대해 벡터 DB를 검색하고, LLM을 통해 답변을 스트리밍으로 생성합니다.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sessionId (str): 세션 ID (사용자 구분)
|
||||||
|
query (str): 사용자 질문
|
||||||
|
top_k (int): 검색할 문서 개수
|
||||||
|
min_similarity (float): 최소 유사도 임계값
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Generator: 스트리밍 응답 제너레이터
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
# 관련 문서 검색 (top_k보다 여유 있게 가져옴)
|
||||||
|
from org_offline import collection
|
||||||
|
results = collection.query(
|
||||||
|
query_texts=[query],
|
||||||
|
n_results=top_k + 2, # 여유분 확보
|
||||||
|
where={"sessionId": sessionId}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"검색 결과: {results}")
|
||||||
|
|
||||||
|
if not results['documents'] or not results['documents'][0]:
|
||||||
|
def generate_empty():
|
||||||
|
yield json.dumps({"kind": "text", "text": "관련 문서를 찾을 수 없습니다."}) + "\n"
|
||||||
|
return generate_empty
|
||||||
|
|
||||||
|
# 유사도 계산 및 필터링
|
||||||
|
filtered_docs = []
|
||||||
|
if results['distances'] and results['distances'][0]:
|
||||||
|
for doc, dist in zip(results['documents'][0], results['distances'][0]):
|
||||||
|
similarity = 1 - dist
|
||||||
|
if similarity >= min_similarity:
|
||||||
|
filtered_docs.append((doc, similarity))
|
||||||
|
if len(filtered_docs) >= top_k:
|
||||||
|
break
|
||||||
|
print(f"필터링된 문서: {filtered_docs}")
|
||||||
|
|
||||||
|
if not filtered_docs:
|
||||||
|
def generate_low_sim():
|
||||||
|
yield json.dumps({"kind": "text", "text": "유사도 기준을 만족하는 문서가 없습니다."}) + "\n"
|
||||||
|
return generate_low_sim
|
||||||
|
|
||||||
|
# 컨텍스트 생성
|
||||||
|
context_parts = []
|
||||||
|
for i, (doc, sim) in enumerate(filtered_docs):
|
||||||
|
context_parts.append(f"[청크 {i+1} | 유사도: {sim:.3f}]\n{doc}")
|
||||||
|
context = "\n\n".join(context_parts)
|
||||||
|
|
||||||
|
# 모델 로드
|
||||||
|
model, tokenizer = get_qwen_model()
|
||||||
|
|
||||||
|
sub_query = ''
|
||||||
|
if query.find('부') > -1:
|
||||||
|
sub_query = '***부로 끝나는 단어는 부서 맵핑'
|
||||||
|
|
||||||
|
# 프롬프트 구성
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
'''
|
||||||
|
당신은 인사 담당 어시스던트 입니다. 인사 이동, 승진, 적정 부서 이동 등 전반적으로 모든 인사 정보에 대해 답변해야합니다.
|
||||||
|
hint) {}
|
||||||
|
'''.format(sub_query)
|
||||||
|
)
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"다음 데이터를 참고하세요:\n\n{context}\n\n질문: {query}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f'Messages: {messages}')
|
||||||
|
|
||||||
|
# 토큰화
|
||||||
|
text = tokenizer.apply_chat_template(
|
||||||
|
messages,
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
||||||
|
)
|
||||||
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||||
|
|
||||||
|
# 스트리머 설정
|
||||||
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||||
|
|
||||||
|
# 생성 인자 설정
|
||||||
|
generation_kwargs = dict(
|
||||||
|
**model_inputs,
|
||||||
|
streamer=streamer,
|
||||||
|
max_new_tokens=150,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.3,
|
||||||
|
top_p=0.9,
|
||||||
|
pad_token_id=tokenizer.eos_token_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 별도 스레드에서 생성 실행
|
||||||
|
from threading import Thread
|
||||||
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# 제너레이터 함수 정의
|
||||||
|
def generate():
|
||||||
|
for new_text in streamer:
|
||||||
|
if new_text:
|
||||||
|
print(f'new Text: {new_text}')
|
||||||
|
yield json.dumps({"kind": "text", "text": new_text}) + "\n"
|
||||||
|
print(f'End time: {datetime.now()}')
|
||||||
|
|
||||||
|
return generate
|
||||||
Reference in New Issue
Block a user