update kb_doc_api.py

This commit is contained in:
imClumsyPanda 2023-08-06 18:32:10 +08:00
parent a447529c2e
commit b91d96ab0c
5 changed files with 140 additions and 115 deletions

View File

@ -1,7 +1,7 @@
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from configs.model_config import llm_model_dict, LLM_MODEL from configs.model_config import llm_model_dict, LLM_MODEL
from .utils import wrap_done from server.chat.utils import wrap_done
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler

View File

@ -4,15 +4,13 @@ from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
VECTOR_SEARCH_TOP_K) VECTOR_SEARCH_TOP_K)
from server.chat.utils import wrap_done from server.chat.utils import wrap_done
from server.utils import BaseResponse from server.utils import BaseResponse
import os
from server.knowledge_base.utils import get_kb_path
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable
import asyncio import asyncio
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from server.knowledge_base.utils import lookup_vs from server.knowledge_base.knowledge_base import KnowledgeBase
import json import json
@ -20,12 +18,12 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"), knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
): ):
kb_path = get_kb_path(knowledge_base_name) if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
if not os.path.exists(kb_path):
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
async def knowledge_base_chat_iterator(query: str, async def knowledge_base_chat_iterator(query: str,
knowledge_base_name: str, kb: KnowledgeBase,
top_k: int, top_k: int,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
@ -37,7 +35,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL model_name=LLM_MODEL
) )
docs = lookup_vs(query, knowledge_base_name, top_k) docs = kb.search_docs(query, top_k)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"]) prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
@ -60,5 +58,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
"docs": source_documents}) "docs": source_documents})
await task await task
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k), return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k),
media_type="text/event-stream") media_type="text/event-stream")

View File

@ -1,10 +1,8 @@
import os import os
import urllib import urllib
import shutil
from fastapi import File, Form, UploadFile from fastapi import File, Form, UploadFile
from server.utils import BaseResponse, ListResponse from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path, from server.knowledge_base.utils import (validate_kb_name)
get_file_path, refresh_vs_cache, get_vs_path)
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
import json import json
from server.knowledge_base.knowledge_file import KnowledgeFile from server.knowledge_base.knowledge_file import KnowledgeFile
@ -16,8 +14,7 @@ async def list_docs(knowledge_base_name: str):
return ListResponse(code=403, msg="Don't attack me", data=[]) return ListResponse(code=403, msg="Don't attack me", data=[])
knowledge_base_name = urllib.parse.unquote(knowledge_base_name) knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
kb_path = get_kb_path(knowledge_base_name) if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
if not os.path.exists(kb_path):
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[]) return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
else: else:
all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs() all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs()
@ -42,9 +39,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
knowledge_base_name=knowledge_base_name) knowledge_base_name=knowledge_base_name)
if (os.path.exists(kb_file.filepath) if (os.path.exists(kb_file.filepath)
and not override and not override
and os.path.getsize(kb_file.filepath) == len(file_content) and os.path.getsize(kb_file.filepath) == len(file_content)
): ):
# TODO: filesize 不同后的处理
file_status = f"文件 {kb_file.filename} 已存在。" file_status = f"文件 {kb_file.filename} 已存在。"
return BaseResponse(code=404, msg=file_status) return BaseResponse(code=404, msg=file_status)
@ -83,6 +81,7 @@ async def update_doc():
# refresh_vs_cache(knowledge_base_name) # refresh_vs_cache(knowledge_base_name)
pass pass
async def download_doc(): async def download_doc():
# TODO: 下载文件 # TODO: 下载文件
pass pass
@ -93,19 +92,16 @@ async def recreate_vector_store(knowledge_base_name: str):
recreate vector store from the content. recreate vector store from the content.
this is usefull when user can copy files to content folder directly instead of upload through network. this is usefull when user can copy files to content folder directly instead of upload through network.
''' '''
async def output(kb_name): kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
vs_path = get_vs_path(kb_name)
if os.path.isdir(vs_path):
shutil.rmtree(vs_path)
os.mkdir(vs_path)
print(f"start to recreate vectore in {vs_path}")
docs = (await list_docs(kb_name)).data async def output(kb: KnowledgeBase):
kb.recreate_vs()
print(f"start to recreate vector store of {kb.kb_name}")
docs = kb.list_docs()
for i, filename in enumerate(docs): for i, filename in enumerate(docs):
kb_file = KnowledgeFile(filename=filename, kb_file = KnowledgeFile(filename=filename,
knowledge_base_name=kb_name) knowledge_base_name=kb.kb_name)
print(f"processing {kb_file.filepath} to vector store.") print(f"processing {kb_file.filepath} to vector store.")
kb = KnowledgeBase.load(knowledge_base_name=kb_name)
kb.add_doc(kb_file) kb.add_doc(kb_file)
yield json.dumps({ yield json.dumps({
"total": len(docs), "total": len(docs),
@ -113,4 +109,4 @@ async def recreate_vector_store(knowledge_base_name: str):
"doc": filename, "doc": filename,
}) })
return StreamingResponse(output(knowledge_base_name), media_type="text/event-stream") return StreamingResponse(output(kb), media_type="text/event-stream")

View File

@ -3,15 +3,62 @@ import sqlite3
import datetime import datetime
import shutil import shutil
from langchain.vectorstores import FAISS from langchain.vectorstores import FAISS
from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path, from langchain.embeddings.huggingface import HuggingFaceEmbeddings
refresh_vs_cache, load_embeddings) from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE,
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, KB_ROOT_PATH, DB_ROOT_PATH, VECTOR_SEARCH_TOP_K, CACHED_VS_NUM)
EMBEDDING_DEVICE, DB_ROOT_PATH)
from server.utils import torch_gc from server.utils import torch_gc
from functools import lru_cache
from server.knowledge_base.knowledge_file import KnowledgeFile from server.knowledge_base.knowledge_file import KnowledgeFile
from typing import List
import numpy as np
SUPPORTED_VS_TYPES = ["faiss", "milvus"] SUPPORTED_VS_TYPES = ["faiss", "milvus"]
_VECTOR_STORE_TICKS = {}
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")
def get_vs_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
def get_file_path(knowledge_base_name: str, doc_name: str):
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
@lru_cache(1)
def load_embeddings(model: str, device: str):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device})
return embeddings
@lru_cache(CACHED_VS_NUM)
def load_vector_store(
knowledge_base_name: str,
embedding_model: str,
embedding_device: str,
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed.
):
print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.")
embeddings = load_embeddings(embedding_model, embedding_device)
vs_path = get_vs_path(knowledge_base_name)
search_index = FAISS.load_local(vs_path, embeddings)
return search_index
def refresh_vs_cache(kb_name: str):
"""
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
def list_kbs_from_db(): def list_kbs_from_db():
conn = sqlite3.connect(DB_ROOT_PATH) conn = sqlite3.connect(DB_ROOT_PATH)
@ -149,6 +196,7 @@ def list_docs_from_db(kb_name):
conn.close() conn.close()
return kbs return kbs
def add_doc_to_db(kb_file: KnowledgeFile): def add_doc_to_db(kb_file: KnowledgeFile):
conn = sqlite3.connect(DB_ROOT_PATH) conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor() c = conn.cursor()
@ -164,14 +212,23 @@ def add_doc_to_db(kb_file: KnowledgeFile):
create_time DATETIME) ''') create_time DATETIME) ''')
# Insert a row of data # Insert a row of data
# TODO: 同名文件添加至知识库时file_version增加 # TODO: 同名文件添加至知识库时file_version增加
c.execute(f"""INSERT INTO knowledge_files c.execute(f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """)
(file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time) record_exist = c.fetchone()
VALUES if record_exist is not None:
('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}', c.execute(f"""UPDATE knowledge_files
'{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""") SET file_version = file_version + 1
WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}"
""")
else:
c.execute(f"""INSERT INTO knowledge_files
(file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time)
VALUES
('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}',
'{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
conn.commit() conn.commit()
conn.close() conn.close()
def delete_file_from_db(kb_file: KnowledgeFile): def delete_file_from_db(kb_file: KnowledgeFile):
conn = sqlite3.connect(DB_ROOT_PATH) conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor() c = conn.cursor()
@ -195,6 +252,7 @@ def delete_file_from_db(kb_file: KnowledgeFile):
conn.close() conn.close()
return True return True
def doc_exists(kb_file: KnowledgeFile): def doc_exists(kb_file: KnowledgeFile):
conn = sqlite3.connect(DB_ROOT_PATH) conn = sqlite3.connect(DB_ROOT_PATH)
c = conn.cursor() c = conn.cursor()
@ -217,6 +275,24 @@ def doc_exists(kb_file: KnowledgeFile):
return status return status
def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]):
overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values())
if not overlapping:
raise ValueError("ids do not exist in the current object")
_reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()}
index_to_delete = [_reversed_index[i] for i in ids]
vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
for _id in index_to_delete:
del vector_store.index_to_docstore_id[_id]
# Remove items from docstore.
overlapping2 = set(ids).intersection(vector_store.docstore._dict)
if not overlapping2:
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
for _id in ids:
vector_store.docstore._dict.pop(_id)
return vector_store
class KnowledgeBase: class KnowledgeBase:
def __init__(self, def __init__(self,
knowledge_base_name: str, knowledge_base_name: str,
@ -249,21 +325,25 @@ class KnowledgeBase:
pass pass
return True return True
def recreate_vs(self):
if self.vs_type in ["faiss"]:
shutil.rmtree(self.vs_path)
self.create()
def add_doc(self, kb_file: KnowledgeFile): def add_doc(self, kb_file: KnowledgeFile):
docs = kb_file.file2text() docs = kb_file.file2text()
vs_path = get_vs_path(self.kb_name)
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
if self.vs_type in ["faiss"]: if self.vs_type in ["faiss"]:
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(vs_path, embeddings) vector_store = FAISS.load_local(self.vs_path, embeddings)
vector_store.add_documents(docs) vector_store.add_documents(docs)
torch_gc() torch_gc()
else: else:
if not os.path.exists(vs_path): if not os.path.exists(self.vs_path):
os.makedirs(vs_path) os.makedirs(self.vs_path)
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表 vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
torch_gc() torch_gc()
vector_store.save_local(vs_path) vector_store.save_local(self.vs_path)
add_doc_to_db(kb_file) add_doc_to_db(kb_file)
refresh_vs_cache(self.kb_name) refresh_vs_cache(self.kb_name)
elif self.vs_type in ["milvus"]: elif self.vs_type in ["milvus"]:
@ -275,7 +355,18 @@ class KnowledgeBase:
os.remove(kb_file.filepath) os.remove(kb_file.filepath)
if self.vs_type in ["faiss"]: if self.vs_type in ["faiss"]:
# TODO: 从FAISS向量库中删除文档 # TODO: 从FAISS向量库中删除文档
delete_file_from_db(kb_file) embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings)
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
if len(ids) == 0:
return None
print(len(ids))
vector_store = delete_doc_from_faiss(vector_store, ids)
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
delete_file_from_db(kb_file)
return True
def exist_doc(self, file_name: str): def exist_doc(self, file_name: str):
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name, return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
@ -284,6 +375,17 @@ class KnowledgeBase:
def list_docs(self): def list_docs(self):
return list_docs_from_db(self.kb_name) return list_docs_from_db(self.kb_name)
def search_docs(self,
query: str,
top_k: int = VECTOR_SEARCH_TOP_K,
embedding_device: str = EMBEDDING_DEVICE, ):
search_index = load_vector_store(self.kb_name,
self.embed_model,
embedding_device,
_VECTOR_STORE_TICKS.get(self.kb_name))
docs = search_index.similarity_search(query, k=top_k)
return docs
@classmethod @classmethod
def exists(cls, def exists(cls,
knowledge_base_name: str): knowledge_base_name: str):
@ -316,4 +418,5 @@ if __name__ == "__main__":
# kb = KnowledgeBase("123", "faiss") # kb = KnowledgeBase("123", "faiss")
# kb.create() # kb.create()
kb = KnowledgeBase.load(knowledge_base_name="123") kb = KnowledgeBase.load(knowledge_base_name="123")
kb.delete_doc(KnowledgeFile(knowledge_base_name="123", filename="README.md"))
print() print()

View File

@ -1,77 +1,5 @@
import os
from configs.model_config import KB_ROOT_PATH
from langchain.vectorstores import FAISS
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from configs.model_config import (CACHED_VS_NUM, VECTOR_SEARCH_TOP_K,
embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
from functools import lru_cache
_VECTOR_STORE_TICKS = {}
def get_kb_path(knowledge_base_name: str):
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
def get_doc_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "content")
def get_vs_path(knowledge_base_name: str):
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
def get_file_path(knowledge_base_name: str, doc_name: str):
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
def validate_kb_name(knowledge_base_id: str) -> bool: def validate_kb_name(knowledge_base_id: str) -> bool:
# 检查是否包含预期外的字符或路径攻击关键字 # 检查是否包含预期外的字符或路径攻击关键字
if "../" in knowledge_base_id: if "../" in knowledge_base_id:
return False return False
return True return True
@lru_cache(1)
def load_embeddings(model: str, device: str):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device})
return embeddings
@lru_cache(CACHED_VS_NUM)
def load_vector_store(
knowledge_base_name: str,
embedding_model: str,
embedding_device: str,
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed.
):
print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.")
embeddings = load_embeddings(embedding_model, embedding_device)
vs_path = get_vs_path(knowledge_base_name)
search_index = FAISS.load_local(vs_path, embeddings)
return search_index
def lookup_vs(
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
embedding_model: str = EMBEDDING_MODEL,
embedding_device: str = EMBEDDING_DEVICE,
):
search_index = load_vector_store(knowledge_base_name,
embedding_model,
embedding_device,
_VECTOR_STORE_TICKS.get(knowledge_base_name))
docs = search_index.similarity_search(query, k=top_k)
return docs
def refresh_vs_cache(kb_name: str):
"""
make vector store cache refreshed when next loading
"""
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1