313 lines
11 KiB
Python
313 lines
11 KiB
Python
|
|
# 该文件封装了对api.py的请求,可以被不同的webui使用
|
|||
|
|
# 通过ApiRequest和AsyncApiRequest支持同步/异步调用
|
|||
|
|
|
|||
|
|
from typing import *
|
|||
|
|
from pathlib import Path
|
|||
|
|
# 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同
|
|||
|
|
# from configs import (
|
|||
|
|
# EMBEDDING_MODEL,
|
|||
|
|
# DEFAULT_VS_TYPE,
|
|||
|
|
# LLM_MODELS,
|
|||
|
|
# TEMPERATURE,
|
|||
|
|
# SCORE_THRESHOLD,
|
|||
|
|
# CHUNK_SIZE,
|
|||
|
|
# OVERLAP_SIZE,
|
|||
|
|
# ZH_TITLE_ENHANCE,
|
|||
|
|
# FIRST_VECTOR_SEARCH_TOP_K,
|
|||
|
|
# VECTOR_SEARCH_TOP_K,
|
|||
|
|
# SEARCH_ENGINE_TOP_K,
|
|||
|
|
# HTTPX_DEFAULT_TIMEOUT,
|
|||
|
|
# logger, log_verbose,
|
|||
|
|
# )
|
|||
|
|
import httpx
|
|||
|
|
import contextlib
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
from io import BytesIO
|
|||
|
|
# from server.utils import set_httpx_config, api_address, get_httpx_client
|
|||
|
|
|
|||
|
|
from pprint import pprint
|
|||
|
|
|
|||
|
|
|
|||
|
|
# from langchain_core._api import deprecated
|
|||
|
|
|
|||
|
|
# set_httpx_config()
|
|||
|
|
|
|||
|
|
def get_httpx_client(
|
|||
|
|
use_async: bool = False,
|
|||
|
|
proxies: Union[str, Dict] = None,
|
|||
|
|
timeout: float = 300,
|
|||
|
|
**kwargs,
|
|||
|
|
) -> Union[httpx.Client, httpx.AsyncClient]:
|
|||
|
|
|
|||
|
|
if use_async:
|
|||
|
|
return httpx.AsyncClient(**kwargs)
|
|||
|
|
else:
|
|||
|
|
return httpx.Client(**kwargs)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ApiRequest:
|
|||
|
|
'''
|
|||
|
|
api.py调用的封装(同步模式),简化api调用方式
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
base_url: str = 'http://127.0.0.1:7861',
|
|||
|
|
# base_url: str = 'http://192.168.0.21:17861',
|
|||
|
|
timeout: float = 300,
|
|||
|
|
):
|
|||
|
|
self.base_url = base_url
|
|||
|
|
self.timeout = timeout
|
|||
|
|
self._use_async = False
|
|||
|
|
self._client = None
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def client(self):
|
|||
|
|
if self._client is None or self._client.is_closed:
|
|||
|
|
self._client = get_httpx_client(base_url=self.base_url,
|
|||
|
|
use_async=self._use_async,
|
|||
|
|
timeout=self.timeout)
|
|||
|
|
return self._client
|
|||
|
|
|
|||
|
|
def get(
|
|||
|
|
self,
|
|||
|
|
url: str,
|
|||
|
|
params: Union[Dict, List[Tuple], bytes] = None,
|
|||
|
|
retry: int = 3,
|
|||
|
|
stream: bool = False,
|
|||
|
|
**kwargs: Any,
|
|||
|
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
|||
|
|
while retry > 0:
|
|||
|
|
try:
|
|||
|
|
if stream:
|
|||
|
|
return self.client.stream("GET", url, params=params, **kwargs)
|
|||
|
|
else:
|
|||
|
|
return self.client.get(url, params=params, **kwargs)
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = f"error when get {url}: {e}"
|
|||
|
|
# logger.error(f'{e.__class__.__name__}: {msg}',
|
|||
|
|
# exc_info=e if log_verbose else None)
|
|||
|
|
retry -= 1
|
|||
|
|
|
|||
|
|
def post(
|
|||
|
|
self,
|
|||
|
|
url: str,
|
|||
|
|
data: Dict = None,
|
|||
|
|
json: Dict = None,
|
|||
|
|
retry: int = 3,
|
|||
|
|
stream: bool = False,
|
|||
|
|
**kwargs: Any
|
|||
|
|
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
|||
|
|
while retry > 0:
|
|||
|
|
try:
|
|||
|
|
# print(kwargs)
|
|||
|
|
if stream:
|
|||
|
|
return self.client.stream("POST", url, data=data, json=json, **kwargs)
|
|||
|
|
else:
|
|||
|
|
return self.client.post(url, data=data, json=json, **kwargs)
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = f"error when post {url}: {e}"
|
|||
|
|
# logger.error(f'{e.__class__.__name__}: {msg}',
|
|||
|
|
# exc_info=e if log_verbose else None)
|
|||
|
|
retry -= 1
|
|||
|
|
|
|||
|
|
def _httpx_stream2generator(
|
|||
|
|
self,
|
|||
|
|
response: contextlib._GeneratorContextManager,
|
|||
|
|
as_json: bool = False,
|
|||
|
|
):
|
|||
|
|
'''
|
|||
|
|
将httpx.stream返回的GeneratorContextManager转化为普通生成器
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
async def ret_async(response, as_json):
|
|||
|
|
try:
|
|||
|
|
async with response as r:
|
|||
|
|
async for chunk in r.aiter_text(None):
|
|||
|
|
if not chunk: # fastchat api yield empty bytes on start and end
|
|||
|
|
continue
|
|||
|
|
if as_json:
|
|||
|
|
try:
|
|||
|
|
if chunk.startswith("data: "):
|
|||
|
|
data = json.loads(chunk[6:-2])
|
|||
|
|
elif chunk.startswith(":"): # skip sse comment line
|
|||
|
|
continue
|
|||
|
|
else:
|
|||
|
|
data = json.loads(chunk)
|
|||
|
|
yield data
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
|||
|
|
# logger.error(f'{e.__class__.__name__}: {msg}',
|
|||
|
|
# exc_info=e if log_verbose else None)
|
|||
|
|
else:
|
|||
|
|
# print(chunk, end="", flush=True)
|
|||
|
|
yield chunk
|
|||
|
|
except httpx.ConnectError as e:
|
|||
|
|
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
|||
|
|
# logger.error(msg)
|
|||
|
|
yield {"code": 500, "msg": msg}
|
|||
|
|
except httpx.ReadTimeout as e:
|
|||
|
|
msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})"
|
|||
|
|
# logger.error(msg)
|
|||
|
|
yield {"code": 500, "msg": msg}
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = f"API通信遇到错误:{e}"
|
|||
|
|
# logger.error(f'{e.__class__.__name__}: {msg}',
|
|||
|
|
# exc_info=e if log_verbose else None)
|
|||
|
|
yield {"code": 500, "msg": msg}
|
|||
|
|
|
|||
|
|
def ret_sync(response, as_json):
|
|||
|
|
try:
|
|||
|
|
with response as r:
|
|||
|
|
for chunk in r.iter_text(None):
|
|||
|
|
if not chunk: # fastchat api yield empty bytes on start and end
|
|||
|
|
continue
|
|||
|
|
if as_json:
|
|||
|
|
try:
|
|||
|
|
if chunk.startswith("data: "):
|
|||
|
|
data = json.loads(chunk[6:-2])
|
|||
|
|
elif chunk.startswith(":"): # skip sse comment line
|
|||
|
|
continue
|
|||
|
|
else:
|
|||
|
|
data = json.loads(chunk)
|
|||
|
|
yield data
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
|||
|
|
# # logger.error(f'{e.__class__.__name__}: {msg}',
|
|||
|
|
# exc_info=e if log_verbose else None)
|
|||
|
|
else:
|
|||
|
|
# print(chunk, end="", flush=True)
|
|||
|
|
yield chunk
|
|||
|
|
except httpx.ConnectError as e:
|
|||
|
|
msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})"
|
|||
|
|
# logger.error(msg)
|
|||
|
|
yield {"code": 500, "msg": msg}
|
|||
|
|
except httpx.ReadTimeout as e:
|
|||
|
|
msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})"
|
|||
|
|
# logger.error(msg)
|
|||
|
|
yield {"code": 500, "msg": msg}
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = f"API通信遇到错误:{e}"
|
|||
|
|
# logger.error(f'{e.__class__.__name__}: {msg}',
|
|||
|
|
# exc_info=e if log_verbose else None)
|
|||
|
|
yield {"code": 500, "msg": msg}
|
|||
|
|
|
|||
|
|
if self._use_async:
|
|||
|
|
return ret_async(response, as_json)
|
|||
|
|
else:
|
|||
|
|
return ret_sync(response, as_json)
|
|||
|
|
|
|||
|
|
def _get_response_value(
|
|||
|
|
self,
|
|||
|
|
response: httpx.Response,
|
|||
|
|
as_json: bool = False,
|
|||
|
|
value_func: Callable = None,
|
|||
|
|
):
|
|||
|
|
'''
|
|||
|
|
转换同步或异步请求返回的响应
|
|||
|
|
`as_json`: 返回json
|
|||
|
|
`value_func`: 用户可以自定义返回值,该函数接受response或json
|
|||
|
|
'''
|
|||
|
|
|
|||
|
|
def to_json(r):
|
|||
|
|
try:
|
|||
|
|
return r.json()
|
|||
|
|
except Exception as e:
|
|||
|
|
msg = "API未能返回正确的JSON。" + str(e)
|
|||
|
|
# if log_verbose:
|
|||
|
|
# # logger.error(f'{e.__class__.__name__}: {msg}',
|
|||
|
|
# # exc_info=e if log_verbose else None)
|
|||
|
|
return {"code": 500, "msg": msg, "data": None}
|
|||
|
|
|
|||
|
|
if value_func is None:
|
|||
|
|
value_func = (lambda r: r)
|
|||
|
|
|
|||
|
|
async def ret_async(response):
|
|||
|
|
if as_json:
|
|||
|
|
return value_func(to_json(await response))
|
|||
|
|
else:
|
|||
|
|
return value_func(await response)
|
|||
|
|
|
|||
|
|
if self._use_async:
|
|||
|
|
return ret_async(response)
|
|||
|
|
else:
|
|||
|
|
if as_json:
|
|||
|
|
return value_func(to_json(response))
|
|||
|
|
else:
|
|||
|
|
return value_func(response)
|
|||
|
|
|
|||
|
|
|
|||
|
|
def knowledge_base_chat(
|
|||
|
|
self,
|
|||
|
|
query: str,
|
|||
|
|
knowledge_base_name: str,
|
|||
|
|
top_k: int = 1,
|
|||
|
|
score_threshold: float = 2,
|
|||
|
|
history: List[Dict] = [],
|
|||
|
|
stream: bool = True,
|
|||
|
|
model: str = "chatglm3-6b",
|
|||
|
|
temperature: float = 0.01,
|
|||
|
|
max_tokens: int = None,
|
|||
|
|
prompt_name: str = "default",
|
|||
|
|
):
|
|||
|
|
'''
|
|||
|
|
对应api.py/chat/knowledge_base_chat接口
|
|||
|
|
'''
|
|||
|
|
data = {
|
|||
|
|
"query": query,
|
|||
|
|
"knowledge_base_name": knowledge_base_name,
|
|||
|
|
"top_k": top_k,
|
|||
|
|
"score_threshold": score_threshold,
|
|||
|
|
"history": history,
|
|||
|
|
"stream": stream,
|
|||
|
|
"model_name": model,
|
|||
|
|
"temperature": temperature,
|
|||
|
|
"max_tokens": max_tokens,
|
|||
|
|
"prompt_name": prompt_name,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# print(f"received input message:")
|
|||
|
|
# pprint(data)
|
|||
|
|
|
|||
|
|
response = self.post(
|
|||
|
|
"/chat/knowledge_base_chat",
|
|||
|
|
json=data,
|
|||
|
|
stream=True,
|
|||
|
|
)
|
|||
|
|
return self._httpx_stream2generator(response, as_json=True)
|
|||
|
|
|
|||
|
|
|
|||
|
|
# class AsyncApiRequest(ApiRequest):
|
|||
|
|
# def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):
|
|||
|
|
# super().__init__(base_url, timeout)
|
|||
|
|
# self._use_async = True
|
|||
|
|
|
|||
|
|
|
|||
|
|
def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str:
|
|||
|
|
'''
|
|||
|
|
return error message if error occured when requests API
|
|||
|
|
'''
|
|||
|
|
if isinstance(data, dict):
|
|||
|
|
if key in data:
|
|||
|
|
return data[key]
|
|||
|
|
if "code" in data and data["code"] != 200:
|
|||
|
|
return data["msg"]
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str:
|
|||
|
|
'''
|
|||
|
|
return error message if error occured when requests API
|
|||
|
|
'''
|
|||
|
|
if (isinstance(data, dict)
|
|||
|
|
and key in data
|
|||
|
|
and "code" in data
|
|||
|
|
and data["code"] == 200):
|
|||
|
|
return data[key]
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
api = ApiRequest()
|