313 lines
11 KiB
Python
313 lines
11 KiB
Python
# 该文件封装了对api.py的请求,可以被不同的webui使用
|
||
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用
|
||
|
||
from typing import *
|
||
from pathlib import Path
|
||
# 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
|
||
# from configs import (
|
||
# EMBEDDING_MODEL,
|
||
# DEFAULT_VS_TYPE,
|
||
# LLM_MODELS,
|
||
# TEMPERATURE,
|
||
# SCORE_THRESHOLD,
|
||
# CHUNK_SIZE,
|
||
# OVERLAP_SIZE,
|
||
# ZH_TITLE_ENHANCE,
|
||
# FIRST_VECTOR_SEARCH_TOP_K,
|
||
# VECTOR_SEARCH_TOP_K,
|
||
# SEARCH_ENGINE_TOP_K,
|
||
# HTTPX_DEFAULT_TIMEOUT,
|
||
# logger, log_verbose,
|
||
# )
|
||
import httpx
|
||
import contextlib
|
||
import json
|
||
import os
|
||
from io import BytesIO
|
||
# from server.utils import set_httpx_config, api_address, get_httpx_client
|
||
|
||
from pprint import pprint
|
||
|
||
|
||
# from langchain_core._api import deprecated
|
||
|
||
# set_httpx_config()
|
||
|
||
def get_httpx_client(
|
||
use_async: bool = False,
|
||
proxies: Union[str, Dict] = None,
|
||
timeout: float = 300,
|
||
**kwargs,
|
||
) -> Union[httpx.Client, httpx.AsyncClient]:
|
||
|
||
if use_async:
|
||
return httpx.AsyncClient(**kwargs)
|
||
else:
|
||
return httpx.Client(**kwargs)
|
||
|
||
|
||
class ApiRequest:
|
||
'''
|
||
api.py调用的封装(同步模式),简化api调用方式
|
||
'''
|
||
|
||
def __init__(
|
||
self,
|
||
base_url: str = 'http://127.0.0.1:7861',
|
||
# base_url: str = 'http://192.168.0.21:17861',
|
||
timeout: float = 300,
|
||
):
|
||
self.base_url = base_url
|
||
self.timeout = timeout
|
||
self._use_async = False
|
||
self._client = None
|
||
|
||
@property
|
||
def client(self):
|
||
if self._client is None or self._client.is_closed:
|
||
self._client = get_httpx_client(base_url=self.base_url,
|
||
use_async=self._use_async,
|
||
timeout=self.timeout)
|
||
return self._client
|
||
|
||
def get(
|
||
self,
|
||
url: str,
|
||
params: Union[Dict, List[Tuple], bytes] = None,
|
||
retry: int = 3,
|
||
stream: bool = False,
|
||
**kwargs: Any,
|
||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||
while retry > 0:
|
||
try:
|
||
if stream:
|
||
return self.client.stream("GET", url, params=params, **kwargs)
|
||
else:
|
||
return self.client.get(url, params=params, **kwargs)
|
||
except Exception as e:
|
||
msg = f"error when get {url}: {e}"
|
||
# logger.error(f'{e.__class__.__name__}: {msg}',
|
||
# exc_info=e if log_verbose else None)
|
||
retry -= 1
|
||
|
||
def post(
|
||
self,
|
||
url: str,
|
||
data: Dict = None,
|
||
json: Dict = None,
|
||
retry: int = 3,
|
||
stream: bool = False,
|
||
**kwargs: Any
|
||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||
while retry > 0:
|
||
try:
|
||
# print(kwargs)
|
||
if stream:
|
||
return self.client.stream("POST", url, data=data, json=json, **kwargs)
|
||
else:
|
||
return self.client.post(url, data=data, json=json, **kwargs)
|
||
except Exception as e:
|
||
msg = f"error when post {url}: {e}"
|
||
# logger.error(f'{e.__class__.__name__}: {msg}',
|
||
# exc_info=e if log_verbose else None)
|
||
retry -= 1
|
||
|
||
def _httpx_stream2generator(
|
||
self,
|
||
response: contextlib._GeneratorContextManager,
|
||
as_json: bool = False,
|
||
):
|
||
'''
|
||
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
||
'''
|
||
|
||
async def ret_async(response, as_json):
|
||
try:
|
||
async with response as r:
|
||
async for chunk in r.aiter_text(None):
|
||
if not chunk: # fastchat api yield empty bytes on start and end
|
||
continue
|
||
if as_json:
|
||
try:
|
||
if chunk.startswith("data: "):
|
||
data = json.loads(chunk[6:-2])
|
||
elif chunk.startswith(":"): # skip sse comment line
|
||
continue
|
||
else:
|
||
data = json.loads(chunk)
|
||
yield data
|
||
except Exception as e:
|
||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||
# logger.error(f'{e.__class__.__name__}: {msg}',
|
||
# exc_info=e if log_verbose else None)
|
||
else:
|
||
# print(chunk, end="", flush=True)
|
||
yield chunk
|
||
except httpx.ConnectError as e:
|
||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
||
# logger.error(msg)
|
||
yield {"code": 500, "msg": msg}
|
||
except httpx.ReadTimeout as e:
|
||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})"
|
||
# logger.error(msg)
|
||
yield {"code": 500, "msg": msg}
|
||
except Exception as e:
|
||
msg = f"API通信遇到错误:{e}"
|
||
# logger.error(f'{e.__class__.__name__}: {msg}',
|
||
# exc_info=e if log_verbose else None)
|
||
yield {"code": 500, "msg": msg}
|
||
|
||
def ret_sync(response, as_json):
|
||
try:
|
||
with response as r:
|
||
for chunk in r.iter_text(None):
|
||
if not chunk: # fastchat api yield empty bytes on start and end
|
||
continue
|
||
if as_json:
|
||
try:
|
||
if chunk.startswith("data: "):
|
||
data = json.loads(chunk[6:-2])
|
||
elif chunk.startswith(":"): # skip sse comment line
|
||
continue
|
||
else:
|
||
data = json.loads(chunk)
|
||
yield data
|
||
except Exception as e:
|
||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||
# # logger.error(f'{e.__class__.__name__}: {msg}',
|
||
# exc_info=e if log_verbose else None)
|
||
else:
|
||
# print(chunk, end="", flush=True)
|
||
yield chunk
|
||
except httpx.ConnectError as e:
|
||
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
||
# logger.error(msg)
|
||
yield {"code": 500, "msg": msg}
|
||
except httpx.ReadTimeout as e:
|
||
msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})"
|
||
# logger.error(msg)
|
||
yield {"code": 500, "msg": msg}
|
||
except Exception as e:
|
||
msg = f"API通信遇到错误:{e}"
|
||
# logger.error(f'{e.__class__.__name__}: {msg}',
|
||
# exc_info=e if log_verbose else None)
|
||
yield {"code": 500, "msg": msg}
|
||
|
||
if self._use_async:
|
||
return ret_async(response, as_json)
|
||
else:
|
||
return ret_sync(response, as_json)
|
||
|
||
def _get_response_value(
|
||
self,
|
||
response: httpx.Response,
|
||
as_json: bool = False,
|
||
value_func: Callable = None,
|
||
):
|
||
'''
|
||
转换同步或异步请求返回的响应
|
||
`as_json`: 返回json
|
||
`value_func`: 用户可以自定义返回值,该函数接受response或json
|
||
'''
|
||
|
||
def to_json(r):
|
||
try:
|
||
return r.json()
|
||
except Exception as e:
|
||
msg = "API未能返回正确的JSON。" + str(e)
|
||
# if log_verbose:
|
||
# # logger.error(f'{e.__class__.__name__}: {msg}',
|
||
# # exc_info=e if log_verbose else None)
|
||
return {"code": 500, "msg": msg, "data": None}
|
||
|
||
if value_func is None:
|
||
value_func = (lambda r: r)
|
||
|
||
async def ret_async(response):
|
||
if as_json:
|
||
return value_func(to_json(await response))
|
||
else:
|
||
return value_func(await response)
|
||
|
||
if self._use_async:
|
||
return ret_async(response)
|
||
else:
|
||
if as_json:
|
||
return value_func(to_json(response))
|
||
else:
|
||
return value_func(response)
|
||
|
||
|
||
def knowledge_base_chat(
|
||
self,
|
||
query: str,
|
||
knowledge_base_name: str,
|
||
top_k: int = 1,
|
||
score_threshold: float = 2,
|
||
history: List[Dict] = [],
|
||
stream: bool = True,
|
||
model: str = "chatglm3-6b",
|
||
temperature: float = 0.01,
|
||
max_tokens: int = None,
|
||
prompt_name: str = "default",
|
||
):
|
||
'''
|
||
对应api.py/chat/knowledge_base_chat接口
|
||
'''
|
||
data = {
|
||
"query": query,
|
||
"knowledge_base_name": knowledge_base_name,
|
||
"top_k": top_k,
|
||
"score_threshold": score_threshold,
|
||
"history": history,
|
||
"stream": stream,
|
||
"model_name": model,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
"prompt_name": prompt_name,
|
||
}
|
||
|
||
# print(f"received input message:")
|
||
# pprint(data)
|
||
|
||
response = self.post(
|
||
"/chat/knowledge_base_chat",
|
||
json=data,
|
||
stream=True,
|
||
)
|
||
return self._httpx_stream2generator(response, as_json=True)
|
||
|
||
|
||
# class AsyncApiRequest(ApiRequest):
|
||
# def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):
|
||
# super().__init__(base_url, timeout)
|
||
# self._use_async = True
|
||
|
||
|
||
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
|
||
'''
|
||
return error message if error occured when requests API
|
||
'''
|
||
if isinstance(data, dict):
|
||
if key in data:
|
||
return data[key]
|
||
if "code" in data and data["code"] != 200:
|
||
return data["msg"]
|
||
return ""
|
||
|
||
|
||
def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
|
||
'''
|
||
return error message if error occured when requests API
|
||
'''
|
||
if (isinstance(data, dict)
|
||
and key in data
|
||
and "code" in data
|
||
and data["code"] == 200):
|
||
return data[key]
|
||
return ""
|
||
|
||
|
||
if __name__ == "__main__":
|
||
api = ApiRequest()
|