import pydantic from pydantic import BaseModel from typing import List from fastapi import FastAPI from pathlib import Path import asyncio from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, MODEL_PATH, MODEL_ROOT_PATH, logger, log_verbose, FSCHAT_MODEL_WORKERS) import os from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Literal, Optional, Callable, Generator, Dict, Any thread_pool = ThreadPoolExecutor(os.cpu_count()) class BaseResponse(BaseModel): code: int = pydantic.Field(200, description="API status code") msg: str = pydantic.Field("success", description="API status message") data: Any = pydantic.Field(None, description="API data") class Config: schema_extra = { "example": { "code": 200, "msg": "success", } } class ListResponse(BaseResponse): data: List[str] = pydantic.Field(..., description="List of names") class Config: schema_extra = { "example": { "code": 200, "msg": "success", "data": ["doc1.docx", "doc2.pdf", "doc3.txt"], } } class ChatMessage(BaseModel): question: str = pydantic.Field(..., description="Question text") response: str = pydantic.Field(..., description="Response text") history: List[List[str]] = pydantic.Field(..., description="History text") source_documents: List[str] = pydantic.Field( ..., description="List of source documents and their scores" ) class Config: schema_extra = { "example": { "question": "工伤保险如何办理?", "response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n" "2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n" "3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n" "4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n" "5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n" "6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。", "history": [ [ "工伤保险是什么?", "工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费," "由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。", ] ], "source_documents": [ "出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx:\n\n\t" "( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。", "出处 [2] ...", "出处 [3] ...", ], } } def torch_gc(): import torch if torch.cuda.is_available(): # with torch.cuda.device(DEVICE): torch.cuda.empty_cache() torch.cuda.ipc_collect() elif torch.backends.mps.is_available(): try: from torch.mps import empty_cache empty_cache() except Exception as e: msg=("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本," "以支持及时清理 torch 产生的内存占用。") logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) def run_async(cor): ''' 在同步环境中运行异步代码. ''' try: loop = asyncio.get_event_loop() except: loop = asyncio.new_event_loop() return loop.run_until_complete(cor) def iter_over_async(ait, loop): ''' 将异步生成器封装成同步生成器. ''' ait = ait.__aiter__() async def get_next(): try: obj = await ait.__anext__() return False, obj except StopAsyncIteration: return True, None while True: done, obj = loop.run_until_complete(get_next()) if done: break yield obj def MakeFastAPIOffline( app: FastAPI, static_dir = Path(__file__).parent / "static", static_url = "/static-offline-docs", docs_url: Optional[str] = "/docs", redoc_url: Optional[str] = "/redoc", ) -> None: """patch the FastAPI obj that doesn't rely on CDN for the documentation page""" from fastapi import Request from fastapi.openapi.docs import ( get_redoc_html, get_swagger_ui_html, get_swagger_ui_oauth2_redirect_html, ) from fastapi.staticfiles import StaticFiles from starlette.responses import HTMLResponse openapi_url = app.openapi_url swagger_ui_oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url def remove_route(url: str) -> None: ''' remove original route from app ''' index = None for i, r in enumerate(app.routes): if r.path.lower() == url.lower(): index = i break if isinstance(index, int): app.routes.pop(i) # Set up static file mount app.mount( static_url, StaticFiles(directory=Path(static_dir).as_posix()), name="static-offline-docs", ) if docs_url is not None: remove_route(docs_url) remove_route(swagger_ui_oauth2_redirect_url) # Define the doc and redoc pages, pointing at the right files @app.get(docs_url, include_in_schema=False) async def custom_swagger_ui_html(request: Request) -> HTMLResponse: root = request.scope.get("root_path") favicon = f"{root}{static_url}/favicon.png" return get_swagger_ui_html( openapi_url=f"{root}{openapi_url}", title=app.title + " - Swagger UI", oauth2_redirect_url=swagger_ui_oauth2_redirect_url, swagger_js_url=f"{root}{static_url}/swagger-ui-bundle.js", swagger_css_url=f"{root}{static_url}/swagger-ui.css", swagger_favicon_url=favicon, ) @app.get(swagger_ui_oauth2_redirect_url, include_in_schema=False) async def swagger_ui_redirect() -> HTMLResponse: return get_swagger_ui_oauth2_redirect_html() if redoc_url is not None: remove_route(redoc_url) @app.get(redoc_url, include_in_schema=False) async def redoc_html(request: Request) -> HTMLResponse: root = request.scope.get("root_path") favicon = f"{root}{static_url}/favicon.png" return get_redoc_html( openapi_url=f"{root}{openapi_url}", title=app.title + " - ReDoc", redoc_js_url=f"{root}{static_url}/redoc.standalone.js", with_google_fonts=False, redoc_favicon_url=favicon, ) # 从model_config中获取模型信息 def list_embed_models() -> List[str]: return list(MODEL_PATH["embed_model"]) def list_llm_models() -> List[str]: return list(MODEL_PATH["llm_model"]) def get_model_path(model_name: str, type: str = None) -> Optional[str]: if type in MODEL_PATH: paths = MODEL_PATH[type] else: paths = {} for v in MODEL_PATH.values(): paths.update(v) if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径 path = Path(path_str) if path.is_dir(): # 任意绝对路径 return str(path) root_path = Path(MODEL_ROOT_PATH) if root_path.is_dir(): path = root_path / model_name if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b return str(path) path = root_path / path_str if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new return str(path) path = root_path / path_str.split("/")[-1] if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new return str(path) return path_str # THUDM/chatglm06b # 从server_config中获取服务信息 def get_model_worker_config(model_name: str = None) -> dict: ''' 加载model worker的配置项。 优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"] ''' from configs.model_config import ONLINE_LLM_MODEL from configs.server_config import FSCHAT_MODEL_WORKERS from server import model_workers config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config.update(ONLINE_LLM_MODEL.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) # 在线模型API if model_name in ONLINE_LLM_MODEL: config["online_api"] = True if provider := config.get("provider"): try: config["worker_class"] = getattr(model_workers, provider) except Exception as e: msg = f"在线模型 ‘{model_name}’ 的provider没有正确配置" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) config["model_path"] = get_model_path(model_name) config["device"] = llm_device(config.get("device")) return config def get_all_model_worker_configs() -> dict: result = {} model_names = set(FSCHAT_MODEL_WORKERS.keys()) for name in model_names: if name != "default": result[name] = get_model_worker_config(name) return result def fschat_controller_address() -> str: from configs.server_config import FSCHAT_CONTROLLER host = FSCHAT_CONTROLLER["host"] port = FSCHAT_CONTROLLER["port"] return f"http://{host}:{port}" def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str: if model := get_model_worker_config(model_name): host = model["host"] port = model["port"] return f"http://{host}:{port}" return "" def fschat_openai_api_address() -> str: from configs.server_config import FSCHAT_OPENAI_API host = FSCHAT_OPENAI_API["host"] port = FSCHAT_OPENAI_API["port"] return f"http://{host}:{port}/v1" def api_address() -> str: from configs.server_config import API_SERVER host = API_SERVER["host"] port = API_SERVER["port"] return f"http://{host}:{port}" def webui_address() -> str: from configs.server_config import WEBUI_SERVER host = WEBUI_SERVER["host"] port = WEBUI_SERVER["port"] return f"http://{host}:{port}" def set_httpx_timeout(timeout: float = None): ''' 设置httpx默认timeout。 httpx默认timeout是5秒,在请求LLM回答时不够用。 ''' import httpx from configs.server_config import HTTPX_DEFAULT_TIMEOUT timeout = timeout or HTTPX_DEFAULT_TIMEOUT httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout # 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch def detect_device() -> Literal["cuda", "mps", "cpu"]: try: import torch if torch.cuda.is_available(): return "cuda" if torch.backends.mps.is_available(): return "mps" except: pass return "cpu" def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: device = device or LLM_DEVICE if device not in ["cuda", "mps", "cpu"]: device = detect_device() return device def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: device = device or EMBEDDING_DEVICE if device not in ["cuda", "mps", "cpu"]: device = detect_device() return device def run_in_thread_pool( func: Callable, params: List[Dict] = [], pool: ThreadPoolExecutor = None, ) -> Generator: ''' 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 ''' tasks = [] pool = pool or thread_pool for kwargs in params: thread = pool.submit(func, **kwargs) tasks.append(thread) for obj in as_completed(tasks): yield obj.result()