From a1a7484ef4d61e398fc54530ac4ba3eccf4bd35b Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Tue, 1 Aug 2023 14:15:42 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0webui=5Futils.py=EF=BC=8C?= =?UTF-8?q?=E5=8C=85=E6=8B=AC=E5=88=B6=E4=BD=9Cwebui=E9=80=9A=E7=94=A8?= =?UTF-8?q?=E7=9A=84=E5=B7=A5=E5=85=B7=EF=BC=8C=E6=96=B9=E4=BE=BF=E4=BB=A5?= =?UTF-8?q?=E5=90=8E=E5=BC=80=E5=8F=91=E5=85=B6=E4=BB=96webui?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- webui_utils.py | 247 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 247 insertions(+) create mode 100644 webui_utils.py diff --git a/webui_utils.py b/webui_utils.py new file mode 100644 index 0000000..12dab3d --- /dev/null +++ b/webui_utils.py @@ -0,0 +1,247 @@ +# 该文件包含webui通用工具,可以被不同的webui使用 + +from typing import * +from pathlib import Path +import os +from configs.model_config import ( + KB_ROOT_PATH, + LLM_MODEL, + llm_model_dict, +) +import httpx +import asyncio +from server.chat.openai_chat import OpenAiChatMsgIn +from fastapi.responses import StreamingResponse + + +def set_httpx_timeout(timeout=60.0): + ''' + 设置httpx默认timeout到60秒。 + httpx默认timeout是5秒,在请求LLM回答时不够用。 + ''' + httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout + httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout + httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout + + +KB_ROOT_PATH = Path(KB_ROOT_PATH) +set_httpx_timeout() + + +def get_kb_list() -> List[str]: + ''' + 获取知识库列表 + ''' + kb_list = os.listdir(KB_ROOT_PATH) + return [x for x in kb_list if (KB_ROOT_PATH / x).is_dir()] + + +def get_kb_files(kb: str) -> List[str]: + ''' + 获取某个知识库下包含的所有文件(只包括根目录一级) + ''' + kb = KB_ROOT_PATH / kb / "content" + if kb.is_dir(): + kb_files = os.listdir(kb) + return kb_files + else: + return [] + + +def run_async(cor): + ''' + 在同步环境中运行异步代码. + ''' + try: + loop = asyncio.get_event_loop() + except: + loop = asyncio.new_event_loop() + return loop.run_until_complete(cor) + + +def iter_over_async(ait, loop): + ''' + 将异步生成器封装成同步生成器. + ''' + ait = ait.__aiter__() + async def get_next(): + try: + obj = await ait.__anext__() + return False, obj + except StopAsyncIteration: + return True, None + while True: + done, obj = loop.run_until_complete(get_next()) + if done: + break + yield obj + + +class ApiRequest: + ''' + api.py调用的封装,主要实现: + 1. 简化api调用方式 + 2. 实现无api调用(直接运行server.chat.*中的视图函数获取结果),无需启动api.py + ''' + def __init__( + self, + base_url: str = "http://127.0.0.1:7861", + timeout: float = 60.0, + ): + self.base_url = base_url + self.timeout = timeout + + def _parse_url(self, url: str) -> str: + if (not url.startswith("http") + and self.base_url + ): + part1 = self.base_url.strip(" /") + part2 = url.strip(" /") + return f"{part1}/{part2}" + else: + return url + + def get( + self, + url: str, + params: Union[Dict, List[Tuple], bytes] = None, + retry: int = 3, + **kwargs: Any, + ) -> Union[httpx.Response, None]: + url = self._parse_url(url) + kwargs.setdefault("timeout", self.timeout) + while retry > 0: + try: + return httpx.get(url, params, **kwargs) + except: + retry -= 1 + + async def aget( + self, + url: str, + params: Union[Dict, List[Tuple], bytes] = None, + retry: int = 3, + **kwargs: Any, + ) -> Union[httpx.Response, None]: + rl = self._parse_url(url) + kwargs.setdefault("timeout", self.timeout) + async with httpx.AsyncClient() as client: + while retry > 0: + try: + return await client.get(url, params, **kwargs) + except: + retry -= 1 + + def post( + self, + url: str, + data: Dict = None, + json: Dict = None, + retry: int = 3, + stream: bool = False, + **kwargs: Any + ) -> Union[httpx.Response, None]: + url = self._parse_url(url) + kwargs.setdefault("timeout", self.timeout) + while retry > 0: + try: + # return requests.post(url, data=data, json=json, stream=stream, **kwargs) + if stream: + return httpx.stream("POST", url, data=data, json=json, **kwargs) + else: + return httpx.post(url, data=data, json=json, **kwargs) + except: + retry -= 1 + + async def apost( + self, + url: str, + data: Dict = None, + json: Dict = None, + retry: int = 3, + **kwargs: Any + ) -> Union[httpx.Response, None]: + rl = self._parse_url(url) + kwargs.setdefault("timeout", self.timeout) + async with httpx.AsyncClient() as client: + while retry > 0: + try: + return await client.post(url, data=data, json=json, **kwargs) + except: + retry -= 1 + + def _stream2generator(self, response: StreamingResponse): + ''' + 将api.py中视图函数返回的StreamingResponse转化为同步生成器 + ''' + try: + loop = asyncio.get_event_loop() + except: + loop = asyncio.new_event_loop() + return iter_over_async(response.body_iterator, loop) + + def chat_fastchat( + self, + messages: List[Dict], + stream: bool = True, + model: str = LLM_MODEL, + temperature: float = 0.7, + max_tokens: int = 1024, # todo:根据message内容自动计算max_tokens + no_remote_api=False, # all api view function directly + **kwargs: Any, + ): + ''' + 对应api.py/chat/fastchat接口 + ''' + msg = OpenAiChatMsgIn(**{ + "messages": messages, + "stream": stream, + "model": model, + "temperature": temperature, + "max_tokens": max_tokens, + **kwargs, + }) + + if no_remote_api: + from server.chat.openai_chat import openai_chat + response = openai_chat(msg) + return self._stream2generator(response) + else: + data = msg.dict(exclude_unset=True, exclude_none=True) + response = self.post( + "/chat/fastchat", + json=data, + stream=stream, + ) + return response + + def chat_chat( + self, + query: str, + no_remote_api: bool = False, + ): + ''' + 对应api.py/chat/chat接口 + ''' + if no_remote_api: + from server.chat.chat import chat + response = chat(query) + return self._stream2generator(response) + else: + response = self.post("/chat/chat", json=f"{query}", stream=True) + return response + + +if __name__ == "__main__": + api = ApiRequest() + # print(api.chat_fastchat( + # messages=[{"role": "user", "content": "hello"}] + # )) + + with api.chat_chat("你好") as r: + for t in r.iter_text(None): + print(t) + + r = api.chat_chat("你好", no_remote_api=True) + for t in r: + print(t) \ No newline at end of file