新增特性:1.支持vllm推理加速框架;2. 更新支持模型列表
This commit is contained in:
parent
89aed8e675
commit
810145c5fb
|
|
@ -40,11 +40,49 @@ MODEL_PATH = {
|
|||
"chatglm2-6b": "THUDM/chatglm2-6b",
|
||||
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
|
||||
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
||||
|
||||
"baichuan2-13b": "baichuan-inc/Baichuan-13B-Chat",
|
||||
"baichuan2-7b":"baichuan-inc/Baichuan2-7B-Chat",
|
||||
|
||||
"baichuan-7b": "baichuan-inc/Baichuan-7B",
|
||||
"baichuan2-7b-chat":"baichuan-inc/Baichuan2-7B-Chat",
|
||||
"baichuan2-13b-chat":"baichuan-inc/Baichuan2-13B-Chat",
|
||||
"baichuan2-7b":"baichuan-inc/Baichuan2-7B-Base",
|
||||
"baichuan2-13b":"baichuan-inc/Baichuan2-13B-Base",
|
||||
"baichuan-13b": "baichuan-inc/Baichuan-13B",
|
||||
'baichuan-13b-chat':'baichuan-inc/Baichuan-13B-Chat',
|
||||
|
||||
"aquila-7b":"BAAI/Aquila-7B",
|
||||
"aquilachat-7b":"BAAI/AquilaChat-7B",
|
||||
|
||||
"internlm-7b":"internlm/internlm-7b",
|
||||
"internlm-chat-7b":"internlm/internlm-chat-7b",
|
||||
|
||||
"falcon-7b":"tiiuae/falcon-7b",
|
||||
"falcon-40b":"tiiuae/falcon-40b",
|
||||
"falcon-rw-7b":"tiiuae/falcon-rw-7b",
|
||||
|
||||
"gpt2":"gpt2",
|
||||
"gpt2-xl":"gpt2-xl",
|
||||
|
||||
"gpt-j-6b":"EleutherAI/gpt-j-6b",
|
||||
"gpt4all-j":"nomic-ai/gpt4all-j",
|
||||
"gpt-neox-20b":"EleutherAI/gpt-neox-20b",
|
||||
"pythia-12b":"EleutherAI/pythia-12b",
|
||||
"oasst-sft-4-pythia-12b-epoch-3.5":"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
|
||||
"dolly-v2-12b":"databricks/dolly-v2-12b",
|
||||
"stablelm-tuned-alpha-7b":"stabilityai/stablelm-tuned-alpha-7b",
|
||||
|
||||
"Llama-2-13b-hf":"meta-llama/Llama-2-13b-hf",
|
||||
"Llama-2-70b-hf":"meta-llama/Llama-2-70b-hf",
|
||||
"open_llama_13b":"openlm-research/open_llama_13b",
|
||||
"vicuna-13b-v1.3":"lmsys/vicuna-13b-v1.3",
|
||||
"koala":"young-geng/koala",
|
||||
|
||||
"mpt-7b":"mosaicml/mpt-7b",
|
||||
"mpt-7b-storywriter":"mosaicml/mpt-7b-storywriter",
|
||||
"mpt-30b":"mosaicml/mpt-30b",
|
||||
"opt-66b":"facebook/opt-66b",
|
||||
"opt-iml-max-30b":"facebook/opt-iml-max-30b",
|
||||
|
||||
"Qwen-7B":"Qwen/Qwen-7B",
|
||||
"Qwen-7B-Chat":"Qwen/Qwen-7B-Chat",
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -55,7 +93,7 @@ EMBEDDING_MODEL = "m3e-base"
|
|||
EMBEDDING_DEVICE = "auto"
|
||||
|
||||
# LLM 名称
|
||||
LLM_MODEL = "chatglm2-6b"
|
||||
LLM_MODEL = "baichuan-7b"
|
||||
|
||||
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||
LLM_DEVICE = "auto"
|
||||
|
|
@ -140,3 +178,48 @@ ONLINE_LLM_MODEL = {
|
|||
|
||||
# nltk 模型存储路径
|
||||
NLTK_DATA_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "nltk_data")
|
||||
|
||||
|
||||
VLLM_MODEL_DICT = {
|
||||
"aquila-7b":"BAAI/Aquila-7B",
|
||||
"aquilachat-7b":"BAAI/AquilaChat-7B",
|
||||
|
||||
"baichuan-7b": "baichuan-inc/Baichuan-7B",
|
||||
"baichuan-13b": "baichuan-inc/Baichuan-13B",
|
||||
'baichuan-13b-chat':'baichuan-inc/Baichuan-13B-Chat',
|
||||
# 注意:bloom系列的tokenizer与model是分离的,因此虽然vllm支持,但与fschat框架不兼容
|
||||
# "bloom":"bigscience/bloom",
|
||||
# "bloomz":"bigscience/bloomz",
|
||||
# "bloomz-560m":"bigscience/bloomz-560m",
|
||||
# "bloomz-7b1":"bigscience/bloomz-7b1",
|
||||
# "bloomz-1b7":"bigscience/bloomz-1b7",
|
||||
|
||||
"internlm-7b":"internlm/internlm-7b",
|
||||
"internlm-chat-7b":"internlm/internlm-chat-7b",
|
||||
"falcon-7b":"tiiuae/falcon-7b",
|
||||
"falcon-40b":"tiiuae/falcon-40b",
|
||||
"falcon-rw-7b":"tiiuae/falcon-rw-7b",
|
||||
"gpt2":"gpt2",
|
||||
"gpt2-xl":"gpt2-xl",
|
||||
"gpt-j-6b":"EleutherAI/gpt-j-6b",
|
||||
"gpt4all-j":"nomic-ai/gpt4all-j",
|
||||
"gpt-neox-20b":"EleutherAI/gpt-neox-20b",
|
||||
"pythia-12b":"EleutherAI/pythia-12b",
|
||||
"oasst-sft-4-pythia-12b-epoch-3.5":"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",
|
||||
"dolly-v2-12b":"databricks/dolly-v2-12b",
|
||||
"stablelm-tuned-alpha-7b":"stabilityai/stablelm-tuned-alpha-7b",
|
||||
"Llama-2-13b-hf":"meta-llama/Llama-2-13b-hf",
|
||||
"Llama-2-70b-hf":"meta-llama/Llama-2-70b-hf",
|
||||
"open_llama_13b":"openlm-research/open_llama_13b",
|
||||
"vicuna-13b-v1.3":"lmsys/vicuna-13b-v1.3",
|
||||
"koala":"young-geng/koala",
|
||||
"mpt-7b":"mosaicml/mpt-7b",
|
||||
"mpt-7b-storywriter":"mosaicml/mpt-7b-storywriter",
|
||||
"mpt-30b":"mosaicml/mpt-30b",
|
||||
"opt-66b":"facebook/opt-66b",
|
||||
"opt-iml-max-30b":"facebook/opt-iml-max-30b",
|
||||
|
||||
"Qwen-7B":"Qwen/Qwen-7B",
|
||||
"Qwen-7B-Chat":"Qwen/Qwen-7B-Chat",
|
||||
|
||||
}
|
||||
|
|
@ -9,7 +9,7 @@ HTTPX_DEFAULT_TIMEOUT = 300.0
|
|||
OPEN_CROSS_DOMAIN = False
|
||||
|
||||
# 各服务器默认绑定host。如改为"0.0.0.0"需要修改下方所有XX_SERVER的host
|
||||
DEFAULT_BIND_HOST = "127.0.0.1"
|
||||
DEFAULT_BIND_HOST = "0.0.0.0"
|
||||
|
||||
# webui.py server
|
||||
WEBUI_SERVER = {
|
||||
|
|
@ -38,13 +38,14 @@ FSCHAT_MODEL_WORKERS = {
|
|||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 20002,
|
||||
"device": LLM_DEVICE,
|
||||
"infer_turbo": "vllm" # False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题,参见doc/FAQ
|
||||
|
||||
# 多卡加载需要配置的参数
|
||||
# model_worker多卡加载需要配置的参数
|
||||
# "gpus": None, # 使用的GPU,以str的格式指定,如"0,1"
|
||||
# "num_gpus": 1, # 使用GPU的数量
|
||||
# "max_gpu_memory": "20GiB", # 每个GPU占用的最大显存
|
||||
|
||||
# 以下为非常用参数,可根据需要配置
|
||||
# 以下为model_worker非常用参数,可根据需要配置
|
||||
# "load_8bit": False, # 开启8bit量化
|
||||
# "cpu_offloading": None,
|
||||
# "gptq_ckpt": None,
|
||||
|
|
@ -60,9 +61,32 @@ FSCHAT_MODEL_WORKERS = {
|
|||
# "stream_interval": 2,
|
||||
# "no_register": False,
|
||||
# "embed_in_truncate": False,
|
||||
},
|
||||
"baichuan-7b": { # 使用default中的IP和端口
|
||||
"device": "cpu",
|
||||
|
||||
# 以下为vllm_woker配置参数,注意使用vllm必须有gpu
|
||||
|
||||
# tokenizer = model_path # 如果tokenizer与model_path不一致在此处添加
|
||||
# 'tokenizer_mode':'auto',
|
||||
# 'trust_remote_code':True,
|
||||
# 'download_dir':None,
|
||||
# 'load_format':'auto',
|
||||
# 'dtype':'auto',
|
||||
# 'seed':0,
|
||||
# 'worker_use_ray':False,
|
||||
# 'pipeline_parallel_size':1,
|
||||
# 'tensor_parallel_size':1,
|
||||
# 'block_size':16,
|
||||
# 'swap_space':4 , # GiB
|
||||
# 'gpu_memory_utilization':0.90,
|
||||
# 'max_num_batched_tokens':2560,
|
||||
# 'max_num_seqs':256,
|
||||
# 'disable_log_stats':False,
|
||||
# 'conv_template':None,
|
||||
# 'limit_worker_concurrency':5,
|
||||
# 'no_register':False,
|
||||
# 'num_gpus': 1
|
||||
# 'engine_use_ray': False,
|
||||
# 'disable_log_requests': False
|
||||
|
||||
},
|
||||
"zhipu-api": { # 请为每个要运行的在线API设置不同的端口
|
||||
"port": 21001,
|
||||
|
|
|
|||
67
docs/FAQ.md
67
docs/FAQ.md
|
|
@ -107,7 +107,7 @@ embedding_model_dict = {
|
|||
|
||||
Q9: 执行 `python cli_demo.py`过程中,显卡内存爆了,提示 "OutOfMemoryError: CUDA out of memory"
|
||||
|
||||
A9: 将 `VECTOR_SEARCH_TOP_K` 和 `LLM_HISTORY_LEN` 的值调低,比如 `VECTOR_SEARCH_TOP_K = 5` 和 `LLM_HISTORY_LEN = 2`,这样由 `query` 和 `context` 拼接得到的 `prompt` 会变短,会减少内存的占用。或者打开量化,请在 [configs/model_config.py](../configs/model_config.py) 文件中,对`LOAD_IN_8BIT`参数进行修改
|
||||
A9: 将 `VECTOR_SEARCH_TOP_K` 和 `LLM_HISTORY_LEN` 的值调低,比如 `VECTOR_SEARCH_TOP_K = 5` 和 `LLM_HISTORY_LEN = 2`,这样由 `query` 和 `context` 拼接得到的 `prompt` 会变短,会减少内存的占用。或者打开量化,请在 [configs/model_config.py](../configs/model_config.py) 文件中,对 `LOAD_IN_8BIT`参数进行修改
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -171,7 +171,6 @@ Q14: 修改配置中路径后,加载 text2vec-large-chinese 依然提示 `WARN
|
|||
|
||||
A14: 尝试更换 embedding,如 text2vec-base-chinese,请在 [configs/model_config.py](../configs/model_config.py) 文件中,修改 `text2vec-base`参数为本地路径,绝对路径或者相对路径均可
|
||||
|
||||
|
||||
---
|
||||
|
||||
Q15: 使用pg向量库建表报错
|
||||
|
|
@ -182,4 +181,66 @@ A15: 需要手动安装对应的vector扩展(连接pg执行 CREATE EXTENSION IF
|
|||
|
||||
Q16: pymilvus 连接超时
|
||||
|
||||
A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1.3
|
||||
A16.pymilvus版本需要匹配和milvus对应否则会超时参考pymilvus==2.1.3
|
||||
|
||||
Q16: 使用vllm推理加速框架时,已经下载了模型但出现HuggingFace通信问题
|
||||
|
||||
A16: 参照如下代码修改python环境下/site-packages/vllm/model_executor/weight_utils.py文件的prepare_hf_model_weights函数如下:
|
||||
|
||||
```python
|
||||
def prepare_hf_model_weights(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
use_safetensors: bool = False,
|
||||
fall_back_to_pt: bool = True,
|
||||
):
|
||||
# Download model weights from huggingface.
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
allow_patterns = "*.safetensors" if use_safetensors else "*.bin"
|
||||
if not is_local:
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
model_path_temp = os.path.join(
|
||||
os.getenv("HOME"),
|
||||
".cache/huggingface/hub",
|
||||
"models--" + model_name_or_path.replace("/", "--"),
|
||||
"snapshots/",
|
||||
)
|
||||
downloaded = False
|
||||
if os.path.exists(model_path_temp):
|
||||
temp_last_dir = os.listdir(model_path_temp)[-1]
|
||||
model_path_temp = os.path.join(model_path_temp, temp_last_dir)
|
||||
base_pattern = os.path.join(model_path_temp, "pytorch_model*.bin")
|
||||
files = glob.glob(base_pattern)
|
||||
if len(files) > 0:
|
||||
downloaded = True
|
||||
|
||||
if downloaded:
|
||||
hf_folder = model_path_temp
|
||||
else:
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
hf_folder = snapshot_download(model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=Disabledtqdm)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
|
||||
if not use_safetensors:
|
||||
hf_weights_files = [
|
||||
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
||||
]
|
||||
|
||||
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
|
||||
return prepare_hf_model_weights(model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
use_safetensors=False,
|
||||
fall_back_to_pt=False)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
```
|
||||
|
|
|
|||
217
startup.py
217
startup.py
|
|
@ -58,6 +58,21 @@ def create_controller_app(
|
|||
|
||||
|
||||
def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
||||
"""
|
||||
kwargs包含的字段如下:
|
||||
host:
|
||||
port:
|
||||
model_names:[`model_name`]
|
||||
controller_address:
|
||||
worker_address:
|
||||
|
||||
对于online_api:
|
||||
online_api:True
|
||||
worker_class: `provider`
|
||||
对于离线模型:
|
||||
model_path: `model_name_or_path`,huggingface的repo-id或本地路径
|
||||
device:`LLM_DEVICE`
|
||||
"""
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id, logger
|
||||
|
|
@ -66,104 +81,142 @@ def create_model_worker_app(log_level: str = "INFO", **kwargs) -> FastAPI:
|
|||
import fastchat.serve.model_worker
|
||||
logger.setLevel(log_level)
|
||||
|
||||
# workaround to make program exit with Ctrl+c
|
||||
# it should be deleted after pr is merged by fastchat
|
||||
def _new_init_heart_beat(self):
|
||||
self.register_to_controller()
|
||||
self.heart_beat_thread = threading.Thread(
|
||||
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
|
||||
)
|
||||
self.heart_beat_thread.start()
|
||||
|
||||
ModelWorker.init_heart_beat = _new_init_heart_beat
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args([])
|
||||
# default args. should be deleted after pr is merged by fastchat
|
||||
args.gpus = None
|
||||
args.max_gpu_memory = "20GiB"
|
||||
args.load_8bit = False
|
||||
args.cpu_offloading = None
|
||||
args.gptq_ckpt = None
|
||||
args.gptq_wbits = 16
|
||||
args.gptq_groupsize = -1
|
||||
args.gptq_act_order = False
|
||||
args.awq_ckpt = None
|
||||
args.awq_wbits = 16
|
||||
args.awq_groupsize = -1
|
||||
args.num_gpus = 1
|
||||
args.model_names = []
|
||||
args.conv_template = None
|
||||
args.limit_worker_concurrency = 5
|
||||
args.stream_interval = 2
|
||||
args.no_register = False
|
||||
args.embed_in_truncate = False
|
||||
|
||||
for k, v in kwargs.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
if args.gpus:
|
||||
if args.num_gpus is None:
|
||||
args.num_gpus = len(args.gpus.split(','))
|
||||
if len(args.gpus.split(",")) < args.num_gpus:
|
||||
raise ValueError(
|
||||
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
||||
)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
|
||||
# 在线模型API
|
||||
if worker_class := kwargs.get("worker_class"):
|
||||
worker = worker_class(model_names=args.model_names,
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
# 本地模型
|
||||
else:
|
||||
# workaround to make program exit with Ctrl+c
|
||||
# it should be deleted after pr is merged by fastchat
|
||||
def _new_init_heart_beat(self):
|
||||
self.register_to_controller()
|
||||
self.heart_beat_thread = threading.Thread(
|
||||
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
|
||||
from configs.model_config import VLLM_MODEL_DICT
|
||||
if kwargs["model_names"][0] in VLLM_MODEL_DICT and args.infer_turbo == "vllm":
|
||||
from fastchat.serve.vllm_worker import VLLMWorker
|
||||
from vllm import AsyncLLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs,EngineArgs
|
||||
#! -------------似乎会在这个地方加入tokenizer------------
|
||||
|
||||
# parser = AsyncEngineArgs.add_cli_args(args)
|
||||
# # args = parser.parse_args()
|
||||
|
||||
|
||||
args.tokenizer = args.model_path # 如果tokenizer与model_path不一致在此处添加
|
||||
args.tokenizer_mode = 'auto'
|
||||
args.trust_remote_code= True
|
||||
args.download_dir= None
|
||||
args.load_format = 'auto'
|
||||
args.dtype = 'auto'
|
||||
args.seed = 0
|
||||
args.worker_use_ray = False
|
||||
args.pipeline_parallel_size = 1
|
||||
args.tensor_parallel_size = 1
|
||||
args.block_size = 16
|
||||
args.swap_space = 4 # GiB
|
||||
args.gpu_memory_utilization = 0.90
|
||||
args.max_num_batched_tokens = 2560
|
||||
args.max_num_seqs = 256
|
||||
args.disable_log_stats = False
|
||||
args.conv_template = None
|
||||
args.limit_worker_concurrency = 5
|
||||
args.no_register = False
|
||||
args.num_gpus = 1
|
||||
args.engine_use_ray = False
|
||||
args.disable_log_requests = False
|
||||
if args.model_path:
|
||||
args.model = args.model_path
|
||||
if args.num_gpus > 1:
|
||||
args.tensor_parallel_size = args.num_gpus
|
||||
|
||||
for k, v in kwargs.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
worker = VLLMWorker(
|
||||
controller_addr = args.controller_address,
|
||||
worker_addr = args.worker_address,
|
||||
worker_id = worker_id,
|
||||
model_path = args.model_path,
|
||||
model_names = args.model_names,
|
||||
limit_worker_concurrency = args.limit_worker_concurrency,
|
||||
no_register = args.no_register,
|
||||
llm_engine = engine,
|
||||
conv_template = args.conv_template,
|
||||
)
|
||||
sys.modules["fastchat.serve.vllm_worker"].worker = worker
|
||||
|
||||
else:
|
||||
# default args. should be deleted after pr is merged by fastchat
|
||||
args.gpus = "1"
|
||||
args.max_gpu_memory = "20GiB"
|
||||
args.load_8bit = False
|
||||
args.cpu_offloading = None
|
||||
args.gptq_ckpt = None
|
||||
args.gptq_wbits = 16
|
||||
args.gptq_groupsize = -1
|
||||
args.gptq_act_order = False
|
||||
args.awq_ckpt = None
|
||||
args.awq_wbits = 16
|
||||
args.awq_groupsize = -1
|
||||
args.num_gpus = 1
|
||||
args.model_names = []
|
||||
args.conv_template = None
|
||||
args.limit_worker_concurrency = 5
|
||||
args.stream_interval = 2
|
||||
args.no_register = False
|
||||
args.embed_in_truncate = False
|
||||
for k, v in kwargs.items():
|
||||
setattr(args, k, v)
|
||||
if args.gpus:
|
||||
if args.num_gpus is None:
|
||||
args.num_gpus = len(args.gpus.split(','))
|
||||
if len(args.gpus.split(",")) < args.num_gpus:
|
||||
raise ValueError(
|
||||
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
||||
)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=args.gptq_ckpt or args.model_path,
|
||||
wbits=args.gptq_wbits,
|
||||
groupsize=args.gptq_groupsize,
|
||||
act_order=args.gptq_act_order,
|
||||
)
|
||||
awq_config = AWQConfig(
|
||||
ckpt=args.awq_ckpt or args.model_path,
|
||||
wbits=args.awq_wbits,
|
||||
groupsize=args.awq_groupsize,
|
||||
)
|
||||
self.heart_beat_thread.start()
|
||||
|
||||
ModelWorker.init_heart_beat = _new_init_heart_beat
|
||||
worker = ModelWorker(
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address,
|
||||
worker_id=worker_id,
|
||||
model_path=args.model_path,
|
||||
model_names=args.model_names,
|
||||
limit_worker_concurrency=args.limit_worker_concurrency,
|
||||
no_register=args.no_register,
|
||||
device=args.device,
|
||||
num_gpus=args.num_gpus,
|
||||
max_gpu_memory=args.max_gpu_memory,
|
||||
load_8bit=args.load_8bit,
|
||||
cpu_offloading=args.cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
conv_template=args.conv_template,
|
||||
embed_in_truncate=args.embed_in_truncate,
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=args.gptq_ckpt or args.model_path,
|
||||
wbits=args.gptq_wbits,
|
||||
groupsize=args.gptq_groupsize,
|
||||
act_order=args.gptq_act_order,
|
||||
)
|
||||
awq_config = AWQConfig(
|
||||
ckpt=args.awq_ckpt or args.model_path,
|
||||
wbits=args.awq_wbits,
|
||||
groupsize=args.awq_groupsize,
|
||||
)
|
||||
|
||||
worker = ModelWorker(
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address,
|
||||
worker_id=worker_id,
|
||||
model_path=args.model_path,
|
||||
model_names=args.model_names,
|
||||
limit_worker_concurrency=args.limit_worker_concurrency,
|
||||
no_register=args.no_register,
|
||||
device=args.device,
|
||||
num_gpus=args.num_gpus,
|
||||
max_gpu_memory=args.max_gpu_memory,
|
||||
load_8bit=args.load_8bit,
|
||||
cpu_offloading=args.cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
conv_template=args.conv_template,
|
||||
embed_in_truncate=args.embed_in_truncate,
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = f"FastChat LLM Server ({args.model_names[0]})"
|
||||
|
|
|
|||
Loading…
Reference in New Issue