diff --git a/webui_pages/utils.py b/webui_pages/utils.py index fc1253a..1f7ab5d 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -8,12 +8,15 @@ from configs.model_config import ( KB_ROOT_PATH, LLM_MODEL, llm_model_dict, + VECTOR_SEARCH_TOP_K, + SEARCH_ENGINE_TOP_K, ) import httpx import asyncio from server.chat.openai_chat import OpenAiChatMsgIn from fastapi.responses import StreamingResponse import contextlib +import json def set_httpx_timeout(timeout=60.0): @@ -30,26 +33,6 @@ KB_ROOT_PATH = Path(KB_ROOT_PATH) set_httpx_timeout() -def get_kb_list() -> List[str]: - ''' - 获取知识库列表 - ''' - kb_list = os.listdir(KB_ROOT_PATH) - return [x for x in kb_list if (KB_ROOT_PATH / x).is_dir()] - - -def get_kb_files(kb: str) -> List[str]: - ''' - 获取某个知识库下包含的所有文件(只包括根目录一级) - ''' - kb = KB_ROOT_PATH / kb / "content" - if kb.is_dir(): - kb_files = os.listdir(kb) - return kb_files - else: - return [] - - def run_async(cor): ''' 在同步环境中运行异步代码. @@ -174,7 +157,7 @@ class ApiRequest: except: retry -= 1 - def _fastapi_stream2generator(self, response: StreamingResponse): + def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False): ''' 将api.py中视图函数返回的StreamingResponse转化为同步生成器 ''' @@ -182,15 +165,27 @@ class ApiRequest: loop = asyncio.get_event_loop() except: loop = asyncio.new_event_loop() - return iter_over_async(response.body_iterator, loop) + + for chunk in iter_over_async(response.body_iterator, loop): + if as_json and chunk: + yield json.loads(chunk) + else: + yield chunk - def _httpx_stream2generator(self,response: contextlib._GeneratorContextManager): + def _httpx_stream2generator( + self, + response: contextlib._GeneratorContextManager, + as_json: bool = False, + ): ''' 将httpx.stream返回的GeneratorContextManager转化为普通生成器 ''' with response as r: for chunk in r.iter_text(None): - yield chunk + if as_json and chunk: + yield json.loads(chunk) + else: + yield chunk # 对话相关操作 @@ -254,7 +249,7 @@ class ApiRequest: self, query: str, knowledge_base_name: str, - top_k: int = 3, + top_k: int = VECTOR_SEARCH_TOP_K, no_remote_api: bool = None, ): ''' @@ -266,20 +261,20 @@ class ApiRequest: if no_remote_api: from server.chat.knowledge_base_chat import knowledge_base_chat response = knowledge_base_chat(query, knowledge_base_name, top_k) - return self._fastapi_stream2generator(response) + return self._fastapi_stream2generator(response, as_json=True) else: response = self.post( "/chat/knowledge_base_chat", json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k}, stream=True, ) - return self._httpx_stream2generator(response) + return self._httpx_stream2generator(response, as_json=True) def search_engine_chat( self, query: str, search_engine_name: str, - top_k: int, + top_k: int = SEARCH_ENGINE_TOP_K, no_remote_api: bool = None, ): ''' @@ -291,14 +286,14 @@ class ApiRequest: if no_remote_api: from server.chat.search_engine_chat import search_engine_chat response = search_engine_chat(query, search_engine_name, top_k) - return self._fastapi_stream2generator(response) + return self._fastapi_stream2generator(response, as_json=True) else: response = self.post( "/chat/search_engine_chat", json={"query": query, "search_engine_name": search_engine_name, "top_k": top_k}, stream=True, ) - return self._httpx_stream2generator(response) + return self._httpx_stream2generator(response, as_json=True) # 知识库相关操作