테스트

This commit is contained in:
2026-01-28 20:58:57 +09:00
parent 95c5fb7867
commit 5f828601ce
16 changed files with 115 additions and 82 deletions

0
config/__init__.py Normal file
View File

0
config/ai/__init__.py Normal file
View File

View File

@@ -0,0 +1,53 @@
from transformers import AutoTokenizer, AutoModelForCausalLM
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
_model = None
_tokenizer = None
def get_qwen_model():
"""
Qwen 모델과 토크나이저를 로드하거나 캐시된 인스턴스를 반환합니다.
torch.compile을 사용하여 추론 속도를 최적화합니다.
Returns:
tuple: (model, tokenizer)
"""
global _model, _tokenizer
if _model is not None:
return _model, _tokenizer
# 토크나이저 로드
_tokenizer = AutoTokenizer.from_pretrained(
QWEN_MODEL_PATH,
trust_remote_code=True,
local_files_only=True
)
# 모델 로드
from sympy.printing.pytorch import torch
_model = AutoModelForCausalLM.from_pretrained(
QWEN_MODEL_PATH,
dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장
device_map="auto",
trust_remote_code=True,
local_files_only=True
)
# ✅ torch.compile() 적용 (PyTorch 2.0+)
if hasattr(torch, 'compile'):
try:
print("🚀 torch.compile() 적용 중...")
# mode="reduce-overhead": 추론 시 추천
# dynamic=True: 입력 길이가 유동적인 RAG 환경에 적합
_model = torch.compile(
_model,
mode="reduce-overhead",
dynamic=True
)
print("✅ torch.compile() 성공!")
except Exception as e:
print(f"⚠️ torch.compile() 실패, 원본 모델 사용: {e}")
pass # 실패하면 원본 사용
return _model, _tokenizer

View File

@@ -0,0 +1,87 @@
import json
from typing import List, Union
import chromadb
from chromadb.utils import embedding_functions
# === 경로 설정 (모두 로컬) ===
EMBEDDING_MODEL_PATH = "./models/ko-sroberta-multitask"
# 2. 벡터 DB 설정
persist_directory = "./chroma_db"
chroma_client = chromadb.PersistentClient(path=persist_directory)
# 임베딩 함수 설정
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=EMBEDDING_MODEL_PATH, # 로컬 폴더 경로 가능
device="cpu",
normalize_embeddings=True
)
# 컬렉션 생성 또는 가져오기
collection = chroma_client.get_or_create_collection(
name="orgchart",
embedding_function=embedding_fn,
metadata={"hnsw:space": "cosine"}
)
def init(sessionId: str, data: List[Union[str, dict]]):
"""
데이터를 벡터 DB에 초기화(저장)합니다.
Args:
sessionId (str): 세션 ID
data (List[Union[str, dict]]): 저장할 데이터 리스트 (문자열 또는 딕셔너리)
"""
print(f'{sessionId} init start')
# 문서 ID 생성
doc_ids = [f"{sessionId}_{i}" for i in range(len(data))]
# 데이터 처리: 문자열이면 그대로, 객체면 JSON 문자열로 변환
documents = []
doc_list = []
for q in data:
# 각 필드를 안전하게 추출 (None 방어)
name = q.get('name') or ""
dept = q.get('deptNm') or ""
grade = q.get('gradeNm') or ""
position = q.get('ptsnNm') or ""
office_phone = q.get('ofcePhn') or ""
mobile_phone = q.get('mblPhn') or ""
chief_name = q.get('chiefNm') or ""
state_code = q.get('state') or ""
# 상태 코드 한글화
state_map = {'C': '재직', 'T': '퇴사', 'H': '휴직'}
status = state_map.get(state_code, "정보없음")
# [핵심] 검색 엔진이 좋아할만한 서술형 문장 생성
# 부서명과 이름을 앞부분에 배치하여 가중치 유도
if name == '':
doc = (
f"부서: {dept}. "
f"해당 {dept}의 부서장은 {chief_name}입니다."
)
else:
doc = (
f" [이름]:{name}"
f" [부서]:{dept}"
f" [소속]:{dept}"
f" [직급]:{grade}"
f" [직위]:{position}"
f" 현재 {status} 중입니다. "
f" 사내 전화번호(사선)는 {office_phone}입니다."
)
doc_list.append(doc)
# 벡터 DB에 추가
collection.add(
documents=doc_list,
ids=doc_ids,
metadatas=[{"sessionId": sessionId, 'deptCd': d.get('deptCd') or "", 'gradeCd': d.get('gradeCd') or ""} for d in data]
)
print(f'{sessionId} init end')

View File

@@ -0,0 +1,124 @@
from typing import Generator
from transformers import TextIteratorStreamer
from config.ai.call_llm_model import get_qwen_model
from config.db.chroma import collection
def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator:
"""
사용자 질문에 대해 벡터 DB를 검색하고, LLM을 통해 답변을 스트리밍으로 생성합니다.
Args:
sessionId (str): 세션 ID (사용자 구분)
query (str): 사용자 질문
top_k (int): 검색할 문서 개수
min_similarity (float): 최소 유사도 임계값
Returns:
Generator: 스트리밍 응답 제너레이터
"""
from datetime import datetime
import json
# 관련 문서 검색 (top_k보다 여유 있게 가져옴)
results = collection.query(
query_texts=[query],
n_results=top_k + 2, # 여유분 확보
where={"sessionId": sessionId}
)
print(f"검색 결과: {results}")
if not results['documents'] or not results['documents'][0]:
def generate_empty():
yield json.dumps({"kind": "text", "text": "관련 문서를 찾을 수 없습니다."}) + "\n"
return generate_empty
# 유사도 계산 및 필터링
filtered_docs = []
if results['distances'] and results['distances'][0]:
for doc, dist in zip(results['documents'][0], results['distances'][0]):
similarity = 1 - dist
if similarity >= min_similarity:
filtered_docs.append((doc, similarity))
if len(filtered_docs) >= top_k:
break
print(f"필터링된 문서: {filtered_docs}")
if not filtered_docs:
def generate_low_sim():
yield json.dumps({"kind": "text", "text": "유사도 기준을 만족하는 문서가 없습니다."}) + "\n"
return generate_low_sim
# 컨텍스트 생성
context_parts = []
for i, (doc, sim) in enumerate(filtered_docs):
context_parts.append(f"[청크 {i+1} | 유사도: {sim:.3f}]\n{doc}")
context = "\n\n".join(context_parts)
# 모델 로드
model, tokenizer = get_qwen_model()
sub_query = ''
if query.find('') > -1:
sub_query = '***부로 끝나는 단어는 부서 맵핑'
# 프롬프트 구성
messages = [
{
"role": "system",
"content": (
'''
당신은 인사 담당 어시스던트 입니다. 인사 이동, 승진, 적정 부서 이동 등 전반적으로 모든 인사 정보에 대해 답변해야합니다.
hint) {}
'''.format(sub_query)
)
},
{
"role": "user",
"content": f"다음 데이터를 참고하세요:\n\n{context}\n\n질문: {query}"
}
]
print(f'Messages: {messages}')
# 토큰화
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False # Qwen 모델 버전에 따라 지원 여부 확인 필요
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
# 스트리머 설정
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# 생성 인자 설정
generation_kwargs = dict(
**model_inputs,
streamer=streamer,
max_new_tokens=150,
do_sample=True,
temperature=0.3,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
# 별도 스레드에서 생성 실행
from threading import Thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# 제너레이터 함수 정의
def generate():
for new_text in streamer:
if new_text:
print(f'new Text: {new_text}')
yield json.dumps({"kind": "text", "text": new_text}) + "\n"
print(f'End time: {datetime.now()}')
return generate

0
config/db/__init__.py Normal file
View File

9
config/db/chroma.py Normal file
View File

@@ -0,0 +1,9 @@
import chromadb
# 2. 벡터 DB 설정
persist_directory = "./chroma_db"
chroma_client = chromadb.PersistentClient(path=persist_directory)
collection = chroma_client.get_or_create_collection(
name="orgchart",
)

11
config/db/database.py Normal file
View File

@@ -0,0 +1,11 @@
from sqlmodel import SQLModel, Field, create_engine, Session, select
from typing import Optional
from dotenv import load_dotenv
import os
load_dotenv()
DATABASE_URL = os.getenv("DATABASE_URL")
engine = create_engine(DATABASE_URL, echo=True)

0
config/util/__init__.py Normal file
View File

44
config/util/org_filter.py Normal file

File diff suppressed because one or more lines are too long