update kb_doc_api.py
This commit is contained in:
parent
a447529c2e
commit
b91d96ab0c
|
|
@ -1,7 +1,7 @@
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs.model_config import llm_model_dict, LLM_MODEL
|
from configs.model_config import llm_model_dict, LLM_MODEL
|
||||||
from .utils import wrap_done
|
from server.chat.utils import wrap_done
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
|
|
|
||||||
|
|
@ -4,15 +4,13 @@ from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||||
VECTOR_SEARCH_TOP_K)
|
VECTOR_SEARCH_TOP_K)
|
||||||
from server.chat.utils import wrap_done
|
from server.chat.utils import wrap_done
|
||||||
from server.utils import BaseResponse
|
from server.utils import BaseResponse
|
||||||
import os
|
|
||||||
from server.knowledge_base.utils import get_kb_path
|
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
import asyncio
|
import asyncio
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from server.knowledge_base.utils import lookup_vs
|
from server.knowledge_base.knowledge_base import KnowledgeBase
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -20,12 +18,12 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
||||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||||
):
|
):
|
||||||
kb_path = get_kb_path(knowledge_base_name)
|
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
||||||
if not os.path.exists(kb_path):
|
|
||||||
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}")
|
||||||
|
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||||
|
|
||||||
async def knowledge_base_chat_iterator(query: str,
|
async def knowledge_base_chat_iterator(query: str,
|
||||||
knowledge_base_name: str,
|
kb: KnowledgeBase,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
|
|
@ -37,7 +35,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||||
model_name=LLM_MODEL
|
model_name=LLM_MODEL
|
||||||
)
|
)
|
||||||
docs = lookup_vs(query, knowledge_base_name, top_k)
|
docs = kb.search_docs(query, top_k)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
||||||
|
|
||||||
|
|
@ -60,5 +58,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
"docs": source_documents})
|
"docs": source_documents})
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(knowledge_base_chat_iterator(query, knowledge_base_name, top_k),
|
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k),
|
||||||
media_type="text/event-stream")
|
media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,8 @@
|
||||||
import os
|
import os
|
||||||
import urllib
|
import urllib
|
||||||
import shutil
|
|
||||||
from fastapi import File, Form, UploadFile
|
from fastapi import File, Form, UploadFile
|
||||||
from server.utils import BaseResponse, ListResponse
|
from server.utils import BaseResponse, ListResponse
|
||||||
from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path,
|
from server.knowledge_base.utils import (validate_kb_name)
|
||||||
get_file_path, refresh_vs_cache, get_vs_path)
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
import json
|
import json
|
||||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||||
|
|
@ -16,8 +14,7 @@ async def list_docs(knowledge_base_name: str):
|
||||||
return ListResponse(code=403, msg="Don't attack me", data=[])
|
return ListResponse(code=403, msg="Don't attack me", data=[])
|
||||||
|
|
||||||
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
knowledge_base_name = urllib.parse.unquote(knowledge_base_name)
|
||||||
kb_path = get_kb_path(knowledge_base_name)
|
if not KnowledgeBase.exists(knowledge_base_name=knowledge_base_name):
|
||||||
if not os.path.exists(kb_path):
|
|
||||||
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
return ListResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}", data=[])
|
||||||
else:
|
else:
|
||||||
all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs()
|
all_doc_names = KnowledgeBase.load(knowledge_base_name=knowledge_base_name).list_docs()
|
||||||
|
|
@ -42,9 +39,10 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
||||||
knowledge_base_name=knowledge_base_name)
|
knowledge_base_name=knowledge_base_name)
|
||||||
|
|
||||||
if (os.path.exists(kb_file.filepath)
|
if (os.path.exists(kb_file.filepath)
|
||||||
and not override
|
and not override
|
||||||
and os.path.getsize(kb_file.filepath) == len(file_content)
|
and os.path.getsize(kb_file.filepath) == len(file_content)
|
||||||
):
|
):
|
||||||
|
# TODO: filesize 不同后的处理
|
||||||
file_status = f"文件 {kb_file.filename} 已存在。"
|
file_status = f"文件 {kb_file.filename} 已存在。"
|
||||||
return BaseResponse(code=404, msg=file_status)
|
return BaseResponse(code=404, msg=file_status)
|
||||||
|
|
||||||
|
|
@ -83,6 +81,7 @@ async def update_doc():
|
||||||
# refresh_vs_cache(knowledge_base_name)
|
# refresh_vs_cache(knowledge_base_name)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
async def download_doc():
|
async def download_doc():
|
||||||
# TODO: 下载文件
|
# TODO: 下载文件
|
||||||
pass
|
pass
|
||||||
|
|
@ -93,19 +92,16 @@ async def recreate_vector_store(knowledge_base_name: str):
|
||||||
recreate vector store from the content.
|
recreate vector store from the content.
|
||||||
this is usefull when user can copy files to content folder directly instead of upload through network.
|
this is usefull when user can copy files to content folder directly instead of upload through network.
|
||||||
'''
|
'''
|
||||||
async def output(kb_name):
|
kb = KnowledgeBase.load(knowledge_base_name=knowledge_base_name)
|
||||||
vs_path = get_vs_path(kb_name)
|
|
||||||
if os.path.isdir(vs_path):
|
|
||||||
shutil.rmtree(vs_path)
|
|
||||||
os.mkdir(vs_path)
|
|
||||||
print(f"start to recreate vectore in {vs_path}")
|
|
||||||
|
|
||||||
docs = (await list_docs(kb_name)).data
|
async def output(kb: KnowledgeBase):
|
||||||
|
kb.recreate_vs()
|
||||||
|
print(f"start to recreate vector store of {kb.kb_name}")
|
||||||
|
docs = kb.list_docs()
|
||||||
for i, filename in enumerate(docs):
|
for i, filename in enumerate(docs):
|
||||||
kb_file = KnowledgeFile(filename=filename,
|
kb_file = KnowledgeFile(filename=filename,
|
||||||
knowledge_base_name=kb_name)
|
knowledge_base_name=kb.kb_name)
|
||||||
print(f"processing {kb_file.filepath} to vector store.")
|
print(f"processing {kb_file.filepath} to vector store.")
|
||||||
kb = KnowledgeBase.load(knowledge_base_name=kb_name)
|
|
||||||
kb.add_doc(kb_file)
|
kb.add_doc(kb_file)
|
||||||
yield json.dumps({
|
yield json.dumps({
|
||||||
"total": len(docs),
|
"total": len(docs),
|
||||||
|
|
@ -113,4 +109,4 @@ async def recreate_vector_store(knowledge_base_name: str):
|
||||||
"doc": filename,
|
"doc": filename,
|
||||||
})
|
})
|
||||||
|
|
||||||
return StreamingResponse(output(knowledge_base_name), media_type="text/event-stream")
|
return StreamingResponse(output(kb), media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -3,15 +3,62 @@ import sqlite3
|
||||||
import datetime
|
import datetime
|
||||||
import shutil
|
import shutil
|
||||||
from langchain.vectorstores import FAISS
|
from langchain.vectorstores import FAISS
|
||||||
from server.knowledge_base.utils import (get_vs_path, get_kb_path, get_doc_path,
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
refresh_vs_cache, load_embeddings)
|
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE,
|
||||||
from configs.model_config import (embedding_model_dict, EMBEDDING_MODEL,
|
KB_ROOT_PATH, DB_ROOT_PATH, VECTOR_SEARCH_TOP_K, CACHED_VS_NUM)
|
||||||
EMBEDDING_DEVICE, DB_ROOT_PATH)
|
|
||||||
from server.utils import torch_gc
|
from server.utils import torch_gc
|
||||||
|
from functools import lru_cache
|
||||||
from server.knowledge_base.knowledge_file import KnowledgeFile
|
from server.knowledge_base.knowledge_file import KnowledgeFile
|
||||||
|
from typing import List
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
SUPPORTED_VS_TYPES = ["faiss", "milvus"]
|
SUPPORTED_VS_TYPES = ["faiss", "milvus"]
|
||||||
|
|
||||||
|
_VECTOR_STORE_TICKS = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_kb_path(knowledge_base_name: str):
|
||||||
|
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
||||||
|
|
||||||
|
|
||||||
|
def get_doc_path(knowledge_base_name: str):
|
||||||
|
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
||||||
|
|
||||||
|
|
||||||
|
def get_vs_path(knowledge_base_name: str):
|
||||||
|
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_path(knowledge_base_name: str, doc_name: str):
|
||||||
|
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
||||||
|
|
||||||
|
@lru_cache(1)
|
||||||
|
def load_embeddings(model: str, device: str):
|
||||||
|
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||||
|
model_kwargs={'device': device})
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(CACHED_VS_NUM)
|
||||||
|
def load_vector_store(
|
||||||
|
knowledge_base_name: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_device: str,
|
||||||
|
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||||
|
):
|
||||||
|
print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.")
|
||||||
|
embeddings = load_embeddings(embedding_model, embedding_device)
|
||||||
|
vs_path = get_vs_path(knowledge_base_name)
|
||||||
|
search_index = FAISS.load_local(vs_path, embeddings)
|
||||||
|
return search_index
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_vs_cache(kb_name: str):
|
||||||
|
"""
|
||||||
|
make vector store cache refreshed when next loading
|
||||||
|
"""
|
||||||
|
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
||||||
|
|
||||||
|
|
||||||
def list_kbs_from_db():
|
def list_kbs_from_db():
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||||
|
|
@ -149,6 +196,7 @@ def list_docs_from_db(kb_name):
|
||||||
conn.close()
|
conn.close()
|
||||||
return kbs
|
return kbs
|
||||||
|
|
||||||
|
|
||||||
def add_doc_to_db(kb_file: KnowledgeFile):
|
def add_doc_to_db(kb_file: KnowledgeFile):
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
|
|
@ -164,14 +212,23 @@ def add_doc_to_db(kb_file: KnowledgeFile):
|
||||||
create_time DATETIME) ''')
|
create_time DATETIME) ''')
|
||||||
# Insert a row of data
|
# Insert a row of data
|
||||||
# TODO: 同名文件添加至知识库时,file_version增加
|
# TODO: 同名文件添加至知识库时,file_version增加
|
||||||
c.execute(f"""INSERT INTO knowledge_files
|
c.execute(f"""SELECT 1 FROM knowledge_files WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}" """)
|
||||||
(file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time)
|
record_exist = c.fetchone()
|
||||||
VALUES
|
if record_exist is not None:
|
||||||
('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}',
|
c.execute(f"""UPDATE knowledge_files
|
||||||
'{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
SET file_version = file_version + 1
|
||||||
|
WHERE file_name="{kb_file.filename}" AND kb_name="{kb_file.kb_name}"
|
||||||
|
""")
|
||||||
|
else:
|
||||||
|
c.execute(f"""INSERT INTO knowledge_files
|
||||||
|
(file_name, file_ext, kb_name, document_loader_name, text_splitter_name, file_version, create_time)
|
||||||
|
VALUES
|
||||||
|
('{kb_file.filename}','{kb_file.ext}','{kb_file.kb_name}', '{kb_file.document_loader_name}',
|
||||||
|
'{kb_file.text_splitter_name}',0,'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}')""")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
def delete_file_from_db(kb_file: KnowledgeFile):
|
def delete_file_from_db(kb_file: KnowledgeFile):
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
|
|
@ -195,6 +252,7 @@ def delete_file_from_db(kb_file: KnowledgeFile):
|
||||||
conn.close()
|
conn.close()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def doc_exists(kb_file: KnowledgeFile):
|
def doc_exists(kb_file: KnowledgeFile):
|
||||||
conn = sqlite3.connect(DB_ROOT_PATH)
|
conn = sqlite3.connect(DB_ROOT_PATH)
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
|
|
@ -217,6 +275,24 @@ def doc_exists(kb_file: KnowledgeFile):
|
||||||
return status
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
def delete_doc_from_faiss(vector_store: FAISS, ids: List[str]):
|
||||||
|
overlapping = set(ids).intersection(vector_store.index_to_docstore_id.values())
|
||||||
|
if not overlapping:
|
||||||
|
raise ValueError("ids do not exist in the current object")
|
||||||
|
_reversed_index = {v: k for k, v in vector_store.index_to_docstore_id.items()}
|
||||||
|
index_to_delete = [_reversed_index[i] for i in ids]
|
||||||
|
vector_store.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
|
||||||
|
for _id in index_to_delete:
|
||||||
|
del vector_store.index_to_docstore_id[_id]
|
||||||
|
# Remove items from docstore.
|
||||||
|
overlapping2 = set(ids).intersection(vector_store.docstore._dict)
|
||||||
|
if not overlapping2:
|
||||||
|
raise ValueError(f"Tried to delete ids that does not exist: {ids}")
|
||||||
|
for _id in ids:
|
||||||
|
vector_store.docstore._dict.pop(_id)
|
||||||
|
return vector_store
|
||||||
|
|
||||||
|
|
||||||
class KnowledgeBase:
|
class KnowledgeBase:
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
|
|
@ -249,21 +325,25 @@ class KnowledgeBase:
|
||||||
pass
|
pass
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def recreate_vs(self):
|
||||||
|
if self.vs_type in ["faiss"]:
|
||||||
|
shutil.rmtree(self.vs_path)
|
||||||
|
self.create()
|
||||||
|
|
||||||
def add_doc(self, kb_file: KnowledgeFile):
|
def add_doc(self, kb_file: KnowledgeFile):
|
||||||
docs = kb_file.file2text()
|
docs = kb_file.file2text()
|
||||||
vs_path = get_vs_path(self.kb_name)
|
|
||||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||||
if self.vs_type in ["faiss"]:
|
if self.vs_type in ["faiss"]:
|
||||||
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
|
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||||
vector_store = FAISS.load_local(vs_path, embeddings)
|
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
||||||
vector_store.add_documents(docs)
|
vector_store.add_documents(docs)
|
||||||
torch_gc()
|
torch_gc()
|
||||||
else:
|
else:
|
||||||
if not os.path.exists(vs_path):
|
if not os.path.exists(self.vs_path):
|
||||||
os.makedirs(vs_path)
|
os.makedirs(self.vs_path)
|
||||||
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
|
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
|
||||||
torch_gc()
|
torch_gc()
|
||||||
vector_store.save_local(vs_path)
|
vector_store.save_local(self.vs_path)
|
||||||
add_doc_to_db(kb_file)
|
add_doc_to_db(kb_file)
|
||||||
refresh_vs_cache(self.kb_name)
|
refresh_vs_cache(self.kb_name)
|
||||||
elif self.vs_type in ["milvus"]:
|
elif self.vs_type in ["milvus"]:
|
||||||
|
|
@ -275,7 +355,18 @@ class KnowledgeBase:
|
||||||
os.remove(kb_file.filepath)
|
os.remove(kb_file.filepath)
|
||||||
if self.vs_type in ["faiss"]:
|
if self.vs_type in ["faiss"]:
|
||||||
# TODO: 从FAISS向量库中删除文档
|
# TODO: 从FAISS向量库中删除文档
|
||||||
delete_file_from_db(kb_file)
|
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||||
|
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||||
|
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
||||||
|
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||||
|
if len(ids) == 0:
|
||||||
|
return None
|
||||||
|
print(len(ids))
|
||||||
|
vector_store = delete_doc_from_faiss(vector_store, ids)
|
||||||
|
vector_store.save_local(self.vs_path)
|
||||||
|
refresh_vs_cache(self.kb_name)
|
||||||
|
delete_file_from_db(kb_file)
|
||||||
|
return True
|
||||||
|
|
||||||
def exist_doc(self, file_name: str):
|
def exist_doc(self, file_name: str):
|
||||||
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
return doc_exists(KnowledgeFile(knowledge_base_name=self.kb_name,
|
||||||
|
|
@ -284,6 +375,17 @@ class KnowledgeBase:
|
||||||
def list_docs(self):
|
def list_docs(self):
|
||||||
return list_docs_from_db(self.kb_name)
|
return list_docs_from_db(self.kb_name)
|
||||||
|
|
||||||
|
def search_docs(self,
|
||||||
|
query: str,
|
||||||
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
|
embedding_device: str = EMBEDDING_DEVICE, ):
|
||||||
|
search_index = load_vector_store(self.kb_name,
|
||||||
|
self.embed_model,
|
||||||
|
embedding_device,
|
||||||
|
_VECTOR_STORE_TICKS.get(self.kb_name))
|
||||||
|
docs = search_index.similarity_search(query, k=top_k)
|
||||||
|
return docs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def exists(cls,
|
def exists(cls,
|
||||||
knowledge_base_name: str):
|
knowledge_base_name: str):
|
||||||
|
|
@ -316,4 +418,5 @@ if __name__ == "__main__":
|
||||||
# kb = KnowledgeBase("123", "faiss")
|
# kb = KnowledgeBase("123", "faiss")
|
||||||
# kb.create()
|
# kb.create()
|
||||||
kb = KnowledgeBase.load(knowledge_base_name="123")
|
kb = KnowledgeBase.load(knowledge_base_name="123")
|
||||||
|
kb.delete_doc(KnowledgeFile(knowledge_base_name="123", filename="README.md"))
|
||||||
print()
|
print()
|
||||||
|
|
|
||||||
|
|
@ -1,77 +1,5 @@
|
||||||
import os
|
|
||||||
from configs.model_config import KB_ROOT_PATH
|
|
||||||
from langchain.vectorstores import FAISS
|
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
||||||
from configs.model_config import (CACHED_VS_NUM, VECTOR_SEARCH_TOP_K,
|
|
||||||
embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
|
|
||||||
from functools import lru_cache
|
|
||||||
|
|
||||||
|
|
||||||
_VECTOR_STORE_TICKS = {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_kb_path(knowledge_base_name: str):
|
|
||||||
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
|
||||||
|
|
||||||
|
|
||||||
def get_doc_path(knowledge_base_name: str):
|
|
||||||
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
|
||||||
|
|
||||||
|
|
||||||
def get_vs_path(knowledge_base_name: str):
|
|
||||||
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
|
||||||
|
|
||||||
|
|
||||||
def get_file_path(knowledge_base_name: str, doc_name: str):
|
|
||||||
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_kb_name(knowledge_base_id: str) -> bool:
|
def validate_kb_name(knowledge_base_id: str) -> bool:
|
||||||
# 检查是否包含预期外的字符或路径攻击关键字
|
# 检查是否包含预期外的字符或路径攻击关键字
|
||||||
if "../" in knowledge_base_id:
|
if "../" in knowledge_base_id:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(1)
|
|
||||||
def load_embeddings(model: str, device: str):
|
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
|
||||||
model_kwargs={'device': device})
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(CACHED_VS_NUM)
|
|
||||||
def load_vector_store(
|
|
||||||
knowledge_base_name: str,
|
|
||||||
embedding_model: str,
|
|
||||||
embedding_device: str,
|
|
||||||
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed.
|
|
||||||
):
|
|
||||||
print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.")
|
|
||||||
embeddings = load_embeddings(embedding_model, embedding_device)
|
|
||||||
vs_path = get_vs_path(knowledge_base_name)
|
|
||||||
search_index = FAISS.load_local(vs_path, embeddings)
|
|
||||||
return search_index
|
|
||||||
|
|
||||||
|
|
||||||
def lookup_vs(
|
|
||||||
query: str,
|
|
||||||
knowledge_base_name: str,
|
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
|
||||||
embedding_model: str = EMBEDDING_MODEL,
|
|
||||||
embedding_device: str = EMBEDDING_DEVICE,
|
|
||||||
):
|
|
||||||
search_index = load_vector_store(knowledge_base_name,
|
|
||||||
embedding_model,
|
|
||||||
embedding_device,
|
|
||||||
_VECTOR_STORE_TICKS.get(knowledge_base_name))
|
|
||||||
docs = search_index.similarity_search(query, k=top_k)
|
|
||||||
return docs
|
|
||||||
|
|
||||||
|
|
||||||
def refresh_vs_cache(kb_name: str):
|
|
||||||
"""
|
|
||||||
make vector store cache refreshed when next loading
|
|
||||||
"""
|
|
||||||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue