336 lines
9.7 KiB
Python
336 lines
9.7 KiB
Python
|
|
from abc import ABC, abstractmethod
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sqlite3
|
||
|
|
from functools import lru_cache
|
||
|
|
|
||
|
|
from langchain.embeddings import HuggingFaceEmbeddings
|
||
|
|
from langchain.embeddings.base import Embeddings
|
||
|
|
from langchain.docstore.document import Document
|
||
|
|
|
||
|
|
from configs.model_config import (DB_ROOT_PATH, kbs_config, VECTOR_SEARCH_TOP_K,
|
||
|
|
embedding_model_dict, EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
||
|
|
import datetime
|
||
|
|
from server.knowledge_base.utils import (get_kb_path, get_doc_path)
|
||
|
|
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||
|
|
from typing import List
|
||
|
|
|
||
|
|
|
||
|
|
class SupportedVSType:
|
||
|
|
FAISS = 'faiss'
|
||
|
|
MILVUS = 'milvus'
|
||
|
|
DEFAULT = 'default'
|
||
|
|
|
||
|
|
|
||
|
|
def init_db():
|
||
|
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||
|
|
c = conn.cursor()
|
||
|
|
c.execute('''CREATE TABLE if not exists knowledge_base
|
||
|
|
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
|
|
kb_name TEXT,
|
||
|
|
vs_type TEXT,
|
||
|
|
embed_model TEXT,
|
||
|
|
file_count INTEGER,
|
||
|
|
create_time DATETIME) ''')
|
||
|
|
c.execute('''CREATE TABLE if not exists knowledge_files
|
||
|
|
(id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
|
|
file_name TEXT,
|
||
|
|
file_ext TEXT,
|
||
|
|
kb_name TEXT,
|
||
|
|
document_loader_name TEXT,
|
||
|
|
text_splitter_name TEXT,
|
||
|
|
file_version INTEGER,
|
||
|
|
create_time DATETIME) ''')
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
|
||
|
|
|
||
|
|
def list_kbs_from_db():
|
||
|
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||
|
|
c = conn.cursor()
|
||
|
|
c.execute(f'''SELECT kb_name
|
||
|
|
FROM knowledge_base
|
||
|
|
WHERE file_count>0 ''')
|
||
|
|
kbs = [i[0] for i in c.fetchall() if i]
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
return kbs
|
||
|
|
|
||
|
|
|
||
|
|
def add_kb_to_db(kb_name, vs_type, embed_model):
|
||
|
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||
|
|
c = conn.cursor()
|
||
|
|
# Insert a row of data
|
||
|
|
c.execute(f"""INSERT INTO knowledge_base
|
||
|
|
(kb_name, vs_type, embed_model, file_count, create_time)
|
||
|
|
VALUES
|
||
|
|
('{kb_name}','{vs_type}','{embed_model}',
|
||
|
|
0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
def kb_exists(kb_name):
|
||
|
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||
|
|
c = conn.cursor()
|
||
|
|
c.execute(f'''SELECT COUNT(*)
|
||
|
|
FROM knowledge_base
|
||
|
|
WHERE kb_name="{kb_name}" ''')
|
||
|
|
status = True if c.fetchone()[0] else False
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
return status
|
||
|
|
|
||
|
|
|
||
|
|
def load_kb_from_db(kb_name):
|
||
|
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||
|
|
c = conn.cursor()
|
||
|
|
c.execute(f'''SELECT kb_name, vs_type, embed_model
|
||
|
|
FROM knowledge_base
|
||
|
|
WHERE kb_name="{kb_name}" ''')
|
||
|
|
resp = c.fetchone()
|
||
|
|
if resp:
|
||
|
|
kb_name, vs_type, embed_model = resp
|
||
|
|
else:
|
||
|
|
kb_name, vs_type, embed_model = None, None, None
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
return kb_name, vs_type, embed_model
|
||
|
|
|
||
|
|
|
||
|
|
def delete_kb_from_db(kb_name):
|
||
|
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||
|
|
c = conn.cursor()
|
||
|
|
c.execute(f'''DELETE
|
||
|
|
FROM knowledge_base
|
||
|
|
WHERE kb_name="{kb_name}" ''')
|
||
|
|
c.execute(f"""DELETE
|
||
|
|
FROM knowledge_files
|
||
|
|
WHERE kb_name="{kb_name}"
|
||
|
|
""")
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
def list_docs_from_db(kb_name):
|
||
|
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||
|
|
c = conn.cursor()
|
||
|
|
c.execute(f'''SELECT file_name
|
||
|
|
FROM knowledge_files
|
||
|
|
WHERE kb_name="{kb_name}" ''')
|
||
|
|
kbs = [i[0] for i in c.fetchall() if i]
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
return kbs
|
||
|
|
|
||
|
|
|
||
|
|
def add_doc_to_db(kb_file: KnowledgeFile):
|
||
|
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||
|
|
c = conn.cursor()
|
||
|
|
# Insert a row of data
|
||
|
|
c.execute(
|
||
|
|
f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """)
|
||
|
|
record_exist = c.fetchone()
|
||
|
|
if record_exist is not None:
|
||
|
|
c.execute(f"""UPDATE knowledge_files
|
||
|
|
SET file_version = file_version + 1
|
||
|
|
WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}"
|
||
|
|
""")
|
||
|
|
else:
|
||
|
|
c.execute(f"""INSERT INTO knowledge_files
|
||
|
|
(file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time)
|
||
|
|
VALUES
|
||
|
|
('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}',
|
||
|
|
'{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
||
|
|
c.execute(f"""UPDATE knowledge_base
|
||
|
|
SET file_count = file_count + 1
|
||
|
|
WHERE kb_name="{kb_file.kb_name}"
|
||
|
|
""")
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
def delete_file_from_db(kb_file: KnowledgeFile):
|
||
|
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||
|
|
c = conn.cursor()
|
||
|
|
# Insert a row of data
|
||
|
|
c.execute(f"""DELETE
|
||
|
|
FROM knowledge_files
|
||
|
|
WHERE file_name="{kb_file.filename}"
|
||
|
|
AND kb_name="{kb_file.kb_name}"
|
||
|
|
""")
|
||
|
|
c.execute(f"""UPDATE knowledge_base
|
||
|
|
SET file_count = file_count - 1
|
||
|
|
WHERE kb_name="{kb_file.kb_name}"
|
||
|
|
""")
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
return True
|
||
|
|
|
||
|
|
|
||
|
|
def doc_exists(kb_file: KnowledgeFile):
|
||
|
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||
|
|
c = conn.cursor()
|
||
|
|
c.execute(f'''SELECT COUNT(*)
|
||
|
|
FROM knowledge_files
|
||
|
|
WHERE file_name="{kb_file.filename}"
|
||
|
|
AND kb_name="{kb_file.kb_name}" ''')
|
||
|
|
status = True if c.fetchone()[0] else False
|
||
|
|
conn.commit()
|
||
|
|
conn.close()
|
||
|
|
return status
|
||
|
|
|
||
|
|
|
||
|
|
@lru_cache(1)
|
||
|
|
def load_embeddings(model: str, device: str):
|
||
|
|
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||
|
|
model_kwargs={'device': device})
|
||
|
|
return embeddings
|
||
|
|
|
||
|
|
|
||
|
|
class KBService(ABC):
|
||
|
|
|
||
|
|
def __init__(self,
|
||
|
|
knowledge_base_name: str,
|
||
|
|
vector_store_type: str = "faiss",
|
||
|
|
embed_model: str = EMBEDDING_MODEL,
|
||
|
|
):
|
||
|
|
self.kb_name = knowledge_base_name
|
||
|
|
self.vs_type = vector_store_type
|
||
|
|
self.embed_model = embed_model
|
||
|
|
self.kb_path = get_kb_path(self.kb_name)
|
||
|
|
self.doc_path = get_doc_path(self.kb_name)
|
||
|
|
self.do_init()
|
||
|
|
|
||
|
|
def create_kb(self):
|
||
|
|
"""
|
||
|
|
创建知识库
|
||
|
|
"""
|
||
|
|
if not os.path.exists(self.doc_path):
|
||
|
|
os.makedirs(self.doc_path)
|
||
|
|
self.do_create_kb()
|
||
|
|
status = add_kb_to_db(self.kb_name, self.vs_type, self.embed_model)
|
||
|
|
return status
|
||
|
|
|
||
|
|
def clear_vs(self):
|
||
|
|
"""
|
||
|
|
用知识库中已上传文件重建向量库
|
||
|
|
"""
|
||
|
|
self.do_clear_vs()
|
||
|
|
|
||
|
|
def drop_kb(self):
|
||
|
|
"""
|
||
|
|
删除知识库
|
||
|
|
"""
|
||
|
|
self.do_remove_kb()
|
||
|
|
status = delete_kb_from_db(self.kb_name)
|
||
|
|
return status
|
||
|
|
|
||
|
|
def add_doc(self, kb_file: KnowledgeFile):
|
||
|
|
"""
|
||
|
|
向知识库添加文件
|
||
|
|
"""
|
||
|
|
docs = kb_file.file2text()
|
||
|
|
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||
|
|
self.do_add_doc(docs, embeddings)
|
||
|
|
status = add_doc_to_db(kb_file)
|
||
|
|
return status
|
||
|
|
|
||
|
|
def delete_doc(self, kb_file: KnowledgeFile):
|
||
|
|
"""
|
||
|
|
从知识库删除文件
|
||
|
|
"""
|
||
|
|
if os.path.exists(kb_file.filepath):
|
||
|
|
os.remove(kb_file.filepath)
|
||
|
|
self.do_delete(kb_file)
|
||
|
|
status = delete_file_from_db(kb_file)
|
||
|
|
return status
|
||
|
|
|
||
|
|
def exist_doc(self, file_name: str):
|
||
|
|
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||
|
|
filename=file_name))
|
||
|
|
|
||
|
|
def list_docs(self):
|
||
|
|
return list_docs_from_db(self.kb_name)
|
||
|
|
|
||
|
|
def search_docs(self,
|
||
|
|
query: str,
|
||
|
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||
|
|
embedding_device: str = EMBEDDING_DEVICE, ):
|
||
|
|
embeddings = load_embeddings(self.embed_model, embedding_device)
|
||
|
|
docs = self.do_search(query, top_k, embeddings)
|
||
|
|
return docs
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def do_create_kb(self):
|
||
|
|
"""
|
||
|
|
创建知识库子类实自己逻辑
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def list_kbs_type():
|
||
|
|
return list(kbs_config.keys())
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def list_kbs(cls):
|
||
|
|
return list_kbs_from_db()
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def exists(cls,
|
||
|
|
knowledge_base_name: str):
|
||
|
|
return kb_exists(knowledge_base_name)
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def vs_type(self) -> str:
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def do_init(self):
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def do_remove_kb(self):
|
||
|
|
"""
|
||
|
|
删除知识库子类实自己逻辑
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def do_search(self,
|
||
|
|
query: str,
|
||
|
|
top_k: int,
|
||
|
|
embeddings: Embeddings,
|
||
|
|
) -> List[Document]:
|
||
|
|
"""
|
||
|
|
搜索知识库子类实自己逻辑
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def do_add_doc(self,
|
||
|
|
docs: List[Document],
|
||
|
|
embeddings: Embeddings):
|
||
|
|
"""
|
||
|
|
向知识库添加文档子类实自己逻辑
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def do_delete(self,
|
||
|
|
kb_file: KnowledgeFile):
|
||
|
|
"""
|
||
|
|
从知识库删除文档子类实自己逻辑
|
||
|
|
"""
|
||
|
|
pass
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def do_clear_vs(self):
|
||
|
|
"""
|
||
|
|
从知识库删除全部向量子类实自己逻辑
|
||
|
|
"""
|
||
|
|
pass
|