添加Milvus库 (#1011)
This commit is contained in:
parent
18d31f5116
commit
0f46185cfb
|
|
@ -280,10 +280,13 @@ BING_SUBSCRIPTION_KEY = ""
|
||||||
|
|
||||||
kbs_config = {
|
kbs_config = {
|
||||||
"faiss": {
|
"faiss": {
|
||||||
|
|
||||||
},
|
},
|
||||||
"milvus": {
|
"milvus": {
|
||||||
"milvus_host": "192.168.50.128",
|
"host": "127.0.0.1",
|
||||||
"milvus_port": 19530
|
"port": "19530",
|
||||||
|
"user": "",
|
||||||
|
"password": "",
|
||||||
|
"secure": False,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||||
|
|
@ -195,11 +195,9 @@ class KBService(ABC):
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
vector_store_type: str = "faiss",
|
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
):
|
):
|
||||||
self.kb_name = knowledge_base_name
|
self.kb_name = knowledge_base_name
|
||||||
self.vs_type = vector_store_type
|
|
||||||
self.embed_model = embed_model
|
self.embed_model = embed_model
|
||||||
self.kb_path = get_kb_path(self.kb_name)
|
self.kb_path = get_kb_path(self.kb_name)
|
||||||
self.doc_path = get_doc_path(self.kb_name)
|
self.doc_path = get_doc_path(self.kb_name)
|
||||||
|
|
@ -212,7 +210,7 @@ class KBService(ABC):
|
||||||
if not os.path.exists(self.doc_path):
|
if not os.path.exists(self.doc_path):
|
||||||
os.makedirs(self.doc_path)
|
os.makedirs(self.doc_path)
|
||||||
self.do_create_kb()
|
self.do_create_kb()
|
||||||
status = add_kb_to_db(self.kb_name, self.vs_type, self.embed_model)
|
status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
def clear_vs(self):
|
def clear_vs(self):
|
||||||
|
|
@ -225,7 +223,7 @@ class KBService(ABC):
|
||||||
"""
|
"""
|
||||||
删除知识库
|
删除知识库
|
||||||
"""
|
"""
|
||||||
self.do_remove_kb()
|
self.do_drop_kb()
|
||||||
status = delete_kb_from_db(self.kb_name)
|
status = delete_kb_from_db(self.kb_name)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
|
@ -245,7 +243,7 @@ class KBService(ABC):
|
||||||
"""
|
"""
|
||||||
if os.path.exists(kb_file.filepath):
|
if os.path.exists(kb_file.filepath):
|
||||||
os.remove(kb_file.filepath)
|
os.remove(kb_file.filepath)
|
||||||
self.do_delete(kb_file)
|
self.do_delete_doc(kb_file)
|
||||||
status = delete_file_from_db(kb_file)
|
status = delete_file_from_db(kb_file)
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
|
@ -293,7 +291,7 @@ class KBService(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def do_remove_kb(self):
|
def do_drop_kb(self):
|
||||||
"""
|
"""
|
||||||
删除知识库子类实自己逻辑
|
删除知识库子类实自己逻辑
|
||||||
"""
|
"""
|
||||||
|
|
@ -320,7 +318,7 @@ class KBService(ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def do_delete(self,
|
def do_delete_doc(self,
|
||||||
kb_file: KnowledgeFile):
|
kb_file: KnowledgeFile):
|
||||||
"""
|
"""
|
||||||
从知识库删除文档子类实自己逻辑
|
从知识库删除文档子类实自己逻辑
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ class DefaultKBService(KBService):
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def do_remove_kbs(self):
|
def do_drop_kbs(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def do_search(self):
|
def do_search(self):
|
||||||
|
|
@ -23,5 +23,5 @@ class DefaultKBService(KBService):
|
||||||
def do_insert_one_knowledge(self):
|
def do_insert_one_knowledge(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def do_delete(self):
|
def do_delete_doc(self):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ class FaissKBService(KBService):
|
||||||
if not os.path.exists(self.vs_path):
|
if not os.path.exists(self.vs_path):
|
||||||
os.makedirs(self.vs_path)
|
os.makedirs(self.vs_path)
|
||||||
|
|
||||||
def do_remove_kb(self):
|
def do_drop_kb(self):
|
||||||
shutil.rmtree(self.kb_path)
|
shutil.rmtree(self.kb_path)
|
||||||
|
|
||||||
def do_search(self,
|
def do_search(self,
|
||||||
|
|
@ -105,7 +105,7 @@ class FaissKBService(KBService):
|
||||||
vector_store.save_local(self.vs_path)
|
vector_store.save_local(self.vs_path)
|
||||||
refresh_vs_cache(self.kb_name)
|
refresh_vs_cache(self.kb_name)
|
||||||
|
|
||||||
def do_delete(self,
|
def do_delete_doc(self,
|
||||||
kb_file: KnowledgeFile):
|
kb_file: KnowledgeFile):
|
||||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||||
|
|
@ -113,7 +113,6 @@ class FaissKBService(KBService):
|
||||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||||
if len(ids) == 0:
|
if len(ids) == 0:
|
||||||
return None
|
return None
|
||||||
print(len(ids))
|
|
||||||
vector_store = delete_doc_from_faiss(vector_store, ids)
|
vector_store = delete_doc_from_faiss(vector_store, ids)
|
||||||
vector_store.save_local(self.vs_path)
|
vector_store.save_local(self.vs_path)
|
||||||
refresh_vs_cache(self.kb_name)
|
refresh_vs_cache(self.kb_name)
|
||||||
|
|
|
||||||
|
|
@ -1,65 +1,79 @@
|
||||||
from pymilvus import (
|
from typing import List
|
||||||
connections,
|
|
||||||
utility,
|
|
||||||
FieldSchema,
|
|
||||||
CollectionSchema,
|
|
||||||
DataType,
|
|
||||||
Collection,
|
|
||||||
)
|
|
||||||
|
|
||||||
from server.knowledge_base.kb_service.base import KBService
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.schema import Document
|
||||||
|
from langchain.vectorstores import Milvus
|
||||||
|
|
||||||
|
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
||||||
|
from server.knowledge_base import KnowledgeFile
|
||||||
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings, add_doc_to_db
|
||||||
|
|
||||||
|
|
||||||
def get_collection(milvus_name):
|
class MilvusKBService(KBService):
|
||||||
return Collection(milvus_name)
|
milvus: Milvus
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_collection(milvus_name):
|
||||||
|
from pymilvus import Collection
|
||||||
|
return Collection(milvus_name)
|
||||||
|
|
||||||
def search(milvus_name, content, limit=3):
|
@staticmethod
|
||||||
search_params = {
|
def search(milvus_name, content, limit=3):
|
||||||
"metric_type": "L2",
|
search_params = {
|
||||||
"params": {"nprobe": 10},
|
|
||||||
}
|
|
||||||
c = get_collection(milvus_name)
|
|
||||||
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["random"])
|
|
||||||
|
|
||||||
|
|
||||||
class MilvusKBService():
|
|
||||||
milvus_host: str
|
|
||||||
milvus_port: int
|
|
||||||
dim: int
|
|
||||||
|
|
||||||
def __init__(self, knowledge_base_name: str, vector_store_type: str, milvus_host="localhost", milvus_port=19530,
|
|
||||||
dim=8):
|
|
||||||
|
|
||||||
super().__init__(knowledge_base_name, vector_store_type)
|
|
||||||
self.milvus_host = milvus_host
|
|
||||||
self.milvus_port = milvus_port
|
|
||||||
self.dim = dim
|
|
||||||
|
|
||||||
def connect(self):
|
|
||||||
connections.connect("default", host=self.milvus_host, port=self.milvus_port)
|
|
||||||
|
|
||||||
def create_collection(self, milvus_name):
|
|
||||||
fields = [
|
|
||||||
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=False),
|
|
||||||
FieldSchema(name="content", dtype=DataType.STRING),
|
|
||||||
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=self.dim)
|
|
||||||
]
|
|
||||||
schema = CollectionSchema(fields)
|
|
||||||
collection = Collection(milvus_name, schema)
|
|
||||||
index = {
|
|
||||||
"index_type": "IVF_FLAT",
|
|
||||||
"metric_type": "L2",
|
"metric_type": "L2",
|
||||||
"params": {"nlist": 128},
|
"params": {"nprobe": 10},
|
||||||
}
|
}
|
||||||
collection.create_index("embeddings", index)
|
c = MilvusKBService.get_collection(milvus_name)
|
||||||
collection.load()
|
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
|
||||||
return collection
|
|
||||||
|
|
||||||
def insert_collection(self, milvus_name, content=[]):
|
def do_create_kb(self):
|
||||||
get_collection(milvus_name).insert(dataset)
|
pass
|
||||||
|
|
||||||
|
def vs_type(self) -> str:
|
||||||
|
return SupportedVSType.MILVUS
|
||||||
|
|
||||||
|
def _load_milvus(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None):
|
||||||
|
_embeddings = embeddings
|
||||||
|
if _embeddings is None:
|
||||||
|
_embeddings = load_embeddings(self.embed_model, embedding_device)
|
||||||
|
self.milvus = Milvus(embedding_function=_embeddings,
|
||||||
|
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
|
||||||
|
|
||||||
|
def do_init(self):
|
||||||
|
self._load_milvus()
|
||||||
|
|
||||||
|
def do_drop_kb(self):
|
||||||
|
self.milvus.col.drop()
|
||||||
|
|
||||||
|
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
|
||||||
|
self._load_milvus(embeddings=embeddings)
|
||||||
|
return self.milvus.similarity_search(query, top_k)
|
||||||
|
|
||||||
|
def add_doc(self, kb_file: KnowledgeFile):
|
||||||
|
"""
|
||||||
|
向知识库添加文件
|
||||||
|
"""
|
||||||
|
docs = kb_file.file2text()
|
||||||
|
self.milvus.add_documents(docs)
|
||||||
|
status = add_doc_to_db(kb_file)
|
||||||
|
return status
|
||||||
|
|
||||||
|
def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def do_delete_doc(self, kb_file: KnowledgeFile):
|
||||||
|
filepath = kb_file.filepath.replace('\\', '\\\\')
|
||||||
|
delete_list = [item.get("pk") for item in
|
||||||
|
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
|
||||||
|
self.milvus.col.delete(expr=f'pk in {delete_list}')
|
||||||
|
|
||||||
|
def do_clear_vs(self):
|
||||||
|
self.milvus.col.drop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
milvusService = MilvusService(milvus_host='192.168.50.128')
|
milvusService = MilvusKBService("test")
|
||||||
milvusService.insert_collection(test,dataset)
|
# milvusService.add_doc(KnowledgeFile("test.pdf", "test"))
|
||||||
|
# milvusService.delete_doc(KnowledgeFile("test.pdf", "test"))
|
||||||
|
milvusService.do_drop_kb()
|
||||||
|
print(milvusService.search_docs("测试"))
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db
|
||||||
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||||
|
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||||
|
|
||||||
|
|
||||||
class KBServiceFactory:
|
class KBServiceFactory:
|
||||||
|
|
@ -11,6 +12,9 @@ class KBServiceFactory:
|
||||||
if SupportedVSType.FAISS == vector_store_type:
|
if SupportedVSType.FAISS == vector_store_type:
|
||||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||||
return FaissKBService(kb_name)
|
return FaissKBService(kb_name)
|
||||||
|
elif SupportedVSType.MILVUS == vector_store_type:
|
||||||
|
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
|
||||||
|
return MilvusKBService(kb_name)
|
||||||
elif SupportedVSType.DEFAULT == vector_store_type:
|
elif SupportedVSType.DEFAULT == vector_store_type:
|
||||||
return DefaultKBService(kb_name)
|
return DefaultKBService(kb_name)
|
||||||
|
|
||||||
|
|
@ -28,9 +32,9 @@ class KBServiceFactory:
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS)
|
KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS)
|
||||||
init_db()
|
init_db()
|
||||||
KBService.create_kbs()
|
KBService.create_kb()
|
||||||
KBService = KBServiceFactory.get_default()
|
KBService = KBServiceFactory.get_default()
|
||||||
print(KBService.list_kbs())
|
print(KBService.list_kbs())
|
||||||
KBService = KBServiceFactory.get_service_by_name("test")
|
KBService = KBServiceFactory.get_service_by_name("test")
|
||||||
print(KBService.list_docs())
|
print(KBService.list_docs())
|
||||||
KBService.drop_kbs()
|
KBService.drop_kb()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue