update api and webui:

1. 增加search_docs接口,返回原始知识库检索文档,close #1103
2. 为FAISS检索增加score_threshold参数。milvus和PG暂不支持
This commit is contained in:
liunux4odoo 2023-08-16 13:18:58 +08:00
parent f3a1247629
commit 67b8ebef52
10 changed files with 53 additions and 17 deletions

View File

@ -14,8 +14,11 @@ from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
update_doc, download_doc, recreate_vector_store)
update_doc, download_doc, recreate_vector_store,
search_docs, DocumentWithScore)
from server.utils import BaseResponse, ListResponse
from typing import List
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -83,6 +86,12 @@ def create_app():
summary="获取知识库内的文件列表"
)(list_docs)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithScore],
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/upload_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,

View File

@ -1,26 +1,27 @@
from fastapi import Body, Request
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
VECTOR_SEARCH_TOP_K)
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
from typing import AsyncIterable, List, Optional
import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
import json
import os
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
history: List[History] = Body([],
description="历史对话",
examples=[[
@ -53,7 +54,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
model_name=LLM_MODEL
)
docs = kb.search_docs(query, top_k)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
context = "\n".join([doc.page_content for doc in docs])
chat_prompt = ChatPromptTemplate.from_messages(

View File

@ -1,13 +1,32 @@
import os
import urllib
from fastapi import File, Form, Body, Query, UploadFile
from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from server.utils import BaseResponse, ListResponse
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
from fastapi.responses import StreamingResponse, FileResponse
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from typing import List
from typing import List, Dict
from langchain.docstore.document import Document
class DocumentWithScore(Document):
score: float = None
def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]),
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值取值范围在0-1之间SCORE越小相关度越高取到1相当于不筛选建议设置在0.5左右", ge=0, le=1),
) -> List[DocumentWithScore]:
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
docs = kb.search_docs(query, top_k, score_threshold)
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
return data
async def list_docs(

View File

@ -13,7 +13,7 @@ from server.db.repository.knowledge_file_repository import (
list_docs_from_db, get_file_detail, delete_file_from_db
)
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_DEVICE, EMBEDDING_MODEL)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
@ -112,9 +112,10 @@ class KBService(ABC):
def search_docs(self,
query: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
):
embeddings = self._load_embeddings()
docs = self.do_search(query, top_k, embeddings)
docs = self.do_search(query, top_k, score_threshold, embeddings)
return docs
@abstractmethod

View File

@ -81,12 +81,13 @@ class FaissKBService(KBService):
def do_search(self,
query: str,
top_k: int,
embeddings: Embeddings,
score_threshold: float = SCORE_THRESHOLD,
embeddings: Embeddings = None,
) -> List[Document]:
search_index = load_vector_store(self.kb_name,
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD)
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
return docs
def do_add_doc(self,

View File

@ -45,7 +45,8 @@ class MilvusKBService(KBService):
def do_drop_kb(self):
self.milvus.col.drop()
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
# todo: support score threshold
self._load_milvus(embeddings=embeddings)
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)

View File

@ -43,7 +43,8 @@ class PGKBService(KBService):
'''))
connect.commit()
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
# todo: support score threshold
self._load_pg_vector(embeddings=embeddings)
return self.pg_vector.similarity_search(query, top_k)

BIN
server/static/favicon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.1 KiB

View File

@ -76,7 +76,7 @@ def dialogue_page(api: ApiRequest):
key="selected_kb",
)
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
# score_threshold = st.slider("知识匹配分数阈值:", 0, 1, 0, disabled=True)
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
elif dialogue_mode == "搜索引擎问答":
@ -111,8 +111,8 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="知识库匹配结果"),
])
text = ""
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history):
if error_msg := check_error_msg(t): # check whether error occured
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
text += d["answer"]
chat_box.update_msg(text, 0)
@ -125,7 +125,7 @@ def dialogue_page(api: ApiRequest):
])
text = ""
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
if error_msg := check_error_msg(t): # check whether error occured
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
text += d["answer"]
chat_box.update_msg(text, 0)

View File

@ -6,6 +6,7 @@ from configs.model_config import (
DEFAULT_VS_TYPE,
KB_ROOT_PATH,
LLM_MODEL,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
logger,
@ -312,6 +313,7 @@ class ApiRequest:
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: float = SCORE_THRESHOLD,
history: List[Dict] = [],
stream: bool = True,
no_remote_api: bool = None,
@ -326,6 +328,7 @@ class ApiRequest:
"query": query,
"knowledge_base_name": knowledge_base_name,
"top_k": top_k,
"score_threshold": score_threshold,
"history": history,
"stream": stream,
"local_doc_url": no_remote_api,