llm_api升级到fschat==0.2.20,支持百川模型
This commit is contained in:
parent
51ea717606
commit
463659f0ba
|
|
@ -2,8 +2,9 @@ langchain==0.0.237
|
|||
openai
|
||||
sentence_transformers
|
||||
chromadb
|
||||
fschat==0.2.15
|
||||
fschat==0.2.20
|
||||
transformers
|
||||
transformers_stream_generator
|
||||
torch~=2.0.0
|
||||
fastapi~=0.99.1
|
||||
loguru
|
||||
|
|
|
|||
|
|
@ -41,13 +41,15 @@ def create_model_worker_app(
|
|||
load_8bit=False,
|
||||
gptq_ckpt=None,
|
||||
gptq_wbits=16,
|
||||
gptq_groupsize=-1,
|
||||
gptq_act_order=None,
|
||||
gpus=None,
|
||||
num_gpus=1,
|
||||
max_gpu_memory=None,
|
||||
cpu_offloading=None,
|
||||
worker_address=base_url.format(model_worker_port),
|
||||
controller_address=base_url.format(controller_port),
|
||||
limit_model_concurrency=5,
|
||||
limit_worker_concurrency=5,
|
||||
stream_interval=2,
|
||||
no_register=False,
|
||||
):
|
||||
|
|
@ -57,29 +59,6 @@ def create_model_worker_app(
|
|||
from loguru import logger
|
||||
logger.add(os.path.join(LOG_PATH, "model_worker.log"), level="INFO")
|
||||
|
||||
if gpus and num_gpus is None:
|
||||
num_gpus = len(gpus.split(','))
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=gptq_ckpt or model_path,
|
||||
wbits=gptq_wbits,
|
||||
groupsize=-1,
|
||||
act_order=None,
|
||||
)
|
||||
worker = ModelWorker(
|
||||
controller_addr=controller_address,
|
||||
worker_addr=worker_address,
|
||||
worker_id=worker_id,
|
||||
no_register=no_register,
|
||||
model_path=model_path,
|
||||
model_names=model_names,
|
||||
device=device,
|
||||
num_gpus=num_gpus,
|
||||
max_gpu_memory=max_gpu_memory,
|
||||
load_8bit=load_8bit,
|
||||
cpu_offloading=cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
# limit_worker_concurrency=1,
|
||||
)
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args()
|
||||
args.model_path = model_path
|
||||
|
|
@ -88,16 +67,53 @@ def create_model_worker_app(
|
|||
args.load_8bit = load_8bit
|
||||
args.gptq_ckpt = gptq_ckpt
|
||||
args.gptq_wbits = gptq_wbits
|
||||
args.gptq_groupsize = gptq_groupsize
|
||||
args.gptq_act_order = gptq_act_order
|
||||
args.gpus = gpus
|
||||
args.num_gpus = num_gpus
|
||||
args.max_gpu_memory = max_gpu_memory
|
||||
args.cpu_offloading = cpu_offloading
|
||||
args.worker_address = worker_address
|
||||
args.controller_address = controller_address
|
||||
args.limit_model_concurrency = limit_model_concurrency
|
||||
args.limit_worker_concurrency = limit_worker_concurrency
|
||||
args.stream_interval = stream_interval
|
||||
args.no_register = no_register
|
||||
|
||||
if args.gpus:
|
||||
if len(args.gpus.split(",")) < args.num_gpus:
|
||||
raise ValueError(
|
||||
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
||||
)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
|
||||
if gpus and num_gpus is None:
|
||||
num_gpus = len(gpus.split(','))
|
||||
args.num_gpus = num_gpus
|
||||
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=gptq_ckpt or model_path,
|
||||
wbits=args.gptq_wbits,
|
||||
groupsize=args.gptq_groupsize,
|
||||
act_order=args.gptq_act_order,
|
||||
)
|
||||
|
||||
worker = ModelWorker(
|
||||
args.controller_address,
|
||||
args.worker_address,
|
||||
worker_id,
|
||||
args.model_path,
|
||||
args.model_names,
|
||||
args.limit_worker_concurrency,
|
||||
no_register=args.no_register,
|
||||
device=args.device,
|
||||
num_gpus=args.num_gpus,
|
||||
max_gpu_memory=args.max_gpu_memory,
|
||||
load_8bit=args.load_8bit,
|
||||
cpu_offloading=args.cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
)
|
||||
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
|
@ -142,9 +158,9 @@ def run_controller(q):
|
|||
uvicorn.run(app, host=host_ip, port=controller_port)
|
||||
|
||||
|
||||
def run_model_worker(q):
|
||||
def run_model_worker(q, *args, **kwargs):
|
||||
import uvicorn
|
||||
app = create_model_worker_app()
|
||||
app = create_model_worker_app(*args, **kwargs)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
|
|
@ -193,7 +209,6 @@ def run_openai_api(q):
|
|||
if __name__ == "__main__":
|
||||
logger.info(llm_model_dict[LLM_MODEL])
|
||||
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
|
||||
# model_path = "d:\\chatglm\\models\\chatglm-6b"
|
||||
|
||||
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
||||
|
||||
|
|
@ -212,6 +227,7 @@ if __name__ == "__main__":
|
|||
target=run_model_worker,
|
||||
name=f"model_worker({os.getpid()})",
|
||||
args=(queue,),
|
||||
# kwargs={"load_8bit": True},
|
||||
daemon=True,
|
||||
)
|
||||
model_worker_process.start()
|
||||
|
|
|
|||
Loading…
Reference in New Issue