升级注意

数据库表发生变化,需要重建知识库

 新功能
- 增加FileDocModel库表,存储知识文件与向量库Document ID对应关系以及元数据,便于检索向量库
- 增加FileDocModel对应的数据库操作函数(这些函数主要是给KBService调用,用户一般无需使用):
  - list_docs_from_db: 根据知识库名称、文件名称、元数据检索对应的Document IDs
  - delete_docs_from_db: 根据知识库名称、文件名称删除对应的file-doc映射
  - add_docs_to_db: 添加对应的file-doc映射
- KBService增加list_docs方法,可以根据文件名、元数据检索Document。当前仅支持FAISS,待milvus/pg实现get_doc_by_id方法后即自动支持。
- 去除server.utils对torch的依赖

 待完善
- milvus/pg kb_service需要实现get_doc_by_id方法
This commit is contained in:
liunux4odoo 2023-09-01 22:54:57 +08:00
parent 96770cca53
commit 55e417a263
11 changed files with 139 additions and 25 deletions

View File

@ -3,10 +3,14 @@ from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from configs.model_config import SQLALCHEMY_DATABASE_URI from configs.model_config import SQLALCHEMY_DATABASE_URI
import json
engine = create_engine(SQLALCHEMY_DATABASE_URI)
engine = create_engine(
SQLALCHEMY_DATABASE_URI,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()

View File

@ -1,4 +1,4 @@
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, func from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func
from server.db.base import Base from server.db.base import Base
@ -23,3 +23,18 @@ class KnowledgeFileModel(Base):
def __repr__(self): def __repr__(self):
return f"<KnowledgeFile(id='{self.id}', file_name='{self.file_name}', file_ext='{self.file_ext}', kb_name='{self.kb_name}', document_loader_name='{self.document_loader_name}', text_splitter_name='{self.text_splitter_name}', file_version='{self.file_version}', create_time='{self.create_time}')>" return f"<KnowledgeFile(id='{self.id}', file_name='{self.file_name}', file_ext='{self.file_ext}', kb_name='{self.kb_name}', document_loader_name='{self.document_loader_name}', text_splitter_name='{self.text_splitter_name}', file_version='{self.file_version}', create_time='{self.create_time}')>"
class FileDocModel(Base):
"""
文件-向量库文档模型
"""
__tablename__ = 'file_doc'
id = Column(Integer, primary_key=True, autoincrement=True, comment='ID')
kb_name = Column(String(50), comment='知识库名称')
file_name = Column(String(255), comment='文件名称')
doc_id = Column(String(50), comment="向量库文档ID")
meta_data = Column(JSON, default={})
def __repr__(self):
return f"<FileDoc(id='{self.id}', kb_name='{self.kb_name}', file_name='{self.file_name}', doc_id='{self.doc_id}', metadata='{self.metadata}')>"

View File

@ -1,7 +1,65 @@
from server.db.models.knowledge_base_model import KnowledgeBaseModel from server.db.models.knowledge_base_model import KnowledgeBaseModel
from server.db.models.knowledge_file_model import KnowledgeFileModel from server.db.models.knowledge_file_model import KnowledgeFileModel, FileDocModel
from server.db.session import with_session from server.db.session import with_session
from server.knowledge_base.utils import KnowledgeFile from server.knowledge_base.utils import KnowledgeFile
from typing import List, Dict
@with_session
def list_docs_from_db(session,
kb_name: str,
file_name: str = None,
metadata: Dict = {},
) -> List[Dict]:
'''
列出某知识库某文件对应的所有Document
返回形式[{"id": str, "metadata": dict}, ...]
'''
docs = session.query(FileDocModel).filter_by(kb_name=kb_name)
if file_name:
docs = docs.filter_by(file_name=file_name)
for k, v in metadata.items():
docs = docs.filter(FileDocModel.meta_data[k].as_string()==str(v))
return [{"id": x.doc_id, "metadata": x.metadata} for x in docs.all()]
@with_session
def delete_docs_from_db(session,
kb_name: str,
file_name: str = None,
) -> List[Dict]:
'''
删除某知识库某文件对应的所有Document并返回被删除的Document
返回形式[{"id": str, "metadata": dict}, ...]
'''
docs = list_docs_from_db(kb_name=kb_name, file_name=file_name)
query = session.query(FileDocModel).filter_by(kb_name=kb_name)
if file_name:
query = query.filter_by(file_name=file_name)
query.delete()
session.commit()
return docs
@with_session
def add_docs_to_db(session,
kb_name: str,
file_name: str,
doc_infos: List[Dict]):
'''
将某知识库某文件对应的所有Document信息添加到数据库
doc_infos形式[{"id": str, "metadata": dict}, ...]
'''
for d in doc_infos:
obj = FileDocModel(
kb_name=kb_name,
file_name=file_name,
doc_id=d["id"],
meta_data=d["metadata"],
)
session.add(obj)
return True
@with_session @with_session
@ -20,7 +78,9 @@ def list_files_from_db(session, kb_name):
def add_file_to_db(session, def add_file_to_db(session,
kb_file: KnowledgeFile, kb_file: KnowledgeFile,
docs_count: int = 0, docs_count: int = 0,
custom_docs: bool = False,): custom_docs: bool = False,
doc_infos: List[str] = [], # 形式:[{"id": str, "metadata": dict}, ...]
):
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
if kb: if kb:
# 如果已经存在该文件,则更新文件信息与版本号 # 如果已经存在该文件,则更新文件信息与版本号
@ -52,6 +112,7 @@ def add_file_to_db(session,
) )
kb.file_count += 1 kb.file_count += 1
session.add(new_file) session.add(new_file)
add_docs_to_db(kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos)
return True return True
@ -61,6 +122,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
kb_name=kb_file.kb_name).first() kb_name=kb_file.kb_name).first()
if existing_file: if existing_file:
session.delete(existing_file) session.delete(existing_file)
delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename)
session.commit() session.commit()
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first() kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
@ -73,7 +135,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
@with_session @with_session
def delete_files_from_db(session, knowledge_base_name: str): def delete_files_from_db(session, knowledge_base_name: str):
session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete() session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete()
session.query(FileDocModel).filter_by(kb_name=knowledge_base_name).delete()
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first() kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first()
if kb: if kb:
kb.file_count = 0 kb.file_count = 0

View File

@ -14,7 +14,8 @@ from server.db.repository.knowledge_base_repository import (
) )
from server.db.repository.knowledge_file_repository import ( from server.db.repository.knowledge_file_repository import (
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db, add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
list_docs_from_db,
) )
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
@ -25,6 +26,7 @@ from server.knowledge_base.utils import (
) )
from server.utils import embedding_device from server.utils import embedding_device
from typing import List, Union, Dict from typing import List, Union, Dict
from typing import List, Union, Dict, Optional
class SupportedVSType: class SupportedVSType:
@ -88,8 +90,11 @@ class KBService(ABC):
if docs: if docs:
self.delete_doc(kb_file) self.delete_doc(kb_file)
self.do_add_doc(docs, **kwargs) doc_infos = self.do_add_doc(docs, **kwargs)
status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs)) status = add_file_to_db(kb_file,
custom_docs=custom_docs,
docs_count=len(docs),
doc_infos=doc_infos)
else: else:
status = False status = False
return status return status
@ -132,6 +137,18 @@ class KBService(ABC):
docs = self.do_search(query, top_k, score_threshold, embeddings) docs = self.do_search(query, top_k, score_threshold, embeddings)
return docs return docs
# TODO: milvus/pg需要实现该方法
def get_doc_by_id(self, id: str) -> Optional[Document]:
return None
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
'''
通过file_name或metadata检索Document
'''
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
docs = [self.get_doc_by_id(x["id"]) for x in doc_infos]
return docs
@abstractmethod @abstractmethod
def do_create_kb(self): def do_create_kb(self):
""" """
@ -181,7 +198,7 @@ class KBService(ABC):
@abstractmethod @abstractmethod
def do_add_doc(self, def do_add_doc(self,
docs: List[Document], docs: List[Document],
): ) -> List[Dict]:
""" """
向知识库添加文档子类实自己逻辑 向知识库添加文档子类实自己逻辑
""" """

View File

@ -12,7 +12,7 @@ from functools import lru_cache
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
from langchain.vectorstores import FAISS from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from typing import List from typing import List, Dict, Optional
from langchain.docstore.document import Document from langchain.docstore.document import Document
from server.utils import torch_gc, embedding_device from server.utils import torch_gc, embedding_device
@ -88,6 +88,10 @@ class FaissKBService(KBService):
def refresh_vs_cache(self): def refresh_vs_cache(self):
refresh_vs_cache(self.kb_name) refresh_vs_cache(self.kb_name)
def get_doc_by_id(self, id: str) -> Optional[Document]:
vector_store = self.load_vector_store()
return vector_store.docstore._dict.get(id)
def do_init(self): def do_init(self):
self.kb_path = self.get_kb_path() self.kb_path = self.get_kb_path()
self.vs_path = self.get_vs_path() self.vs_path = self.get_vs_path()
@ -114,14 +118,15 @@ class FaissKBService(KBService):
def do_add_doc(self, def do_add_doc(self,
docs: List[Document], docs: List[Document],
**kwargs, **kwargs,
): ) -> List[Dict]:
vector_store = self.load_vector_store() vector_store = self.load_vector_store()
vector_store.add_documents(docs) ids = vector_store.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
torch_gc() torch_gc()
if not kwargs.get("not_refresh_vs_cache"): if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path) vector_store.save_local(self.vs_path)
self.refresh_vs_cache() self.refresh_vs_cache()
return vector_store return doc_infos
def do_delete_doc(self, def do_delete_doc(self,
kb_file: KnowledgeFile, kb_file: KnowledgeFile,

View File

@ -1,4 +1,4 @@
from typing import List from typing import List, Dict, Optional
import numpy as np import numpy as np
from faiss import normalize_L2 from faiss import normalize_L2
@ -22,6 +22,10 @@ class MilvusKBService(KBService):
from pymilvus import Collection from pymilvus import Collection
return Collection(milvus_name) return Collection(milvus_name)
# TODO:
def get_doc_by_id(self, id: str) -> Optional[Document]:
return None
@staticmethod @staticmethod
def search(milvus_name, content, limit=3): def search(milvus_name, content, limit=3):
search_params = { search_params = {
@ -54,8 +58,10 @@ class MilvusKBService(KBService):
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings)) self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], **kwargs): def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
self.milvus.add_documents(docs) ids = self.milvus.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
filepath = kb_file.filepath.replace('\\', '\\\\') filepath = kb_file.filepath.replace('\\', '\\\\')

View File

@ -1,4 +1,4 @@
from typing import List from typing import List, Dict, Optional
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.schema import Document from langchain.schema import Document
@ -24,6 +24,10 @@ class PGKBService(KBService):
distance_strategy=DistanceStrategy.EUCLIDEAN, distance_strategy=DistanceStrategy.EUCLIDEAN,
connection_string=kbs_config.get("pg").get("connection_uri")) connection_string=kbs_config.get("pg").get("connection_uri"))
# TODO:
def get_doc_by_id(self, id: str) -> Optional[Document]:
return None
def do_init(self): def do_init(self):
self._load_pg_vector() self._load_pg_vector()
@ -51,8 +55,10 @@ class PGKBService(KBService):
return score_threshold_process(score_threshold, top_k, return score_threshold_process(score_threshold, top_k,
self.pg_vector.similarity_search_with_score(query, top_k)) self.pg_vector.similarity_search_with_score(query, top_k))
def do_add_doc(self, docs: List[Document], **kwargs): def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
self.pg_vector.add_documents(docs) ids = self.pg_vector.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
with self.pg_vector.connect() as connect: with self.pg_vector.connect() as connect:

View File

@ -8,7 +8,7 @@ from server.db.repository.knowledge_file_repository import add_file_to_db
from server.db.base import Base, engine from server.db.base import Base, engine
import os import os
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Literal, Callable, Any, List from typing import Literal, Any, List
pool = ThreadPoolExecutor(os.cpu_count()) pool = ThreadPoolExecutor(os.cpu_count())

View File

@ -1,7 +1,6 @@
import pydantic import pydantic
from pydantic import BaseModel from pydantic import BaseModel
from typing import List from typing import List
import torch
from fastapi import FastAPI from fastapi import FastAPI
from pathlib import Path from pathlib import Path
import asyncio import asyncio
@ -69,6 +68,7 @@ class ChatMessage(BaseModel):
} }
def torch_gc(): def torch_gc():
import torch
if torch.cuda.is_available(): if torch.cuda.is_available():
# with torch.cuda.device(DEVICE): # with torch.cuda.device(DEVICE):
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -20,8 +20,8 @@ from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN,
FSCHAT_OPENAI_API, ) FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address, from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout, fschat_openai_api_address, set_httpx_timeout,
llm_device, embedding_device, get_model_worker_config) llm_device, embedding_device, get_model_worker_config,
from server.utils import MakeFastAPIOffline, FastAPI MakeFastAPIOffline, FastAPI)
import argparse import argparse
from typing import Tuple, List from typing import Tuple, List
from configs import VERSION from configs import VERSION

View File

@ -1,4 +1,3 @@
from doctest import testfile
import requests import requests
import json import json
import sys import sys