Langchain-Chatchat/webui_pages/utils.py

247 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 该文件包含webui通用工具可以被不同的webui使用
from typing import *
from pathlib import Path
import os
from configs.model_config import (
KB_ROOT_PATH,
LLM_MODEL,
llm_model_dict,
)
import httpx
import asyncio
from server.chat.openai_chat import OpenAiChatMsgIn
from fastapi.responses import StreamingResponse
def set_httpx_timeout(timeout=60.0):
'''
设置httpx默认timeout到60秒。
httpx默认timeout是5秒在请求LLM回答时不够用。
'''
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
KB_ROOT_PATH = Path(KB_ROOT_PATH)
set_httpx_timeout()
def get_kb_list() -> List[str]:
'''
获取知识库列表
'''
kb_list = os.listdir(KB_ROOT_PATH)
return [x for x in kb_list if (KB_ROOT_PATH / x).is_dir()]
def get_kb_files(kb: str) -> List[str]:
'''
获取某个知识库下包含的所有文件(只包括根目录一级)
'''
kb = KB_ROOT_PATH / kb / "content"
if kb.is_dir():
kb_files = os.listdir(kb)
return kb_files
else:
return []
def run_async(cor):
'''
在同步环境中运行异步代码.
'''
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
return loop.run_until_complete(cor)
def iter_over_async(ait, loop):
'''
将异步生成器封装成同步生成器.
'''
ait = ait.__aiter__()
async def get_next():
try:
obj = await ait.__anext__()
return False, obj
except StopAsyncIteration:
return True, None
while True:
done, obj = loop.run_until_complete(get_next())
if done:
break
yield obj
class ApiRequest:
'''
api.py调用的封装,主要实现:
1. 简化api调用方式
2. 实现无api调用(直接运行server.chat.*中的视图函数获取结果),无需启动api.py
'''
def __init__(
self,
base_url: str = "http://127.0.0.1:7861",
timeout: float = 60.0,
):
self.base_url = base_url
self.timeout = timeout
def _parse_url(self, url: str) -> str:
if (not url.startswith("http")
and self.base_url
):
part1 = self.base_url.strip(" /")
part2 = url.strip(" /")
return f"{part1}/{part2}"
else:
return url
def get(
self,
url: str,
params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3,
**kwargs: Any,
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
while retry > 0:
try:
return httpx.get(url, params, **kwargs)
except:
retry -= 1
async def aget(
self,
url: str,
params: Union[Dict, List[Tuple], bytes] = None,
retry: int = 3,
**kwargs: Any,
) -> Union[httpx.Response, None]:
rl = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
async with httpx.AsyncClient() as client:
while retry > 0:
try:
return await client.get(url, params, **kwargs)
except:
retry -= 1
def post(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
stream: bool = False,
**kwargs: Any
) -> Union[httpx.Response, None]:
url = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
while retry > 0:
try:
# return requests.post(url, data=data, json=json, stream=stream, **kwargs)
if stream:
return httpx.stream("POST", url, data=data, json=json, **kwargs)
else:
return httpx.post(url, data=data, json=json, **kwargs)
except:
retry -= 1
async def apost(
self,
url: str,
data: Dict = None,
json: Dict = None,
retry: int = 3,
**kwargs: Any
) -> Union[httpx.Response, None]:
rl = self._parse_url(url)
kwargs.setdefault("timeout", self.timeout)
async with httpx.AsyncClient() as client:
while retry > 0:
try:
return await client.post(url, data=data, json=json, **kwargs)
except:
retry -= 1
def _stream2generator(self, response: StreamingResponse):
'''
将api.py中视图函数返回的StreamingResponse转化为同步生成器
'''
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
return iter_over_async(response.body_iterator, loop)
def chat_fastchat(
self,
messages: List[Dict],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = 0.7,
max_tokens: int = 1024, # todo:根据message内容自动计算max_tokens
no_remote_api=False, # all api view function directly
**kwargs: Any,
):
'''
对应api.py/chat/fastchat接口
'''
msg = OpenAiChatMsgIn(**{
"messages": messages,
"stream": stream,
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
**kwargs,
})
if no_remote_api:
from server.chat.openai_chat import openai_chat
response = openai_chat(msg)
return self._stream2generator(response)
else:
data = msg.dict(exclude_unset=True, exclude_none=True)
response = self.post(
"/chat/fastchat",
json=data,
stream=stream,
)
return response
def chat_chat(
self,
query: str,
no_remote_api: bool = False,
):
'''
对应api.py/chat/chat接口
'''
if no_remote_api:
from server.chat.chat import chat
response = chat(query)
return self._stream2generator(response)
else:
response = self.post("/chat/chat", json=f"{query}", stream=True)
return response
if __name__ == "__main__":
api = ApiRequest()
# print(api.chat_fastchat(
# messages=[{"role": "user", "content": "hello"}]
# ))
with api.chat_chat("你好") as r:
for t in r.iter_text(None):
print(t)
r = api.chat_chat("你好", no_remote_api=True)
for t in r:
print(t)