添加Milvus库 (#1011)

This commit is contained in:
zqt996 2023-08-07 16:56:57 +08:00 committed by GitHub
parent 18d31f5116
commit 0f46185cfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 90 additions and 72 deletions

View File

@ -280,10 +280,13 @@ BING_SUBSCRIPTION_KEY = ""
kbs_config = {
"faiss": {
},
"milvus": {
"milvus_host": "192.168.50.128",
"milvus_port": 19530
"host": "127.0.0.1",
"port": "19530",
"user": "",
"password": "",
"secure": False,
}
}
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")

View File

@ -195,11 +195,9 @@ class KBService(ABC):
def __init__(self,
knowledge_base_name: str,
vector_store_type: str = "faiss",
embed_model: str = EMBEDDING_MODEL,
):
self.kb_name = knowledge_base_name
self.vs_type = vector_store_type
self.embed_model = embed_model
self.kb_path = get_kb_path(self.kb_name)
self.doc_path = get_doc_path(self.kb_name)
@ -212,7 +210,7 @@ class KBService(ABC):
if not os.path.exists(self.doc_path):
os.makedirs(self.doc_path)
self.do_create_kb()
status = add_kb_to_db(self.kb_name, self.vs_type, self.embed_model)
status = add_kb_to_db(self.kb_name, self.vs_type(), self.embed_model)
return status
def clear_vs(self):
@ -225,7 +223,7 @@ class KBService(ABC):
"""
删除知识库
"""
self.do_remove_kb()
self.do_drop_kb()
status = delete_kb_from_db(self.kb_name)
return status
@ -245,7 +243,7 @@ class KBService(ABC):
"""
if os.path.exists(kb_file.filepath):
os.remove(kb_file.filepath)
self.do_delete(kb_file)
self.do_delete_doc(kb_file)
status = delete_file_from_db(kb_file)
return status
@ -293,7 +291,7 @@ class KBService(ABC):
pass
@abstractmethod
def do_remove_kb(self):
def do_drop_kb(self):
"""
删除知识库子类实自己逻辑
"""
@ -320,7 +318,7 @@ class KBService(ABC):
pass
@abstractmethod
def do_delete(self,
def do_delete_doc(self,
kb_file: KnowledgeFile):
"""
从知识库删除文档子类实自己逻辑

View File

@ -11,7 +11,7 @@ class DefaultKBService(KBService):
def do_init(self):
pass
def do_remove_kbs(self):
def do_drop_kbs(self):
pass
def do_search(self):
@ -23,5 +23,5 @@ class DefaultKBService(KBService):
def do_insert_one_knowledge(self):
pass
def do_delete(self):
def do_delete_doc(self):
pass

View File

@ -76,7 +76,7 @@ class FaissKBService(KBService):
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
def do_remove_kb(self):
def do_drop_kb(self):
shutil.rmtree(self.kb_path)
def do_search(self,
@ -105,7 +105,7 @@ class FaissKBService(KBService):
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
def do_delete(self,
def do_delete_doc(self,
kb_file: KnowledgeFile):
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
@ -113,7 +113,6 @@ class FaissKBService(KBService):
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None
print(len(ids))
vector_store = delete_doc_from_faiss(vector_store, ids)
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)

View File

@ -1,65 +1,79 @@
from pymilvus import (
connections,
utility,
FieldSchema,
CollectionSchema,
DataType,
Collection,
)
from typing import List
from server.knowledge_base.kb_service.base import KBService
from langchain.embeddings.base import Embeddings
from langchain.schema import Document
from langchain.vectorstores import Milvus
from configs.model_config import EMBEDDING_DEVICE, kbs_config
from server.knowledge_base import KnowledgeFile
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings, add_doc_to_db
def get_collection(milvus_name):
class MilvusKBService(KBService):
milvus: Milvus
@staticmethod
def get_collection(milvus_name):
from pymilvus import Collection
return Collection(milvus_name)
def search(milvus_name, content, limit=3):
@staticmethod
def search(milvus_name, content, limit=3):
search_params = {
"metric_type": "L2",
"params": {"nprobe": 10},
}
c = get_collection(milvus_name)
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["random"])
c = MilvusKBService.get_collection(milvus_name)
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["content"])
def do_create_kb(self):
pass
class MilvusKBService():
milvus_host: str
milvus_port: int
dim: int
def vs_type(self) -> str:
return SupportedVSType.MILVUS
def __init__(self, knowledge_base_name: str, vector_store_type: str, milvus_host="localhost", milvus_port=19530,
dim=8):
def _load_milvus(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None):
_embeddings = embeddings
if _embeddings is None:
_embeddings = load_embeddings(self.embed_model, embedding_device)
self.milvus = Milvus(embedding_function=_embeddings,
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
super().__init__(knowledge_base_name, vector_store_type)
self.milvus_host = milvus_host
self.milvus_port = milvus_port
self.dim = dim
def do_init(self):
self._load_milvus()
def connect(self):
connections.connect("default", host=self.milvus_host, port=self.milvus_port)
def do_drop_kb(self):
self.milvus.col.drop()
def create_collection(self, milvus_name):
fields = [
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=False),
FieldSchema(name="content", dtype=DataType.STRING),
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=self.dim)
]
schema = CollectionSchema(fields)
collection = Collection(milvus_name, schema)
index = {
"index_type": "IVF_FLAT",
"metric_type": "L2",
"params": {"nlist": 128},
}
collection.create_index("embeddings", index)
collection.load()
return collection
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
self._load_milvus(embeddings=embeddings)
return self.milvus.similarity_search(query, top_k)
def insert_collection(self, milvus_name, content=[]):
get_collection(milvus_name).insert(dataset)
def add_doc(self, kb_file: KnowledgeFile):
"""
向知识库添加文件
"""
docs = kb_file.file2text()
self.milvus.add_documents(docs)
status = add_doc_to_db(kb_file)
return status
def do_add_doc(self, docs: List[Document], embeddings: Embeddings):
pass
def do_delete_doc(self, kb_file: KnowledgeFile):
filepath = kb_file.filepath.replace('\\', '\\\\')
delete_list = [item.get("pk") for item in
self.milvus.col.query(expr=f'source == "{filepath}"', output_fields=["pk"])]
self.milvus.col.delete(expr=f'pk in {delete_list}')
def do_clear_vs(self):
self.milvus.col.drop()
if __name__ == '__main__':
milvusService = MilvusService(milvus_host='192.168.50.128')
milvusService.insert_collection(test,dataset)
milvusService = MilvusKBService("test")
# milvusService.add_doc(KnowledgeFile("test.pdf", "test"))
# milvusService.delete_doc(KnowledgeFile("test.pdf", "test"))
milvusService.do_drop_kb()
print(milvusService.search_docs("测试"))

View File

@ -1,5 +1,6 @@
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
class KBServiceFactory:
@ -11,6 +12,9 @@ class KBServiceFactory:
if SupportedVSType.FAISS == vector_store_type:
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
return FaissKBService(kb_name)
elif SupportedVSType.MILVUS == vector_store_type:
from server.knowledge_base.kb_service.milvus_kb_service import MilvusKBService
return MilvusKBService(kb_name)
elif SupportedVSType.DEFAULT == vector_store_type:
return DefaultKBService(kb_name)
@ -28,9 +32,9 @@ class KBServiceFactory:
if __name__ == '__main__':
KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS)
init_db()
KBService.create_kbs()
KBService.create_kb()
KBService = KBServiceFactory.get_default()
print(KBService.list_kbs())
KBService = KBServiceFactory.get_service_by_name("test")
print(KBService.list_docs())
KBService.drop_kbs()
KBService.drop_kb()