增加webui_utils.py,包括制作webui通用的工具,方便以后开发其他webui
This commit is contained in:
parent
9f4567865c
commit
a1a7484ef4
|
|
@ -0,0 +1,247 @@
|
|||
# 该文件包含webui通用工具,可以被不同的webui使用
|
||||
|
||||
from typing import *
|
||||
from pathlib import Path
|
||||
import os
|
||||
from configs.model_config import (
|
||||
KB_ROOT_PATH,
|
||||
LLM_MODEL,
|
||||
llm_model_dict,
|
||||
)
|
||||
import httpx
|
||||
import asyncio
|
||||
from server.chat.openai_chat import OpenAiChatMsgIn
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
'''
|
||||
设置httpx默认timeout到60秒。
|
||||
httpx默认timeout是5秒,在请求LLM回答时不够用。
|
||||
'''
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
||||
|
||||
|
||||
KB_ROOT_PATH = Path(KB_ROOT_PATH)
|
||||
set_httpx_timeout()
|
||||
|
||||
|
||||
def get_kb_list() -> List[str]:
|
||||
'''
|
||||
获取知识库列表
|
||||
'''
|
||||
kb_list = os.listdir(KB_ROOT_PATH)
|
||||
return [x for x in kb_list if (KB_ROOT_PATH / x).is_dir()]
|
||||
|
||||
|
||||
def get_kb_files(kb: str) -> List[str]:
|
||||
'''
|
||||
获取某个知识库下包含的所有文件(只包括根目录一级)
|
||||
'''
|
||||
kb = KB_ROOT_PATH / kb / "content"
|
||||
if kb.is_dir():
|
||||
kb_files = os.listdir(kb)
|
||||
return kb_files
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def run_async(cor):
|
||||
'''
|
||||
在同步环境中运行异步代码.
|
||||
'''
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
return loop.run_until_complete(cor)
|
||||
|
||||
|
||||
def iter_over_async(ait, loop):
|
||||
'''
|
||||
将异步生成器封装成同步生成器.
|
||||
'''
|
||||
ait = ait.__aiter__()
|
||||
async def get_next():
|
||||
try:
|
||||
obj = await ait.__anext__()
|
||||
return False, obj
|
||||
except StopAsyncIteration:
|
||||
return True, None
|
||||
while True:
|
||||
done, obj = loop.run_until_complete(get_next())
|
||||
if done:
|
||||
break
|
||||
yield obj
|
||||
|
||||
|
||||
class ApiRequest:
|
||||
'''
|
||||
api.py调用的封装,主要实现:
|
||||
1. 简化api调用方式
|
||||
2. 实现无api调用(直接运行server.chat.*中的视图函数获取结果),无需启动api.py
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str = "http://127.0.0.1:7861",
|
||||
timeout: float = 60.0,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
|
||||
def _parse_url(self, url: str) -> str:
|
||||
if (not url.startswith("http")
|
||||
and self.base_url
|
||||
):
|
||||
part1 = self.base_url.strip(" /")
|
||||
part2 = url.strip(" /")
|
||||
return f"{part1}/{part2}"
|
||||
else:
|
||||
return url
|
||||
|
||||
def get(
|
||||
self,
|
||||
url: str,
|
||||
params: Union[Dict, List[Tuple], bytes] = None,
|
||||
retry: int = 3,
|
||||
**kwargs: Any,
|
||||
) -> Union[httpx.Response, None]:
|
||||
url = self._parse_url(url)
|
||||
kwargs.setdefault("timeout", self.timeout)
|
||||
while retry > 0:
|
||||
try:
|
||||
return httpx.get(url, params, **kwargs)
|
||||
except:
|
||||
retry -= 1
|
||||
|
||||
async def aget(
|
||||
self,
|
||||
url: str,
|
||||
params: Union[Dict, List[Tuple], bytes] = None,
|
||||
retry: int = 3,
|
||||
**kwargs: Any,
|
||||
) -> Union[httpx.Response, None]:
|
||||
rl = self._parse_url(url)
|
||||
kwargs.setdefault("timeout", self.timeout)
|
||||
async with httpx.AsyncClient() as client:
|
||||
while retry > 0:
|
||||
try:
|
||||
return await client.get(url, params, **kwargs)
|
||||
except:
|
||||
retry -= 1
|
||||
|
||||
def post(
|
||||
self,
|
||||
url: str,
|
||||
data: Dict = None,
|
||||
json: Dict = None,
|
||||
retry: int = 3,
|
||||
stream: bool = False,
|
||||
**kwargs: Any
|
||||
) -> Union[httpx.Response, None]:
|
||||
url = self._parse_url(url)
|
||||
kwargs.setdefault("timeout", self.timeout)
|
||||
while retry > 0:
|
||||
try:
|
||||
# return requests.post(url, data=data, json=json, stream=stream, **kwargs)
|
||||
if stream:
|
||||
return httpx.stream("POST", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
return httpx.post(url, data=data, json=json, **kwargs)
|
||||
except:
|
||||
retry -= 1
|
||||
|
||||
async def apost(
|
||||
self,
|
||||
url: str,
|
||||
data: Dict = None,
|
||||
json: Dict = None,
|
||||
retry: int = 3,
|
||||
**kwargs: Any
|
||||
) -> Union[httpx.Response, None]:
|
||||
rl = self._parse_url(url)
|
||||
kwargs.setdefault("timeout", self.timeout)
|
||||
async with httpx.AsyncClient() as client:
|
||||
while retry > 0:
|
||||
try:
|
||||
return await client.post(url, data=data, json=json, **kwargs)
|
||||
except:
|
||||
retry -= 1
|
||||
|
||||
def _stream2generator(self, response: StreamingResponse):
|
||||
'''
|
||||
将api.py中视图函数返回的StreamingResponse转化为同步生成器
|
||||
'''
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
return iter_over_async(response.body_iterator, loop)
|
||||
|
||||
def chat_fastchat(
|
||||
self,
|
||||
messages: List[Dict],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 1024, # todo:根据message内容自动计算max_tokens
|
||||
no_remote_api=False, # all api view function directly
|
||||
**kwargs: Any,
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/fastchat接口
|
||||
'''
|
||||
msg = OpenAiChatMsgIn(**{
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
**kwargs,
|
||||
})
|
||||
|
||||
if no_remote_api:
|
||||
from server.chat.openai_chat import openai_chat
|
||||
response = openai_chat(msg)
|
||||
return self._stream2generator(response)
|
||||
else:
|
||||
data = msg.dict(exclude_unset=True, exclude_none=True)
|
||||
response = self.post(
|
||||
"/chat/fastchat",
|
||||
json=data,
|
||||
stream=stream,
|
||||
)
|
||||
return response
|
||||
|
||||
def chat_chat(
|
||||
self,
|
||||
query: str,
|
||||
no_remote_api: bool = False,
|
||||
):
|
||||
'''
|
||||
对应api.py/chat/chat接口
|
||||
'''
|
||||
if no_remote_api:
|
||||
from server.chat.chat import chat
|
||||
response = chat(query)
|
||||
return self._stream2generator(response)
|
||||
else:
|
||||
response = self.post("/chat/chat", json=f"{query}", stream=True)
|
||||
return response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
api = ApiRequest()
|
||||
# print(api.chat_fastchat(
|
||||
# messages=[{"role": "user", "content": "hello"}]
|
||||
# ))
|
||||
|
||||
with api.chat_chat("你好") as r:
|
||||
for t in r.iter_text(None):
|
||||
print(t)
|
||||
|
||||
r = api.chat_chat("你好", no_remote_api=True)
|
||||
for t in r:
|
||||
print(t)
|
||||
Loading…
Reference in New Issue