from threading import Thread import json from typing import List, Generator import torch import chromadb from pydantic import BaseModel from starlette.middleware.cors import CORSMiddleware from starlette.responses import StreamingResponse from fastapi import FastAPI from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from org_transformer_offline import init # === 경로 설정 (모두 로컬) === QWEN_MODEL_PATH = "./models/Qwen3-0.6B" # 전역 변수 설정 _model = None _tokenizer = None # 2. 벡터 DB 설정 persist_directory = "./chroma_db" chroma_client = chromadb.PersistentClient(path=persist_directory) collection = chroma_client.get_or_create_collection( name="orgchart", ) def search_employees(data: List[dict], query: str) -> List[dict]: """ 직원 데이터에서 검색어가 포함된 항목을 필터링합니다. (현재 API에서 직접 사용되지 않으나 유틸리티 목적으로 유지) Args: data (List[dict]): 직원 데이터 리스트 query (str): 검색어 Returns: List[dict]: 필터링된 직원 리스트 """ if not query: return data query = query.lower().strip() # 모든 필드값 중 검색어가 포함된 항목 필터링 filtered = [ emp for emp in data if any(query in str(value).lower() for value in emp.values() if value) ] return filtered def get_qwen_model(): """ Qwen 모델과 토크나이저를 로드하거나 캐시된 인스턴스를 반환합니다. torch.compile을 사용하여 추론 속도를 최적화합니다. Returns: tuple: (model, tokenizer) """ global _model, _tokenizer if _model is not None: return _model, _tokenizer # 토크나이저 로드 _tokenizer = AutoTokenizer.from_pretrained( QWEN_MODEL_PATH, trust_remote_code=True, local_files_only=True ) # 모델 로드 _model = AutoModelForCausalLM.from_pretrained( QWEN_MODEL_PATH, dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장 device_map="auto", trust_remote_code=True, local_files_only=True ) # ✅ torch.compile() 적용 (PyTorch 2.0+) if hasattr(torch, 'compile'): try: print("🚀 torch.compile() 적용 중...") # mode="reduce-overhead": 추론 시 추천 # dynamic=True: 입력 길이가 유동적인 RAG 환경에 적합 _model = torch.compile( _model, mode="reduce-overhead", dynamic=True ) print("✅ torch.compile() 성공!") except Exception as e: print(f"⚠️ torch.compile() 실패, 원본 모델 사용: {e}") pass # 실패하면 원본 사용 return _model, _tokenizer def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator: """ 사용자 질문에 대해 벡터 DB를 검색하고, LLM을 통해 답변을 스트리밍으로 생성합니다. Args: sessionId (str): 세션 ID (사용자 구분) query (str): 사용자 질문 top_k (int): 검색할 문서 개수 min_similarity (float): 최소 유사도 임계값 Returns: Generator: 스트리밍 응답 제너레이터 """ from datetime import datetime # 관련 문서 검색 (top_k보다 여유 있게 가져옴) results = collection.query( query_texts=[query], n_results=top_k + 2, # 여유분 확보 where={"sessionId": sessionId} ) print(f"검색 결과: {results}") if not results['documents'] or not results['documents'][0]: def generate_empty(): yield json.dumps({"kind": "text", "text": "관련 문서를 찾을 수 없습니다."}) + "\n" return generate_empty # 유사도 계산 및 필터링 filtered_docs = [] if results['distances'] and results['distances'][0]: for doc, dist in zip(results['documents'][0], results['distances'][0]): similarity = 1 - dist if similarity >= min_similarity: filtered_docs.append((doc, similarity)) if len(filtered_docs) >= top_k: break print(f"필터링된 문서: {filtered_docs}") if not filtered_docs: def generate_low_sim(): yield json.dumps({"kind": "text", "text": "유사도 기준을 만족하는 문서가 없습니다."}) + "\n" return generate_low_sim # 컨텍스트 생성 context_parts = [] for i, (doc, sim) in enumerate(filtered_docs): context_parts.append(f"[청크 {i+1} | 유사도: {sim:.3f}]\n{doc}") context = "\n\n".join(context_parts) # 모델 로드 model, tokenizer = get_qwen_model() sub_query = '' if query.find('부') > -1: sub_query = '***부로 끝나는 단어는 부서 맵핑' # 프롬프트 구성 messages = [ { "role": "system", "content": ( ''' 당신은 인사 담당 어시스던트 입니다. 인사 이동, 승진, 적정 부서 이동 등 전반적으로 모든 인사 정보에 대해 답변해야합니다. hint) {} '''.format(sub_query) ) }, { "role": "user", "content": f"다음 데이터를 참고하세요:\n\n{context}\n\n질문: {query}" } ] print(f'Messages: {messages}') # 토큰화 text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요 ) model_inputs = tokenizer([text], return_tensors="pt").to(model.device) # 스트리머 설정 streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # 생성 인자 설정 generation_kwargs = dict( **model_inputs, streamer=streamer, max_new_tokens=150, do_sample=True, temperature=0.3, top_p=0.9, pad_token_id=tokenizer.eos_token_id ) # 별도 스레드에서 생성 실행 thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # 제너레이터 함수 정의 def generate(): for new_text in streamer: if new_text: print(f'new Text: {new_text}') yield json.dumps({"kind": "text", "text": new_text}) + "\n" print(f'End time: {datetime.now()}') return generate app = FastAPI() # CORS 설정 추가 app.add_middleware( CORSMiddleware, allow_origins=["*"], # 모든 출처 허용 allow_credentials=True, allow_methods=["*"], # 모든 HTTP 메서드 허용 allow_headers=["*"], # 모든 헤더 허용 ) # === 1. 기초 데이터 주입 API === class Item(BaseModel): context: list sessionId: str @app.post("/set-data") async def set_data(query: Item): """ 클라이언트로부터 받은 인사 데이터를 자연어 문장으로 변환하여 벡터 DB에 저장합니다. 기존 세션 데이터는 삭제 후 재생성됩니다. """ # 기존 데이터 삭제 collection.delete( where={"sessionId": query.sessionId} ) # 삭제 확인 (디버깅용) remaining_count = collection.get(where={"sessionId": query.sessionId}) print(f"남은 데이터 수: {len(remaining_count['ids'])}") doc_list = [] for q in query.context: # 각 필드를 안전하게 추출 (None 방어) name = q.get('name') or "" dept = q.get('deptNm') or "" grade = q.get('gradeNm') or "" position = q.get('ptsnNm') or "" office_phone = q.get('ofcePhn') or "" mobile_phone = q.get('mblPhn') or "" chief_name = q.get('chiefNm') or "" state_code = q.get('state') or "" # 상태 코드 한글화 state_map = {'C': '재직', 'T': '퇴사', 'H': '휴직'} status = state_map.get(state_code, "정보없음") # [핵심] 검색 엔진이 좋아할만한 서술형 문장 생성 # 부서명과 이름을 앞부분에 배치하여 가중치 유도 if name == '': doc = ( f"부서: {dept}. " f"해당 {dept}의 부서장은 {chief_name}입니다." ) else: doc = ( f"부서: {dept}. 이름: {name}. {dept} 소속의 {name} {grade}입니다. " f"직위는 {position}이며 현재 {status} 중입니다. " f"사내 전화번호(사선)는 {office_phone}입니다." ) doc_list.append(doc) init(query.sessionId, doc_list) return {"status": "success", "message": f"{len(query.context)}건의 데이터가 로드되었습니다."} @app.get("/") def question(sessionId: str, query: str): """ 질의응답 API 엔드포인트 """ generate = query_and_summarize_stream(sessionId=sessionId, query=query) return StreamingResponse(generate(), media_type="application/x-ndjson") # 개발용 실행 (직접 실행 시) if __name__ == "__main__": import uvicorn print("서버 시작: uvicorn manual:app --reload") # uvicorn.run(app, host="0.0.0.0", port=8000)