修改knowledge_base_chat实现:利用缓存避免每次请求重新加载向量库;增加top_k参数。
This commit is contained in:
parent
e1698ce12e
commit
df61f31c8e
|
|
@ -12,10 +12,43 @@ from langchain.prompts import PromptTemplate
|
|||
from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from server.knowledge_base.utils import get_vs_path
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||
model_kwargs={'device': device})
|
||||
return embeddings
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_vector_store(
|
||||
knowledge_base_name: str,
|
||||
embedding_model: str,
|
||||
embedding_device: str,
|
||||
):
|
||||
embeddings = load_embeddings(embedding_model, embedding_device)
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
search_index = FAISS.load_local(vs_path, embeddings)
|
||||
return search_index
|
||||
|
||||
|
||||
def lookup_vs(
|
||||
query: str,
|
||||
knowledge_base_name: str,
|
||||
top_k: int = 3,
|
||||
embedding_model: str = EMBEDDING_MODEL,
|
||||
embedding_device: str = EMBEDDING_DEVICE,
|
||||
):
|
||||
search_index = load_vector_store(knowledge_base_name, embedding_model, embedding_device)
|
||||
docs = search_index.similarity_search(query, k=top_k)
|
||||
return docs
|
||||
|
||||
|
||||
def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
||||
top_k: int = Body(3, description="匹配向量数"),
|
||||
):
|
||||
async def knowledge_base_chat_iterator(query: str,
|
||||
knowledge_base_name: str,
|
||||
|
|
@ -30,11 +63,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
model_name=LLM_MODEL
|
||||
)
|
||||
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],
|
||||
model_kwargs={'device': EMBEDDING_DEVICE})
|
||||
search_index = FAISS.load_local(vs_path, embeddings)
|
||||
docs = search_index.similarity_search(query, k=4)
|
||||
docs = lookup_vs(query, knowledge_base_name, top_k)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
||||
|
||||
|
|
|
|||
|
|
@ -52,7 +52,6 @@ def dialogue_page(api: ApiRequest):
|
|||
|
||||
cols = st.columns(2)
|
||||
chat_list = chat_box.get_chat_names()
|
||||
print(chat_list, chat_box.cur_chat_name)
|
||||
try:
|
||||
index = chat_list.index(chat_box.cur_chat_name)
|
||||
except:
|
||||
|
|
@ -74,7 +73,7 @@ def dialogue_page(api: ApiRequest):
|
|||
on_change=on_kb_change,
|
||||
key="selected_kb",
|
||||
)
|
||||
top_k = st.slider("匹配知识条数:", 1, 20, 3, disabled=True)
|
||||
top_k = st.slider("匹配知识条数:", 1, 20, 3)
|
||||
score_threshold = st.slider("知识匹配分数阈值:", 0, 1000, 0, disabled=True)
|
||||
chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
||||
chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
||||
|
|
@ -95,7 +94,7 @@ def dialogue_page(api: ApiRequest):
|
|||
elif dialogue_mode == "知识库问答":
|
||||
chat_box.ai_say(f"正在查询知识库: `{selected_kb}` ...")
|
||||
text = ""
|
||||
for t in api.knowledge_base_chat(prompt, selected_kb):
|
||||
for t in api.knowledge_base_chat(prompt, selected_kb, top_k):
|
||||
text += t
|
||||
chat_box.update_msg(text)
|
||||
chat_box.update_msg(text, streaming=False)
|
||||
|
|
|
|||
Loading…
Reference in New Issue