58 lines
1.6 KiB
Python
58 lines
1.6 KiB
Python
import json
|
|
from typing import List, Union
|
|
|
|
import chromadb
|
|
from chromadb.utils import embedding_functions
|
|
|
|
# === 경로 설정 (모두 로컬) ===
|
|
EMBEDDING_MODEL_PATH = "./models/ko-sroberta-multitask"
|
|
|
|
# 2. 벡터 DB 설정
|
|
persist_directory = "./chroma_db"
|
|
chroma_client = chromadb.PersistentClient(path=persist_directory)
|
|
|
|
# 임베딩 함수 설정
|
|
embedding_fn = embedding_functions.SentenceTransformerEmbeddingFunction(
|
|
model_name=EMBEDDING_MODEL_PATH, # 로컬 폴더 경로 가능
|
|
device="cpu",
|
|
normalize_embeddings=True
|
|
)
|
|
|
|
# 컬렉션 생성 또는 가져오기
|
|
collection = chroma_client.get_or_create_collection(
|
|
name="orgchart",
|
|
embedding_function=embedding_fn,
|
|
metadata={"hnsw:space": "cosine"}
|
|
)
|
|
|
|
|
|
def init(sessionId: str, data: List[Union[str, dict]]):
|
|
"""
|
|
데이터를 벡터 DB에 초기화(저장)합니다.
|
|
|
|
Args:
|
|
sessionId (str): 세션 ID
|
|
data (List[Union[str, dict]]): 저장할 데이터 리스트 (문자열 또는 딕셔너리)
|
|
"""
|
|
print(f'{sessionId} init start')
|
|
|
|
# 문서 ID 생성
|
|
doc_ids = [f"{sessionId}_{i}" for i in range(len(data))]
|
|
|
|
# 데이터 처리: 문자열이면 그대로, 객체면 JSON 문자열로 변환
|
|
documents = []
|
|
for item in data:
|
|
if isinstance(item, str):
|
|
documents.append(item)
|
|
else:
|
|
documents.append(json.dumps(item, ensure_ascii=False))
|
|
|
|
# 벡터 DB에 추가
|
|
collection.add(
|
|
documents=documents,
|
|
ids=doc_ids,
|
|
metadatas=[{"sessionId": sessionId} for _ in doc_ids]
|
|
)
|
|
|
|
print(f'{sessionId} init end')
|