From df61f31c8e199ee224f7c802d0d763c3a831d280 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Thu, 3 Aug 2023 15:22:46 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9knowledge=5Fbase=5Fchat?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0:=E5=88=A9=E7=94=A8=E7=BC=93=E5=AD=98?= =?UTF-8?q?=E9=81=BF=E5=85=8D=E6=AF=8F=E6=AC=A1=E8=AF=B7=E6=B1=82=E9=87=8D?= =?UTF-8?q?=E6=96=B0=E5=8A=A0=E8=BD=BD=E5=90=91=E9=87=8F=E5=BA=93=EF=BC=9B?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0top=5Fk=E5=8F=82=E6=95=B0=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/chat/knowledge_base_chat.py | 39 ++++++++++++++++++++++++++---- webui_pages/dialogue/dialogue.py | 5 ++-- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 8af417e..852e9ca 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -12,10 +12,43 @@ from langchain.prompts import PromptTemplate from langchain.vectorstores import FAISS from langchain.embeddings.huggingface import HuggingFaceEmbeddings from server.knowledge_base.utils import get_vs_path +from functools import lru_cache + + +@lru_cache(1) +def load_embeddings(model: str, device: str): + embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], + model_kwargs={'device': device}) + return embeddings + + +@lru_cache(1) +def load_vector_store( + knowledge_base_name: str, + embedding_model: str, + embedding_device: str, +): + embeddings = load_embeddings(embedding_model, embedding_device) + vs_path = get_vs_path(knowledge_base_name) + search_index = FAISS.load_local(vs_path, embeddings) + return search_index + + +def lookup_vs( + query: str, + knowledge_base_name: str, + top_k: int = 3, + embedding_model: str = EMBEDDING_MODEL, + embedding_device: str = EMBEDDING_DEVICE, +): + search_index = load_vector_store(knowledge_base_name, embedding_model, embedding_device) + docs = search_index.similarity_search(query, k=top_k) + return docs def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"), knowledge_base_name: str = Body(..., description="知识库名称", example="samples"), + top_k: int = Body(3, description="匹配向量数"), ): async def knowledge_base_chat_iterator(query: str, knowledge_base_name: str, @@ -30,11 +63,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp model_name=LLM_MODEL ) - vs_path = get_vs_path(knowledge_base_name) - embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL], - model_kwargs={'device': EMBEDDING_DEVICE}) - search_index = FAISS.load_local(vs_path, embeddings) - docs = search_index.similarity_search(query, k=4) + docs = lookup_vs(query, knowledge_base_name, top_k) context = "\n".join([doc.page_content for doc in docs]) prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"]) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 06e37c7..ea0f3b6 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -52,7 +52,6 @@ def dialogue_page(api: ApiRequest): cols = st.columns(2) chat_list = chat_box.get_chat_names() - print(chat_list, chat_box.cur_chat_name) try: index = chat_list.index(chat_box.cur_chat_name) except: @@ -74,7 +73,7 @@ def dialogue_page(api: ApiRequest): on_change=on_kb_change, key="selected_kb", ) - top_k = st.slider("匹配知识条数:", 1, 20, 3, disabled=True) + top_k = st.slider("匹配知识条数:", 1, 20, 3) score_threshold = st.slider("知识匹配分数阈值:", 0, 1000, 0, disabled=True) chunk_content = st.checkbox("关联上下文", False, disabled=True) chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) @@ -95,7 +94,7 @@ def dialogue_page(api: ApiRequest): elif dialogue_mode == "知识库问答": chat_box.ai_say(f"正在查询知识库: `{selected_kb}` ...") text = "" - for t in api.knowledge_base_chat(prompt, selected_kb): + for t in api.knowledge_base_chat(prompt, selected_kb, top_k): text += t chat_box.update_msg(text) chat_box.update_msg(text, streaming=False)