From 84a4141dec1b8c5da5b62d3ad885927791f6d3ee Mon Sep 17 00:00:00 2001 From: glide-the <2533736852@qq.com> Date: Wed, 4 Oct 2023 14:02:57 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=BB=A3=E7=90=86=E4=B8=BA?= =?UTF-8?q?=E7=A9=BA=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/utils.py | 82 +++++++++++++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 36 deletions(-) diff --git a/server/utils.py b/server/utils.py index 6f975e3..52d34b9 100644 --- a/server/utils.py +++ b/server/utils.py @@ -7,14 +7,13 @@ import asyncio from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose, - FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT) + FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT) import os from concurrent.futures import ThreadPoolExecutor, as_completed from langchain.chat_models import ChatOpenAI import httpx from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union - thread_pool = ThreadPoolExecutor(os.cpu_count()) @@ -33,12 +32,12 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): def get_ChatOpenAI( - model_name: str, - temperature: float, - streaming: bool = True, - callbacks: List[Callable] = [], - verbose: bool = True, - **kwargs: Any, + model_name: str, + temperature: float, + streaming: bool = True, + callbacks: List[Callable] = [], + verbose: bool = True, + **kwargs: Any, ) -> ChatOpenAI: config = get_model_worker_config(model_name) model = ChatOpenAI( @@ -68,6 +67,7 @@ class BaseResponse(BaseModel): } } + class ListResponse(BaseResponse): data: List[str] = pydantic.Field(..., description="List of names") @@ -115,6 +115,7 @@ class ChatMessage(BaseModel): } } + def torch_gc(): import torch if torch.cuda.is_available(): @@ -126,11 +127,12 @@ def torch_gc(): from torch.mps import empty_cache empty_cache() except Exception as e: - msg=("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本," - "以支持及时清理 torch 产生的内存占用。") + msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本," + "以支持及时清理 torch 产生的内存占用。") logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) + def run_async(cor): ''' 在同步环境中运行异步代码. @@ -147,12 +149,14 @@ def iter_over_async(ait, loop): 将异步生成器封装成同步生成器. ''' ait = ait.__aiter__() + async def get_next(): try: obj = await ait.__anext__() return False, obj except StopAsyncIteration: return True, None + while True: done, obj = loop.run_until_complete(get_next()) if done: @@ -161,11 +165,11 @@ def iter_over_async(ait, loop): def MakeFastAPIOffline( - app: FastAPI, - static_dir = Path(__file__).parent / "static", - static_url = "/static-offline-docs", - docs_url: Optional[str] = "/docs", - redoc_url: Optional[str] = "/redoc", + app: FastAPI, + static_dir=Path(__file__).parent / "static", + static_url="/static-offline-docs", + docs_url: Optional[str] = "/docs", + redoc_url: Optional[str] = "/redoc", ) -> None: """patch the FastAPI obj that doesn't rely on CDN for the documentation page""" from fastapi import Request @@ -245,6 +249,7 @@ def list_embed_models() -> List[str]: ''' return list(MODEL_PATH["embed_model"]) + def list_llm_models() -> Dict[str, List[str]]: ''' get names of configured llm models with different types. @@ -268,23 +273,23 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]: for v in MODEL_PATH.values(): paths.update(v) - if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径 + if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径 path = Path(path_str) - if path.is_dir(): # 任意绝对路径 + if path.is_dir(): # 任意绝对路径 return str(path) root_path = Path(MODEL_ROOT_PATH) if root_path.is_dir(): path = root_path / model_name - if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b + if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b return str(path) path = root_path / path_str - if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new + if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new return str(path) path = root_path / path_str.split("/")[-1] - if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new + if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new return str(path) - return path_str # THUDM/chatglm06b + return path_str # THUDM/chatglm06b # 从server_config中获取服务信息 @@ -372,7 +377,7 @@ def get_prompt_template(name: str) -> Optional[str]: ''' from configs import prompt_config import importlib - importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载 + importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载 return prompt_config.PROMPT_TEMPLATES.get(name) @@ -380,7 +385,7 @@ def get_prompt_template(name: str) -> Optional[str]: def set_httpx_config( timeout: float = HTTPX_DEFAULT_TIMEOUT, proxy: Union[str, Dict] = None, - ): +): ''' 设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。 将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效) @@ -400,9 +405,9 @@ def set_httpx_config( proxies[n + "_proxy"] = proxy elif isinstance(proxy, dict): for n in ["http", "https", "all"]: - if p:= proxy.get(n): + if p := proxy.get(n): proxies[n + "_proxy"] = p - elif p:= proxy.get(n + "_proxy"): + elif p := proxy.get(n + "_proxy"): proxies[n + "_proxy"] = p for k, v in proxies.items(): @@ -463,9 +468,9 @@ def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: def run_in_thread_pool( - func: Callable, - params: List[Dict] = [], - pool: ThreadPoolExecutor = None, + func: Callable, + params: List[Dict] = [], + pool: ThreadPoolExecutor = None, ) -> Generator: ''' 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 @@ -483,10 +488,10 @@ def run_in_thread_pool( def get_httpx_client( - use_async: bool = False, - proxies: Union[str, Dict] = None, - timeout: float = HTTPX_DEFAULT_TIMEOUT, - **kwargs, + use_async: bool = False, + proxies: Union[str, Dict] = None, + timeout: float = HTTPX_DEFAULT_TIMEOUT, + **kwargs, ) -> Union[httpx.Client, httpx.AsyncClient]: ''' helper to get httpx client with default proxies that bypass local addesses. @@ -508,9 +513,15 @@ def get_httpx_client( # get proxies from system envionrent # proxy not str empty string, None, False, 0, [] or {} default_proxies.update({ - "http://": os.environ.get("http_proxy") if len(os.environ.get("http_proxy").strip()) > 0 else None, - "https://": os.environ.get("https_proxy") if len(os.environ.get("https_proxy").strip()) > 0 else None, - "all://": os.environ.get("all_proxy") if len(os.environ.get("all_proxy").strip()) > 0 else None, + "http://": (os.environ.get("http_proxy") + if os.environ.get("http_proxy") and len(os.environ.get("http_proxy").strip()) + else None), + "https://": (os.environ.get("https_proxy") + if os.environ.get("https_proxy") and len(os.environ.get("https_proxy").strip()) + else None), + "all://": (os.environ.get("all_proxy") + if os.environ.get("all_proxy") and len(os.environ.get("all_proxy").strip()) + else None), }) for host in os.environ.get("no_proxy", "").split(","): if host := host.strip(): @@ -530,4 +541,3 @@ def get_httpx_client( return httpx.AsyncClient(**kwargs) else: return httpx.Client(**kwargs) -