diff --git a/server/db/base.py b/server/db/base.py index 3a8529b..1d911c0 100644 --- a/server/db/base.py +++ b/server/db/base.py @@ -3,10 +3,14 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from configs.model_config import SQLALCHEMY_DATABASE_URI +import json -engine = create_engine(SQLALCHEMY_DATABASE_URI) + +engine = create_engine( + SQLALCHEMY_DATABASE_URI, + json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False), +) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() - diff --git a/server/db/models/knowledge_file_model.py b/server/db/models/knowledge_file_model.py index 3d885ea..c5784d1 100644 --- a/server/db/models/knowledge_file_model.py +++ b/server/db/models/knowledge_file_model.py @@ -1,4 +1,4 @@ -from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, func +from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func from server.db.base import Base @@ -23,3 +23,18 @@ class KnowledgeFileModel(Base): def __repr__(self): return f"" + + +class FileDocModel(Base): + """ + 文件-向量库文档模型 + """ + __tablename__ = 'file_doc' + id = Column(Integer, primary_key=True, autoincrement=True, comment='ID') + kb_name = Column(String(50), comment='知识库名称') + file_name = Column(String(255), comment='文件名称') + doc_id = Column(String(50), comment="向量库文档ID") + meta_data = Column(JSON, default={}) + + def __repr__(self): + return f"" diff --git a/server/db/repository/knowledge_file_repository.py b/server/db/repository/knowledge_file_repository.py index 6277ad6..08417a4 100644 --- a/server/db/repository/knowledge_file_repository.py +++ b/server/db/repository/knowledge_file_repository.py @@ -1,7 +1,65 @@ from server.db.models.knowledge_base_model import KnowledgeBaseModel -from server.db.models.knowledge_file_model import KnowledgeFileModel +from server.db.models.knowledge_file_model import KnowledgeFileModel, FileDocModel from server.db.session import with_session from server.knowledge_base.utils import KnowledgeFile +from typing import List, Dict + + +@with_session +def list_docs_from_db(session, + kb_name: str, + file_name: str = None, + metadata: Dict = {}, + ) -> List[Dict]: + ''' + 列出某知识库某文件对应的所有Document。 + 返回形式:[{"id": str, "metadata": dict}, ...] + ''' + docs = session.query(FileDocModel).filter_by(kb_name=kb_name) + if file_name: + docs = docs.filter_by(file_name=file_name) + for k, v in metadata.items(): + docs = docs.filter(FileDocModel.meta_data[k].as_string()==str(v)) + + return [{"id": x.doc_id, "metadata": x.metadata} for x in docs.all()] + + +@with_session +def delete_docs_from_db(session, + kb_name: str, + file_name: str = None, + ) -> List[Dict]: + ''' + 删除某知识库某文件对应的所有Document,并返回被删除的Document。 + 返回形式:[{"id": str, "metadata": dict}, ...] + ''' + docs = list_docs_from_db(kb_name=kb_name, file_name=file_name) + query = session.query(FileDocModel).filter_by(kb_name=kb_name) + if file_name: + query = query.filter_by(file_name=file_name) + query.delete() + session.commit() + return docs + + +@with_session +def add_docs_to_db(session, + kb_name: str, + file_name: str, + doc_infos: List[Dict]): + ''' + 将某知识库某文件对应的所有Document信息添加到数据库。 + doc_infos形式:[{"id": str, "metadata": dict}, ...] + ''' + for d in doc_infos: + obj = FileDocModel( + kb_name=kb_name, + file_name=file_name, + doc_id=d["id"], + meta_data=d["metadata"], + ) + session.add(obj) + return True @with_session @@ -20,7 +78,9 @@ def list_files_from_db(session, kb_name): def add_file_to_db(session, kb_file: KnowledgeFile, docs_count: int = 0, - custom_docs: bool = False,): + custom_docs: bool = False, + doc_infos: List[str] = [], # 形式:[{"id": str, "metadata": dict}, ...] + ): kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() if kb: # 如果已经存在该文件,则更新文件信息与版本号 @@ -52,6 +112,7 @@ def add_file_to_db(session, ) kb.file_count += 1 session.add(new_file) + add_docs_to_db(kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos) return True @@ -61,6 +122,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile): kb_name=kb_file.kb_name).first() if existing_file: session.delete(existing_file) + delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename) session.commit() kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() @@ -73,7 +135,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile): @with_session def delete_files_from_db(session, knowledge_base_name: str): session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete() - + session.query(FileDocModel).filter_by(kb_name=knowledge_base_name).delete() kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first() if kb: kb.file_count = 0 diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 79b1518..ca0919e 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -14,7 +14,8 @@ from server.db.repository.knowledge_base_repository import ( ) from server.db.repository.knowledge_file_repository import ( add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db, - count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db + count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db, + list_docs_from_db, ) from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, @@ -25,6 +26,7 @@ from server.knowledge_base.utils import ( ) from server.utils import embedding_device from typing import List, Union, Dict +from typing import List, Union, Dict, Optional class SupportedVSType: @@ -88,8 +90,11 @@ class KBService(ABC): if docs: self.delete_doc(kb_file) - self.do_add_doc(docs, **kwargs) - status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs)) + doc_infos = self.do_add_doc(docs, **kwargs) + status = add_file_to_db(kb_file, + custom_docs=custom_docs, + docs_count=len(docs), + doc_infos=doc_infos) else: status = False return status @@ -132,6 +137,18 @@ class KBService(ABC): docs = self.do_search(query, top_k, score_threshold, embeddings) return docs + # TODO: milvus/pg需要实现该方法 + def get_doc_by_id(self, id: str) -> Optional[Document]: + return None + + def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]: + ''' + 通过file_name或metadata检索Document + ''' + doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata) + docs = [self.get_doc_by_id(x["id"]) for x in doc_infos] + return docs + @abstractmethod def do_create_kb(self): """ @@ -181,7 +198,7 @@ class KBService(ABC): @abstractmethod def do_add_doc(self, docs: List[Document], - ): + ) -> List[Dict]: """ 向知识库添加文档子类实自己逻辑 """ diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index f17b2da..15cc790 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -12,7 +12,7 @@ from functools import lru_cache from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile from langchain.vectorstores import FAISS from langchain.embeddings.base import Embeddings -from typing import List +from typing import List, Dict, Optional from langchain.docstore.document import Document from server.utils import torch_gc, embedding_device @@ -88,6 +88,10 @@ class FaissKBService(KBService): def refresh_vs_cache(self): refresh_vs_cache(self.kb_name) + def get_doc_by_id(self, id: str) -> Optional[Document]: + vector_store = self.load_vector_store() + return vector_store.docstore._dict.get(id) + def do_init(self): self.kb_path = self.get_kb_path() self.vs_path = self.get_vs_path() @@ -114,14 +118,15 @@ class FaissKBService(KBService): def do_add_doc(self, docs: List[Document], **kwargs, - ): + ) -> List[Dict]: vector_store = self.load_vector_store() - vector_store.add_documents(docs) + ids = vector_store.add_documents(docs) + doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] torch_gc() if not kwargs.get("not_refresh_vs_cache"): vector_store.save_local(self.vs_path) self.refresh_vs_cache() - return vector_store + return doc_infos def do_delete_doc(self, kb_file: KnowledgeFile, diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index ac7712b..9819713 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Dict, Optional import numpy as np from faiss import normalize_L2 @@ -22,6 +22,10 @@ class MilvusKBService(KBService): from pymilvus import Collection return Collection(milvus_name) + # TODO: + def get_doc_by_id(self, id: str) -> Optional[Document]: + return None + @staticmethod def search(milvus_name, content, limit=3): search_params = { @@ -54,8 +58,10 @@ class MilvusKBService(KBService): self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings)) return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) - def do_add_doc(self, docs: List[Document], **kwargs): - self.milvus.add_documents(docs) + def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: + ids = self.milvus.add_documents(docs) + doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] + return doc_infos def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): filepath = kb_file.filepath.replace('\\', '\\\\') diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 3e3dd52..8e05b42 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Dict, Optional from langchain.embeddings.base import Embeddings from langchain.schema import Document @@ -24,6 +24,10 @@ class PGKBService(KBService): distance_strategy=DistanceStrategy.EUCLIDEAN, connection_string=kbs_config.get("pg").get("connection_uri")) + # TODO: + def get_doc_by_id(self, id: str) -> Optional[Document]: + return None + def do_init(self): self._load_pg_vector() @@ -51,8 +55,10 @@ class PGKBService(KBService): return score_threshold_process(score_threshold, top_k, self.pg_vector.similarity_search_with_score(query, top_k)) - def do_add_doc(self, docs: List[Document], **kwargs): - self.pg_vector.add_documents(docs) + def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: + ids = self.pg_vector.add_documents(docs) + doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] + return doc_infos def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): with self.pg_vector.connect() as connect: diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 4285b79..129dd53 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -8,7 +8,7 @@ from server.db.repository.knowledge_file_repository import add_file_to_db from server.db.base import Base, engine import os from concurrent.futures import ThreadPoolExecutor -from typing import Literal, Callable, Any, List +from typing import Literal, Any, List pool = ThreadPoolExecutor(os.cpu_count()) diff --git a/server/utils.py b/server/utils.py index d716582..0f9d2df 100644 --- a/server/utils.py +++ b/server/utils.py @@ -1,7 +1,6 @@ import pydantic from pydantic import BaseModel from typing import List -import torch from fastapi import FastAPI from pathlib import Path import asyncio @@ -69,6 +68,7 @@ class ChatMessage(BaseModel): } def torch_gc(): + import torch if torch.cuda.is_available(): # with torch.cuda.device(DEVICE): torch.cuda.empty_cache() diff --git a/startup.py b/startup.py index 5ef05ce..3a21010 100644 --- a/startup.py +++ b/startup.py @@ -20,8 +20,8 @@ from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_OPENAI_API, ) from server.utils import (fschat_controller_address, fschat_model_worker_address, fschat_openai_api_address, set_httpx_timeout, - llm_device, embedding_device, get_model_worker_config) -from server.utils import MakeFastAPIOffline, FastAPI + llm_device, embedding_device, get_model_worker_config, + MakeFastAPIOffline, FastAPI) import argparse from typing import Tuple, List from configs import VERSION diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index cefeec0..51bbac1 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -1,4 +1,3 @@ -from doctest import testfile import requests import json import sys