테스트
This commit is contained in:
303
app.py
Normal file
303
app.py
Normal file
@@ -0,0 +1,303 @@
|
||||
from threading import Thread
|
||||
import json
|
||||
from typing import List, Generator
|
||||
|
||||
import chromadb
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import StreamingResponse
|
||||
from fastapi import FastAPI
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from config.ai.org_transformer import init
|
||||
from config.util.org_filter import extract_keywords_simple
|
||||
from repository.usersRepository import findAll
|
||||
from config.ai.call_llm_model import get_qwen_model
|
||||
from config.db.chroma import collection
|
||||
|
||||
store_data = {}
|
||||
|
||||
def search_employees(data: List[dict], query: str) -> List[dict]:
|
||||
"""
|
||||
직원 데이터에서 검색어가 포함된 항목을 필터링합니다.
|
||||
(현재 API에서 직접 사용되지 않으나 유틸리티 목적으로 유지)
|
||||
|
||||
Args:
|
||||
data (List[dict]): 직원 데이터 리스트
|
||||
query (str): 검색어
|
||||
|
||||
Returns:
|
||||
List[dict]: 필터링된 직원 리스트
|
||||
"""
|
||||
if not query:
|
||||
return data
|
||||
|
||||
query = query.lower().strip()
|
||||
|
||||
# 모든 필드값 중 검색어가 포함된 항목 필터링
|
||||
filtered = [
|
||||
emp for emp in data
|
||||
if any(query in str(value).lower() for value in emp.values() if value)
|
||||
]
|
||||
return filtered
|
||||
|
||||
|
||||
|
||||
def query_select(sessionId: str, query: str, limit_unlock: bool) :
|
||||
keywords = extract_keywords_simple(query)
|
||||
print(keywords)
|
||||
filter_conditions = [{"sessionId": sessionId}]
|
||||
if keywords['dept'] != '' :
|
||||
filter_conditions.append({"deptCd": keywords['dept']})
|
||||
if keywords['rank'] != '' :
|
||||
filter_conditions.append({"gradeCd": keywords['rank']})
|
||||
|
||||
# 2. 조건이 1개보다 많을 때만 $and로 묶기
|
||||
if len(filter_conditions) > -1:
|
||||
where_clause = {"$and": filter_conditions}
|
||||
else:
|
||||
where_clause = filter_conditions[0]
|
||||
print(where_clause)
|
||||
if limit_unlock :
|
||||
results = collection.get(
|
||||
where=where_clause
|
||||
)
|
||||
else :
|
||||
results = collection.query(
|
||||
query_texts=[keywords['keyword']],
|
||||
where=where_clause
|
||||
)
|
||||
|
||||
return results, keywords['keyword']
|
||||
|
||||
|
||||
def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.2) -> Generator:
|
||||
"""
|
||||
사용자 질문에 대해 벡터 DB를 검색하고, LLM을 통해 답변을 스트리밍으로 생성합니다.
|
||||
|
||||
|
||||
Returns:
|
||||
Generator: 스트리밍 응답 제너레이터
|
||||
"""
|
||||
|
||||
if ai :
|
||||
if not results['documents'] or not results['documents'][0]:
|
||||
def generate_empty():
|
||||
yield json.dumps({"kind": "text", "text": "관련 문서를 찾을 수 없습니다."}) + "\n"
|
||||
return generate_empty
|
||||
|
||||
# 유사도 계산 및 필터링
|
||||
filtered_docs = []
|
||||
if results['distances'] and results['distances'][0]:
|
||||
for doc, dist in zip(results['documents'][0], results['distances'][0]):
|
||||
similarity = 1 - dist
|
||||
if similarity >= min_similarity:
|
||||
filtered_docs.append((doc, similarity))
|
||||
if ai and len(filtered_docs) >= 5:
|
||||
break
|
||||
print(f"필터링된 문서: {filtered_docs}")
|
||||
|
||||
# 컨텍스트 생성
|
||||
context_parts = []
|
||||
for i, (doc, sim) in enumerate(filtered_docs):
|
||||
context_parts.append(f"[유사도: {sim:.3f}]\n{doc}")
|
||||
context = "\n\n".join(context_parts)
|
||||
else :
|
||||
print('일반', results.get('documents'))
|
||||
context_parts = [f'검색된 사용자 수는 {len(results.get('ids'))}']
|
||||
docs = [f"{d[d.find('[이름]'): d.find('[', d.find('[이름]')+1)]} {d[d.find('[부서]'): d.find('[', d.find('[부서]')+1)]}" for d in results.get('documents')]
|
||||
|
||||
|
||||
|
||||
context = "\n".join(context_parts + docs)
|
||||
print(context)
|
||||
|
||||
# 모델 로드
|
||||
model, tokenizer = get_qwen_model()
|
||||
|
||||
# 프롬프트 구성
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
'''
|
||||
당신은 인사 담당 어시스던트 입니다. 인사 이동, 승진, 적정 부서 이동 등 전반적으로 모든 인사 정보에 대해 답변해야합니다.
|
||||
'''
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"다음 데이터를 참고하세요:\n\n{context}\n\n질문: {query}"
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# 토큰화
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
# 스트리머 설정
|
||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
||||
|
||||
# 생성 인자 설정
|
||||
generation_kwargs = dict(
|
||||
**model_inputs,
|
||||
streamer=streamer,
|
||||
max_new_tokens=3000,
|
||||
do_sample=True,
|
||||
temperature=0.3,
|
||||
top_p=0.9,
|
||||
pad_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
|
||||
# 별도 스레드에서 생성 실행
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
# 제너레이터 함수 정의
|
||||
def generate():
|
||||
|
||||
for new_text in streamer:
|
||||
if new_text:
|
||||
yield json.dumps({"kind": "text", "text": new_text}) + "\n"
|
||||
|
||||
return generate
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# CORS 설정 추가
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 모든 출처 허용
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # 모든 HTTP 메서드 허용
|
||||
allow_headers=["*"], # 모든 헤더 허용
|
||||
)
|
||||
|
||||
|
||||
|
||||
def query_summarize_simple(query: str) :
|
||||
model, tokenizer = get_qwen_model()
|
||||
# 프롬프트 구성
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
'''
|
||||
당신은 데이터 진단 전문가 입니다. 아래 질의 내용을 보고 해당 내용이 단순 질문 인지 아니면 통계, 총개수, 카운트, 직원수 질문 인지 확인해야합니다.
|
||||
[부서 직원수=1, 직급 직원수=2, 기타=99] 로 대답해주세요. 부가 내용은 필요 없습니다.
|
||||
'''
|
||||
)
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"질문: {query}"
|
||||
}
|
||||
]
|
||||
|
||||
# 토큰화
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
||||
)
|
||||
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
||||
|
||||
# conduct text completion
|
||||
generated_ids = model.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=300,
|
||||
do_sample=True, # ✅ 샘플링 활성화
|
||||
temperature=0.3,
|
||||
top_p=0.9,
|
||||
pad_token_id=tokenizer.eos_token_id
|
||||
)
|
||||
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
||||
|
||||
# parsing thinking content
|
||||
try:
|
||||
# rindex finding 151668 (</think>)
|
||||
index = len(output_ids) - output_ids[::-1].index(151668)
|
||||
except ValueError:
|
||||
index = 0
|
||||
|
||||
thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
|
||||
end_think_id = tokenizer.convert_tokens_to_ids("</think>")
|
||||
if end_think_id in output_ids:
|
||||
idx = len(output_ids) - output_ids[::-1].index(end_think_id)
|
||||
else:
|
||||
idx = 0
|
||||
content = tokenizer.decode(output_ids[idx:], skip_special_tokens=True).strip()
|
||||
|
||||
return content
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# CORS 설정 추가
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # 모든 출처 허용
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"], # 모든 HTTP 메서드 허용
|
||||
allow_headers=["*"], # 모든 헤더 허용
|
||||
)
|
||||
|
||||
# === 1. 기초 데이터 주입 API ===
|
||||
class Item(BaseModel):
|
||||
context: list
|
||||
sessionId: str
|
||||
|
||||
@app.post("/set-data")
|
||||
async def set_data(query: Item):
|
||||
"""
|
||||
클라이언트로부터 받은 인사 데이터를 자연어 문장으로 변환하여 벡터 DB에 저장합니다.
|
||||
기존 세션 데이터는 삭제 후 재생성됩니다.
|
||||
"""
|
||||
# 기존 데이터 삭제
|
||||
collection.delete(
|
||||
where={"sessionId": query.sessionId}
|
||||
)
|
||||
|
||||
# 삭제 확인 (디버깅용)
|
||||
remaining_count = collection.get(where={"sessionId": query.sessionId})
|
||||
print(f"남은 데이터 수: {len(remaining_count['ids'])}")
|
||||
|
||||
|
||||
store_data[query.sessionId] = query.context
|
||||
init(query.sessionId, query.context)
|
||||
|
||||
return {"status": "success", "message": f"{len(query.context)}건의 데이터가 로드되었습니다."}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def question(sessionId: str, query: str):
|
||||
"""
|
||||
질의응답 API 엔드포인트
|
||||
"""
|
||||
type = query_summarize_simple(query=query)
|
||||
print(type)
|
||||
if(type == '99') :
|
||||
results, keyword = query_select(sessionId, query, False)
|
||||
print(f'단순질문 AI : {len(results)}')
|
||||
generate = query_select_summarize_stream(results, query=keyword, ai=True)
|
||||
# 부서 총직원수
|
||||
else :
|
||||
results, keyword = query_select(sessionId, query, True)
|
||||
generate = query_select_summarize_stream(results, query=keyword, ai=False)
|
||||
|
||||
|
||||
return StreamingResponse(generate(), media_type="application/x-ndjson")
|
||||
|
||||
|
||||
# 개발용 실행 (직접 실행 시)
|
||||
if __name__ == "__main__":
|
||||
print("서버 시작: uvicorn manual:app --reload")
|
||||
# uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
Reference in New Issue
Block a user