Merge pull request #1581 from hzg0601/dev

测试vllm推理框架成功
This commit is contained in:
Zhi-guo Huang 2023-09-24 02:08:44 +08:00 committed by GitHub
commit 9cbd9f6711
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 17 additions and 37 deletions

View File

@ -88,6 +88,11 @@ FSCHAT_MODEL_WORKERS = {
# 'disable_log_requests': False # 'disable_log_requests': False
}, },
# 可以如下示例方式更改默认配置
# "baichuan-7b": { # 使用default中的IP和端口
# "device": "cpu",
# },
"zhipu-api": { # 请为每个要运行的在线API设置不同的端口 "zhipu-api": { # 请为每个要运行的在线API设置不同的端口
"port": 21001, "port": 21001,
}, },

View File

@ -185,18 +185,10 @@ A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1.
Q16: 使用vllm推理加速框架时已经下载了模型但出现HuggingFace通信问题 Q16: 使用vllm推理加速框架时已经下载了模型但出现HuggingFace通信问题
A16: 参照如下代码修改python环境下/site-packages/vllm/model_executor/weight_utils.py文件的prepare_hf_model_weights函数如下 A16: 参照如下代码修改python环境下/site-packages/vllm/model_executor/weight_utils.py文件的prepare_hf_model_weights函数如下对应代码
```python ```python
def prepare_hf_model_weights(
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_safetensors: bool = False,
fall_back_to_pt: bool = True,
):
# Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path)
allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
if not is_local: if not is_local:
# Use file lock to prevent multiple processes from # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time. # downloading the same model weights at the same time.
@ -225,22 +217,7 @@ def prepare_hf_model_weights(
tqdm_class=Disabledtqdm) tqdm_class=Disabledtqdm)
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
if not use_safetensors:
hf_weights_files = [
x for x in hf_weights_files if not x.endswith("training_args.bin")
]
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
return prepare_hf_model_weights(model_name_or_path,
cache_dir=cache_dir,
use_safetensors=False,
fall_back_to_pt=False)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
``` ```

View File

@ -23,7 +23,7 @@ pathlib
pytest pytest
scikit-learn scikit-learn
numexpr numexpr
vllm==0.1.7
# online api libs # online api libs
# zhipuai # zhipuai
# dashscope>=1.10.0 # qwen # dashscope>=1.10.0 # qwen

View File

@ -23,7 +23,7 @@ pathlib
pytest pytest
scikit-learn scikit-learn
numexpr numexpr
vllm==0.1.7
# online api libs # online api libs
# zhipuai # zhipuai
# dashscope>=1.10.0 # qwen # dashscope>=1.10.0 # qwen

View File

@ -75,10 +75,10 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
""" """
import fastchat.constants import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger from fastchat.serve.model_worker import worker_id, logger
import argparse import argparse
import threading
import fastchat.serve.model_worker import fastchat.serve.model_worker
import fastchat.serve.vllm_worker
logger.setLevel(log_level) logger.setLevel(log_level)
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -89,6 +89,7 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
# 在线模型API # 在线模型API
if worker_class := kwargs.get("worker_class"): if worker_class := kwargs.get("worker_class"):
from fastchat.serve.model_worker import app
worker = worker_class(model_names=args.model_names, worker = worker_class(model_names=args.model_names,
controller_addr=args.controller_address, controller_addr=args.controller_address,
worker_addr=args.worker_address) worker_addr=args.worker_address)
@ -97,15 +98,10 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
else: else:
from configs.model_config import VLLM_MODEL_DICT from configs.model_config import VLLM_MODEL_DICT
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm": if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
from fastchat.serve.vllm_worker import VLLMWorker from fastchat.serve.vllm_worker import VLLMWorker,app
from vllm import AsyncLLMEngine from vllm import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
#! -------------似乎会在这个地方加入tokenizer------------
# parser = AsyncEngineArgs.add_cli_args(args)
# # args = parser.parse_args()
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加 args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
args.tokenizer_mode = 'auto' args.tokenizer_mode = 'auto'
args.trust_remote_code= True args.trust_remote_code= True
@ -150,9 +146,11 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
llm_engine = engine, llm_engine = engine,
conv_template = args.conv_template, conv_template = args.conv_template,
) )
sys.modules["fastchat.serve.vllm_worker"].engine = engine
sys.modules["fastchat.serve.vllm_worker"].worker = worker sys.modules["fastchat.serve.vllm_worker"].worker = worker
else: else:
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker
args.gpus = "1" args.gpus = "1"
args.max_gpu_memory = "20GiB" args.max_gpu_memory = "20GiB"
args.load_8bit = False args.load_8bit = False