diff --git a/configs/server_config.py.example b/configs/server_config.py.example index ff0a251..cc3a19b 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -88,6 +88,11 @@ FSCHAT_MODEL_WORKERS = { # 'disable_log_requests': False }, + # 可以如下示例方式更改默认配置 + # "baichuan-7b": { # 使用default中的IP和端口 + # "device": "cpu", + # }, + "zhipu-api": { # 请为每个要运行的在线API设置不同的端口 "port": 21001, }, diff --git a/docs/FAQ.md b/docs/FAQ.md index 276d8f4..ad6683b 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -185,18 +185,10 @@ A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1. Q16: 使用vllm推理加速框架时,已经下载了模型但出现HuggingFace通信问题 -A16: 参照如下代码修改python环境下/site-packages/vllm/model_executor/weight_utils.py文件的prepare_hf_model_weights函数如下: +A16: 参照如下代码修改python环境下/site-packages/vllm/model_executor/weight_utils.py文件的prepare_hf_model_weights函数如下对应代码: ```python -def prepare_hf_model_weights( - model_name_or_path: str, - cache_dir: Optional[str] = None, - use_safetensors: bool = False, - fall_back_to_pt: bool = True, -): - # Download model weights from huggingface. - is_local = os.path.isdir(model_name_or_path) - allow_patterns = "*.safetensors" if use_safetensors else "*.bin" + if not is_local: # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. @@ -225,22 +217,7 @@ def prepare_hf_model_weights( tqdm_class=Disabledtqdm) else: hf_folder = model_name_or_path - hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns)) - if not use_safetensors: - hf_weights_files = [ - x for x in hf_weights_files if not x.endswith("training_args.bin") - ] - if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt: - return prepare_hf_model_weights(model_name_or_path, - cache_dir=cache_dir, - use_safetensors=False, - fall_back_to_pt=False) - if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{model_name_or_path}`") - - return hf_folder, hf_weights_files, use_safetensors ``` diff --git a/requirements.txt b/requirements.txt index 0e5d224..3242ea7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,7 @@ pathlib pytest scikit-learn numexpr - +vllm==0.1.7 # online api libs # zhipuai # dashscope>=1.10.0 # qwen diff --git a/requirements_api.txt b/requirements_api.txt index fd5228d..674d759 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -23,7 +23,7 @@ pathlib pytest scikit-learn numexpr - +vllm==0.1.7 # online api libs # zhipuai # dashscope>=1.10.0 # qwen diff --git a/startup.py b/startup.py index 01a003a..c4f9e7b 100644 --- a/startup.py +++ b/startup.py @@ -75,10 +75,10 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: """ import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger + from fastchat.serve.model_worker import worker_id, logger import argparse - import threading import fastchat.serve.model_worker + import fastchat.serve.vllm_worker logger.setLevel(log_level) parser = argparse.ArgumentParser() @@ -89,6 +89,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: # 在线模型API if worker_class := kwargs.get("worker_class"): + from fastchat.serve.model_worker import app worker = worker_class(model_names=args.model_names, controller_addr=args.controller_address, worker_addr=args.worker_address) @@ -97,15 +98,10 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: else: from configs.model_config import VLLM_MODEL_DICT if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm": - from fastchat.serve.vllm_worker import VLLMWorker + from fastchat.serve.vllm_worker import VLLMWorker,app from vllm import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs - #! -------------似乎会在这个地方加入tokenizer------------ - - # parser = AsyncEngineArgs.add_cli_args(args) - # # args = parser.parse_args() - - + args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加 args.tokenizer_mode = 'auto' args.trust_remote_code= True @@ -150,9 +146,11 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI: llm_engine = engine, conv_template = args.conv_template, ) + sys.modules["fastchat.serve.vllm_worker"].engine = engine sys.modules["fastchat.serve.vllm_worker"].worker = worker - + else: + from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker args.gpus = "1" args.max_gpu_memory = "20GiB" args.load_8bit = False