升级注意

数据库表发生变化,需要重建知识库

 新功能
- 增加FileDocModel库表,存储知识文件与向量库Document ID对应关系以及元数据,便于检索向量库
- 增加FileDocModel对应的数据库操作函数(这些函数主要是给KBService调用,用户一般无需使用):
  - list_docs_from_db: 根据知识库名称、文件名称、元数据检索对应的Document IDs
  - delete_docs_from_db: 根据知识库名称、文件名称删除对应的file-doc映射
  - add_docs_to_db: 添加对应的file-doc映射
- KBService增加list_docs方法,可以根据文件名、元数据检索Document。当前仅支持FAISS,待milvus/pg实现get_doc_by_id方法后即自动支持。
- 去除server.utils对torch的依赖

 待完善
- milvus/pg kb_service需要实现get_doc_by_id方法
This commit is contained in:
liunux4odoo 2023-09-01 22:54:57 +08:00
parent 96770cca53
commit 55e417a263
11 changed files with 139 additions and 25 deletions

View File

@ -3,10 +3,14 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from configs.model_config import SQLALCHEMY_DATABASE_URI
import json
engine = create_engine(SQLALCHEMY_DATABASE_URI)
engine = create_engine(
SQLALCHEMY_DATABASE_URI,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base()

View File

@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, func
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func
from server.db.base import Base
@ -23,3 +23,18 @@ class KnowledgeFileModel(Base):
def __repr__(self):
return f"<KnowledgeFile(id='{self.id}', file_name='{self.file_name}', file_ext='{self.file_ext}', kb_name='{self.kb_name}', document_loader_name='{self.document_loader_name}', text_splitter_name='{self.text_splitter_name}', file_version='{self.file_version}', create_time='{self.create_time}')>"
class FileDocModel(Base):
"""
文件-向量库文档模型
"""
__tablename__ = 'file_doc'
id = Column(Integer, primary_key=True, autoincrement=True, comment='ID')
kb_name = Column(String(50), comment='知识库名称')
file_name = Column(String(255), comment='文件名称')
doc_id = Column(String(50), comment="向量库文档ID")
meta_data = Column(JSON, default={})
def __repr__(self):
return f"<FileDoc(id='{self.id}', kb_name='{self.kb_name}', file_name='{self.file_name}', doc_id='{self.doc_id}', metadata='{self.metadata}')>"

View File

@ -1,7 +1,65 @@
from server.db.models.knowledge_base_model import KnowledgeBaseModel
from server.db.models.knowledge_file_model import KnowledgeFileModel
from server.db.models.knowledge_file_model import KnowledgeFileModel, FileDocModel
from server.db.session import with_session
from server.knowledge_base.utils import KnowledgeFile
from typing import List, Dict
@with_session
def list_docs_from_db(session,
kb_name: str,
file_name: str = None,
metadata: Dict = {},
) -> List[Dict]:
'''
列出某知识库某文件对应的所有Document
返回形式[{"id": str, "metadata": dict}, ...]
'''
docs = session.query(FileDocModel).filter_by(kb_name=kb_name)
if file_name:
docs = docs.filter_by(file_name=file_name)
for k, v in metadata.items():
docs = docs.filter(FileDocModel.meta_data[k].as_string()==str(v))
return [{"id": x.doc_id, "metadata": x.metadata} for x in docs.all()]
@with_session
def delete_docs_from_db(session,
kb_name: str,
file_name: str = None,
) -> List[Dict]:
'''
删除某知识库某文件对应的所有Document并返回被删除的Document
返回形式[{"id": str, "metadata": dict}, ...]
'''
docs = list_docs_from_db(kb_name=kb_name, file_name=file_name)
query = session.query(FileDocModel).filter_by(kb_name=kb_name)
if file_name:
query = query.filter_by(file_name=file_name)
query.delete()
session.commit()
return docs
@with_session
def add_docs_to_db(session,
kb_name: str,
file_name: str,
doc_infos: List[Dict]):
'''
将某知识库某文件对应的所有Document信息添加到数据库
doc_infos形式[{"id": str, "metadata": dict}, ...]
'''
for d in doc_infos:
obj = FileDocModel(
kb_name=kb_name,
file_name=file_name,
doc_id=d["id"],
meta_data=d["metadata"],
)
session.add(obj)
return True
@with_session
@ -20,7 +78,9 @@ def list_files_from_db(session, kb_name):
def add_file_to_db(session,
kb_file: KnowledgeFile,
docs_count: int = 0,
custom_docs: bool = False,):
custom_docs: bool = False,
doc_infos: List[str] = [], # 形式:[{"id": str, "metadata": dict}, ...]
):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
if kb:
# 如果已经存在该文件,则更新文件信息与版本号
@ -52,6 +112,7 @@ def add_file_to_db(session,
)
kb.file_count += 1
session.add(new_file)
add_docs_to_db(kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos)
return True
@ -61,6 +122,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
kb_name=kb_file.kb_name).first()
if existing_file:
session.delete(existing_file)
delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename)
session.commit()
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
@ -73,7 +135,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
@with_session
def delete_files_from_db(session, knowledge_base_name: str):
session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete()
session.query(FileDocModel).filter_by(kb_name=knowledge_base_name).delete()
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first()
if kb:
kb.file_count = 0

View File

@ -14,7 +14,8 @@ from server.db.repository.knowledge_base_repository import (
)
from server.db.repository.knowledge_file_repository import (
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
list_docs_from_db,
)
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
@ -25,6 +26,7 @@ from server.knowledge_base.utils import (
)
from server.utils import embedding_device
from typing import List, Union, Dict
from typing import List, Union, Dict, Optional
class SupportedVSType:
@ -88,8 +90,11 @@ class KBService(ABC):
if docs:
self.delete_doc(kb_file)
self.do_add_doc(docs, **kwargs)
status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs))
doc_infos = self.do_add_doc(docs, **kwargs)
status = add_file_to_db(kb_file,
custom_docs=custom_docs,
docs_count=len(docs),
doc_infos=doc_infos)
else:
status = False
return status
@ -132,6 +137,18 @@ class KBService(ABC):
docs = self.do_search(query, top_k, score_threshold, embeddings)
return docs
# TODO: milvus/pg需要实现该方法
def get_doc_by_id(self, id: str) -> Optional[Document]:
return None
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
'''
通过file_name或metadata检索Document
'''
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
docs = [self.get_doc_by_id(x["id"]) for x in doc_infos]
return docs
@abstractmethod
def do_create_kb(self):
"""
@ -181,7 +198,7 @@ class KBService(ABC):
@abstractmethod
def do_add_doc(self,
docs: List[Document],
):
) -> List[Dict]:
"""
向知识库添加文档子类实自己逻辑
"""

View File

@ -12,7 +12,7 @@ from functools import lru_cache
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from typing import List
from typing import List, Dict, Optional
from langchain.docstore.document import Document
from server.utils import torch_gc, embedding_device
@ -88,6 +88,10 @@ class FaissKBService(KBService):
def refresh_vs_cache(self):
refresh_vs_cache(self.kb_name)
def get_doc_by_id(self, id: str) -> Optional[Document]:
vector_store = self.load_vector_store()
return vector_store.docstore._dict.get(id)
def do_init(self):
self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path()
@ -114,14 +118,15 @@ class FaissKBService(KBService):
def do_add_doc(self,
docs: List[Document],
**kwargs,
):
) -> List[Dict]:
vector_store = self.load_vector_store()
vector_store.add_documents(docs)
ids = vector_store.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
torch_gc()
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
self.refresh_vs_cache()
return vector_store
return doc_infos
def do_delete_doc(self,
kb_file: KnowledgeFile,

View File

@ -1,4 +1,4 @@
from typing import List
from typing import List, Dict, Optional
import numpy as np
from faiss import normalize_L2
@ -22,6 +22,10 @@ class MilvusKBService(KBService):
from pymilvus import Collection
return Collection(milvus_name)
# TODO:
def get_doc_by_id(self, id: str) -> Optional[Document]:
return None
@staticmethod
def search(milvus_name, content, limit=3):
search_params = {
@ -54,8 +58,10 @@ class MilvusKBService(KBService):
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], **kwargs):
self.milvus.add_documents(docs)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.milvus.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
filepath = kb_file.filepath.replace('\\', '\\\\')

View File

@ -1,4 +1,4 @@
from typing import List
from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
@ -24,6 +24,10 @@ class PGKBService(KBService):
distance_strategy=DistanceStrategy.EUCLIDEAN,
connection_string=kbs_config.get("pg").get("connection_uri"))
# TODO:
def get_doc_by_id(self, id: str) -> Optional[Document]:
return None
def do_init(self):
self._load_pg_vector()
@ -51,8 +55,10 @@ class PGKBService(KBService):
return score_threshold_process(score_threshold, top_k,
self.pg_vector.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], **kwargs):
self.pg_vector.add_documents(docs)
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.pg_vector.connect() as connect:

View File

@ -8,7 +8,7 @@ from server.db.repository.knowledge_file_repository import add_file_to_db
from server.db.base import Base, engine
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Literal, Callable, Any, List
from typing import Literal, Any, List
pool = ThreadPoolExecutor(os.cpu_count())

View File

@ -1,7 +1,6 @@
import pydantic
from pydantic import BaseModel
from typing import List
import torch
from fastapi import FastAPI
from pathlib import Path
import asyncio
@ -69,6 +68,7 @@ class ChatMessage(BaseModel):
}
def torch_gc():
import torch
if torch.cuda.is_available():
# with torch.cuda.device(DEVICE):
torch.cuda.empty_cache()

View File

@ -20,8 +20,8 @@ from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN,
FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout,
llm_device, embedding_device, get_model_worker_config)
from server.utils import MakeFastAPIOffline, FastAPI
llm_device, embedding_device, get_model_worker_config,
MakeFastAPIOffline, FastAPI)
import argparse
from typing import Tuple, List
from configs import VERSION

View File

@ -1,4 +1,3 @@
from doctest import testfile
import requests
import json
import sys