修改knowledge_base_chat实现:利用缓存避免每次请求重新加载向量库;增加top_k参数。
This commit is contained in:
parent
e1698ce12e
commit
df61f31c8e
|
|
@ -12,10 +12,43 @@ from langchain.prompts import PromptTemplate
|
||||||
from langchain.vectorstores import FAISS
|
from langchain.vectorstores import FAISS
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from server.knowledge_base.utils import get_vs_path
|
from server.knowledge_base.utils import get_vs_path
|
||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(1)
|
||||||
|
def load_embeddings(model: str, device: str):
|
||||||
|
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||||
|
model_kwargs={'device': device})
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(1)
|
||||||
|
def load_vector_store(
|
||||||
|
knowledge_base_name: str,
|
||||||
|
embedding_model: str,
|
||||||
|
embedding_device: str,
|
||||||
|
):
|
||||||
|
embeddings = load_embeddings(embedding_model, embedding_device)
|
||||||
|
vs_path = get_vs_path(knowledge_base_name)
|
||||||
|
search_index = FAISS.load_local(vs_path, embeddings)
|
||||||
|
return search_index
|
||||||
|
|
||||||
|
|
||||||
|
def lookup_vs(
|
||||||
|
query: str,
|
||||||
|
knowledge_base_name: str,
|
||||||
|
top_k: int = 3,
|
||||||
|
embedding_model: str = EMBEDDING_MODEL,
|
||||||
|
embedding_device: str = EMBEDDING_DEVICE,
|
||||||
|
):
|
||||||
|
search_index = load_vector_store(knowledge_base_name, embedding_model, embedding_device)
|
||||||
|
docs = search_index.similarity_search(query, k=top_k)
|
||||||
|
return docs
|
||||||
|
|
||||||
|
|
||||||
def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"),
|
def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
||||||
|
top_k: int = Body(3, description="匹配向量数"),
|
||||||
):
|
):
|
||||||
async def knowledge_base_chat_iterator(query: str,
|
async def knowledge_base_chat_iterator(query: str,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
|
|
@ -30,11 +63,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
model_name=LLM_MODEL
|
model_name=LLM_MODEL
|
||||||
)
|
)
|
||||||
|
|
||||||
vs_path = get_vs_path(knowledge_base_name)
|
docs = lookup_vs(query, knowledge_base_name, top_k)
|
||||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],
|
|
||||||
model_kwargs={'device': EMBEDDING_DEVICE})
|
|
||||||
search_index = FAISS.load_local(vs_path, embeddings)
|
|
||||||
docs = search_index.similarity_search(query, k=4)
|
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,6 @@ def dialogue_page(api: ApiRequest):
|
||||||
|
|
||||||
cols = st.columns(2)
|
cols = st.columns(2)
|
||||||
chat_list = chat_box.get_chat_names()
|
chat_list = chat_box.get_chat_names()
|
||||||
print(chat_list, chat_box.cur_chat_name)
|
|
||||||
try:
|
try:
|
||||||
index = chat_list.index(chat_box.cur_chat_name)
|
index = chat_list.index(chat_box.cur_chat_name)
|
||||||
except:
|
except:
|
||||||
|
|
@ -74,7 +73,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
on_change=on_kb_change,
|
on_change=on_kb_change,
|
||||||
key="selected_kb",
|
key="selected_kb",
|
||||||
)
|
)
|
||||||
top_k = st.slider("匹配知识条数:", 1, 20, 3, disabled=True)
|
top_k = st.slider("匹配知识条数:", 1, 20, 3)
|
||||||
score_threshold = st.slider("知识匹配分数阈值:", 0, 1000, 0, disabled=True)
|
score_threshold = st.slider("知识匹配分数阈值:", 0, 1000, 0, disabled=True)
|
||||||
chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
||||||
chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
||||||
|
|
@ -95,7 +94,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
elif dialogue_mode == "知识库问答":
|
elif dialogue_mode == "知识库问答":
|
||||||
chat_box.ai_say(f"正在查询知识库: `{selected_kb}` ...")
|
chat_box.ai_say(f"正在查询知识库: `{selected_kb}` ...")
|
||||||
text = ""
|
text = ""
|
||||||
for t in api.knowledge_base_chat(prompt, selected_kb):
|
for t in api.knowledge_base_chat(prompt, selected_kb, top_k):
|
||||||
text += t
|
text += t
|
||||||
chat_box.update_msg(text)
|
chat_box.update_msg(text)
|
||||||
chat_box.update_msg(text, streaming=False)
|
chat_box.update_msg(text, streaming=False)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue