Files
org/app.py
2026-02-07 15:35:08 +09:00

338 lines
10 KiB
Python

from threading import Thread
import json
from typing import List, Generator
import chromadb
from pydantic import BaseModel
from starlette.middleware.cors import CORSMiddleware
from starlette.responses import StreamingResponse
from fastapi import FastAPI
from transformers import TextIteratorStreamer
from config.ai.org_transformer import init
from config.util.org_filter import extract_keywords_simple
from repository.usersRepository import findAll
from config.ai.call_llm_model import get_qwen_model
from config.db.chroma import collection
import torch_directml
# DirectML 디바이스 선언
store_data = {}
device = torch_directml.device()
def search_employees(data: List[dict], query: str) -> List[dict]:
"""
직원 데이터에서 검색어가 포함된 항목을 필터링합니다.
(현재 API에서 직접 사용되지 않으나 유틸리티 목적으로 유지)
Args:
data (List[dict]): 직원 데이터 리스트
query (str): 검색어
Returns:
List[dict]: 필터링된 직원 리스트
"""
if not query:
return data
query = query.lower().strip()
# 모든 필드값 중 검색어가 포함된 항목 필터링
filtered = [
emp for emp in data
if any(query in str(value).lower() for value in emp.values() if value)
]
return filtered
def query_select(sessionId: str, query: str, limit_unlock: bool) :
keywords = extract_keywords_simple(query)
print(keywords)
filter_conditions = [{"sessionId": sessionId}]
if keywords['dept'] != '' :
filter_conditions.append({"deptCd": keywords['dept']})
if keywords['rank'] != '' :
filter_conditions.append({"gradeCd": keywords['rank']})
# 2. 조건이 1개보다 많을 때만 $and로 묶기
if len(filter_conditions) > -1:
where_clause = {"$and": filter_conditions}
else:
where_clause = filter_conditions[0]
print(where_clause)
if limit_unlock :
results = collection.get(
where=where_clause
)
else :
results = collection.query(
query_texts=[keywords['keyword']],
where=where_clause
)
return results, keywords['keyword']
def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.2) -> Generator:
"""
사용자 질문에 대해 벡터 DB를 검색하고, LLM을 통해 답변을 스트리밍으로 생성합니다.
Returns:
Generator: 스트리밍 응답 제너레이터
"""
if ai :
if not results['documents'] or not results['documents'][0]:
def generate_empty():
yield json.dumps({"kind": "text", "text": "관련 문서를 찾을 수 없습니다."}) + "\n"
return generate_empty
# 유사도 계산 및 필터링
filtered_docs = []
if results['distances'] and results['distances'][0]:
for doc, dist in zip(results['documents'][0], results['distances'][0]):
similarity = 1 - dist
if similarity >= min_similarity:
filtered_docs.append((doc, similarity))
if ai and len(filtered_docs) >= 5:
break
print(f"필터링된 문서: {filtered_docs}")
# 컨텍스트 생성
context_parts = []
for i, (doc, sim) in enumerate(filtered_docs):
context_parts.append(f"[유사도: {sim:.3f}]\n{doc}")
context = "\n\n".join(context_parts)
else :
context = results
# 모델 로드
model, tokenizer = get_qwen_model()
# 프롬프트 구성
messages = [
{
"role": "system",
"content": (
f'''
당신은 인사 담당 어시스던트 입니다. 인사 이동, 승진, 적정 부서 이동 등 전반적으로 모든 인사 정보에 대해 답변해야합니다.
다음 데이터를 참고하세요
{context}
'''
)
},
{
"role": "user",
"content": f"질문: {query}"
}
]
# 토큰화
model_inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
enable_thinking=False, # Qwen 모델 버전에 따라 지원 여부 확인 필요
return_tensors="pt",
return_dict=True
).to(device)
# 스트리머 설정
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 생성 인자 설정
generation_kwargs = dict(
**model_inputs,
streamer=streamer,
max_new_tokens=400,
do_sample=True,
temperature=0.3,
top_p=0.9,
use_cache=True, # 속도 향상 핵심
pad_token_id=tokenizer.eos_token_id
)
# 별도 스레드에서 생성 실행
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# 제너레이터 함수 정의
def generate():
for new_text in streamer:
if new_text:
yield json.dumps({"kind": "text", "text": new_text}) + "\n"
return generate
app = FastAPI()
# CORS 설정 추가
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 모든 출처 허용
allow_credentials=True,
allow_methods=["*"], # 모든 HTTP 메서드 허용
allow_headers=["*"], # 모든 헤더 허용
)
def query_summarize_simple(query: str) :
model, tokenizer = get_qwen_model()
# 프롬프트 구성
messages = [
{
"role": "system",
"content": (
'''
당신은 데이터 진단 전문가 입니다. 아래 질의 내용을 보고 해당 내용이 단순 질문 인지 아니면 통계, 총개수, 카운트, 직원수 질문 인지 확인해야합니다.
[부서 직원수=1, 직급 직원수=2, 기타=99] 로 대답해주세요. 부가 내용은 필요 없습니다.
'''
)
},
{
"role": "user",
"content": f"질문: {query}"
}
]
# 토큰화
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
)
model_inputs = tokenizer([text], return_tensors="pt").to(device)
# conduct text completion
generated_ids = model.generate(
**model_inputs,
max_new_tokens=600,
do_sample=True, # ✅ 샘플링 활성화
temperature=0.3,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
# parsing thinking content
try:
# rindex finding 151668 (</think>)
index = len(output_ids) - output_ids[::-1].index(151668)
except ValueError:
index = 0
thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
end_think_id = tokenizer.convert_tokens_to_ids("</think>")
if end_think_id in output_ids:
idx = len(output_ids) - output_ids[::-1].index(end_think_id)
else:
idx = 0
content = tokenizer.decode(output_ids[idx:], skip_special_tokens=True).strip()
return content
app = FastAPI()
# CORS 설정 추가
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 모든 출처 허용
allow_credentials=True,
allow_methods=["*"], # 모든 HTTP 메서드 허용
allow_headers=["*"], # 모든 헤더 허용
)
# === 1. 기초 데이터 주입 API ===
class Item(BaseModel):
context: list
sessionId: str
import gc
import torch
def clear_memory():
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# API 호출 사이사이나 요약 작업 직후에 실행
# clear_memory()
@app.post("/set-data")
async def set_data(query: Item):
"""
클라이언트로부터 받은 인사 데이터를 자연어 문장으로 변환하여 벡터 DB에 저장합니다.
기존 세션 데이터는 삭제 후 재생성됩니다.
"""
# # 기존 데이터 삭제
# collection.delete(
# where={"sessionId": query.sessionId}
# )
#
# # 삭제 확인 (디버깅용)
# remaining_count = collection.get(where={"sessionId": query.sessionId})
# print(f"남은 데이터 수: {len(remaining_count['ids'])}")
depts = []
users = []
for n in query.context :
if n['userYn'] == False :
depts.append(n)
else :
users.append(n)
print(len(depts), len(users))
store_data[query.sessionId] = query.context
results = []
for d in depts :
subStr = f'[부서명] : {d['deptNm']} [상위부서명] : {d['pDeptNm']} [부서장] : {d['chiefDisplayNm']}\n'
inUsers = list(filter(lambda u : u['deptCd'] == d['deptCd'], users ))
reUsers = []
for iu in inUsers :
reUsers.append(f'** 이름 : {iu['name']} 사번 : {iu['sabun']} 직급 : {iu['gradeNm']} 직위 : {iu['ptsnNm']} **')
subStr += f'[직원수] : {len(inUsers)}\n'
subStr += f'[직원명단] : {', '.join(reUsers)}'
results.append(subStr)
# init(query.sessionId, query.context)
store_data[query.sessionId] = results
return {"status": "success", "message": f"{len(query.context)}건의 데이터가 로드되었습니다."}
@app.get("/")
def question(sessionId: str, query: str):
"""
질의응답 API 엔드포인트
"""
# type = query_summarize_simple(query=query)
# print(type)
# if(type == '99') :
# results, keyword = query_select(sessionId, query, False)
# print(f'단순질문 AI : {len(results)}')
# generate = query_select_summarize_stream(results, query=keyword, ai=True)
# # 부서 총직원수
# else :
# results, keyword = query_select(sessionId, query, True)
sessionData = store_data.get(sessionId)
results = '\n-------------------------------------------\n'.join(sessionData)
print(sessionId, query)
generate = query_select_summarize_stream(results, query, ai=False)
clear_memory()
return StreamingResponse(generate(), media_type="application/x-ndjson")
# 개발용 실행 (직접 실행 시)
if __name__ == "__main__":
print("서버 시작: uvicorn manual:app --reload")
# uvicorn.run(app, host="0.0.0.0", port=8000)