Langchain-Chatchat/server/knowledge_base/kb_service/milvus_kb_service.py

114 lines
4.1 KiB
Python
Raw Normal View History

from typing import List, Dict, Optional
2023-08-07 16:32:34 +08:00
from langchain.schema import Document
from langchain.vectorstores.milvus import Milvus
2023-08-07 16:32:34 +08:00
from configs import kbs_config
2023-08-27 11:21:10 +08:00
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import KnowledgeFile
2023-08-07 16:32:34 +08:00
class MilvusKBService(KBService):
milvus: Milvus
@staticmethod
def get_collection(milvus_name):
from pymilvus import Collection
return Collection(milvus_name)
# def save_vector_store(self):
# if self.milvus.col:
# self.milvus.col.flush()
def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
result = []
if self.milvus.col:
data_list = self.milvus.col.query(expr=f'pk in {ids}', output_fields=["*"])
for data in data_list:
text = data.pop("text")
result.append(Document(page_content=text, metadata=data))
return result
def del_doc_by_ids(self, ids: List[str]) -> bool:
self.milvus.col.delete(expr=f'pk in {ids}')
2023-08-07 16:32:34 +08:00
@staticmethod
def search(milvus_name, content, limit=3):
search_params = {
"metric_type": "L2",
2023-08-07 16:32:34 +08:00
"params": {"nprobe": 10},
}
2023-08-07 16:32:34 +08:00
c = MilvusKBService.get_collection(milvus_name)
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
def do_create_kb(self):
pass
def vs_type(self) -> str:
return SupportedVSType.MILVUS
def _load_milvus(self):
self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model),
collection_name=self.kb_name,
connection_args=kbs_config.get("milvus"),
index_params=kbs_config.ge("milvus_kwargs")["index_params"],
search_params=kbs_config.get("milvus_kwargs")["search_params"]
)
2023-08-07 16:32:34 +08:00
def do_init(self):
self._load_milvus()
def do_drop_kb(self):
if self.milvus.col:
self.milvus.col.release()
self.milvus.col.drop()
2023-08-07 16:32:34 +08:00
def do_search(self, query: str, top_k: int, score_threshold: float):
self._load_milvus()
支持在线 Embeddings, Lite 模式支持所有知识库相关功能 (#1924) 新功能: - 支持在线 Embeddings:zhipu-api, qwen-api, minimax-api, qianfan-api - API 增加 /other/embed_texts 接口 - init_database.py 增加 --embed-model 参数,可以指定使用的嵌入模型(本地或在线均可) - 对于 FAISS 知识库,支持多向量库,默认位置:{KB_PATH}/vector_store/{embed_model} - Lite 模式支持所有知识库相关功能。此模式下最主要的限制是: - 不能使用本地 LLM 和 Embeddings 模型 - 知识库不支持 PDF 文件 - init_database.py 重建知识库时不再默认情况数据库表,增加 clear-tables 参数手动控制。 - API 和 WEBUI 中 score_threshold 参数范围改为 [0, 2],以更好的适应在线嵌入模型 问题修复: - API 中 list_config_models 会删除 ONLINE_LLM_MODEL 中的敏感信息,导致第二轮API请求错误 开发者: - 统一向量库的识别:以(kb_name,embed_model)为判断向量库唯一性的依据,避免 FAISS 知识库缓存加载逻辑错误 - KBServiceFactory.get_service_by_name 中添加 default_embed_model 参数,用于在构建新知识库时设置 embed_model - 优化 kb_service 中 Embeddings 操作: - 统一加载接口: server.utils.load_embeddings,利用全局缓存避免各处 Embeddings 传参 - 统一文本嵌入接口:server.knowledge_base.kb_service.base.[embed_texts, embed_documents] - 重写 normalize 函数,去除对 scikit-learn/scipy 的依赖
2023-10-31 14:26:50 +08:00
embed_func = EmbeddingsFunAdapter(self.embed_model)
embeddings = embed_func.embed_query(query)
docs = self.milvus.similarity_search_with_score_by_vector(embeddings, top_k)
return score_threshold_process(score_threshold, top_k, docs)
2023-08-07 16:32:34 +08:00
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
# TODO: workaround for bug #10492 in langchain
for doc in docs:
for k, v in doc.metadata.items():
doc.metadata[k] = str(v)
for field in self.milvus.fields:
doc.metadata.setdefault(field, "")
doc.metadata.pop(self.milvus._text_field, None)
doc.metadata.pop(self.milvus._vector_field, None)
ids = self.milvus.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
return doc_infos
2023-08-07 16:32:34 +08:00
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs):
if self.milvus.col:
filepath = kb_file.filepath.replace('\\', '\\\\')
delete_list = [item.get("pk") for item in
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
self.milvus.col.delete(expr=f'pk in {delete_list}')
2023-08-07 16:32:34 +08:00
def do_clear_vs(self):
2023-08-24 22:53:13 +08:00
if self.milvus.col:
self.do_drop_kb()
self.do_init()
if __name__ == '__main__':
2023-08-08 14:25:55 +08:00
# 测试建表使用
from server.db.base import Base, engine
2023-08-22 16:52:04 +08:00
2023-08-08 14:25:55 +08:00
Base.metadata.create_all(bind=engine)
2023-08-07 16:32:34 +08:00
milvusService = MilvusKBService("test")
# milvusService.add_doc(KnowledgeFile("README.md", "test"))
print(milvusService.get_doc_by_ids(["444022434274215486"]))
# milvusService.delete_doc(KnowledgeFile("README.md", "test"))
# milvusService.do_drop_kb()
# print(milvusService.search_docs("如何启动api服务"))