테스트
This commit is contained in:
1
.env
Normal file
1
.env
Normal file
@@ -0,0 +1 @@
|
|||||||
|
DATABASE_URL=postgresql://bangae1:fpdlwms1@hmsn.ink:35432/orgchart
|
||||||
@@ -2,32 +2,20 @@ from threading import Thread
|
|||||||
import json
|
import json
|
||||||
from typing import List, Generator
|
from typing import List, Generator
|
||||||
|
|
||||||
import torch
|
|
||||||
import chromadb
|
import chromadb
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.middleware.cors import CORSMiddleware
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
from starlette.responses import StreamingResponse
|
from starlette.responses import StreamingResponse
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
from util.org_transformer_offline import init
|
from config.ai.org_transformer import init
|
||||||
from util.org_filter import extract_keywords_simple
|
from config.util.org_filter import extract_keywords_simple
|
||||||
|
from repository.usersRepository import findAll
|
||||||
# === 경로 설정 (모두 로컬) ===
|
from config.ai.call_llm_model import get_qwen_model
|
||||||
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
|
from config.db.chroma import collection
|
||||||
|
|
||||||
# 전역 변수 설정
|
|
||||||
_model = None
|
|
||||||
_tokenizer = None
|
|
||||||
|
|
||||||
# 2. 벡터 DB 설정
|
|
||||||
persist_directory = "./chroma_db"
|
|
||||||
chroma_client = chromadb.PersistentClient(path=persist_directory)
|
|
||||||
|
|
||||||
collection = chroma_client.get_or_create_collection(
|
|
||||||
name="orgchart",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
store_data = {}
|
||||||
|
|
||||||
def search_employees(data: List[dict], query: str) -> List[dict]:
|
def search_employees(data: List[dict], query: str) -> List[dict]:
|
||||||
"""
|
"""
|
||||||
@@ -54,51 +42,6 @@ def search_employees(data: List[dict], query: str) -> List[dict]:
|
|||||||
return filtered
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
def get_qwen_model():
|
|
||||||
"""
|
|
||||||
Qwen 모델과 토크나이저를 로드하거나 캐시된 인스턴스를 반환합니다.
|
|
||||||
torch.compile을 사용하여 추론 속도를 최적화합니다.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple: (model, tokenizer)
|
|
||||||
"""
|
|
||||||
global _model, _tokenizer
|
|
||||||
if _model is not None:
|
|
||||||
return _model, _tokenizer
|
|
||||||
|
|
||||||
# 토크나이저 로드
|
|
||||||
_tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
QWEN_MODEL_PATH,
|
|
||||||
trust_remote_code=True,
|
|
||||||
local_files_only=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 모델 로드
|
|
||||||
_model = AutoModelForCausalLM.from_pretrained(
|
|
||||||
QWEN_MODEL_PATH,
|
|
||||||
dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장
|
|
||||||
device_map="auto",
|
|
||||||
trust_remote_code=True,
|
|
||||||
local_files_only=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# ✅ torch.compile() 적용 (PyTorch 2.0+)
|
|
||||||
if hasattr(torch, 'compile'):
|
|
||||||
try:
|
|
||||||
print("🚀 torch.compile() 적용 중...")
|
|
||||||
# mode="reduce-overhead": 추론 시 추천
|
|
||||||
# dynamic=True: 입력 길이가 유동적인 RAG 환경에 적합
|
|
||||||
_model = torch.compile(
|
|
||||||
_model,
|
|
||||||
mode="reduce-overhead",
|
|
||||||
dynamic=True
|
|
||||||
)
|
|
||||||
print("✅ torch.compile() 성공!")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"⚠️ torch.compile() 실패, 원본 모델 사용: {e}")
|
|
||||||
pass # 실패하면 원본 사용
|
|
||||||
return _model, _tokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def query_select(sessionId: str, query: str, limit_unlock: bool) :
|
def query_select(sessionId: str, query: str, limit_unlock: bool) :
|
||||||
keywords = extract_keywords_simple(query)
|
keywords = extract_keywords_simple(query)
|
||||||
@@ -160,11 +103,14 @@ def query_select_summarize_stream(results, query, ai, min_similarity: float = 0.
|
|||||||
context_parts.append(f"[유사도: {sim:.3f}]\n{doc}")
|
context_parts.append(f"[유사도: {sim:.3f}]\n{doc}")
|
||||||
context = "\n\n".join(context_parts)
|
context = "\n\n".join(context_parts)
|
||||||
else :
|
else :
|
||||||
print('일반')
|
print('일반', results.get('documents'))
|
||||||
context_parts = []
|
context_parts = [f'검색된 사용자 수는 {len(results.get('ids'))}']
|
||||||
for doc in results.get('documents') :
|
docs = [f"{d[d.find('[이름]'): d.find('[', d.find('[이름]')+1)]} {d[d.find('[부서]'): d.find('[', d.find('[부서]')+1)]}" for d in results.get('documents')]
|
||||||
context_parts.append(f"{doc}")
|
|
||||||
context = "\n".join(context_parts)
|
|
||||||
|
|
||||||
|
context = "\n".join(context_parts + docs)
|
||||||
|
print(context)
|
||||||
|
|
||||||
# 모델 로드
|
# 모델 로드
|
||||||
model, tokenizer = get_qwen_model()
|
model, tokenizer = get_qwen_model()
|
||||||
@@ -245,7 +191,7 @@ def query_summarize_simple(query: str) :
|
|||||||
"content": (
|
"content": (
|
||||||
'''
|
'''
|
||||||
당신은 데이터 진단 전문가 입니다. 아래 질의 내용을 보고 해당 내용이 단순 질문 인지 아니면 통계, 총개수, 카운트, 직원수 질문 인지 확인해야합니다.
|
당신은 데이터 진단 전문가 입니다. 아래 질의 내용을 보고 해당 내용이 단순 질문 인지 아니면 통계, 총개수, 카운트, 직원수 질문 인지 확인해야합니다.
|
||||||
단순질문 일시 0 통계질문 또는 통계, 총개수, 카운트, 직원수 질문 일시 1 로 대답해주세요. 부가 내용은 필요 없습니다.
|
[부서 직원수=1, 직급 직원수=2, 기타=99] 로 대답해주세요. 부가 내용은 필요 없습니다.
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
@@ -325,7 +271,7 @@ async def set_data(query: Item):
|
|||||||
print(f"남은 데이터 수: {len(remaining_count['ids'])}")
|
print(f"남은 데이터 수: {len(remaining_count['ids'])}")
|
||||||
|
|
||||||
|
|
||||||
|
store_data[query.sessionId] = query.context
|
||||||
init(query.sessionId, query.context)
|
init(query.sessionId, query.context)
|
||||||
|
|
||||||
return {"status": "success", "message": f"{len(query.context)}건의 데이터가 로드되었습니다."}
|
return {"status": "success", "message": f"{len(query.context)}건의 데이터가 로드되었습니다."}
|
||||||
@@ -337,13 +283,14 @@ def question(sessionId: str, query: str):
|
|||||||
질의응답 API 엔드포인트
|
질의응답 API 엔드포인트
|
||||||
"""
|
"""
|
||||||
type = query_summarize_simple(query=query)
|
type = query_summarize_simple(query=query)
|
||||||
if(type == '0') :
|
print(type)
|
||||||
|
if(type == '99') :
|
||||||
results, keyword = query_select(sessionId, query, False)
|
results, keyword = query_select(sessionId, query, False)
|
||||||
print(f'단순질문 AI : {len(results)}')
|
print(f'단순질문 AI : {len(results)}')
|
||||||
generate = query_select_summarize_stream(results, query=keyword, ai=True)
|
generate = query_select_summarize_stream(results, query=keyword, ai=True)
|
||||||
|
# 부서 총직원수
|
||||||
else :
|
else :
|
||||||
results, keyword = query_select(sessionId, query, True)
|
results, keyword = query_select(sessionId, query, True)
|
||||||
print(f'개수 데이터베이스조회 : {len(results.get('ids'))}')
|
|
||||||
generate = query_select_summarize_stream(results, query=keyword, ai=False)
|
generate = query_select_summarize_stream(results, query=keyword, ai=False)
|
||||||
|
|
||||||
|
|
||||||
@@ -352,6 +299,5 @@ def question(sessionId: str, query: str):
|
|||||||
|
|
||||||
# 개발용 실행 (직접 실행 시)
|
# 개발용 실행 (직접 실행 시)
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
|
||||||
print("서버 시작: uvicorn manual:app --reload")
|
print("서버 시작: uvicorn manual:app --reload")
|
||||||
# uvicorn.run(app, host="0.0.0.0", port=8000)
|
# uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
0
config/ai/__init__.py
Normal file
0
config/ai/__init__.py
Normal file
53
config/ai/call_llm_model.py
Normal file
53
config/ai/call_llm_model.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
QWEN_MODEL_PATH = "./models/Qwen3-0.6B"
|
||||||
|
|
||||||
|
|
||||||
|
_model = None
|
||||||
|
_tokenizer = None
|
||||||
|
|
||||||
|
def get_qwen_model():
|
||||||
|
"""
|
||||||
|
Qwen 모델과 토크나이저를 로드하거나 캐시된 인스턴스를 반환합니다.
|
||||||
|
torch.compile을 사용하여 추론 속도를 최적화합니다.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (model, tokenizer)
|
||||||
|
"""
|
||||||
|
global _model, _tokenizer
|
||||||
|
if _model is not None:
|
||||||
|
return _model, _tokenizer
|
||||||
|
|
||||||
|
# 토크나이저 로드
|
||||||
|
_tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
QWEN_MODEL_PATH,
|
||||||
|
trust_remote_code=True,
|
||||||
|
local_files_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 모델 로드
|
||||||
|
from sympy.printing.pytorch import torch
|
||||||
|
_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
QWEN_MODEL_PATH,
|
||||||
|
dtype=torch.bfloat16, # CPU: bfloat16, GPU: float16 권장
|
||||||
|
device_map="auto",
|
||||||
|
trust_remote_code=True,
|
||||||
|
local_files_only=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# ✅ torch.compile() 적용 (PyTorch 2.0+)
|
||||||
|
if hasattr(torch, 'compile'):
|
||||||
|
try:
|
||||||
|
print("🚀 torch.compile() 적용 중...")
|
||||||
|
# mode="reduce-overhead": 추론 시 추천
|
||||||
|
# dynamic=True: 입력 길이가 유동적인 RAG 환경에 적합
|
||||||
|
_model = torch.compile(
|
||||||
|
_model,
|
||||||
|
mode="reduce-overhead",
|
||||||
|
dynamic=True
|
||||||
|
)
|
||||||
|
print("✅ torch.compile() 성공!")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ torch.compile() 실패, 원본 모델 사용: {e}")
|
||||||
|
pass # 실패하면 원본 사용
|
||||||
|
return _model, _tokenizer
|
||||||
@@ -67,9 +67,13 @@ def init(sessionId: str, data: List[Union[str, dict]]):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
doc = (
|
doc = (
|
||||||
f"부서: {dept}. 이름: {name}. {dept} 소속의 {name} {grade}입니다. "
|
f" [이름]:{name}"
|
||||||
f"직위는 {position}이며 현재 {status} 중입니다. "
|
f" [부서]:{dept}"
|
||||||
f"사내 전화번호(사선)는 {office_phone}입니다."
|
f" [소속]:{dept}"
|
||||||
|
f" [직급]:{grade}"
|
||||||
|
f" [직위]:{position}"
|
||||||
|
f" 현재 {status} 중입니다. "
|
||||||
|
f" 사내 전화번호(사선)는 {office_phone}입니다."
|
||||||
)
|
)
|
||||||
doc_list.append(doc)
|
doc_list.append(doc)
|
||||||
|
|
||||||
@@ -2,7 +2,8 @@ from typing import Generator
|
|||||||
|
|
||||||
from transformers import TextIteratorStreamer
|
from transformers import TextIteratorStreamer
|
||||||
|
|
||||||
from org_offline import get_qwen_model
|
from config.ai.call_llm_model import get_qwen_model
|
||||||
|
from config.db.chroma import collection
|
||||||
|
|
||||||
|
|
||||||
def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator:
|
def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_similarity: float = 0.2) -> Generator:
|
||||||
@@ -23,7 +24,6 @@ def query_and_summarize_stream(sessionId: str, query: str, top_k: int = 3, min_s
|
|||||||
|
|
||||||
|
|
||||||
# 관련 문서 검색 (top_k보다 여유 있게 가져옴)
|
# 관련 문서 검색 (top_k보다 여유 있게 가져옴)
|
||||||
from org_offline import collection
|
|
||||||
results = collection.query(
|
results = collection.query(
|
||||||
query_texts=[query],
|
query_texts=[query],
|
||||||
n_results=top_k + 2, # 여유분 확보
|
n_results=top_k + 2, # 여유분 확보
|
||||||
0
config/db/__init__.py
Normal file
0
config/db/__init__.py
Normal file
9
config/db/chroma.py
Normal file
9
config/db/chroma.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import chromadb
|
||||||
|
|
||||||
|
# 2. 벡터 DB 설정
|
||||||
|
persist_directory = "./chroma_db"
|
||||||
|
chroma_client = chromadb.PersistentClient(path=persist_directory)
|
||||||
|
|
||||||
|
collection = chroma_client.get_or_create_collection(
|
||||||
|
name="orgchart",
|
||||||
|
)
|
||||||
11
config/db/database.py
Normal file
11
config/db/database.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from sqlmodel import SQLModel, Field, create_engine, Session, select
|
||||||
|
from typing import Optional
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
DATABASE_URL = os.getenv("DATABASE_URL")
|
||||||
|
engine = create_engine(DATABASE_URL, echo=True)
|
||||||
|
|
||||||
|
|
||||||
0
config/util/__init__.py
Normal file
0
config/util/__init__.py
Normal file
0
model/__init__.py
Normal file
0
model/__init__.py
Normal file
0
repository/__init__.py
Normal file
0
repository/__init__.py
Normal file
8
repository/usersRepository.py
Normal file
8
repository/usersRepository.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from sqlmodel import Session, select
|
||||||
|
from config.db.database import engine
|
||||||
|
from model.models import Users
|
||||||
|
|
||||||
|
def findAll() :
|
||||||
|
with Session(engine) as session:
|
||||||
|
users = session.exec(select(Users)).all()
|
||||||
|
return users
|
||||||
7
test.py
7
test.py
@@ -1,5 +1,6 @@
|
|||||||
# 테스트용 스크립트
|
# 테스트용 스크립트
|
||||||
text = '산재예방부 부서장은 누구야?'
|
d='[이름]:이준원[부서]:토건부[소속]:토건부[직급]:2직급[직위]:현재 재직 중입니다.'
|
||||||
|
|
||||||
# '부'라는 글자가 포함되어 있는지 확인
|
print(f'{d[d.find('[이름]'): d.find('[', d.find('[이름]')+1)]}')
|
||||||
print(text.find('부'))
|
|
||||||
|
print(f'{d[d.find('[부서]'): d.find('[', d.find('[부서]')+1)]}')
|
||||||
Reference in New Issue
Block a user