diff --git a/server/api.py b/server/api.py index 458b1d7..d86fcbb 100644 --- a/server/api.py +++ b/server/api.py @@ -14,8 +14,11 @@ from server.chat import (chat, knowledge_base_chat, openai_chat, search_engine_chat) from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc, - update_doc, download_doc, recreate_vector_store) + update_doc, download_doc, recreate_vector_store, + search_docs, DocumentWithScore) from server.utils import BaseResponse, ListResponse +from typing import List + nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -83,6 +86,12 @@ def create_app(): summary="获取知识库内的文件列表" )(list_docs) + app.post("/knowledge_base/search_docs", + tags=["Knowledge Base Management"], + response_model=List[DocumentWithScore], + summary="搜索知识库" + )(search_docs) + app.post("/knowledge_base/upload_doc", tags=["Knowledge Base Management"], response_model=BaseResponse, diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 0ecabf0..84c62f0 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,26 +1,27 @@ from fastapi import Body, Request from fastapi.responses import StreamingResponse from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, - VECTOR_SEARCH_TOP_K) + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) from server.chat.utils import wrap_done from server.utils import BaseResponse from langchain.chat_models import ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler -from typing import AsyncIterable +from typing import AsyncIterable, List, Optional import asyncio from langchain.prompts.chat import ChatPromptTemplate -from typing import List, Optional from server.chat.utils import History from server.knowledge_base.kb_service.base import KBService, KBServiceFactory import json import os from urllib.parse import urlencode +from server.knowledge_base.kb_doc_api import search_docs def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), history: List[History] = Body([], description="历史对话", examples=[[ @@ -53,7 +54,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], model_name=LLM_MODEL ) - docs = kb.search_docs(query, top_k) + docs = search_docs(query, knowledge_base_name, top_k, score_threshold) context = "\n".join([doc.page_content for doc in docs]) chat_prompt = ChatPromptTemplate.from_messages( diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 3f27fb1..0bf2cb7 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -1,13 +1,32 @@ import os import urllib from fastapi import File, Form, Body, Query, UploadFile -from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL +from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile from fastapi.responses import StreamingResponse, FileResponse import json from server.knowledge_base.kb_service.base import KBServiceFactory -from typing import List +from typing import List, Dict +from langchain.docstore.document import Document + + +class DocumentWithScore(Document): + score: float = None + + +def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]), + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), + ) -> List[DocumentWithScore]: + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []} + docs = kb.search_docs(query, top_k, score_threshold) + data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] + + return data async def list_docs( diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index ec1c692..d506f63 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -13,7 +13,7 @@ from server.db.repository.knowledge_file_repository import ( list_docs_from_db, get_file_detail, delete_file_from_db ) -from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, +from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, EMBEDDING_DEVICE, EMBEDDING_MODEL) from server.knowledge_base.utils import ( get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, @@ -112,9 +112,10 @@ class KBService(ABC): def search_docs(self, query: str, top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: float = SCORE_THRESHOLD, ): embeddings = self._load_embeddings() - docs = self.do_search(query, top_k, embeddings) + docs = self.do_search(query, top_k, score_threshold, embeddings) return docs @abstractmethod diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 0ef820a..5c8376f 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -81,12 +81,13 @@ class FaissKBService(KBService): def do_search(self, query: str, top_k: int, - embeddings: Embeddings, + score_threshold: float = SCORE_THRESHOLD, + embeddings: Embeddings = None, ) -> List[Document]: search_index = load_vector_store(self.kb_name, embeddings=embeddings, tick=_VECTOR_STORE_TICKS.get(self.kb_name)) - docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD) + docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) return docs def do_add_doc(self, diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 6f1c392..f9c40c0 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -45,7 +45,8 @@ class MilvusKBService(KBService): def do_drop_kb(self): self.milvus.col.drop() - def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]: + def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: + # todo: support score threshold self._load_milvus(embeddings=embeddings) return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 82511bb..a3126ec 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -43,7 +43,8 @@ class PGKBService(KBService): ''')) connect.commit() - def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]: + def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: + # todo: support score threshold self._load_pg_vector(embeddings=embeddings) return self.pg_vector.similarity_search(query, top_k) diff --git a/server/static/favicon.png b/server/static/favicon.png new file mode 100644 index 0000000..5de8ee8 Binary files /dev/null and b/server/static/favicon.png differ diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 89e148f..a317aba 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -76,7 +76,7 @@ def dialogue_page(api: ApiRequest): key="selected_kb", ) kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3) - # score_threshold = st.slider("知识匹配分数阈值:", 0, 1, 0, disabled=True) + score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01) # chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) elif dialogue_mode == "搜索引擎问答": @@ -111,8 +111,8 @@ def dialogue_page(api: ApiRequest): Markdown("...", in_expander=True, title="知识库匹配结果"), ]) text = "" - for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history): - if error_msg := check_error_msg(t): # check whether error occured + for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history): + if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) text += d["answer"] chat_box.update_msg(text, 0) @@ -125,7 +125,7 @@ def dialogue_page(api: ApiRequest): ]) text = "" for d in api.search_engine_chat(prompt, search_engine, se_top_k): - if error_msg := check_error_msg(t): # check whether error occured + if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) text += d["answer"] chat_box.update_msg(text, 0) diff --git a/webui_pages/utils.py b/webui_pages/utils.py index b1d5c28..3e67ed7 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -6,6 +6,7 @@ from configs.model_config import ( DEFAULT_VS_TYPE, KB_ROOT_PATH, LLM_MODEL, + SCORE_THRESHOLD, VECTOR_SEARCH_TOP_K, SEARCH_ENGINE_TOP_K, logger, @@ -312,6 +313,7 @@ class ApiRequest: query: str, knowledge_base_name: str, top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: float = SCORE_THRESHOLD, history: List[Dict] = [], stream: bool = True, no_remote_api: bool = None, @@ -326,6 +328,7 @@ class ApiRequest: "query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k, + "score_threshold": score_threshold, "history": history, "stream": stream, "local_doc_url": no_remote_api,