Langchain-Chatchat/api_one.py

158 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import sys
import os
from configs.model_config import (
llm_model_dict,
LLM_MODEL,
EMBEDDING_DEVICE,
LLM_DEVICE,
LOG_PATH,
logger,
)
from fastchat.serve.model_worker import heart_beat_worker
from fastapi import FastAPI
import threading
import asyncio
# 把fastchat的3个服务端整合在一起分别是
# - http://{host_ip}:{port}/controller 对应 python -m fastchat.serve.controller
# - http://{host_ip}:{port}/model_worker 对应 python -m fastchat.serve.model_worker ...
# - http://{host_ip}:{port}/openai 对应 python -m fastchat.serve.openai_api_server ...
host_ip = "0.0.0.0"
port = 8888
base_url = f"http://127.0.0.1:{port}"
def create_controller_app(
dispatch_method="shortest_queue",
):
from fastchat.serve.controller import app, Controller
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
logger.info(f"controller dispatch method: {dispatch_method}")
return app, controller
def create_model_worker_app(
model_path,
model_names=[LLM_MODEL],
device=LLM_DEVICE,
load_8bit=False,
gptq_ckpt=None,
gptq_wbits=16,
gpus='',
num_gpus=-1,
max_gpu_memory=-1,
cpu_offloading=None,
worker_address=f"{base_url}/model_worker",
controller_address=f"{base_url}/controller",
limit_model_concurrency=5,
stream_interval=2,
no_register=True, # mannually register
):
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
from fastchat.serve import model_worker
if gpus and num_gpus is None:
num_gpus = len(gpus.split(','))
gptq_config = GptqConfig(
ckpt=gptq_ckpt or model_path,
wbits=gptq_wbits,
groupsize=-1,
act_order=None,
)
worker = ModelWorker(
controller_address,
worker_address,
worker_id,
no_register,
model_path,
model_names,
device,
num_gpus,
max_gpu_memory,
load_8bit,
cpu_offloading,
gptq_config,
)
sys.modules["fastchat.serve.model_worker"].worker = worker
return app, worker
def create_openai_api_app(
host=host_ip,
port=port,
controller_address=f"{base_url}/controller",
allow_credentials=None,
allowed_origins=["*"],
allowed_methods=["*"],
allowed_headers=["*"],
api_keys=[],
):
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=allow_credentials,
allow_methods=allowed_methods,
allow_headers=allowed_headers,
)
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
sys.modules["fastchat.serve.openai_api_server.app_settings"] = app_settings
return app
LLM_MODEL = 'chatglm-6b'
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
global controller
if not model_path:
logger.error("local_model_path 不能为空")
else:
logger.info(f"using local model: {model_path}")
app = FastAPI()
controller_app, controller = create_controller_app()
app.mount("/controller", controller_app)
model_woker_app, worker = create_model_worker_app(model_path)
app.mount("/model_worker", model_woker_app)
openai_api_app = create_openai_api_app()
app.mount("/openai", openai_api_app)
@app.on_event("startup")
async def on_startup():
logger.info("Register to controller")
controller.register_worker(
worker.worker_addr,
True,
worker.get_status(),
)
worker.heart_beat_thread = threading.Thread(
target=heart_beat_worker, args=(worker,)
)
worker.heart_beat_thread.start()
# 通过网络请求注册会卡死
# async def register():
# while True:
# try:
# worker.register_to_controller()
# worker.heart_beat_thread = threading.Thread(
# target=heart_beat_worker, args=(worker,)
# )
# worker.heart_beat_thread.start()
# except:
# await asyncio.sleep(1)
# asyncio.get_event_loop().create_task(register())
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=host_ip, port=port)