diff --git a/server/chat/chat.py b/server/chat/chat.py index 41e2528..3239cda 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,7 +1,7 @@ from fastapi import Body from fastapi.responses import StreamingResponse from configs.model_config import llm_model_dict, LLM_MODEL -from .utils import wrap_done +from server.chat.utils import wrap_done from langchain.chat_models import ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 68d99e4..4eaf3fd 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -4,15 +4,13 @@ from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, VECTOR_SEARCH_TOP_K) from server.chat.utils import wrap_done from server.utils import BaseResponse -import os -from server.knowledge_base.utils import get_kb_path from langchain.chat_models import ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio from langchain.prompts import PromptTemplate -from server.knowledge_base.utils import lookup_vs +from server.knowledge_base.knowledge_base import KnowledgeBase import json @@ -20,12 +18,12 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp knowledge_base_name: str = Body(..., description="知识库名称", example="samples"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), ): - kb_path = get_kb_path(knowledge_base_name) - if not os.path.exists(kb_path): + if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) async def knowledge_base_chat_iterator(query: str, - knowledge_base_name: str, + kb: KnowledgeBase, top_k: int, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() @@ -37,7 +35,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], model_name=LLM_MODEL ) - docs = lookup_vs(query, knowledge_base_name, top_k) + docs = kb.search_docs(query, top_k) context = "\n".join([doc.page_content for doc in docs]) prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"]) @@ -60,5 +58,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp "docs": source_documents}) await task - return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k), + return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k), media_type="text/event-stream") diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index d116c58..62de2d9 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -1,10 +1,8 @@ import os import urllib -import shutil from fastapi import File, Form, UploadFile from server.utils import BaseResponse, ListResponse -from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path, - get_file_path, refresh_vs_cache, get_vs_path) +from server.knowledge_base.utils import (validate_kb_name) from fastapi.responses import StreamingResponse import json from server.knowledge_base.knowledge_file import KnowledgeFile @@ -16,8 +14,7 @@ async def list_docs(knowledge_base_name: str): return ListResponse(code=403, msg="Don't attack me", data=[]) knowledge_base_name = urllib.parse.unquote(knowledge_base_name) - kb_path = get_kb_path(knowledge_base_name) - if not os.path.exists(kb_path): + if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name): return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) else: all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs() @@ -42,9 +39,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), knowledge_base_name=knowledge_base_name) if (os.path.exists(kb_file.filepath) - and not override - and os.path.getsize(kb_file.filepath) == len(file_content) + and not override + and os.path.getsize(kb_file.filepath) == len(file_content) ): + # TODO: filesize 不同后的处理 file_status = f"文件 {kb_file.filename} 已存在。" return BaseResponse(code=404, msg=file_status) @@ -83,6 +81,7 @@ async def update_doc(): # refresh_vs_cache(knowledge_base_name) pass + async def download_doc(): # TODO: 下载文件 pass @@ -93,24 +92,21 @@ async def recreate_vector_store(knowledge_base_name: str): recreate vector store from the content. this is usefull when user can copy files to content folder directly instead of upload through network. ''' - async def output(kb_name): - vs_path = get_vs_path(kb_name) - if os.path.isdir(vs_path): - shutil.rmtree(vs_path) - os.mkdir(vs_path) - print(f"start to recreate vectore in {vs_path}") + kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name) - docs = (await list_docs(kb_name)).data + async def output(kb: KnowledgeBase): + kb.recreate_vs() + print(f"start to recreate vector store of {kb.kb_name}") + docs = kb.list_docs() for i, filename in enumerate(docs): kb_file = KnowledgeFile(filename=filename, - knowledge_base_name=kb_name) + knowledge_base_name=kb.kb_name) print(f"processing {kb_file.filepath} to vector store.") - kb = KnowledgeBase.load(knowledge_base_name=kb_name) kb.add_doc(kb_file) yield json.dumps({ "total": len(docs), "finished": i + 1, "doc": filename, }) - - return StreamingResponse(output(knowledge_base_name), media_type="text/event-stream") + + return StreamingResponse(output(kb), media_type="text/event-stream") diff --git a/server/knowledge_base/knowledge_base.py b/server/knowledge_base/knowledge_base.py index de7f661..b4d5311 100644 --- a/server/knowledge_base/knowledge_base.py +++ b/server/knowledge_base/knowledge_base.py @@ -3,15 +3,62 @@ import sqlite3 import datetime import shutil from langchain.vectorstores import FAISS -from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path, - refresh_vs_cache, load_embeddings) -from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, - EMBEDDING_DEVICE, DB_ROOT_PATH) +from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE, + KB_ROOT_PATH, DB_ROOT_PATH, VECTOR_SEARCH_TOP_K, CACHED_VS_NUM) from server.utils import torch_gc +from functools import lru_cache from server.knowledge_base.knowledge_file import KnowledgeFile +from typing import List +import numpy as np SUPPORTED_VS_TYPES = ["faiss", "milvus"] +_VECTOR_STORE_TICKS = {} + + +def get_kb_path(knowledge_base_name: str): + return os.path.join(KB_ROOT_PATH, knowledge_base_name) + + +def get_doc_path(knowledge_base_name: str): + return os.path.join(get_kb_path(knowledge_base_name), "content") + + +def get_vs_path(knowledge_base_name: str): + return os.path.join(get_kb_path(knowledge_base_name), "vector_store") + + +def get_file_path(knowledge_base_name: str, doc_name: str): + return os.path.join(get_doc_path(knowledge_base_name), doc_name) + +@lru_cache(1) +def load_embeddings(model: str, device: str): + embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], + model_kwargs={'device': device}) + return embeddings + + +@lru_cache(CACHED_VS_NUM) +def load_vector_store( + knowledge_base_name: str, + embedding_model: str, + embedding_device: str, + tick: int, # tick will be changed by upload_doc etc. and make cache refreshed. +): + print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.") + embeddings = load_embeddings(embedding_model, embedding_device) + vs_path = get_vs_path(knowledge_base_name) + search_index = FAISS.load_local(vs_path, embeddings) + return search_index + + +def refresh_vs_cache(kb_name: str): + """ + make vector store cache refreshed when next loading + """ + _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 + def list_kbs_from_db(): conn = sqlite3.connect(DB_ROOT_PATH) @@ -149,6 +196,7 @@ def list_docs_from_db(kb_name): conn.close() return kbs + def add_doc_to_db(kb_file: KnowledgeFile): conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() @@ -164,14 +212,23 @@ def add_doc_to_db(kb_file: KnowledgeFile): create_time DATETIME) ''') # Insert a row of data # TODO: 同名文件添加至知识库时,file_version增加 - c.execute(f"""INSERT INTO knowledge_files - (file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time) - VALUES - ('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}', - '{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") + c.execute(f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """) + record_exist = c.fetchone() + if record_exist is not None: + c.execute(f"""UPDATE knowledge_files + SET file_version = file_version + 1 + WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" + """) + else: + c.execute(f"""INSERT INTO knowledge_files + (file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time) + VALUES + ('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}', + '{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") conn.commit() conn.close() + def delete_file_from_db(kb_file: KnowledgeFile): conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() @@ -195,6 +252,7 @@ def delete_file_from_db(kb_file: KnowledgeFile): conn.close() return True + def doc_exists(kb_file: KnowledgeFile): conn = sqlite3.connect(DB_ROOT_PATH) c = conn.cursor() @@ -217,6 +275,24 @@ def doc_exists(kb_file: KnowledgeFile): return status +def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]): + overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values()) + if not overlapping: + raise ValueError("ids do not exist in the current object") + _reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()} + index_to_delete = [_reversed_index[i] for i in ids] + vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64)) + for _id in index_to_delete: + del vector_store.index_to_docstore_id[_id] + # Remove items from docstore. + overlapping2 = set(ids).intersection(vector_store.docstore._dict) + if not overlapping2: + raise ValueError(f"Tried to delete ids that does not exist: {ids}") + for _id in ids: + vector_store.docstore._dict.pop(_id) + return vector_store + + class KnowledgeBase: def __init__(self, knowledge_base_name: str, @@ -249,21 +325,25 @@ class KnowledgeBase: pass return True + def recreate_vs(self): + if self.vs_type in ["faiss"]: + shutil.rmtree(self.vs_path) + self.create() + def add_doc(self, kb_file: KnowledgeFile): docs = kb_file.file2text() - vs_path = get_vs_path(self.kb_name) embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) if self.vs_type in ["faiss"]: - if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): - vector_store = FAISS.load_local(vs_path, embeddings) + if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): + vector_store = FAISS.load_local(self.vs_path, embeddings) vector_store.add_documents(docs) torch_gc() else: - if not os.path.exists(vs_path): - os.makedirs(vs_path) + if not os.path.exists(self.vs_path): + os.makedirs(self.vs_path) vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表 torch_gc() - vector_store.save_local(vs_path) + vector_store.save_local(self.vs_path) add_doc_to_db(kb_file) refresh_vs_cache(self.kb_name) elif self.vs_type in ["milvus"]: @@ -275,7 +355,18 @@ class KnowledgeBase: os.remove(kb_file.filepath) if self.vs_type in ["faiss"]: # TODO: 从FAISS向量库中删除文档 - delete_file_from_db(kb_file) + embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) + if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): + vector_store = FAISS.load_local(self.vs_path, embeddings) + ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] + if len(ids) == 0: + return None + print(len(ids)) + vector_store = delete_doc_from_faiss(vector_store, ids) + vector_store.save_local(self.vs_path) + refresh_vs_cache(self.kb_name) + delete_file_from_db(kb_file) + return True def exist_doc(self, file_name: str): return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, @@ -284,6 +375,17 @@ class KnowledgeBase: def list_docs(self): return list_docs_from_db(self.kb_name) + def search_docs(self, + query: str, + top_k: int = VECTOR_SEARCH_TOP_K, + embedding_device: str = EMBEDDING_DEVICE, ): + search_index = load_vector_store(self.kb_name, + self.embed_model, + embedding_device, + _VECTOR_STORE_TICKS.get(self.kb_name)) + docs = search_index.similarity_search(query, k=top_k) + return docs + @classmethod def exists(cls, knowledge_base_name: str): @@ -316,4 +418,5 @@ if __name__ == "__main__": # kb = KnowledgeBase("123", "faiss") # kb.create() kb = KnowledgeBase.load(knowledge_base_name="123") + kb.delete_doc(KnowledgeFile(knowledge_base_name="123", filename="README.md")) print() diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index df35234..5773f71 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,77 +1,5 @@ -import os -from configs.model_config import KB_ROOT_PATH -from langchain.vectorstores import FAISS -from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from configs.model_config import (CACHED_VS_NUM, VECTOR_SEARCH_TOP_K, - embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE) -from functools import lru_cache - - -_VECTOR_STORE_TICKS = {} - - -def get_kb_path(knowledge_base_name: str): - return os.path.join(KB_ROOT_PATH, knowledge_base_name) - - -def get_doc_path(knowledge_base_name: str): - return os.path.join(get_kb_path(knowledge_base_name), "content") - - -def get_vs_path(knowledge_base_name: str): - return os.path.join(get_kb_path(knowledge_base_name), "vector_store") - - -def get_file_path(knowledge_base_name: str, doc_name: str): - return os.path.join(get_doc_path(knowledge_base_name), doc_name) - - def validate_kb_name(knowledge_base_id: str) -> bool: # 检查是否包含预期外的字符或路径攻击关键字 if "../" in knowledge_base_id: return False return True - - -@lru_cache(1) -def load_embeddings(model: str, device: str): - embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], - model_kwargs={'device': device}) - return embeddings - - -@lru_cache(CACHED_VS_NUM) -def load_vector_store( - knowledge_base_name: str, - embedding_model: str, - embedding_device: str, - tick: int, # tick will be changed by upload_doc etc. and make cache refreshed. -): - print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.") - embeddings = load_embeddings(embedding_model, embedding_device) - vs_path = get_vs_path(knowledge_base_name) - search_index = FAISS.load_local(vs_path, embeddings) - return search_index - - -def lookup_vs( - query: str, - knowledge_base_name: str, - top_k: int = VECTOR_SEARCH_TOP_K, - embedding_model: str = EMBEDDING_MODEL, - embedding_device: str = EMBEDDING_DEVICE, -): - search_index = load_vector_store(knowledge_base_name, - embedding_model, - embedding_device, - _VECTOR_STORE_TICKS.get(knowledge_base_name)) - docs = search_index.similarity_search(query, k=top_k) - return docs - - -def refresh_vs_cache(kb_name: str): - """ - make vector store cache refreshed when next loading - """ - _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 -