增加api_one.py,把fastchat

3个服务端整合在一起。后面考虑把api.py也整合进来。
3个进程变成1个进程可能带来少许性能损失,但对于个人用户降低了部署难度。
This commit is contained in:
liunux4odoo 2023-07-28 16:33:42 +08:00
parent 97ee4686a1
commit 70c6870776
2 changed files with 159 additions and 1 deletions

3
.gitignore vendored
View File

@ -2,4 +2,5 @@
*.log.*
logs
.idea/
__pycache__/
__pycache__/
knowledge_base/

157
api_one.py Normal file
View File

@ -0,0 +1,157 @@
import sys
import os
from configs.model_config import (
llm_model_dict,
LLM_MODEL,
EMBEDDING_DEVICE,
LLM_DEVICE,
LOG_PATH,
logger,
)
from fastchat.serve.model_worker import heart_beat_worker
from fastapi import FastAPI
import threading
import asyncio
# 把fastchat的3个服务端整合在一起分别是
# - http://{host_ip}:{port}/controller 对应 python -m fastchat.serve.controller
# - http://{host_ip}:{port}/model_worker 对应 python -m fastchat.serve.model_worker ...
# - http://{host_ip}:{port}/openai 对应 python -m fastchat.serve.openai_api_server ...
host_ip = "0.0.0.0"
port = 8888
base_url = f"http://127.0.0.1:{port}"
def create_controller_app(
dispatch_method="shortest_queue",
):
from fastchat.serve.controller import app, Controller
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
logger.info(f"controller dispatch method: {dispatch_method}")
return app, controller
def create_model_worker_app(
model_path,
model_names=[LLM_MODEL],
device=LLM_DEVICE,
load_8bit=False,
gptq_ckpt=None,
gptq_wbits=16,
gpus='',
num_gpus=-1,
max_gpu_memory=-1,
cpu_offloading=None,
worker_address=f"{base_url}/model_worker",
controller_address=f"{base_url}/controller",
limit_model_concurrency=5,
stream_interval=2,
no_register=True, # mannually register
):
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
from fastchat.serve import model_worker
if gpus and num_gpus is None:
num_gpus = len(gpus.split(','))
gptq_config = GptqConfig(
ckpt=gptq_ckpt or model_path,
wbits=gptq_wbits,
groupsize=-1,
act_order=None,
)
worker = ModelWorker(
controller_address,
worker_address,
worker_id,
no_register,
model_path,
model_names,
device,
num_gpus,
max_gpu_memory,
load_8bit,
cpu_offloading,
gptq_config,
)
sys.modules["fastchat.serve.model_worker"].worker = worker
return app, worker
def create_openai_api_app(
host=host_ip,
port=port,
controller_address=f"{base_url}/controller",
allow_credentials=None,
allowed_origins=["*"],
allowed_methods=["*"],
allowed_headers=["*"],
api_keys=[],
):
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=allow_credentials,
allow_methods=allowed_methods,
allow_headers=allowed_headers,
)
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
sys.modules["fastchat.serve.openai_api_server.app_settings"] = app_settings
return app
LLM_MODEL = 'chatglm-6b'
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
global controller
if not model_path:
logger.error("local_model_path 不能为空")
else:
logger.info(f"using local model: {model_path}")
app = FastAPI()
controller_app, controller = create_controller_app()
app.mount("/controller", controller_app)
model_woker_app, worker = create_model_worker_app(model_path)
app.mount("/model_worker", model_woker_app)
openai_api_app = create_openai_api_app()
app.mount("/openai", openai_api_app)
@app.on_event("startup")
async def on_startup():
logger.info("Register to controller")
controller.register_worker(
worker.worker_addr,
True,
worker.get_status(),
)
worker.heart_beat_thread = threading.Thread(
target=heart_beat_worker, args=(worker,)
)
worker.heart_beat_thread.start()
# 通过网络请求注册会卡死
# async def register():
# while True:
# try:
# worker.register_to_controller()
# worker.heart_beat_thread = threading.Thread(
# target=heart_beat_worker, args=(worker,)
# )
# worker.heart_beat_thread.start()
# except:
# await asyncio.sleep(1)
# asyncio.get_event_loop().create_task(register())
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=host_ip, port=port)