升级注意
数据库表发生变化,需要重建知识库 新功能 - 增加FileDocModel库表,存储知识文件与向量库Document ID对应关系以及元数据,便于检索向量库 - 增加FileDocModel对应的数据库操作函数(这些函数主要是给KBService调用,用户一般无需使用): - list_docs_from_db: 根据知识库名称、文件名称、元数据检索对应的Document IDs - delete_docs_from_db: 根据知识库名称、文件名称删除对应的file-doc映射 - add_docs_to_db: 添加对应的file-doc映射 - KBService增加list_docs方法,可以根据文件名、元数据检索Document。当前仅支持FAISS,待milvus/pg实现get_doc_by_id方法后即自动支持。 - 去除server.utils对torch的依赖 待完善 - milvus/pg kb_service需要实现get_doc_by_id方法
This commit is contained in:
parent
96770cca53
commit
55e417a263
|
|
@ -3,10 +3,14 @@ from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from configs.model_config import SQLALCHEMY_DATABASE_URI
|
from configs.model_config import SQLALCHEMY_DATABASE_URI
|
||||||
|
import json
|
||||||
|
|
||||||
engine = create_engine(SQLALCHEMY_DATABASE_URI)
|
|
||||||
|
engine = create_engine(
|
||||||
|
SQLALCHEMY_DATABASE_URI,
|
||||||
|
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, func
|
from sqlalchemy import Column, Integer, String, DateTime, Float, Boolean, JSON, func
|
||||||
|
|
||||||
from server.db.base import Base
|
from server.db.base import Base
|
||||||
|
|
||||||
|
|
@ -23,3 +23,18 @@ class KnowledgeFileModel(Base):
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<KnowledgeFile(id='{self.id}', file_name='{self.file_name}', file_ext='{self.file_ext}', kb_name='{self.kb_name}', document_loader_name='{self.document_loader_name}', text_splitter_name='{self.text_splitter_name}', file_version='{self.file_version}', create_time='{self.create_time}')>"
|
return f"<KnowledgeFile(id='{self.id}', file_name='{self.file_name}', file_ext='{self.file_ext}', kb_name='{self.kb_name}', document_loader_name='{self.document_loader_name}', text_splitter_name='{self.text_splitter_name}', file_version='{self.file_version}', create_time='{self.create_time}')>"
|
||||||
|
|
||||||
|
|
||||||
|
class FileDocModel(Base):
|
||||||
|
"""
|
||||||
|
文件-向量库文档模型
|
||||||
|
"""
|
||||||
|
__tablename__ = 'file_doc'
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True, comment='ID')
|
||||||
|
kb_name = Column(String(50), comment='知识库名称')
|
||||||
|
file_name = Column(String(255), comment='文件名称')
|
||||||
|
doc_id = Column(String(50), comment="向量库文档ID")
|
||||||
|
meta_data = Column(JSON, default={})
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return f"<FileDoc(id='{self.id}', kb_name='{self.kb_name}', file_name='{self.file_name}', doc_id='{self.doc_id}', metadata='{self.metadata}')>"
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,65 @@
|
||||||
from server.db.models.knowledge_base_model import KnowledgeBaseModel
|
from server.db.models.knowledge_base_model import KnowledgeBaseModel
|
||||||
from server.db.models.knowledge_file_model import KnowledgeFileModel
|
from server.db.models.knowledge_file_model import KnowledgeFileModel, FileDocModel
|
||||||
from server.db.session import with_session
|
from server.db.session import with_session
|
||||||
from server.knowledge_base.utils import KnowledgeFile
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def list_docs_from_db(session,
|
||||||
|
kb_name: str,
|
||||||
|
file_name: str = None,
|
||||||
|
metadata: Dict = {},
|
||||||
|
) -> List[Dict]:
|
||||||
|
'''
|
||||||
|
列出某知识库某文件对应的所有Document。
|
||||||
|
返回形式:[{"id": str, "metadata": dict}, ...]
|
||||||
|
'''
|
||||||
|
docs = session.query(FileDocModel).filter_by(kb_name=kb_name)
|
||||||
|
if file_name:
|
||||||
|
docs = docs.filter_by(file_name=file_name)
|
||||||
|
for k, v in metadata.items():
|
||||||
|
docs = docs.filter(FileDocModel.meta_data[k].as_string()==str(v))
|
||||||
|
|
||||||
|
return [{"id": x.doc_id, "metadata": x.metadata} for x in docs.all()]
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def delete_docs_from_db(session,
|
||||||
|
kb_name: str,
|
||||||
|
file_name: str = None,
|
||||||
|
) -> List[Dict]:
|
||||||
|
'''
|
||||||
|
删除某知识库某文件对应的所有Document,并返回被删除的Document。
|
||||||
|
返回形式:[{"id": str, "metadata": dict}, ...]
|
||||||
|
'''
|
||||||
|
docs = list_docs_from_db(kb_name=kb_name, file_name=file_name)
|
||||||
|
query = session.query(FileDocModel).filter_by(kb_name=kb_name)
|
||||||
|
if file_name:
|
||||||
|
query = query.filter_by(file_name=file_name)
|
||||||
|
query.delete()
|
||||||
|
session.commit()
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
|
@with_session
|
||||||
|
def add_docs_to_db(session,
|
||||||
|
kb_name: str,
|
||||||
|
file_name: str,
|
||||||
|
doc_infos: List[Dict]):
|
||||||
|
'''
|
||||||
|
将某知识库某文件对应的所有Document信息添加到数据库。
|
||||||
|
doc_infos形式:[{"id": str, "metadata": dict}, ...]
|
||||||
|
'''
|
||||||
|
for d in doc_infos:
|
||||||
|
obj = FileDocModel(
|
||||||
|
kb_name=kb_name,
|
||||||
|
file_name=file_name,
|
||||||
|
doc_id=d["id"],
|
||||||
|
meta_data=d["metadata"],
|
||||||
|
)
|
||||||
|
session.add(obj)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@with_session
|
@with_session
|
||||||
|
|
@ -20,7 +78,9 @@ def list_files_from_db(session, kb_name):
|
||||||
def add_file_to_db(session,
|
def add_file_to_db(session,
|
||||||
kb_file: KnowledgeFile,
|
kb_file: KnowledgeFile,
|
||||||
docs_count: int = 0,
|
docs_count: int = 0,
|
||||||
custom_docs: bool = False,):
|
custom_docs: bool = False,
|
||||||
|
doc_infos: List[str] = [], # 形式:[{"id": str, "metadata": dict}, ...]
|
||||||
|
):
|
||||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
|
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
|
||||||
if kb:
|
if kb:
|
||||||
# 如果已经存在该文件,则更新文件信息与版本号
|
# 如果已经存在该文件,则更新文件信息与版本号
|
||||||
|
|
@ -52,6 +112,7 @@ def add_file_to_db(session,
|
||||||
)
|
)
|
||||||
kb.file_count += 1
|
kb.file_count += 1
|
||||||
session.add(new_file)
|
session.add(new_file)
|
||||||
|
add_docs_to_db(kb_name=kb_file.kb_name, file_name=kb_file.filename, doc_infos=doc_infos)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -61,6 +122,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
|
||||||
kb_name=kb_file.kb_name).first()
|
kb_name=kb_file.kb_name).first()
|
||||||
if existing_file:
|
if existing_file:
|
||||||
session.delete(existing_file)
|
session.delete(existing_file)
|
||||||
|
delete_docs_from_db(kb_name=kb_file.kb_name, file_name=kb_file.filename)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
|
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=kb_file.kb_name).first()
|
||||||
|
|
@ -73,7 +135,7 @@ def delete_file_from_db(session, kb_file: KnowledgeFile):
|
||||||
@with_session
|
@with_session
|
||||||
def delete_files_from_db(session, knowledge_base_name: str):
|
def delete_files_from_db(session, knowledge_base_name: str):
|
||||||
session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete()
|
session.query(KnowledgeFileModel).filter_by(kb_name=knowledge_base_name).delete()
|
||||||
|
session.query(FileDocModel).filter_by(kb_name=knowledge_base_name).delete()
|
||||||
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first()
|
kb = session.query(KnowledgeBaseModel).filter_by(kb_name=knowledge_base_name).first()
|
||||||
if kb:
|
if kb:
|
||||||
kb.file_count = 0
|
kb.file_count = 0
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,8 @@ from server.db.repository.knowledge_base_repository import (
|
||||||
)
|
)
|
||||||
from server.db.repository.knowledge_file_repository import (
|
from server.db.repository.knowledge_file_repository import (
|
||||||
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
|
add_file_to_db, delete_file_from_db, delete_files_from_db, file_exists_in_db,
|
||||||
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db
|
count_files_from_db, list_files_from_db, get_file_detail, delete_file_from_db,
|
||||||
|
list_docs_from_db,
|
||||||
)
|
)
|
||||||
|
|
||||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
|
|
@ -25,6 +26,7 @@ from server.knowledge_base.utils import (
|
||||||
)
|
)
|
||||||
from server.utils import embedding_device
|
from server.utils import embedding_device
|
||||||
from typing import List, Union, Dict
|
from typing import List, Union, Dict
|
||||||
|
from typing import List, Union, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
class SupportedVSType:
|
class SupportedVSType:
|
||||||
|
|
@ -88,8 +90,11 @@ class KBService(ABC):
|
||||||
|
|
||||||
if docs:
|
if docs:
|
||||||
self.delete_doc(kb_file)
|
self.delete_doc(kb_file)
|
||||||
self.do_add_doc(docs, **kwargs)
|
doc_infos = self.do_add_doc(docs, **kwargs)
|
||||||
status = add_file_to_db(kb_file, custom_docs=custom_docs, docs_count=len(docs))
|
status = add_file_to_db(kb_file,
|
||||||
|
custom_docs=custom_docs,
|
||||||
|
docs_count=len(docs),
|
||||||
|
doc_infos=doc_infos)
|
||||||
else:
|
else:
|
||||||
status = False
|
status = False
|
||||||
return status
|
return status
|
||||||
|
|
@ -132,6 +137,18 @@ class KBService(ABC):
|
||||||
docs = self.do_search(query, top_k, score_threshold, embeddings)
|
docs = self.do_search(query, top_k, score_threshold, embeddings)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
# TODO: milvus/pg需要实现该方法
|
||||||
|
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def list_docs(self, file_name: str = None, metadata: Dict = {}) -> List[Document]:
|
||||||
|
'''
|
||||||
|
通过file_name或metadata检索Document
|
||||||
|
'''
|
||||||
|
doc_infos = list_docs_from_db(kb_name=self.kb_name, file_name=file_name, metadata=metadata)
|
||||||
|
docs = [self.get_doc_by_id(x["id"]) for x in doc_infos]
|
||||||
|
return docs
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def do_create_kb(self):
|
def do_create_kb(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -181,7 +198,7 @@ class KBService(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def do_add_doc(self,
|
def do_add_doc(self,
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
):
|
) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
向知识库添加文档子类实自己逻辑
|
向知识库添加文档子类实自己逻辑
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ from functools import lru_cache
|
||||||
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
|
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
|
||||||
from langchain.vectorstores import FAISS
|
from langchain.vectorstores import FAISS
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from typing import List
|
from typing import List, Dict, Optional
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from server.utils import torch_gc, embedding_device
|
from server.utils import torch_gc, embedding_device
|
||||||
|
|
||||||
|
|
@ -88,6 +88,10 @@ class FaissKBService(KBService):
|
||||||
def refresh_vs_cache(self):
|
def refresh_vs_cache(self):
|
||||||
refresh_vs_cache(self.kb_name)
|
refresh_vs_cache(self.kb_name)
|
||||||
|
|
||||||
|
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||||
|
vector_store = self.load_vector_store()
|
||||||
|
return vector_store.docstore._dict.get(id)
|
||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
self.kb_path = self.get_kb_path()
|
self.kb_path = self.get_kb_path()
|
||||||
self.vs_path = self.get_vs_path()
|
self.vs_path = self.get_vs_path()
|
||||||
|
|
@ -114,14 +118,15 @@ class FaissKBService(KBService):
|
||||||
def do_add_doc(self,
|
def do_add_doc(self,
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
) -> List[Dict]:
|
||||||
vector_store = self.load_vector_store()
|
vector_store = self.load_vector_store()
|
||||||
vector_store.add_documents(docs)
|
ids = vector_store.add_documents(docs)
|
||||||
|
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||||
torch_gc()
|
torch_gc()
|
||||||
if not kwargs.get("not_refresh_vs_cache"):
|
if not kwargs.get("not_refresh_vs_cache"):
|
||||||
vector_store.save_local(self.vs_path)
|
vector_store.save_local(self.vs_path)
|
||||||
self.refresh_vs_cache()
|
self.refresh_vs_cache()
|
||||||
return vector_store
|
return doc_infos
|
||||||
|
|
||||||
def do_delete_doc(self,
|
def do_delete_doc(self,
|
||||||
kb_file: KnowledgeFile,
|
kb_file: KnowledgeFile,
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from faiss import normalize_L2
|
from faiss import normalize_L2
|
||||||
|
|
@ -22,6 +22,10 @@ class MilvusKBService(KBService):
|
||||||
from pymilvus import Collection
|
from pymilvus import Collection
|
||||||
return Collection(milvus_name)
|
return Collection(milvus_name)
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||||
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def search(milvus_name, content, limit=3):
|
def search(milvus_name, content, limit=3):
|
||||||
search_params = {
|
search_params = {
|
||||||
|
|
@ -54,8 +58,10 @@ class MilvusKBService(KBService):
|
||||||
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
|
self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings))
|
||||||
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k))
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs):
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
self.milvus.add_documents(docs)
|
ids = self.milvus.add_documents(docs)
|
||||||
|
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||||
|
return doc_infos
|
||||||
|
|
||||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||||
filepath = kb_file.filepath.replace('\\', '\\\\')
|
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.schema import Document
|
from langchain.schema import Document
|
||||||
|
|
@ -24,6 +24,10 @@ class PGKBService(KBService):
|
||||||
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
distance_strategy=DistanceStrategy.EUCLIDEAN,
|
||||||
connection_string=kbs_config.get("pg").get("connection_uri"))
|
connection_string=kbs_config.get("pg").get("connection_uri"))
|
||||||
|
|
||||||
|
# TODO:
|
||||||
|
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
||||||
|
return None
|
||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
self._load_pg_vector()
|
self._load_pg_vector()
|
||||||
|
|
||||||
|
|
@ -51,8 +55,10 @@ class PGKBService(KBService):
|
||||||
return score_threshold_process(score_threshold, top_k,
|
return score_threshold_process(score_threshold, top_k,
|
||||||
self.pg_vector.similarity_search_with_score(query, top_k))
|
self.pg_vector.similarity_search_with_score(query, top_k))
|
||||||
|
|
||||||
def do_add_doc(self, docs: List[Document], **kwargs):
|
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
|
||||||
self.pg_vector.add_documents(docs)
|
ids = self.pg_vector.add_documents(docs)
|
||||||
|
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
||||||
|
return doc_infos
|
||||||
|
|
||||||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
|
||||||
with self.pg_vector.connect() as connect:
|
with self.pg_vector.connect() as connect:
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from server.db.repository.knowledge_file_repository import add_file_to_db
|
||||||
from server.db.base import Base, engine
|
from server.db.base import Base, engine
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Literal, Callable, Any, List
|
from typing import Literal, Any, List
|
||||||
|
|
||||||
|
|
||||||
pool = ThreadPoolExecutor(os.cpu_count())
|
pool = ThreadPoolExecutor(os.cpu_count())
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
import pydantic
|
import pydantic
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List
|
from typing import List
|
||||||
import torch
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
@ -69,6 +68,7 @@ class ChatMessage(BaseModel):
|
||||||
}
|
}
|
||||||
|
|
||||||
def torch_gc():
|
def torch_gc():
|
||||||
|
import torch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
# with torch.cuda.device(DEVICE):
|
# with torch.cuda.device(DEVICE):
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,8 @@ from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN,
|
||||||
FSCHAT_OPENAI_API, )
|
FSCHAT_OPENAI_API, )
|
||||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||||
fschat_openai_api_address, set_httpx_timeout,
|
fschat_openai_api_address, set_httpx_timeout,
|
||||||
llm_device, embedding_device, get_model_worker_config)
|
llm_device, embedding_device, get_model_worker_config,
|
||||||
from server.utils import MakeFastAPIOffline, FastAPI
|
MakeFastAPIOffline, FastAPI)
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Tuple, List
|
from typing import Tuple, List
|
||||||
from configs import VERSION
|
from configs import VERSION
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
from doctest import testfile
|
|
||||||
import requests
|
import requests
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue