增加api_one.py,把fastchat
3个服务端整合在一起。后面考虑把api.py也整合进来。 3个进程变成1个进程可能带来少许性能损失,但对于个人用户降低了部署难度。
This commit is contained in:
parent
97ee4686a1
commit
70c6870776
|
|
@ -2,4 +2,5 @@
|
|||
*.log.*
|
||||
logs
|
||||
.idea/
|
||||
__pycache__/
|
||||
__pycache__/
|
||||
knowledge_base/
|
||||
|
|
|
|||
|
|
@ -0,0 +1,157 @@
|
|||
import sys
|
||||
import os
|
||||
from configs.model_config import (
|
||||
llm_model_dict,
|
||||
LLM_MODEL,
|
||||
EMBEDDING_DEVICE,
|
||||
LLM_DEVICE,
|
||||
LOG_PATH,
|
||||
logger,
|
||||
)
|
||||
from fastchat.serve.model_worker import heart_beat_worker
|
||||
from fastapi import FastAPI
|
||||
import threading
|
||||
import asyncio
|
||||
|
||||
# 把fastchat的3个服务端整合在一起,分别是:
|
||||
# - http://{host_ip}:{port}/controller 对应 python -m fastchat.serve.controller
|
||||
# - http://{host_ip}:{port}/model_worker 对应 python -m fastchat.serve.model_worker ...
|
||||
# - http://{host_ip}:{port}/openai 对应 python -m fastchat.serve.openai_api_server ...
|
||||
|
||||
|
||||
host_ip = "0.0.0.0"
|
||||
port = 8888
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
|
||||
|
||||
def create_controller_app(
|
||||
dispatch_method="shortest_queue",
|
||||
):
|
||||
from fastchat.serve.controller import app, Controller
|
||||
controller = Controller(dispatch_method)
|
||||
sys.modules["fastchat.serve.controller"].controller = controller
|
||||
logger.info(f"controller dispatch method: {dispatch_method}")
|
||||
return app, controller
|
||||
|
||||
|
||||
def create_model_worker_app(
|
||||
model_path,
|
||||
model_names=[LLM_MODEL],
|
||||
device=LLM_DEVICE,
|
||||
load_8bit=False,
|
||||
gptq_ckpt=None,
|
||||
gptq_wbits=16,
|
||||
gpus='',
|
||||
num_gpus=-1,
|
||||
max_gpu_memory=-1,
|
||||
cpu_offloading=None,
|
||||
worker_address=f"{base_url}/model_worker",
|
||||
controller_address=f"{base_url}/controller",
|
||||
limit_model_concurrency=5,
|
||||
stream_interval=2,
|
||||
no_register=True, # mannually register
|
||||
):
|
||||
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
|
||||
from fastchat.serve import model_worker
|
||||
if gpus and num_gpus is None:
|
||||
num_gpus = len(gpus.split(','))
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=gptq_ckpt or model_path,
|
||||
wbits=gptq_wbits,
|
||||
groupsize=-1,
|
||||
act_order=None,
|
||||
)
|
||||
worker = ModelWorker(
|
||||
controller_address,
|
||||
worker_address,
|
||||
worker_id,
|
||||
no_register,
|
||||
model_path,
|
||||
model_names,
|
||||
device,
|
||||
num_gpus,
|
||||
max_gpu_memory,
|
||||
load_8bit,
|
||||
cpu_offloading,
|
||||
gptq_config,
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
return app, worker
|
||||
|
||||
|
||||
def create_openai_api_app(
|
||||
host=host_ip,
|
||||
port=port,
|
||||
controller_address=f"{base_url}/controller",
|
||||
allow_credentials=None,
|
||||
allowed_origins=["*"],
|
||||
allowed_methods=["*"],
|
||||
allowed_headers=["*"],
|
||||
api_keys=[],
|
||||
):
|
||||
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allowed_origins,
|
||||
allow_credentials=allow_credentials,
|
||||
allow_methods=allowed_methods,
|
||||
allow_headers=allowed_headers,
|
||||
)
|
||||
app_settings.controller_address = controller_address
|
||||
app_settings.api_keys = api_keys
|
||||
sys.modules["fastchat.serve.openai_api_server.app_settings"] = app_settings
|
||||
|
||||
return app
|
||||
|
||||
|
||||
LLM_MODEL = 'chatglm-6b'
|
||||
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
|
||||
global controller
|
||||
|
||||
|
||||
if not model_path:
|
||||
logger.error("local_model_path 不能为空")
|
||||
else:
|
||||
logger.info(f"using local model: {model_path}")
|
||||
app = FastAPI()
|
||||
|
||||
controller_app, controller = create_controller_app()
|
||||
app.mount("/controller", controller_app)
|
||||
|
||||
model_woker_app, worker = create_model_worker_app(model_path)
|
||||
app.mount("/model_worker", model_woker_app)
|
||||
|
||||
openai_api_app = create_openai_api_app()
|
||||
app.mount("/openai", openai_api_app)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
logger.info("Register to controller")
|
||||
controller.register_worker(
|
||||
worker.worker_addr,
|
||||
True,
|
||||
worker.get_status(),
|
||||
)
|
||||
worker.heart_beat_thread = threading.Thread(
|
||||
target=heart_beat_worker, args=(worker,)
|
||||
)
|
||||
worker.heart_beat_thread.start()
|
||||
|
||||
# 通过网络请求注册会卡死
|
||||
# async def register():
|
||||
# while True:
|
||||
# try:
|
||||
# worker.register_to_controller()
|
||||
# worker.heart_beat_thread = threading.Thread(
|
||||
# target=heart_beat_worker, args=(worker,)
|
||||
# )
|
||||
# worker.heart_beat_thread.start()
|
||||
# except:
|
||||
# await asyncio.sleep(1)
|
||||
# asyncio.get_event_loop().create_task(register())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=host_ip, port=port)
|
||||
Loading…
Reference in New Issue