add KBService and KBServiceFactory class
This commit is contained in:
parent
06af3f4c5e
commit
18d31f5116
|
|
@ -277,3 +277,13 @@ BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search"
|
|||
# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out
|
||||
# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG
|
||||
BING_SUBSCRIPTION_KEY = ""
|
||||
|
||||
kbs_config = {
|
||||
"faiss": {
|
||||
|
||||
},
|
||||
"milvus": {
|
||||
"milvus_host": "192.168.50.128",
|
||||
"milvus_port": 19530
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,335 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
import os
|
||||
import sqlite3
|
||||
from functools import lru_cache
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
from configs.model_config import (DB_ROOT_PATH, kbs_config, VECTOR_SEARCH_TOP_K,
|
||||
embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
||||
import datetime
|
||||
from server.knowledge_base.utils import (get_kb_path, get_doc_path)
|
||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||
from typing import List
|
||||
|
||||
|
||||
class SupportedVSType:
|
||||
FAISS = 'faiss'
|
||||
MILVUS = 'milvus'
|
||||
DEFAULT = 'default'
|
||||
|
||||
|
||||
def init_db():
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute('''CREATE TABLE if not exists knowledge_base
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
kb_name TEXT,
|
||||
vs_type TEXT,
|
||||
embed_model TEXT,
|
||||
file_count INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
c.execute('''CREATE TABLE if not exists knowledge_files
|
||||
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
file_name TEXT,
|
||||
file_ext TEXT,
|
||||
kb_name TEXT,
|
||||
document_loader_name TEXT,
|
||||
text_splitter_name TEXT,
|
||||
file_version INTEGER,
|
||||
create_time DATETIME) ''')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def list_kbs_from_db():
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute(f'''SELECT kb_name
|
||||
FROM knowledge_base
|
||||
WHERE file_count>0 ''')
|
||||
kbs = [i[0] for i in c.fetchall() if i]
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return kbs
|
||||
|
||||
|
||||
def add_kb_to_db(kb_name, vs_type, embed_model):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
# Insert a row of data
|
||||
c.execute(f"""INSERT INTO knowledge_base
|
||||
(kb_name, vs_type, embed_model, file_count, create_time)
|
||||
VALUES
|
||||
('{kb_name}','{vs_type}','{embed_model}',
|
||||
0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return True
|
||||
|
||||
|
||||
def kb_exists(kb_name):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute(f'''SELECT COUNT(*)
|
||||
FROM knowledge_base
|
||||
WHERE kb_name="{kb_name}" ''')
|
||||
status = True if c.fetchone()[0] else False
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return status
|
||||
|
||||
|
||||
def load_kb_from_db(kb_name):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute(f'''SELECT kb_name, vs_type, embed_model
|
||||
FROM knowledge_base
|
||||
WHERE kb_name="{kb_name}" ''')
|
||||
resp = c.fetchone()
|
||||
if resp:
|
||||
kb_name, vs_type, embed_model = resp
|
||||
else:
|
||||
kb_name, vs_type, embed_model = None, None, None
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return kb_name, vs_type, embed_model
|
||||
|
||||
|
||||
def delete_kb_from_db(kb_name):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute(f'''DELETE
|
||||
FROM knowledge_base
|
||||
WHERE kb_name="{kb_name}" ''')
|
||||
c.execute(f"""DELETE
|
||||
FROM knowledge_files
|
||||
WHERE kb_name="{kb_name}"
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return True
|
||||
|
||||
|
||||
def list_docs_from_db(kb_name):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute(f'''SELECT file_name
|
||||
FROM knowledge_files
|
||||
WHERE kb_name="{kb_name}" ''')
|
||||
kbs = [i[0] for i in c.fetchall() if i]
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return kbs
|
||||
|
||||
|
||||
def add_doc_to_db(kb_file: KnowledgeFile):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
# Insert a row of data
|
||||
c.execute(
|
||||
f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """)
|
||||
record_exist = c.fetchone()
|
||||
if record_exist is not None:
|
||||
c.execute(f"""UPDATE knowledge_files
|
||||
SET file_version = file_version + 1
|
||||
WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}"
|
||||
""")
|
||||
else:
|
||||
c.execute(f"""INSERT INTO knowledge_files
|
||||
(file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time)
|
||||
VALUES
|
||||
('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}',
|
||||
'{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
||||
c.execute(f"""UPDATE knowledge_base
|
||||
SET file_count = file_count + 1
|
||||
WHERE kb_name="{kb_file.kb_name}"
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return True
|
||||
|
||||
|
||||
def delete_file_from_db(kb_file: KnowledgeFile):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
# Insert a row of data
|
||||
c.execute(f"""DELETE
|
||||
FROM knowledge_files
|
||||
WHERE file_name="{kb_file.filename}"
|
||||
AND kb_name="{kb_file.kb_name}"
|
||||
""")
|
||||
c.execute(f"""UPDATE knowledge_base
|
||||
SET file_count = file_count - 1
|
||||
WHERE kb_name="{kb_file.kb_name}"
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return True
|
||||
|
||||
|
||||
def doc_exists(kb_file: KnowledgeFile):
|
||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||
c = conn.cursor()
|
||||
c.execute(f'''SELECT COUNT(*)
|
||||
FROM knowledge_files
|
||||
WHERE file_name="{kb_file.filename}"
|
||||
AND kb_name="{kb_file.kb_name}" ''')
|
||||
status = True if c.fetchone()[0] else False
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return status
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||
model_kwargs={'device': device})
|
||||
return embeddings
|
||||
|
||||
|
||||
class KBService(ABC):
|
||||
|
||||
def __init__(self,
|
||||
knowledge_base_name: str,
|
||||
vector_store_type: str = "faiss",
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
):
|
||||
self.kb_name = knowledge_base_name
|
||||
self.vs_type = vector_store_type
|
||||
self.embed_model = embed_model
|
||||
self.kb_path = get_kb_path(self.kb_name)
|
||||
self.doc_path = get_doc_path(self.kb_name)
|
||||
self.do_init()
|
||||
|
||||
def create_kb(self):
|
||||
"""
|
||||
创建知识库
|
||||
"""
|
||||
if not os.path.exists(self.doc_path):
|
||||
os.makedirs(self.doc_path)
|
||||
self.do_create_kb()
|
||||
status = add_kb_to_db(self.kb_name, self.vs_type, self.embed_model)
|
||||
return status
|
||||
|
||||
def clear_vs(self):
|
||||
"""
|
||||
用知识库中已上传文件重建向量库
|
||||
"""
|
||||
self.do_clear_vs()
|
||||
|
||||
def drop_kb(self):
|
||||
"""
|
||||
删除知识库
|
||||
"""
|
||||
self.do_remove_kb()
|
||||
status = delete_kb_from_db(self.kb_name)
|
||||
return status
|
||||
|
||||
def add_doc(self, kb_file: KnowledgeFile):
|
||||
"""
|
||||
向知识库添加文件
|
||||
"""
|
||||
docs = kb_file.file2text()
|
||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||
self.do_add_doc(docs, embeddings)
|
||||
status = add_doc_to_db(kb_file)
|
||||
return status
|
||||
|
||||
def delete_doc(self, kb_file: KnowledgeFile):
|
||||
"""
|
||||
从知识库删除文件
|
||||
"""
|
||||
if os.path.exists(kb_file.filepath):
|
||||
os.remove(kb_file.filepath)
|
||||
self.do_delete(kb_file)
|
||||
status = delete_file_from_db(kb_file)
|
||||
return status
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||
filename=file_name))
|
||||
|
||||
def list_docs(self):
|
||||
return list_docs_from_db(self.kb_name)
|
||||
|
||||
def search_docs(self,
|
||||
query: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
embedding_device: str = EMBEDDING_DEVICE, ):
|
||||
embeddings = load_embeddings(self.embed_model, embedding_device)
|
||||
docs = self.do_search(query, top_k, embeddings)
|
||||
return docs
|
||||
|
||||
@abstractmethod
|
||||
def do_create_kb(self):
|
||||
"""
|
||||
创建知识库子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def list_kbs_type():
|
||||
return list(kbs_config.keys())
|
||||
|
||||
@classmethod
|
||||
def list_kbs(cls):
|
||||
return list_kbs_from_db()
|
||||
|
||||
@classmethod
|
||||
def exists(cls,
|
||||
knowledge_base_name: str):
|
||||
return kb_exists(knowledge_base_name)
|
||||
|
||||
@abstractmethod
|
||||
def vs_type(self) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_init(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_remove_kb(self):
|
||||
"""
|
||||
删除知识库子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_search(self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
embeddings: Embeddings,
|
||||
) -> List[Document]:
|
||||
"""
|
||||
搜索知识库子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
embeddings: Embeddings):
|
||||
"""
|
||||
向知识库添加文档子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_delete(self,
|
||||
kb_file: KnowledgeFile):
|
||||
"""
|
||||
从知识库删除文档子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def do_clear_vs(self):
|
||||
"""
|
||||
从知识库删除全部向量子类实自己逻辑
|
||||
"""
|
||||
pass
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
from server.knowledge_base.kb_service.base import KBService
|
||||
|
||||
|
||||
class DefaultKBService(KBService):
|
||||
def vs_type(self) -> str:
|
||||
return "default"
|
||||
|
||||
def do_create_kbs(self):
|
||||
pass
|
||||
|
||||
def do_init(self):
|
||||
pass
|
||||
|
||||
def do_remove_kbs(self):
|
||||
pass
|
||||
|
||||
def do_search(self):
|
||||
pass
|
||||
|
||||
def do_insert_multi_knowledge(self):
|
||||
pass
|
||||
|
||||
def do_insert_one_knowledge(self):
|
||||
pass
|
||||
|
||||
def do_delete(self):
|
||||
pass
|
||||
|
|
@ -0,0 +1,125 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_DEVICE
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings
|
||||
from functools import lru_cache
|
||||
from server.knowledge_base.utils import get_vs_path
|
||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from typing import List
|
||||
from langchain.docstore.document import Document
|
||||
from server.utils import torch_gc
|
||||
import numpy as np
|
||||
|
||||
_VECTOR_STORE_TICKS = {}
|
||||
|
||||
|
||||
@lru_cache(CACHED_VS_NUM)
|
||||
def load_vector_store(
|
||||
knowledge_base_name: str,
|
||||
embeddings: Embeddings,
|
||||
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||
):
|
||||
print(f"loading vector store in '{knowledge_base_name}'.")
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
search_index = FAISS.load_local(vs_path, embeddings)
|
||||
return search_index
|
||||
|
||||
|
||||
def refresh_vs_cache(kb_name: str):
|
||||
"""
|
||||
make vector store cache refreshed when next loading
|
||||
"""
|
||||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
||||
|
||||
|
||||
def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]):
|
||||
overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values())
|
||||
if not overlapping:
|
||||
raise ValueError("ids do not exist in the current object")
|
||||
_reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()}
|
||||
index_to_delete = [_reversed_index[i] for i in ids]
|
||||
vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
|
||||
for _id in index_to_delete:
|
||||
del vector_store.index_to_docstore_id[_id]
|
||||
# Remove items from docstore.
|
||||
overlapping2 = set(ids).intersection(vector_store.docstore._dict)
|
||||
if not overlapping2:
|
||||
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
|
||||
for _id in ids:
|
||||
vector_store.docstore._dict.pop(_id)
|
||||
return vector_store
|
||||
|
||||
|
||||
class FaissKBService(KBService):
|
||||
vs_path: str
|
||||
kb_path: str
|
||||
|
||||
def vs_type(self) -> str:
|
||||
return SupportedVSType.FAISS
|
||||
|
||||
@staticmethod
|
||||
def get_vs_path(knowledge_base_name: str):
|
||||
return os.path.join(FaissKBService.get_kb_path(knowledge_base_name), "vector_store")
|
||||
|
||||
@staticmethod
|
||||
def get_kb_path(knowledge_base_name: str):
|
||||
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
||||
|
||||
def do_init(self):
|
||||
self.kb_path = FaissKBService.get_kb_path(self.kb_name)
|
||||
self.vs_path = FaissKBService.get_vs_path(self.kb_name)
|
||||
|
||||
def do_create_kb(self):
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
|
||||
def do_remove_kb(self):
|
||||
shutil.rmtree(self.kb_path)
|
||||
|
||||
def do_search(self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
embeddings: Embeddings,
|
||||
) -> List[Document]:
|
||||
search_index = load_vector_store(self.kb_name,
|
||||
embeddings,
|
||||
_VECTOR_STORE_TICKS.get(self.kb_name))
|
||||
docs = search_index.similarity_search(query, k=top_k)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
embeddings: Embeddings):
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
else:
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
|
||||
torch_gc()
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
def do_delete(self,
|
||||
kb_file: KnowledgeFile):
|
||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
print(len(ids))
|
||||
vector_store = delete_doc_from_faiss(vector_store, ids)
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
return True
|
||||
else:
|
||||
return None
|
||||
|
||||
def do_clear_vs(self):
|
||||
shutil.rmtree(self.vs_path)
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
from pymilvus import (
|
||||
connections,
|
||||
utility,
|
||||
FieldSchema,
|
||||
CollectionSchema,
|
||||
DataType,
|
||||
Collection,
|
||||
)
|
||||
|
||||
from server.knowledge_base.kb_service.base import KBService
|
||||
|
||||
|
||||
def get_collection(milvus_name):
|
||||
return Collection(milvus_name)
|
||||
|
||||
|
||||
def search(milvus_name, content, limit=3):
|
||||
search_params = {
|
||||
"metric_type": "L2",
|
||||
"params": {"nprobe": 10},
|
||||
}
|
||||
c = get_collection(milvus_name)
|
||||
return c.search(content, "embeddings", search_params, limit=limit, output_fields=["random"])
|
||||
|
||||
|
||||
class MilvusKBService():
|
||||
milvus_host: str
|
||||
milvus_port: int
|
||||
dim: int
|
||||
|
||||
def __init__(self, knowledge_base_name: str, vector_store_type: str, milvus_host="localhost", milvus_port=19530,
|
||||
dim=8):
|
||||
|
||||
super().__init__(knowledge_base_name, vector_store_type)
|
||||
self.milvus_host = milvus_host
|
||||
self.milvus_port = milvus_port
|
||||
self.dim = dim
|
||||
|
||||
def connect(self):
|
||||
connections.connect("default", host=self.milvus_host, port=self.milvus_port)
|
||||
|
||||
def create_collection(self, milvus_name):
|
||||
fields = [
|
||||
FieldSchema(name="pk", dtype=DataType.INT64, is_primary=True, auto_id=False),
|
||||
FieldSchema(name="content", dtype=DataType.STRING),
|
||||
FieldSchema(name="embeddings", dtype=DataType.FLOAT_VECTOR, dim=self.dim)
|
||||
]
|
||||
schema = CollectionSchema(fields)
|
||||
collection = Collection(milvus_name, schema)
|
||||
index = {
|
||||
"index_type": "IVF_FLAT",
|
||||
"metric_type": "L2",
|
||||
"params": {"nlist": 128},
|
||||
}
|
||||
collection.create_index("embeddings", index)
|
||||
collection.load()
|
||||
return collection
|
||||
|
||||
def insert_collection(self, milvus_name, content=[]):
|
||||
get_collection(milvus_name).insert(dataset)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
milvusService = MilvusService(milvus_host='192.168.50.128')
|
||||
milvusService.insert_collection(test,dataset)
|
||||
|
|
@ -5,33 +5,18 @@ import shutil
|
|||
from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE,
|
||||
KB_ROOT_PATH, DB_ROOT_PATH, VECTOR_SEARCH_TOP_K, CACHED_VS_NUM)
|
||||
DB_ROOT_PATH, VECTOR_SEARCH_TOP_K, CACHED_VS_NUM)
|
||||
from server.utils import torch_gc
|
||||
from functools import lru_cache
|
||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||
from typing import List
|
||||
import numpy as np
|
||||
from server.knowledge_base.utils import (get_kb_path, get_doc_path, get_vs_path)
|
||||
|
||||
SUPPORTED_VS_TYPES = ["faiss", "milvus"]
|
||||
|
||||
_VECTOR_STORE_TICKS = {}
|
||||
|
||||
|
||||
def get_kb_path(knowledge_base_name: str):
|
||||
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
||||
|
||||
|
||||
def get_doc_path(knowledge_base_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
||||
|
||||
|
||||
def get_vs_path(knowledge_base_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
||||
|
||||
|
||||
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, init_db, load_kb_from_db
|
||||
from server.knowledge_base.kb_service.default_kb_service import DefaultKBService
|
||||
|
||||
|
||||
class KBServiceFactory:
|
||||
|
||||
@staticmethod
|
||||
def get_service(kb_name: str,
|
||||
vector_store_type: SupportedVSType
|
||||
) -> KBService:
|
||||
if SupportedVSType.FAISS == vector_store_type:
|
||||
from server.knowledge_base.kb_service.faiss_kb_service import FaissKBService
|
||||
return FaissKBService(kb_name)
|
||||
elif SupportedVSType.DEFAULT == vector_store_type:
|
||||
return DefaultKBService(kb_name)
|
||||
|
||||
@staticmethod
|
||||
def get_service_by_name(kb_name: str
|
||||
) -> KBService:
|
||||
kb_name, vs_type = load_kb_from_db(kb_name)
|
||||
return KBServiceFactory.get_service(kb_name, vs_type)
|
||||
|
||||
@staticmethod
|
||||
def get_default():
|
||||
return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
KBService = KBServiceFactory.get_service("test", SupportedVSType.FAISS)
|
||||
init_db()
|
||||
KBService.create_kbs()
|
||||
KBService = KBServiceFactory.get_default()
|
||||
print(KBService.list_kbs())
|
||||
KBService = KBServiceFactory.get_service_by_name("test")
|
||||
print(KBService.list_docs())
|
||||
KBService.drop_kbs()
|
||||
|
|
@ -1,5 +1,20 @@
|
|||
import os.path
|
||||
from configs.model_config import KB_ROOT_PATH
|
||||
|
||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||
# 检查是否包含预期外的字符或路径攻击关键字
|
||||
if "../" in knowledge_base_id:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_kb_path(knowledge_base_name: str):
|
||||
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
||||
|
||||
def get_doc_path(knowledge_base_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
||||
|
||||
def get_vs_path(knowledge_base_name: str):
|
||||
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
||||
|
||||
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
||||
Loading…
Reference in New Issue