diff --git a/manual.py b/manual.py index 15868f6..ac93bc1 100644 --- a/manual.py +++ b/manual.py @@ -1,10 +1,12 @@ import os +from threading import Thread import torch from sentence_transformers import SentenceTransformer import chromadb +import json from chromadb.utils import embedding_functions -from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM +from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from fastapi import FastAPI # 2. 벡터 DB 설정 @@ -99,12 +101,46 @@ def query_and_summarize(job: str, query: str, top_k: int = 3): print("\n\n\n\n\n") return content +def query_and_summarize_stream(job: str, query: str): + # 1. 문서 검색 (기존과 동일) + results = collection.query(query_texts=[query], n_results=1, where={"dept": job}) + top_doc = results['documents'][0][0] + + model, tokenizer = get_qwen_model() + + # 2. 메시지 구성 + messages = [ + {"role": "system", "content": "당신은 회사 재무/회계 업무 전문 어시스턴트입니다."}, + {"role": "user", "content": f"다음 문서를 참고하세요:\n{top_doc}\n\n질문: {query}"} + ] + + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + + # 3. 스트리머 설정 + # skip_prompt=True를 해야 입력한 질문이 다시 나오지 않습니다. + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + + # 4. 별도 스레드에서 생성 실행 (비동기 처리를 위함) + generation_kwargs = dict(model_inputs, streamer=streamer, max_new_tokens=500) + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + # 5. 제너레이터 함수 정의 + def generate(): + for new_text in streamer: + if new_text: + # 클라이언트가 JSON으로 받길 원한다면 형식을 맞춰줍니다. + yield json.dumps({"kind": "text", "text": new_text}) + "\n" + + return generate + app = FastAPI() @app.get("/") def question(query: str) : user_query = query - answer = query_and_summarize(job="FI", query=user_query) + answer = query_and_summarize_stream(job="FI", query=user_query) return {"answer": answer} # 예시 사용 diff --git a/manual_offline.py b/manual_offline.py index 1b1356e..d466bd9 100644 --- a/manual_offline.py +++ b/manual_offline.py @@ -1,10 +1,15 @@ import os +from threading import Thread + import torch from sentence_transformers import SentenceTransformer import chromadb from chromadb.utils import embedding_functions -from transformers import AutoTokenizer, AutoModelForCausalLM +from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer +import json from fastapi import FastAPI +from fastapi.responses import StreamingResponse +from fastapi.middleware.cors import CORSMiddleware # # === 경로 설정 (모두 로컬) === QWEN_MODEL_PATH = "./models/Qwen3-0.6B" @@ -130,17 +135,128 @@ def query_and_summarize(job: str, query: str, top_k: int = 3, min_similarity: fl print(f'9 {datetime.now()}') return response +def query_and_summarize_stream(job: str, query: str, top_k: int = 3, min_similarity: float = 0.2): + from datetime import datetime + print(f'1 {datetime.now()}') + # 관련 문서 검색 (top_k보다 여유 있게 가져옴) + results = collection.query( + query_texts=[query], + n_results=top_k + 2, # 여유분 확보 + where={"dept": job} + ) + print(f'{datetime.now()}') + + if not results['documents'][0]: + def generate_empty(): + yield json.dumps({"kind": "text", "text": "관련 문서를 찾을 수 없습니다."}) + "\n" + return generate_empty + + print(f'2 {datetime.now()}') + # 유사도 계산 및 필터링 + filtered_docs = [] + for doc, dist in zip(results['documents'][0], results['distances'][0]): + similarity = 1 - dist + if similarity >= min_similarity: + filtered_docs.append((doc, similarity)) + if len(filtered_docs) >= top_k: + break + print(f'3 {datetime.now()}') + + if not filtered_docs: + def generate_low_sim(): + yield json.dumps({"kind": "text", "text": "유사도 기준을 만족하는 문서가 없습니다."}) + "\n" + return generate_low_sim + + # 컨텍스트 생성 + context_parts = [] + for i, (doc, sim) in enumerate(filtered_docs): + context_parts.append(f"[청크 {i+1} | 유사도: {sim:.3f}]\n{doc}") + context = "\n\n".join(context_parts) + print(f'4 {datetime.now()}') + + # 모델 로드 + model, tokenizer = get_qwen_model() + print(f'5 {datetime.now()}') + + # 프롬프트 구성 + messages = [ + { + "role": "system", + "content": ( + "당신은 회사 재무/회계 업무 전문 어시스턴트입니다. " + "사용자에게 제공된 여러 청크를 종합하여, 정확하고 상세하게 답변하세요. " + "필요시 문서 내용을 직접 인용하거나 요약해도 됩니다. " + "추측하지 말고, 문서에 근거한 정보만 사용하세요." + ) + }, + { + "role": "user", + "content": f"다음 문서들을 참고하세요:\n\n{context}\n\n질문: {query}" + } + ] + print(f'6 {datetime.now()}') + + # 토큰화 + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False + ) + model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + print(f'7 {datetime.now()}') + + # 스트리머 설정 + streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + + # 생성 인자 설정 + generation_kwargs = dict( + **model_inputs, + streamer=streamer, + max_new_tokens=150, + do_sample=True, + temperature=0.3, + top_p=0.9, + pad_token_id=tokenizer.eos_token_id + ) + + # 별도 스레드에서 생성 실행 + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + print(f'8 {datetime.now()}') + + # 제너레이터 함수 정의 + def generate(): + for new_text in streamer: + if new_text: + yield json.dumps({"kind": "text", "text": new_text}) + "\n" + print(f'9 {datetime.now()}') + + return generate + # FastAPI 앱 app = FastAPI() +# CORS 설정 추가 +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # 모든 출처 허용 + allow_credentials=True, + allow_methods=["*"], # 모든 HTTP 메서드 허용 + allow_headers=["*"], # 모든 헤더 허용 +) + @app.get("/") def question(query: str): - answer = query_and_summarize(job="FI", query=query) - return {"answer": answer} + # answer = query_and_summarize(job="FI", query=query) + # return {"answer": answer} + generate = query_and_summarize_stream(job="FI", query=query) + return StreamingResponse(generate(), media_type="application/x-ndjson") # 개발용 실행 (직접 실행 시) if __name__ == "__main__": + query_and_summarize_stream(job="FI", query='외화 송금 방법?') import uvicorn - print("서버 시작: uvicorn manual:app --reload") + # print("서버 시작: uvicorn manual_offline:app --reload") # 예시 질의 (주석 해제 시 직접 테스트 가능) # print(query_and_summarize("FI", "외화 송금 절차는 어떻게 되나요?")) \ No newline at end of file