158 lines
4.4 KiB
Python
158 lines
4.4 KiB
Python
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
from configs.model_config import (
|
|||
|
|
llm_model_dict,
|
|||
|
|
LLM_MODEL,
|
|||
|
|
EMBEDDING_DEVICE,
|
|||
|
|
LLM_DEVICE,
|
|||
|
|
LOG_PATH,
|
|||
|
|
logger,
|
|||
|
|
)
|
|||
|
|
from fastchat.serve.model_worker import heart_beat_worker
|
|||
|
|
from fastapi import FastAPI
|
|||
|
|
import threading
|
|||
|
|
import asyncio
|
|||
|
|
|
|||
|
|
# 把fastchat的3个服务端整合在一起,分别是:
|
|||
|
|
# - http://{host_ip}:{port}/controller 对应 python -m fastchat.serve.controller
|
|||
|
|
# - http://{host_ip}:{port}/model_worker 对应 python -m fastchat.serve.model_worker ...
|
|||
|
|
# - http://{host_ip}:{port}/openai 对应 python -m fastchat.serve.openai_api_server ...
|
|||
|
|
|
|||
|
|
|
|||
|
|
host_ip = "0.0.0.0"
|
|||
|
|
port = 8888
|
|||
|
|
base_url = f"http://127.0.0.1:{port}"
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_controller_app(
|
|||
|
|
dispatch_method="shortest_queue",
|
|||
|
|
):
|
|||
|
|
from fastchat.serve.controller import app, Controller
|
|||
|
|
controller = Controller(dispatch_method)
|
|||
|
|
sys.modules["fastchat.serve.controller"].controller = controller
|
|||
|
|
logger.info(f"controller dispatch method: {dispatch_method}")
|
|||
|
|
return app, controller
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_model_worker_app(
|
|||
|
|
model_path,
|
|||
|
|
model_names=[LLM_MODEL],
|
|||
|
|
device=LLM_DEVICE,
|
|||
|
|
load_8bit=False,
|
|||
|
|
gptq_ckpt=None,
|
|||
|
|
gptq_wbits=16,
|
|||
|
|
gpus='',
|
|||
|
|
num_gpus=-1,
|
|||
|
|
max_gpu_memory=-1,
|
|||
|
|
cpu_offloading=None,
|
|||
|
|
worker_address=f"{base_url}/model_worker",
|
|||
|
|
controller_address=f"{base_url}/controller",
|
|||
|
|
limit_model_concurrency=5,
|
|||
|
|
stream_interval=2,
|
|||
|
|
no_register=True, # mannually register
|
|||
|
|
):
|
|||
|
|
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
|
|||
|
|
from fastchat.serve import model_worker
|
|||
|
|
if gpus and num_gpus is None:
|
|||
|
|
num_gpus = len(gpus.split(','))
|
|||
|
|
gptq_config = GptqConfig(
|
|||
|
|
ckpt=gptq_ckpt or model_path,
|
|||
|
|
wbits=gptq_wbits,
|
|||
|
|
groupsize=-1,
|
|||
|
|
act_order=None,
|
|||
|
|
)
|
|||
|
|
worker = ModelWorker(
|
|||
|
|
controller_address,
|
|||
|
|
worker_address,
|
|||
|
|
worker_id,
|
|||
|
|
no_register,
|
|||
|
|
model_path,
|
|||
|
|
model_names,
|
|||
|
|
device,
|
|||
|
|
num_gpus,
|
|||
|
|
max_gpu_memory,
|
|||
|
|
load_8bit,
|
|||
|
|
cpu_offloading,
|
|||
|
|
gptq_config,
|
|||
|
|
)
|
|||
|
|
sys.modules["fastchat.serve.model_worker"].worker = worker
|
|||
|
|
return app, worker
|
|||
|
|
|
|||
|
|
|
|||
|
|
def create_openai_api_app(
|
|||
|
|
host=host_ip,
|
|||
|
|
port=port,
|
|||
|
|
controller_address=f"{base_url}/controller",
|
|||
|
|
allow_credentials=None,
|
|||
|
|
allowed_origins=["*"],
|
|||
|
|
allowed_methods=["*"],
|
|||
|
|
allowed_headers=["*"],
|
|||
|
|
api_keys=[],
|
|||
|
|
):
|
|||
|
|
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
|
|||
|
|
app.add_middleware(
|
|||
|
|
CORSMiddleware,
|
|||
|
|
allow_origins=allowed_origins,
|
|||
|
|
allow_credentials=allow_credentials,
|
|||
|
|
allow_methods=allowed_methods,
|
|||
|
|
allow_headers=allowed_headers,
|
|||
|
|
)
|
|||
|
|
app_settings.controller_address = controller_address
|
|||
|
|
app_settings.api_keys = api_keys
|
|||
|
|
sys.modules["fastchat.serve.openai_api_server.app_settings"] = app_settings
|
|||
|
|
|
|||
|
|
return app
|
|||
|
|
|
|||
|
|
|
|||
|
|
LLM_MODEL = 'chatglm-6b'
|
|||
|
|
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
|
|||
|
|
global controller
|
|||
|
|
|
|||
|
|
|
|||
|
|
if not model_path:
|
|||
|
|
logger.error("local_model_path 不能为空")
|
|||
|
|
else:
|
|||
|
|
logger.info(f"using local model: {model_path}")
|
|||
|
|
app = FastAPI()
|
|||
|
|
|
|||
|
|
controller_app, controller = create_controller_app()
|
|||
|
|
app.mount("/controller", controller_app)
|
|||
|
|
|
|||
|
|
model_woker_app, worker = create_model_worker_app(model_path)
|
|||
|
|
app.mount("/model_worker", model_woker_app)
|
|||
|
|
|
|||
|
|
openai_api_app = create_openai_api_app()
|
|||
|
|
app.mount("/openai", openai_api_app)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@app.on_event("startup")
|
|||
|
|
async def on_startup():
|
|||
|
|
logger.info("Register to controller")
|
|||
|
|
controller.register_worker(
|
|||
|
|
worker.worker_addr,
|
|||
|
|
True,
|
|||
|
|
worker.get_status(),
|
|||
|
|
)
|
|||
|
|
worker.heart_beat_thread = threading.Thread(
|
|||
|
|
target=heart_beat_worker, args=(worker,)
|
|||
|
|
)
|
|||
|
|
worker.heart_beat_thread.start()
|
|||
|
|
|
|||
|
|
# 通过网络请求注册会卡死
|
|||
|
|
# async def register():
|
|||
|
|
# while True:
|
|||
|
|
# try:
|
|||
|
|
# worker.register_to_controller()
|
|||
|
|
# worker.heart_beat_thread = threading.Thread(
|
|||
|
|
# target=heart_beat_worker, args=(worker,)
|
|||
|
|
# )
|
|||
|
|
# worker.heart_beat_thread.start()
|
|||
|
|
# except:
|
|||
|
|
# await asyncio.sleep(1)
|
|||
|
|
# asyncio.get_event_loop().create_task(register())
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
import uvicorn
|
|||
|
|
uvicorn.run(app, host=host_ip, port=port)
|