commit
295783b549
|
|
@ -7,14 +7,13 @@ import asyncio
|
||||||
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||||
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
|
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
|
||||||
logger, log_verbose,
|
logger, log_verbose,
|
||||||
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
|
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
import httpx
|
import httpx
|
||||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
|
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
|
||||||
|
|
||||||
|
|
||||||
thread_pool = ThreadPoolExecutor(os.cpu_count())
|
thread_pool = ThreadPoolExecutor(os.cpu_count())
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -33,12 +32,12 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||||
|
|
||||||
|
|
||||||
def get_ChatOpenAI(
|
def get_ChatOpenAI(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
callbacks: List[Callable] = [],
|
callbacks: List[Callable] = [],
|
||||||
verbose: bool = True,
|
verbose: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatOpenAI:
|
) -> ChatOpenAI:
|
||||||
config = get_model_worker_config(model_name)
|
config = get_model_worker_config(model_name)
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
|
|
@ -68,6 +67,7 @@ class BaseResponse(BaseModel):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ListResponse(BaseResponse):
|
class ListResponse(BaseResponse):
|
||||||
data: List[str] = pydantic.Field(..., description="List of names")
|
data: List[str] = pydantic.Field(..., description="List of names")
|
||||||
|
|
||||||
|
|
@ -115,6 +115,7 @@ class ChatMessage(BaseModel):
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def torch_gc():
|
def torch_gc():
|
||||||
import torch
|
import torch
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
|
|
@ -126,11 +127,12 @@ def torch_gc():
|
||||||
from torch.mps import empty_cache
|
from torch.mps import empty_cache
|
||||||
empty_cache()
|
empty_cache()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg=("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,"
|
msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,"
|
||||||
"以支持及时清理 torch 产生的内存占用。")
|
"以支持及时清理 torch 产生的内存占用。")
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||||
exc_info=e if log_verbose else None)
|
exc_info=e if log_verbose else None)
|
||||||
|
|
||||||
|
|
||||||
def run_async(cor):
|
def run_async(cor):
|
||||||
'''
|
'''
|
||||||
在同步环境中运行异步代码.
|
在同步环境中运行异步代码.
|
||||||
|
|
@ -147,12 +149,14 @@ def iter_over_async(ait, loop):
|
||||||
将异步生成器封装成同步生成器.
|
将异步生成器封装成同步生成器.
|
||||||
'''
|
'''
|
||||||
ait = ait.__aiter__()
|
ait = ait.__aiter__()
|
||||||
|
|
||||||
async def get_next():
|
async def get_next():
|
||||||
try:
|
try:
|
||||||
obj = await ait.__anext__()
|
obj = await ait.__anext__()
|
||||||
return False, obj
|
return False, obj
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
return True, None
|
return True, None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
done, obj = loop.run_until_complete(get_next())
|
done, obj = loop.run_until_complete(get_next())
|
||||||
if done:
|
if done:
|
||||||
|
|
@ -161,11 +165,11 @@ def iter_over_async(ait, loop):
|
||||||
|
|
||||||
|
|
||||||
def MakeFastAPIOffline(
|
def MakeFastAPIOffline(
|
||||||
app: FastAPI,
|
app: FastAPI,
|
||||||
static_dir = Path(__file__).parent / "static",
|
static_dir=Path(__file__).parent / "static",
|
||||||
static_url = "/static-offline-docs",
|
static_url="/static-offline-docs",
|
||||||
docs_url: Optional[str] = "/docs",
|
docs_url: Optional[str] = "/docs",
|
||||||
redoc_url: Optional[str] = "/redoc",
|
redoc_url: Optional[str] = "/redoc",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""patch the FastAPI obj that doesn't rely on CDN for the documentation page"""
|
"""patch the FastAPI obj that doesn't rely on CDN for the documentation page"""
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
|
@ -245,6 +249,7 @@ def list_embed_models() -> List[str]:
|
||||||
'''
|
'''
|
||||||
return list(MODEL_PATH["embed_model"])
|
return list(MODEL_PATH["embed_model"])
|
||||||
|
|
||||||
|
|
||||||
def list_llm_models() -> Dict[str, List[str]]:
|
def list_llm_models() -> Dict[str, List[str]]:
|
||||||
'''
|
'''
|
||||||
get names of configured llm models with different types.
|
get names of configured llm models with different types.
|
||||||
|
|
@ -268,23 +273,23 @@ def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
||||||
for v in MODEL_PATH.values():
|
for v in MODEL_PATH.values():
|
||||||
paths.update(v)
|
paths.update(v)
|
||||||
|
|
||||||
if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
|
if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
|
||||||
path = Path(path_str)
|
path = Path(path_str)
|
||||||
if path.is_dir(): # 任意绝对路径
|
if path.is_dir(): # 任意绝对路径
|
||||||
return str(path)
|
return str(path)
|
||||||
|
|
||||||
root_path = Path(MODEL_ROOT_PATH)
|
root_path = Path(MODEL_ROOT_PATH)
|
||||||
if root_path.is_dir():
|
if root_path.is_dir():
|
||||||
path = root_path / model_name
|
path = root_path / model_name
|
||||||
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
|
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
|
||||||
return str(path)
|
return str(path)
|
||||||
path = root_path / path_str
|
path = root_path / path_str
|
||||||
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
|
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
|
||||||
return str(path)
|
return str(path)
|
||||||
path = root_path / path_str.split("/")[-1]
|
path = root_path / path_str.split("/")[-1]
|
||||||
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
|
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
|
||||||
return str(path)
|
return str(path)
|
||||||
return path_str # THUDM/chatglm06b
|
return path_str # THUDM/chatglm06b
|
||||||
|
|
||||||
|
|
||||||
# 从server_config中获取服务信息
|
# 从server_config中获取服务信息
|
||||||
|
|
@ -372,7 +377,7 @@ def get_prompt_template(name: str) -> Optional[str]:
|
||||||
'''
|
'''
|
||||||
from configs import prompt_config
|
from configs import prompt_config
|
||||||
import importlib
|
import importlib
|
||||||
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
|
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
|
||||||
|
|
||||||
return prompt_config.PROMPT_TEMPLATES.get(name)
|
return prompt_config.PROMPT_TEMPLATES.get(name)
|
||||||
|
|
||||||
|
|
@ -380,7 +385,7 @@ def get_prompt_template(name: str) -> Optional[str]:
|
||||||
def set_httpx_config(
|
def set_httpx_config(
|
||||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||||
proxy: Union[str, Dict] = None,
|
proxy: Union[str, Dict] = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。
|
设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。
|
||||||
将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效)
|
将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效)
|
||||||
|
|
@ -400,9 +405,9 @@ def set_httpx_config(
|
||||||
proxies[n + "_proxy"] = proxy
|
proxies[n + "_proxy"] = proxy
|
||||||
elif isinstance(proxy, dict):
|
elif isinstance(proxy, dict):
|
||||||
for n in ["http", "https", "all"]:
|
for n in ["http", "https", "all"]:
|
||||||
if p:= proxy.get(n):
|
if p := proxy.get(n):
|
||||||
proxies[n + "_proxy"] = p
|
proxies[n + "_proxy"] = p
|
||||||
elif p:= proxy.get(n + "_proxy"):
|
elif p := proxy.get(n + "_proxy"):
|
||||||
proxies[n + "_proxy"] = p
|
proxies[n + "_proxy"] = p
|
||||||
|
|
||||||
for k, v in proxies.items():
|
for k, v in proxies.items():
|
||||||
|
|
@ -463,9 +468,9 @@ def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
||||||
|
|
||||||
|
|
||||||
def run_in_thread_pool(
|
def run_in_thread_pool(
|
||||||
func: Callable,
|
func: Callable,
|
||||||
params: List[Dict] = [],
|
params: List[Dict] = [],
|
||||||
pool: ThreadPoolExecutor = None,
|
pool: ThreadPoolExecutor = None,
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
'''
|
'''
|
||||||
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
||||||
|
|
@ -483,10 +488,10 @@ def run_in_thread_pool(
|
||||||
|
|
||||||
|
|
||||||
def get_httpx_client(
|
def get_httpx_client(
|
||||||
use_async: bool = False,
|
use_async: bool = False,
|
||||||
proxies: Union[str, Dict] = None,
|
proxies: Union[str, Dict] = None,
|
||||||
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[httpx.Client, httpx.AsyncClient]:
|
) -> Union[httpx.Client, httpx.AsyncClient]:
|
||||||
'''
|
'''
|
||||||
helper to get httpx client with default proxies that bypass local addesses.
|
helper to get httpx client with default proxies that bypass local addesses.
|
||||||
|
|
@ -508,9 +513,15 @@ def get_httpx_client(
|
||||||
# get proxies from system envionrent
|
# get proxies from system envionrent
|
||||||
# proxy not str empty string, None, False, 0, [] or {}
|
# proxy not str empty string, None, False, 0, [] or {}
|
||||||
default_proxies.update({
|
default_proxies.update({
|
||||||
"http://": os.environ.get("http_proxy") if len(os.environ.get("http_proxy").strip()) > 0 else None,
|
"http://": (os.environ.get("http_proxy")
|
||||||
"https://": os.environ.get("https_proxy") if len(os.environ.get("https_proxy").strip()) > 0 else None,
|
if os.environ.get("http_proxy") and len(os.environ.get("http_proxy").strip())
|
||||||
"all://": os.environ.get("all_proxy") if len(os.environ.get("all_proxy").strip()) > 0 else None,
|
else None),
|
||||||
|
"https://": (os.environ.get("https_proxy")
|
||||||
|
if os.environ.get("https_proxy") and len(os.environ.get("https_proxy").strip())
|
||||||
|
else None),
|
||||||
|
"all://": (os.environ.get("all_proxy")
|
||||||
|
if os.environ.get("all_proxy") and len(os.environ.get("all_proxy").strip())
|
||||||
|
else None),
|
||||||
})
|
})
|
||||||
for host in os.environ.get("no_proxy", "").split(","):
|
for host in os.environ.get("no_proxy", "").split(","):
|
||||||
if host := host.strip():
|
if host := host.strip():
|
||||||
|
|
@ -530,4 +541,3 @@ def get_httpx_client(
|
||||||
return httpx.AsyncClient(**kwargs)
|
return httpx.AsyncClient(**kwargs)
|
||||||
else:
|
else:
|
||||||
return httpx.Client(**kwargs)
|
return httpx.Client(**kwargs)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue