修改knowledge_base_chat实现:利用缓存避免每次请求重新加载向量库;增加top_k参数。

This commit is contained in:
liunux4odoo 2023-08-03 15:22:46 +08:00
parent e1698ce12e
commit df61f31c8e
2 changed files with 36 additions and 8 deletions

View File

@ -12,10 +12,43 @@ from langchain.prompts import PromptTemplate
from langchain.vectorstores import FAISS
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from server.knowledge_base.utils import get_vs_path
from functools import lru_cache
@lru_cache(1)
def load_embeddings(model: str, device: str):
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
model_kwargs={'device': device})
return embeddings
@lru_cache(1)
def load_vector_store(
knowledge_base_name: str,
embedding_model: str,
embedding_device: str,
):
embeddings = load_embeddings(embedding_model, embedding_device)
vs_path = get_vs_path(knowledge_base_name)
search_index = FAISS.load_local(vs_path, embeddings)
return search_index
def lookup_vs(
query: str,
knowledge_base_name: str,
top_k: int = 3,
embedding_model: str = EMBEDDING_MODEL,
embedding_device: str = EMBEDDING_DEVICE,
):
search_index = load_vector_store(knowledge_base_name, embedding_model, embedding_device)
docs = search_index.similarity_search(query, k=top_k)
return docs
def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"),
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
top_k: int = Body(3, description="匹配向量数"),
):
async def knowledge_base_chat_iterator(query: str,
knowledge_base_name: str,
@ -30,11 +63,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
model_name=LLM_MODEL
)
vs_path = get_vs_path(knowledge_base_name)
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],
model_kwargs={'device': EMBEDDING_DEVICE})
search_index = FAISS.load_local(vs_path, embeddings)
docs = search_index.similarity_search(query, k=4)
docs = lookup_vs(query, knowledge_base_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])

View File

@ -52,7 +52,6 @@ def dialogue_page(api: ApiRequest):
cols = st.columns(2)
chat_list = chat_box.get_chat_names()
print(chat_list, chat_box.cur_chat_name)
try:
index = chat_list.index(chat_box.cur_chat_name)
except:
@ -74,7 +73,7 @@ def dialogue_page(api: ApiRequest):
on_change=on_kb_change,
key="selected_kb",
)
top_k = st.slider("匹配知识条数:", 1, 20, 3, disabled=True)
top_k = st.slider("匹配知识条数:", 1, 20, 3)
score_threshold = st.slider("知识匹配分数阈值:", 0, 1000, 0, disabled=True)
chunk_content = st.checkbox("关联上下文", False, disabled=True)
chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
@ -95,7 +94,7 @@ def dialogue_page(api: ApiRequest):
elif dialogue_mode == "知识库问答":
chat_box.ai_say(f"正在查询知识库: `{selected_kb}` ...")
text = ""
for t in api.knowledge_base_chat(prompt, selected_kb):
for t in api.knowledge_base_chat(prompt, selected_kb, top_k):
text += t
chat_box.update_msg(text)
chat_box.update_msg(text, streaming=False)