升级注意
数据库表发生变化,需要重建知识库 新功能 - 增加FileDocModel库表,存储知识文件与向量库Document ID对应关系以及元数据,便于检索向量库 - 增加FileDocModel对应的数据库操作函数(这些函数主要是给KBService调用,用户一般无需使用): - list_docs_from_db: 根据知识库名称、文件名称、元数据检索对应的Document IDs - delete_docs_from_db: 根据知识库名称、文件名称删除对应的file-doc映射 - add_docs_to_db: 添加对应的file-doc映射 - KBService增加list_docs方法,可以根据文件名、元数据检索Document。当前仅支持FAISS,待milvus/pg实现get_doc_by_id方法后即自动支持。 - 去除server.utils对torch的依赖 待完善 - milvus/pg kb_service需要实现get_doc_by_id方法
This commit is contained in:
parent
96770cca53
commit
55e417a263
|
|
@ -3,10 +3,14 @@ from sqlalchemy.ext.declarative import declarative_base
|
|||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs.model_config import SQLALCHEMY_DATABASE_URI
|
||||
import json
|
||||
|
||||
engine = create_engine(SQLALCHEMY_DATABASE_URI)
|
||||
|
||||
engine = create_engine(
|
||||
SQLALCHEMY_DATABASE_URI,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, func
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func
|
||||
|
||||
from server.db.base import Base
|
||||
|
||||
|
|
@ -23,3 +23,18 @@ class KnowledgeFileModel(Base):
|
|||
|
||||
def __repr__(self):
|
||||
return f"<KnowledgeFile(id='{self.id}', file_name='{self.file_name}', file_ext='{self.file_ext}', kb_name='{self.kb_name}', document_loader_name='{self.document_loader_name}', text_splitter_name='{self.text_splitter_name}', file_version='{self.file_version}', create_time='{self.create_time}')>"
|
||||
|
||||
|
||||
class FileDocModel(Base):
|
||||
"""
|
||||
文件-向量库文档模型
|
||||
"""
|
||||
__tablename__ = 'file_doc'
|
||||
id = Column(Integer, primary_key=True, autoincrement=True, comment='ID')
|
||||
kb_name = Column(String(50), comment='知识库名称')
|
||||
file_name = Column(String(255), comment='文件名称')
|
||||
doc_id = Column(String(50), comment="向量库文档ID")
|
||||
meta_data = Column(JSON, default={})
|
||||
|
||||
def __repr__(self):
|
||||
return f"<FileDoc(id='{self.id}', kb_name='{self.kb_name}', file_name='{self.file_name}', doc_id='{self.doc_id}', metadata='{self.metadata}')>"
|
||||
|
|
|
|||
|
|
@ -1,7 +1,65 @@
|
|||
from server.db.models.knowledge_base_model import KnowledgeBaseModel
|
||||
from server.db.models.knowledge_file_model import KnowledgeFileModel
|
||||
from server.db.models.knowledge_file_model import KnowledgeFileModel, FileDocModel
|
||||
from server.db.session import with_session
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
@with_session
|
||||
def list_docs_from_db(session,
|
||||
kb_name: str,
|
||||
file_name: str = None,
|
||||
metadata: Dict = {},
|
||||
) -> List[Dict]:
|
||||
'''
|
||||
列出某知识库某文件对应的所有Document。
|
||||
返回形式:[{"id": str, "metadata": dict}, ...]
|
||||
'''
|
||||
docs = session.query(FileDocModel).filter_by(kb_name=kb_name)
|
||||
if file_name:
|
||||
docs = docs.filter_by(file_name=file_name)
|
||||
for k, v in metadata.items():
|
||||
docs = docs.filter(FileDocModel.meta_data[k].as_string()==str(v))
|
||||
|
||||
return [{"id": x.doc_id, "metadata": x.metadata} for x in docs.all()]
|
||||
|
||||
|
||||
@with_session
|
||||
def delete_docs_from_db(session,
|
||||
kb_name: str,
|
||||
file_name: str = None,
|
||||
) -> List[Dict]:
|
||||
'''
|
||||
删除某知识库某文件对应的所有Document,并返回被删除的Document。
|
||||
返回形式:[{"id": str, "metadata": dict}, ...]
|
||||
'''
|
||||
docs = list_docs_from_db(kb_name=kb_name, file_name=file_name)
|
||||
query = session.query(FileDocModel).filter_by(kb_name=kb_name)
|
||||
if file_name:
|
||||
query = query.filter_by(file_name=file_name)
|
||||
query.delete()
|
||||
session.commit()
|
||||
return docs
|
||||
|
||||
|
||||
@with_session
|
||||
def add_docs_to_db(session,
|
||||
kb_name: str,
|
||||
file_name: str,
|
||||
doc_infos: List[Dict]):
|
||||
'''
|
||||
将某知识库某文件对应的所有Document信息添加到数据库。
|
||||
doc_infos形式:[{"id": str, "metadata": dict}, ...]
|
||||
'''
|
||||
for d in doc_infos:
|
||||
obj = FileDocModel(
|
||||
kb_name=kb_name,
|
||||
file_name=file_name,
|
||||
doc_id=d["id"],
|
||||
meta_data=d["metadata"],
|
||||
)
|
||||
session.add(obj)
|
||||
return True
|
||||
|
||||
|
||||
@with_session
|
||||
|
|
@ -20,7 +78,9 @@ def list_files_from_db(session, kb_name):
|
|||
def add_file_to_db(session,
|
||||
kb_file: KnowledgeFile,
|
||||
docs_count: int = 0,
|
||||
custom_docs: bool = False,):
|
||||
custom_docs: bool = False,
|
||||
doc_infos: List[str] = [], # 形式:[{"id": str, "metadata": dict}, ...]
|
||||
):
|
||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
|
||||
if kb:
|
||||
# 如果已经存在该文件,则更新文件信息与版本号
|
||||
|
|
@ -52,6 +112,7 @@ def add_file_to_db(session,
|
|||
)
|
||||
kb.file_count += 1
|
||||
session.add(new_file)
|
||||
add_docs_to_db(kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos)
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -61,6 +122,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
|
|||
kb_name=kb_file.kb_name).first()
|
||||
if existing_file:
|
||||
session.delete(existing_file)
|
||||
delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename)
|
||||
session.commit()
|
||||
|
||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
|
||||
|
|
@ -73,7 +135,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
|
|||
@with_session
|
||||
def delete_files_from_db(session, knowledge_base_name: str):
|
||||
session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete()
|
||||
|
||||
session.query(FileDocModel).filter_by(kb_name=knowledge_base_name).delete()
|
||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first()
|
||||
if kb:
|
||||
kb.file_count = 0
|
||||
|
|
|
|||
|
|
@ -14,7 +14,8 @@ from server.db.repository.knowledge_base_repository import (
|
|||
)
|
||||
from server.db.repository.knowledge_file_repository import (
|
||||
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
|
||||
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db
|
||||
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
|
||||
list_docs_from_db,
|
||||
)
|
||||
|
||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
|
|
@ -25,6 +26,7 @@ from server.knowledge_base.utils import (
|
|||
)
|
||||
from server.utils import embedding_device
|
||||
from typing import List, Union, Dict
|
||||
from typing import List, Union, Dict, Optional
|
||||
|
||||
|
||||
class SupportedVSType:
|
||||
|
|
@ -88,8 +90,11 @@ class KBService(ABC):
|
|||
|
||||
if docs:
|
||||
self.delete_doc(kb_file)
|
||||
self.do_add_doc(docs, **kwargs)
|
||||
status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs))
|
||||
doc_infos = self.do_add_doc(docs, **kwargs)
|
||||
status = add_file_to_db(kb_file,
|
||||
custom_docs=custom_docs,
|
||||
docs_count=len(docs),
|
||||
doc_infos=doc_infos)
|
||||
else:
|
||||
status = False
|
||||
return status
|
||||
|
|
@ -132,6 +137,18 @@ class KBService(ABC):
|
|||
docs = self.do_search(query, top_k, score_threshold, embeddings)
|
||||
return docs
|
||||
|
||||
# TODO: milvus/pg需要实现该方法
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
return None
|
||||
|
||||
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
|
||||
'''
|
||||
通过file_name或metadata检索Document
|
||||
'''
|
||||
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
|
||||
docs = [self.get_doc_by_id(x["id"]) for x in doc_infos]
|
||||
return docs
|
||||
|
||||
@abstractmethod
|
||||
def do_create_kb(self):
|
||||
"""
|
||||
|
|
@ -181,7 +198,7 @@ class KBService(ABC):
|
|||
@abstractmethod
|
||||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
):
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
向知识库添加文档子类实自己逻辑
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from functools import lru_cache
|
|||
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from typing import List
|
||||
from typing import List, Dict, Optional
|
||||
from langchain.docstore.document import Document
|
||||
from server.utils import torch_gc, embedding_device
|
||||
|
||||
|
|
@ -88,6 +88,10 @@ class FaissKBService(KBService):
|
|||
def refresh_vs_cache(self):
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
vector_store = self.load_vector_store()
|
||||
return vector_store.docstore._dict.get(id)
|
||||
|
||||
def do_init(self):
|
||||
self.kb_path = self.get_kb_path()
|
||||
self.vs_path = self.get_vs_path()
|
||||
|
|
@ -114,14 +118,15 @@ class FaissKBService(KBService):
|
|||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
**kwargs,
|
||||
):
|
||||
) -> List[Dict]:
|
||||
vector_store = self.load_vector_store()
|
||||
vector_store.add_documents(docs)
|
||||
ids = vector_store.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
torch_gc()
|
||||
if not kwargs.get("not_refresh_vs_cache"):
|
||||
vector_store.save_local(self.vs_path)
|
||||
self.refresh_vs_cache()
|
||||
return vector_store
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile,
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
from faiss import normalize_L2
|
||||
|
|
@ -22,6 +22,10 @@ class MilvusKBService(KBService):
|
|||
from pymilvus import Collection
|
||||
return Collection(milvus_name)
|
||||
|
||||
# TODO:
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def search(milvus_name, content, limit=3):
|
||||
search_params = {
|
||||
|
|
@ -54,8 +58,10 @@ class MilvusKBService(KBService):
|
|||
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
|
||||
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs):
|
||||
self.milvus.add_documents(docs)
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
ids = self.milvus.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.schema import Document
|
||||
|
|
@ -24,6 +24,10 @@ class PGKBService(KBService):
|
|||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||
|
||||
# TODO:
|
||||
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||
return None
|
||||
|
||||
def do_init(self):
|
||||
self._load_pg_vector()
|
||||
|
||||
|
|
@ -51,8 +55,10 @@ class PGKBService(KBService):
|
|||
return score_threshold_process(score_threshold, top_k,
|
||||
self.pg_vector.similarity_search_with_score(query, top_k))
|
||||
|
||||
def do_add_doc(self, docs: List[Document], **kwargs):
|
||||
self.pg_vector.add_documents(docs)
|
||||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||
ids = self.pg_vector.add_documents(docs)
|
||||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||
return doc_infos
|
||||
|
||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||
with self.pg_vector.connect() as connect:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from server.db.repository.knowledge_file_repository import add_file_to_db
|
|||
from server.db.base import Base, engine
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Literal, Callable, Any, List
|
||||
from typing import Literal, Any, List
|
||||
|
||||
|
||||
pool = ThreadPoolExecutor(os.cpu_count())
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import pydantic
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
import torch
|
||||
from fastapi import FastAPI
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
|
|
@ -69,6 +68,7 @@ class ChatMessage(BaseModel):
|
|||
}
|
||||
|
||||
def torch_gc():
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
# with torch.cuda.device(DEVICE):
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
|||
|
|
@ -20,8 +20,8 @@ from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN,
|
|||
FSCHAT_OPENAI_API, )
|
||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||
fschat_openai_api_address, set_httpx_timeout,
|
||||
llm_device, embedding_device, get_model_worker_config)
|
||||
from server.utils import MakeFastAPIOffline, FastAPI
|
||||
llm_device, embedding_device, get_model_worker_config,
|
||||
MakeFastAPIOffline, FastAPI)
|
||||
import argparse
|
||||
from typing import Tuple, List
|
||||
from configs import VERSION
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from doctest import testfile
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
|
|
|
|||
Loading…
Reference in New Issue