테스트
This commit is contained in:
1
.env
Normal file
1
.env
Normal file
@@ -0,0 +1 @@
|
||||
DATABASE_URL=postgresql://bangae1:fpdlwms1@hmsn.ink:35432/orgchart
|
||||
@@ -2,32 +2,20 @@ from threading import Thread
|
||||
import json
|
||||
from typing import List, Generator
|
||||
|
||||
import torch
|
||||
import chromadb
|
||||
from pydantic import BaseModel
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
from starlette.responses import StreamingResponse
|
||||
from fastapi import FastAPI
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from util.org_transformer_offline import init
|
||||
from util.org_filter import extract_keywords_simple
|
||||
|
||||
# === 경로 설정 (모두 로컬) ===
|
||||
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
|
||||
|
||||
# 전역 변수 설정
|
||||
_model = None
|
||||
_tokenizer = None
|
||||
|
||||
# 2. 벡터 DB 설정
|
||||
persist_directory = "./chroma_db"
|
||||
chroma_client = chromadb.PersistentClient(path=persist_directory)
|
||||
|
||||
collection = chroma_client.get_or_create_collection(
|
||||
name="orgchart",
|
||||
)
|
||||
from config.ai.org_transformer import init
|
||||
from config.util.org_filter import extract_keywords_simple
|
||||
from repository.usersRepository import findAll
|
||||
from config.ai.call_llm_model import get_qwen_model
|
||||
from config.db.chroma import collection
|
||||
|
||||
store_data = {}
|
||||
|
||||
def search_employees(data: List[dict], query: str) -> List[dict]:
|
||||
"""
|
||||
@@ -54,51 +42,6 @@ def search_employees(data: List[dict], query: str) -> List[dict]:
|
||||
return filtered
|
||||
|
||||
|
||||
def get_qwen_model():
|
||||
"""
|
||||
Qwen 모델과 토크나이저를 로드하거나 캐시된 인스턴스를 반환합니다.
|
||||
torch.compile을 사용하여 추론 속도를 최적화합니다.
|
||||
|
||||
Returns:
|
||||
tuple: (model, tokenizer)
|
||||
"""
|
||||
global _model, _tokenizer
|
||||
if _model is not None:
|
||||
return _model, _tokenizer
|
||||
|
||||
# 토크나이저 로드
|
||||
_tokenizer = AutoTokenizer.from_pretrained(
|
||||
QWEN_MODEL_PATH,
|
||||
trust_remote_code=True,
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# 모델 로드
|
||||
_model = AutoModelForCausalLM.from_pretrained(
|
||||
QWEN_MODEL_PATH,
|
||||
dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# ✅ torch.compile() 적용 (PyTorch 2.0+)
|
||||
if hasattr(torch, 'compile'):
|
||||
try:
|
||||
print("🚀 torch.compile() 적용 중...")
|
||||
# mode="reduce-overhead": 추론 시 추천
|
||||
# dynamic=True: 입력 길이가 유동적인 RAG 환경에 적합
|
||||
_model = torch.compile(
|
||||
_model,
|
||||
mode="reduce-overhead",
|
||||
dynamic=True
|
||||
)
|
||||
print("✅ torch.compile() 성공!")
|
||||
except Exception as e:
|
||||
print(f"⚠️ torch.compile() 실패, 원본 모델 사용: {e}")
|
||||
pass # 실패하면 원본 사용
|
||||
return _model, _tokenizer
|
||||
|
||||
|
||||
def query_select(sessionId: str, query: str, limit_unlock: bool) :
|
||||
keywords = extract_keywords_simple(query)
|
||||
@@ -160,11 +103,14 @@ def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.
|
||||
context_parts.append(f"[유사도: {sim:.3f}]\n{doc}")
|
||||
context = "\n\n".join(context_parts)
|
||||
else :
|
||||
print('일반')
|
||||
context_parts = []
|
||||
for doc in results.get('documents') :
|
||||
context_parts.append(f"{doc}")
|
||||
context = "\n".join(context_parts)
|
||||
print('일반', results.get('documents'))
|
||||
context_parts = [f'검색된 사용자 수는 {len(results.get('ids'))}']
|
||||
docs = [f"{d[d.find('[이름]'): d.find('[', d.find('[이름]')+1)]} {d[d.find('[부서]'): d.find('[', d.find('[부서]')+1)]}" for d in results.get('documents')]
|
||||
|
||||
|
||||
|
||||
context = "\n".join(context_parts + docs)
|
||||
print(context)
|
||||
|
||||
# 모델 로드
|
||||
model, tokenizer = get_qwen_model()
|
||||
@@ -245,7 +191,7 @@ def query_summarize_simple(query: str) :
|
||||
"content": (
|
||||
'''
|
||||
당신은 데이터 진단 전문가 입니다. 아래 질의 내용을 보고 해당 내용이 단순 질문 인지 아니면 통계, 총개수, 카운트, 직원수 질문 인지 확인해야합니다.
|
||||
단순질문 일시 0 통계질문 또는 통계, 총개수, 카운트, 직원수 질문 일시 1 로 대답해주세요. 부가 내용은 필요 없습니다.
|
||||
[부서 직원수=1, 직급 직원수=2, 기타=99] 로 대답해주세요. 부가 내용은 필요 없습니다.
|
||||
'''
|
||||
)
|
||||
},
|
||||
@@ -325,7 +271,7 @@ async def set_data(query: Item):
|
||||
print(f"남은 데이터 수: {len(remaining_count['ids'])}")
|
||||
|
||||
|
||||
|
||||
store_data[query.sessionId] = query.context
|
||||
init(query.sessionId, query.context)
|
||||
|
||||
return {"status": "success", "message": f"{len(query.context)}건의 데이터가 로드되었습니다."}
|
||||
@@ -337,13 +283,14 @@ def question(sessionId: str, query: str):
|
||||
질의응답 API 엔드포인트
|
||||
"""
|
||||
type = query_summarize_simple(query=query)
|
||||
if(type == '0') :
|
||||
print(type)
|
||||
if(type == '99') :
|
||||
results, keyword = query_select(sessionId, query, False)
|
||||
print(f'단순질문 AI : {len(results)}')
|
||||
generate = query_select_summarize_stream(results, query=keyword, ai=True)
|
||||
# 부서 총직원수
|
||||
else :
|
||||
results, keyword = query_select(sessionId, query, True)
|
||||
print(f'개수 데이터베이스조회 : {len(results.get('ids'))}')
|
||||
generate = query_select_summarize_stream(results, query=keyword, ai=False)
|
||||
|
||||
|
||||
@@ -352,6 +299,5 @@ def question(sessionId: str, query: str):
|
||||
|
||||
# 개발용 실행 (직접 실행 시)
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
print("서버 시작: uvicorn manual:app --reload")
|
||||
# uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
0
config/ai/__init__.py
Normal file
0
config/ai/__init__.py
Normal file
53
config/ai/call_llm_model.py
Normal file
53
config/ai/call_llm_model.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
|
||||
|
||||
|
||||
_model = None
|
||||
_tokenizer = None
|
||||
|
||||
def get_qwen_model():
|
||||
"""
|
||||
Qwen 모델과 토크나이저를 로드하거나 캐시된 인스턴스를 반환합니다.
|
||||
torch.compile을 사용하여 추론 속도를 최적화합니다.
|
||||
|
||||
Returns:
|
||||
tuple: (model, tokenizer)
|
||||
"""
|
||||
global _model, _tokenizer
|
||||
if _model is not None:
|
||||
return _model, _tokenizer
|
||||
|
||||
# 토크나이저 로드
|
||||
_tokenizer = AutoTokenizer.from_pretrained(
|
||||
QWEN_MODEL_PATH,
|
||||
trust_remote_code=True,
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# 모델 로드
|
||||
from sympy.printing.pytorch import torch
|
||||
_model = AutoModelForCausalLM.from_pretrained(
|
||||
QWEN_MODEL_PATH,
|
||||
dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
local_files_only=True
|
||||
)
|
||||
|
||||
# ✅ torch.compile() 적용 (PyTorch 2.0+)
|
||||
if hasattr(torch, 'compile'):
|
||||
try:
|
||||
print("🚀 torch.compile() 적용 중...")
|
||||
# mode="reduce-overhead": 추론 시 추천
|
||||
# dynamic=True: 입력 길이가 유동적인 RAG 환경에 적합
|
||||
_model = torch.compile(
|
||||
_model,
|
||||
mode="reduce-overhead",
|
||||
dynamic=True
|
||||
)
|
||||
print("✅ torch.compile() 성공!")
|
||||
except Exception as e:
|
||||
print(f"⚠️ torch.compile() 실패, 원본 모델 사용: {e}")
|
||||
pass # 실패하면 원본 사용
|
||||
return _model, _tokenizer
|
||||
@@ -67,8 +67,12 @@ def init(sessionId: str, data: List[Union[str, dict]]):
|
||||
)
|
||||
else:
|
||||
doc = (
|
||||
f"부서: {dept}. 이름: {name}. {dept} 소속의 {name} {grade}입니다. "
|
||||
f"직위는 {position}이며 현재 {status} 중입니다. "
|
||||
f" [이름]:{name}"
|
||||
f" [부서]:{dept}"
|
||||
f" [소속]:{dept}"
|
||||
f" [직급]:{grade}"
|
||||
f" [직위]:{position}"
|
||||
f" 현재 {status} 중입니다. "
|
||||
f" 사내 전화번호(사선)는 {office_phone}입니다."
|
||||
)
|
||||
doc_list.append(doc)
|
||||
@@ -2,7 +2,8 @@ from typing import Generator
|
||||
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from org_offline import get_qwen_model
|
||||
from config.ai.call_llm_model import get_qwen_model
|
||||
from config.db.chroma import collection
|
||||
|
||||
|
||||
def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator:
|
||||
@@ -23,7 +24,6 @@ def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_s
|
||||
|
||||
|
||||
# 관련 문서 검색 (top_k보다 여유 있게 가져옴)
|
||||
from org_offline import collection
|
||||
results = collection.query(
|
||||
query_texts=[query],
|
||||
n_results=top_k + 2, # 여유분 확보
|
||||
0
config/db/__init__.py
Normal file
0
config/db/__init__.py
Normal file
9
config/db/chroma.py
Normal file
9
config/db/chroma.py
Normal file
@@ -0,0 +1,9 @@
|
||||
import chromadb
|
||||
|
||||
# 2. 벡터 DB 설정
|
||||
persist_directory = "./chroma_db"
|
||||
chroma_client = chromadb.PersistentClient(path=persist_directory)
|
||||
|
||||
collection = chroma_client.get_or_create_collection(
|
||||
name="orgchart",
|
||||
)
|
||||
11
config/db/database.py
Normal file
11
config/db/database.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from sqlmodel import SQLModel, Field, create_engine, Session, select
|
||||
from typing import Optional
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
|
||||
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||
engine = create_engine(DATABASE_URL, echo=True)
|
||||
|
||||
|
||||
0
config/util/__init__.py
Normal file
0
config/util/__init__.py
Normal file
0
model/__init__.py
Normal file
0
model/__init__.py
Normal file
0
repository/__init__.py
Normal file
0
repository/__init__.py
Normal file
8
repository/usersRepository.py
Normal file
8
repository/usersRepository.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from sqlmodel import Session, select
|
||||
from config.db.database import engine
|
||||
from model.models import Users
|
||||
|
||||
def findAll() :
|
||||
with Session(engine) as session:
|
||||
users = session.exec(select(Users)).all()
|
||||
return users
|
||||
Reference in New Issue
Block a user