update api and webui:

1. 增加search_docs接口,返回原始知识库检索文档,close #1103
2. 为FAISS检索增加score_threshold参数。milvus和PG暂不支持
This commit is contained in:
liunux4odoo 2023-08-16 13:18:58 +08:00
parent f3a1247629
commit 67b8ebef52
10 changed files with 53 additions and 17 deletions

View File

@ -14,8 +14,11 @@ from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat) search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc, from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store) update_doc, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse from server.utils import BaseResponse, ListResponse
from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -83,6 +86,12 @@ def create_app():
summary="获取知识库内的文件列表" summary="获取知识库内的文件列表"
)(list_docs) )(list_docs)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithScore],
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/upload_doc", app.post("/knowledge_base/upload_doc",
tags=["Knowledge Base Management"], tags=["Knowledge Base Management"],
response_model=BaseResponse, response_model=BaseResponse,

View File

@ -1,26 +1,27 @@
from fastapi import Body, Request from fastapi import Body, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
VECTOR_SEARCH_TOP_K) VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.chat.utils import wrap_done from server.chat.utils import wrap_done
from server.utils import BaseResponse from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable, List, Optional
import asyncio import asyncio
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional
from server.chat.utils import History from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
import json import json
import os import os
from urllib.parse import urlencode from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
history: List[History] = Body([], history: List[History] = Body([],
description="历史对话", description="历史对话",
examples=[[ examples=[[
@ -53,7 +54,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL model_name=LLM_MODEL
) )
docs = kb.search_docs(query, top_k) docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
chat_prompt = ChatPromptTemplate.from_messages( chat_prompt = ChatPromptTemplate.from_messages(

View File

@ -1,13 +1,32 @@
import os import os
import urllib import urllib
from fastapi import File, Form, Body, Query, UploadFile from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.utils import BaseResponse, ListResponse from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
from fastapi.responses import StreamingResponse, FileResponse from fastapi.responses import StreamingResponse, FileResponse
import json import json
from server.knowledge_base.kb_service.base import KBServiceFactory from server.knowledge_base.kb_service.base import KBServiceFactory
from typing import List from typing import List, Dict
from langchain.docstore.document import Document
class DocumentWithScore(Document):
score: float = None
def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
) -> List[DocumentWithScore]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
return data
async def list_docs( async def list_docs(

View File

@ -13,7 +13,7 @@ from server.db.repository.knowledge_file_repository import (
list_docs_from_db, get_file_detail, delete_file_from_db list_docs_from_db, get_file_detail, delete_file_from_db
) )
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_DEVICE, EMBEDDING_MODEL) EMBEDDING_DEVICE, EMBEDDING_MODEL)
from server.knowledge_base.utils import ( from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
@ -112,9 +112,10 @@ class KBService(ABC):
def search_docs(self, def search_docs(self,
query: str, query: str,
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
): ):
embeddings = self._load_embeddings() embeddings = self._load_embeddings()
docs = self.do_search(query, top_k, embeddings) docs = self.do_search(query, top_k, score_threshold, embeddings)
return docs return docs
@abstractmethod @abstractmethod

View File

@ -81,12 +81,13 @@ class FaissKBService(KBService):
def do_search(self, def do_search(self,
query: str, query: str,
top_k: int, top_k: int,
embeddings: Embeddings, score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None,
) -> List[Document]: ) -> List[Document]:
search_index = load_vector_store(self.kb_name, search_index = load_vector_store(self.kb_name,
embeddings=embeddings, embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name)) tick=_VECTOR_STORE_TICKS.get(self.kb_name))
docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD) docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
return docs return docs
def do_add_doc(self, def do_add_doc(self,

View File

@ -45,7 +45,8 @@ class MilvusKBService(KBService):
def do_drop_kb(self): def do_drop_kb(self):
self.milvus.col.drop() self.milvus.col.drop()
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]: def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
# todo: support score threshold
self._load_milvus(embeddings=embeddings) self._load_milvus(embeddings=embeddings)
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD) return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)

View File

@ -43,7 +43,8 @@ class PGKBService(KBService):
''')) '''))
connect.commit() connect.commit()
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]: def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
# todo: support score threshold
self._load_pg_vector(embeddings=embeddings) self._load_pg_vector(embeddings=embeddings)
return self.pg_vector.similarity_search(query, top_k) return self.pg_vector.similarity_search(query, top_k)

BIN
server/static/favicon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 KiB

View File

@ -76,7 +76,7 @@ def dialogue_page(api: ApiRequest):
key="selected_kb", key="selected_kb",
) )
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3) kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
# score_threshold = st.slider("知识匹配分数阈值:", 0, 1, 0, disabled=True) score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
# chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
elif dialogue_mode == "搜索引擎问答": elif dialogue_mode == "搜索引擎问答":
@ -111,8 +111,8 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="知识库匹配结果"), Markdown("...", in_expander=True, title="知识库匹配结果"),
]) ])
text = "" text = ""
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history): for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
if error_msg := check_error_msg(t): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
text += d["answer"] text += d["answer"]
chat_box.update_msg(text, 0) chat_box.update_msg(text, 0)
@ -125,7 +125,7 @@ def dialogue_page(api: ApiRequest):
]) ])
text = "" text = ""
for d in api.search_engine_chat(prompt, search_engine, se_top_k): for d in api.search_engine_chat(prompt, search_engine, se_top_k):
if error_msg := check_error_msg(t): # check whether error occured if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg) st.error(error_msg)
text += d["answer"] text += d["answer"]
chat_box.update_msg(text, 0) chat_box.update_msg(text, 0)

View File

@ -6,6 +6,7 @@ from configs.model_config import (
DEFAULT_VS_TYPE, DEFAULT_VS_TYPE,
KB_ROOT_PATH, KB_ROOT_PATH,
LLM_MODEL, LLM_MODEL,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K, VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K, SEARCH_ENGINE_TOP_K,
logger, logger,
@ -312,6 +313,7 @@ class ApiRequest:
query: str, query: str,
knowledge_base_name: str, knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [], history: List[Dict] = [],
stream: bool = True, stream: bool = True,
no_remote_api: bool = None, no_remote_api: bool = None,
@ -326,6 +328,7 @@ class ApiRequest:
"query": query, "query": query,
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
"top_k": top_k, "top_k": top_k,
"score_threshold": score_threshold,
"history": history, "history": history,
"stream": stream, "stream": stream,
"local_doc_url": no_remote_api, "local_doc_url": no_remote_api,