update api and webui:
1. 增加search_docs接口,返回原始知识库检索文档,close #1103 2. 为FAISS检索增加score_threshold参数。milvus和PG暂不支持
This commit is contained in:
parent
f3a1247629
commit
67b8ebef52
|
|
@ -14,8 +14,11 @@ from server.chat import (chat, knowledge_base_chat, openai_chat,
|
|||
search_engine_chat)
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
||||
update_doc, download_doc, recreate_vector_store)
|
||||
update_doc, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithScore)
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from typing import List
|
||||
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
|
|
@ -83,6 +86,12 @@ def create_app():
|
|||
summary="获取知识库内的文件列表"
|
||||
)(list_docs)
|
||||
|
||||
app.post("/knowledge_base/search_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=List[DocumentWithScore],
|
||||
summary="搜索知识库"
|
||||
)(search_docs)
|
||||
|
||||
app.post("/knowledge_base/upload_doc",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
|
|
|
|||
|
|
@ -1,26 +1,27 @@
|
|||
from fastapi import Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
VECTOR_SEARCH_TOP_K)
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from server.chat.utils import wrap_done
|
||||
from server.utils import BaseResponse
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
from typing import AsyncIterable, List, Optional
|
||||
import asyncio
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional
|
||||
from server.chat.utils import History
|
||||
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
|
||||
import json
|
||||
import os
|
||||
from urllib.parse import urlencode
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
|
||||
|
||||
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
|
|
@ -53,7 +54,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
model_name=LLM_MODEL
|
||||
)
|
||||
docs = kb.search_docs(query, top_k)
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
|
|
|
|||
|
|
@ -1,13 +1,32 @@
|
|||
import os
|
||||
import urllib
|
||||
from fastapi import File, Form, Body, Query, UploadFile
|
||||
from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL
|
||||
from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
import json
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from typing import List
|
||||
from typing import List, Dict
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
||||
class DocumentWithScore(Document):
|
||||
score: float = None
|
||||
|
||||
|
||||
def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
) -> List[DocumentWithScore]:
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []}
|
||||
docs = kb.search_docs(query, top_k, score_threshold)
|
||||
data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def list_docs(
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from server.db.repository.knowledge_file_repository import (
|
|||
list_docs_from_db, get_file_detail, delete_file_from_db
|
||||
)
|
||||
|
||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K,
|
||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
||||
from server.knowledge_base.utils import (
|
||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||
|
|
@ -112,9 +112,10 @@ class KBService(ABC):
|
|||
def search_docs(self,
|
||||
query: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
):
|
||||
embeddings = self._load_embeddings()
|
||||
docs = self.do_search(query, top_k, embeddings)
|
||||
docs = self.do_search(query, top_k, score_threshold, embeddings)
|
||||
return docs
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -81,12 +81,13 @@ class FaissKBService(KBService):
|
|||
def do_search(self,
|
||||
query: str,
|
||||
top_k: int,
|
||||
embeddings: Embeddings,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
embeddings: Embeddings = None,
|
||||
) -> List[Document]:
|
||||
search_index = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
|
||||
docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD)
|
||||
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self,
|
||||
|
|
|
|||
|
|
@ -45,7 +45,8 @@ class MilvusKBService(KBService):
|
|||
def do_drop_kb(self):
|
||||
self.milvus.col.drop()
|
||||
|
||||
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
|
||||
# todo: support score threshold
|
||||
self._load_milvus(embeddings=embeddings)
|
||||
return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD)
|
||||
|
||||
|
|
|
|||
|
|
@ -43,7 +43,8 @@ class PGKBService(KBService):
|
|||
'''))
|
||||
connect.commit()
|
||||
|
||||
def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]:
|
||||
def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]:
|
||||
# todo: support score threshold
|
||||
self._load_pg_vector(embeddings=embeddings)
|
||||
return self.pg_vector.similarity_search(query, top_k)
|
||||
|
||||
|
|
|
|||
Binary file not shown.
|
After Width: | Height: | Size: 7.1 KiB |
|
|
@ -76,7 +76,7 @@ def dialogue_page(api: ApiRequest):
|
|||
key="selected_kb",
|
||||
)
|
||||
kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3)
|
||||
# score_threshold = st.slider("知识匹配分数阈值:", 0, 1, 0, disabled=True)
|
||||
score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01)
|
||||
# chunk_content = st.checkbox("关联上下文", False, disabled=True)
|
||||
# chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
|
||||
elif dialogue_mode == "搜索引擎问答":
|
||||
|
|
@ -111,8 +111,8 @@ def dialogue_page(api: ApiRequest):
|
|||
Markdown("...", in_expander=True, title="知识库匹配结果"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history):
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
chat_box.update_msg(text, 0)
|
||||
|
|
@ -125,7 +125,7 @@ def dialogue_page(api: ApiRequest):
|
|||
])
|
||||
text = ""
|
||||
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
text += d["answer"]
|
||||
chat_box.update_msg(text, 0)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from configs.model_config import (
|
|||
DEFAULT_VS_TYPE,
|
||||
KB_ROOT_PATH,
|
||||
LLM_MODEL,
|
||||
SCORE_THRESHOLD,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SEARCH_ENGINE_TOP_K,
|
||||
logger,
|
||||
|
|
@ -312,6 +313,7 @@ class ApiRequest:
|
|||
query: str,
|
||||
knowledge_base_name: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
score_threshold: float = SCORE_THRESHOLD,
|
||||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
no_remote_api: bool = None,
|
||||
|
|
@ -326,6 +328,7 @@ class ApiRequest:
|
|||
"query": query,
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
"history": history,
|
||||
"stream": stream,
|
||||
"local_doc_url": no_remote_api,
|
||||
|
|
|
|||
Loading…
Reference in New Issue