358 lines
11 KiB
Python
358 lines
11 KiB
Python
from threading import Thread
|
|
import json
|
|
from typing import List, Generator
|
|
|
|
import torch
|
|
import chromadb
|
|
from pydantic import BaseModel
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
from starlette.responses import StreamingResponse
|
|
from fastapi import FastAPI
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
|
|
|
from util.org_transformer_offline import init
|
|
from util.org_filter import extract_keywords_simple
|
|
|
|
# === 경로 설정 (모두 로컬) ===
|
|
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
|
|
|
|
# 전역 변수 설정
|
|
_model = None
|
|
_tokenizer = None
|
|
|
|
# 2. 벡터 DB 설정
|
|
persist_directory = "./chroma_db"
|
|
chroma_client = chromadb.PersistentClient(path=persist_directory)
|
|
|
|
collection = chroma_client.get_or_create_collection(
|
|
name="orgchart",
|
|
)
|
|
|
|
|
|
def search_employees(data: List[dict], query: str) -> List[dict]:
|
|
"""
|
|
직원 데이터에서 검색어가 포함된 항목을 필터링합니다.
|
|
(현재 API에서 직접 사용되지 않으나 유틸리티 목적으로 유지)
|
|
|
|
Args:
|
|
data (List[dict]): 직원 데이터 리스트
|
|
query (str): 검색어
|
|
|
|
Returns:
|
|
List[dict]: 필터링된 직원 리스트
|
|
"""
|
|
if not query:
|
|
return data
|
|
|
|
query = query.lower().strip()
|
|
|
|
# 모든 필드값 중 검색어가 포함된 항목 필터링
|
|
filtered = [
|
|
emp for emp in data
|
|
if any(query in str(value).lower() for value in emp.values() if value)
|
|
]
|
|
return filtered
|
|
|
|
|
|
def get_qwen_model():
|
|
"""
|
|
Qwen 모델과 토크나이저를 로드하거나 캐시된 인스턴스를 반환합니다.
|
|
torch.compile을 사용하여 추론 속도를 최적화합니다.
|
|
|
|
Returns:
|
|
tuple: (model, tokenizer)
|
|
"""
|
|
global _model, _tokenizer
|
|
if _model is not None:
|
|
return _model, _tokenizer
|
|
|
|
# 토크나이저 로드
|
|
_tokenizer = AutoTokenizer.from_pretrained(
|
|
QWEN_MODEL_PATH,
|
|
trust_remote_code=True,
|
|
local_files_only=True
|
|
)
|
|
|
|
# 모델 로드
|
|
_model = AutoModelForCausalLM.from_pretrained(
|
|
QWEN_MODEL_PATH,
|
|
dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장
|
|
device_map="auto",
|
|
trust_remote_code=True,
|
|
local_files_only=True
|
|
)
|
|
|
|
# ✅ torch.compile() 적용 (PyTorch 2.0+)
|
|
if hasattr(torch, 'compile'):
|
|
try:
|
|
print("🚀 torch.compile() 적용 중...")
|
|
# mode="reduce-overhead": 추론 시 추천
|
|
# dynamic=True: 입력 길이가 유동적인 RAG 환경에 적합
|
|
_model = torch.compile(
|
|
_model,
|
|
mode="reduce-overhead",
|
|
dynamic=True
|
|
)
|
|
print("✅ torch.compile() 성공!")
|
|
except Exception as e:
|
|
print(f"⚠️ torch.compile() 실패, 원본 모델 사용: {e}")
|
|
pass # 실패하면 원본 사용
|
|
return _model, _tokenizer
|
|
|
|
|
|
def query_select(sessionId: str, query: str, limit_unlock: bool) :
|
|
keywords = extract_keywords_simple(query)
|
|
print(keywords)
|
|
filter_conditions = [{"sessionId": sessionId}]
|
|
if keywords['dept'] != '' :
|
|
filter_conditions.append({"deptCd": keywords['dept']})
|
|
if keywords['rank'] != '' :
|
|
filter_conditions.append({"gradeCd": keywords['rank']})
|
|
|
|
# 2. 조건이 1개보다 많을 때만 $and로 묶기
|
|
if len(filter_conditions) > -1:
|
|
where_clause = {"$and": filter_conditions}
|
|
else:
|
|
where_clause = filter_conditions[0]
|
|
print(where_clause)
|
|
if limit_unlock :
|
|
results = collection.get(
|
|
where=where_clause
|
|
)
|
|
else :
|
|
results = collection.query(
|
|
query_texts=[keywords['keyword']],
|
|
where=where_clause
|
|
)
|
|
|
|
return results, keywords['keyword']
|
|
|
|
|
|
def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.2) -> Generator:
|
|
"""
|
|
사용자 질문에 대해 벡터 DB를 검색하고, LLM을 통해 답변을 스트리밍으로 생성합니다.
|
|
|
|
|
|
Returns:
|
|
Generator: 스트리밍 응답 제너레이터
|
|
"""
|
|
|
|
if ai :
|
|
if not results['documents'] or not results['documents'][0]:
|
|
def generate_empty():
|
|
yield json.dumps({"kind": "text", "text": "관련 문서를 찾을 수 없습니다."}) + "\n"
|
|
return generate_empty
|
|
|
|
# 유사도 계산 및 필터링
|
|
filtered_docs = []
|
|
if results['distances'] and results['distances'][0]:
|
|
for doc, dist in zip(results['documents'][0], results['distances'][0]):
|
|
similarity = 1 - dist
|
|
if similarity >= min_similarity:
|
|
filtered_docs.append((doc, similarity))
|
|
if ai and len(filtered_docs) >= 5:
|
|
break
|
|
print(f"필터링된 문서: {filtered_docs}")
|
|
|
|
# 컨텍스트 생성
|
|
context_parts = []
|
|
for i, (doc, sim) in enumerate(filtered_docs):
|
|
context_parts.append(f"[유사도: {sim:.3f}]\n{doc}")
|
|
context = "\n\n".join(context_parts)
|
|
else :
|
|
print('일반')
|
|
context_parts = []
|
|
for doc in results.get('documents') :
|
|
context_parts.append(f"{doc}")
|
|
context = "\n".join(context_parts)
|
|
|
|
# 모델 로드
|
|
model, tokenizer = get_qwen_model()
|
|
|
|
# 프롬프트 구성
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
'''
|
|
당신은 인사 담당 어시스던트 입니다. 인사 이동, 승진, 적정 부서 이동 등 전반적으로 모든 인사 정보에 대해 답변해야합니다.
|
|
'''
|
|
)
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": f"다음 데이터를 참고하세요:\n\n{context}\n\n질문: {query}"
|
|
}
|
|
]
|
|
|
|
|
|
# 토큰화
|
|
text = tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
|
)
|
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
|
|
|
# 스트리머 설정
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
|
|
|
# 생성 인자 설정
|
|
generation_kwargs = dict(
|
|
**model_inputs,
|
|
streamer=streamer,
|
|
max_new_tokens=3000,
|
|
do_sample=True,
|
|
temperature=0.3,
|
|
top_p=0.9,
|
|
pad_token_id=tokenizer.eos_token_id
|
|
)
|
|
|
|
# 별도 스레드에서 생성 실행
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
|
thread.start()
|
|
|
|
# 제너레이터 함수 정의
|
|
def generate():
|
|
|
|
for new_text in streamer:
|
|
if new_text:
|
|
yield json.dumps({"kind": "text", "text": new_text}) + "\n"
|
|
|
|
return generate
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
# CORS 설정 추가
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # 모든 출처 허용
|
|
allow_credentials=True,
|
|
allow_methods=["*"], # 모든 HTTP 메서드 허용
|
|
allow_headers=["*"], # 모든 헤더 허용
|
|
)
|
|
|
|
|
|
|
|
def query_summarize_simple(query: str) :
|
|
model, tokenizer = get_qwen_model()
|
|
# 프롬프트 구성
|
|
messages = [
|
|
{
|
|
"role": "system",
|
|
"content": (
|
|
'''
|
|
당신은 데이터 진단 전문가 입니다. 아래 질의 내용을 보고 해당 내용이 단순 질문 인지 아니면 통계, 총개수, 카운트, 직원수 질문 인지 확인해야합니다.
|
|
단순질문 일시 0 통계질문 또는 통계, 총개수, 카운트, 직원수 질문 일시 1 로 대답해주세요. 부가 내용은 필요 없습니다.
|
|
'''
|
|
)
|
|
},
|
|
{
|
|
"role": "user",
|
|
"content": f"질문: {query}"
|
|
}
|
|
]
|
|
|
|
# 토큰화
|
|
text = tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
|
|
)
|
|
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
|
|
|
# conduct text completion
|
|
generated_ids = model.generate(
|
|
**model_inputs,
|
|
max_new_tokens=300,
|
|
do_sample=True, # ✅ 샘플링 활성화
|
|
temperature=0.3,
|
|
top_p=0.9,
|
|
pad_token_id=tokenizer.eos_token_id
|
|
)
|
|
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
|
|
|
|
# parsing thinking content
|
|
try:
|
|
# rindex finding 151668 (</think>)
|
|
index = len(output_ids) - output_ids[::-1].index(151668)
|
|
except ValueError:
|
|
index = 0
|
|
|
|
thinking_content = tokenizer.decode(output_ids[:index], skip_special_tokens=True).strip("\n")
|
|
end_think_id = tokenizer.convert_tokens_to_ids("</think>")
|
|
if end_think_id in output_ids:
|
|
idx = len(output_ids) - output_ids[::-1].index(end_think_id)
|
|
else:
|
|
idx = 0
|
|
content = tokenizer.decode(output_ids[idx:], skip_special_tokens=True).strip()
|
|
|
|
return content
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
# CORS 설정 추가
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # 모든 출처 허용
|
|
allow_credentials=True,
|
|
allow_methods=["*"], # 모든 HTTP 메서드 허용
|
|
allow_headers=["*"], # 모든 헤더 허용
|
|
)
|
|
|
|
# === 1. 기초 데이터 주입 API ===
|
|
class Item(BaseModel):
|
|
context: list
|
|
sessionId: str
|
|
|
|
@app.post("/set-data")
|
|
async def set_data(query: Item):
|
|
"""
|
|
클라이언트로부터 받은 인사 데이터를 자연어 문장으로 변환하여 벡터 DB에 저장합니다.
|
|
기존 세션 데이터는 삭제 후 재생성됩니다.
|
|
"""
|
|
# 기존 데이터 삭제
|
|
collection.delete(
|
|
where={"sessionId": query.sessionId}
|
|
)
|
|
|
|
# 삭제 확인 (디버깅용)
|
|
remaining_count = collection.get(where={"sessionId": query.sessionId})
|
|
print(f"남은 데이터 수: {len(remaining_count['ids'])}")
|
|
|
|
|
|
|
|
init(query.sessionId, query.context)
|
|
|
|
return {"status": "success", "message": f"{len(query.context)}건의 데이터가 로드되었습니다."}
|
|
|
|
|
|
@app.get("/")
|
|
def question(sessionId: str, query: str):
|
|
"""
|
|
질의응답 API 엔드포인트
|
|
"""
|
|
type = query_summarize_simple(query=query)
|
|
if(type == '0') :
|
|
results, keyword = query_select(sessionId, query, False)
|
|
print(f'단순질문 AI : {len(results)}')
|
|
generate = query_select_summarize_stream(results, query=keyword, ai=True)
|
|
else :
|
|
results, keyword = query_select(sessionId, query, True)
|
|
print(f'개수 데이터베이스조회 : {len(results.get('ids'))}')
|
|
generate = query_select_summarize_stream(results, query=keyword, ai=False)
|
|
|
|
|
|
return StreamingResponse(generate(), media_type="application/x-ndjson")
|
|
|
|
|
|
# 개발용 실행 (직접 실행 시)
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
print("서버 시작: uvicorn manual:app --reload")
|
|
# uvicorn.run(app, host="0.0.0.0", port=8000)
|