419 lines
12 KiB
Python
419 lines
12 KiB
Python
from multiprocessing import Process, Queue
|
|
import multiprocessing as mp
|
|
import subprocess
|
|
import sys
|
|
import os
|
|
from xml.etree.ElementPath import prepare_child
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
|
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
|
|
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
|
|
FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address,)
|
|
from server.utils import MakeFastAPIOffline, FastAPI
|
|
import argparse
|
|
from typing import Tuple, List
|
|
|
|
|
|
def set_httpx_timeout(timeout=60.0):
|
|
import httpx
|
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
|
|
|
|
|
def create_controller_app(
|
|
dispatch_method: str,
|
|
) -> FastAPI:
|
|
import fastchat.constants
|
|
fastchat.constants.LOGDIR = LOG_PATH
|
|
from fastchat.serve.controller import app, Controller
|
|
|
|
controller = Controller(dispatch_method)
|
|
sys.modules["fastchat.serve.controller"].controller = controller
|
|
|
|
MakeFastAPIOffline(app)
|
|
app.title = "FastChat Controller"
|
|
return app
|
|
|
|
|
|
def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
|
|
import fastchat.constants
|
|
fastchat.constants.LOGDIR = LOG_PATH
|
|
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
|
import argparse
|
|
import threading
|
|
import fastchat.serve.model_worker
|
|
|
|
# workaround to make program exit with Ctrl+c
|
|
# it should be deleted after pr is merged by fastchat
|
|
def _new_init_heart_beat(self):
|
|
self.register_to_controller()
|
|
self.heart_beat_thread = threading.Thread(
|
|
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
|
|
)
|
|
self.heart_beat_thread.start()
|
|
ModelWorker.init_heart_beat = _new_init_heart_beat
|
|
|
|
parser = argparse.ArgumentParser()
|
|
args = parser.parse_args([])
|
|
# default args. should be deleted after pr is merged by fastchat
|
|
args.gpus = None
|
|
args.max_gpu_memory = "20GiB"
|
|
args.load_8bit = False
|
|
args.cpu_offloading = None
|
|
args.gptq_ckpt = None
|
|
args.gptq_wbits = 16
|
|
args.gptq_groupsize = -1
|
|
args.gptq_act_order = False
|
|
args.awq_ckpt = None
|
|
args.awq_wbits = 16
|
|
args.awq_groupsize = -1
|
|
args.num_gpus = 1
|
|
args.model_names = []
|
|
args.conv_template = None
|
|
args.limit_worker_concurrency = 5
|
|
args.stream_interval = 2
|
|
args.no_register = False
|
|
|
|
for k, v in kwargs.items():
|
|
setattr(args, k, v)
|
|
|
|
if args.gpus:
|
|
if args.num_gpus is None:
|
|
args.num_gpus = len(args.gpus.split(','))
|
|
if len(args.gpus.split(",")) < args.num_gpus:
|
|
raise ValueError(
|
|
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
|
)
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
|
|
|
|
|
gptq_config = GptqConfig(
|
|
ckpt=args.gptq_ckpt or args.model_path,
|
|
wbits=args.gptq_wbits,
|
|
groupsize=args.gptq_groupsize,
|
|
act_order=args.gptq_act_order,
|
|
)
|
|
awq_config = AWQConfig(
|
|
ckpt=args.awq_ckpt or args.model_path,
|
|
wbits=args.awq_wbits,
|
|
groupsize=args.awq_groupsize,
|
|
)
|
|
|
|
worker = ModelWorker(
|
|
controller_addr=args.controller_address,
|
|
worker_addr=args.worker_address,
|
|
worker_id=worker_id,
|
|
model_path=args.model_path,
|
|
model_names=args.model_names,
|
|
limit_worker_concurrency=args.limit_worker_concurrency,
|
|
no_register=args.no_register,
|
|
device=args.device,
|
|
num_gpus=args.num_gpus,
|
|
max_gpu_memory=args.max_gpu_memory,
|
|
load_8bit=args.load_8bit,
|
|
cpu_offloading=args.cpu_offloading,
|
|
gptq_config=gptq_config,
|
|
awq_config=awq_config,
|
|
stream_interval=args.stream_interval,
|
|
conv_template=args.conv_template,
|
|
)
|
|
|
|
sys.modules["fastchat.serve.model_worker"].worker = worker
|
|
sys.modules["fastchat.serve.model_worker"].args = args
|
|
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
|
|
|
MakeFastAPIOffline(app)
|
|
app.title = f"FastChat LLM Server ({LLM_MODEL})"
|
|
return app
|
|
|
|
|
|
def create_openai_api_app(
|
|
controller_address: str,
|
|
api_keys: List = [],
|
|
) -> FastAPI:
|
|
import fastchat.constants
|
|
fastchat.constants.LOGDIR = LOG_PATH
|
|
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_credentials=True,
|
|
allow_origins=["*"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
app_settings.controller_address = controller_address
|
|
app_settings.api_keys = api_keys
|
|
|
|
MakeFastAPIOffline(app)
|
|
app.title = "FastChat OpeanAI API Server"
|
|
return app
|
|
|
|
|
|
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
|
|
if run_seq == 1:
|
|
@app.on_event("startup")
|
|
async def on_startup():
|
|
set_httpx_timeout()
|
|
q.put(run_seq)
|
|
elif run_seq > 1:
|
|
@app.on_event("startup")
|
|
async def on_startup():
|
|
set_httpx_timeout()
|
|
while True:
|
|
no = q.get()
|
|
if no != run_seq - 1:
|
|
q.put(no)
|
|
else:
|
|
break
|
|
q.put(run_seq)
|
|
|
|
|
|
def run_controller(q: Queue, run_seq: int = 1):
|
|
import uvicorn
|
|
|
|
app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method"))
|
|
_set_app_seq(app, q, run_seq)
|
|
|
|
host = FSCHAT_CONTROLLER["host"]
|
|
port = FSCHAT_CONTROLLER["port"]
|
|
uvicorn.run(app, host=host, port=port)
|
|
|
|
|
|
def run_model_worker(
|
|
model_name: str = LLM_MODEL,
|
|
controller_address: str = "",
|
|
q: Queue = None,
|
|
run_seq: int = 2,
|
|
):
|
|
import uvicorn
|
|
|
|
kwargs = FSCHAT_MODEL_WORKERS[LLM_MODEL].copy()
|
|
host = kwargs.pop("host")
|
|
port = kwargs.pop("port")
|
|
model_path = llm_model_dict[model_name].get("local_model_path", "")
|
|
kwargs["model_path"] = model_path
|
|
kwargs["model_names"] = [model_name]
|
|
kwargs["controller_address"] = controller_address or fschat_controller_address()
|
|
kwargs["worker_address"] = fschat_model_worker_address()
|
|
|
|
app = create_model_worker_app(**kwargs)
|
|
_set_app_seq(app, q, run_seq)
|
|
|
|
uvicorn.run(app, host=host, port=port)
|
|
|
|
|
|
def run_openai_api(q: Queue, run_seq: int = 3):
|
|
import uvicorn
|
|
|
|
controller_addr = fschat_controller_address()
|
|
app = create_openai_api_app(controller_addr) # todo: not support keys yet.
|
|
_set_app_seq(app, q, run_seq)
|
|
|
|
host = FSCHAT_OPENAI_API["host"]
|
|
port = FSCHAT_OPENAI_API["port"]
|
|
uvicorn.run(app, host=host, port=port)
|
|
|
|
|
|
def run_api_server(q: Queue, run_seq: int = 4):
|
|
from server.api import create_app
|
|
import uvicorn
|
|
|
|
app = create_app()
|
|
_set_app_seq(app, q, run_seq)
|
|
|
|
host = API_SERVER["host"]
|
|
port = API_SERVER["port"]
|
|
|
|
uvicorn.run(app, host=host, port=port)
|
|
|
|
|
|
def run_webui(q: Queue, run_seq: int = 5):
|
|
host = WEBUI_SERVER["host"]
|
|
port = WEBUI_SERVER["port"]
|
|
while True:
|
|
no = q.get()
|
|
if no != run_seq - 1:
|
|
q.put(no)
|
|
else:
|
|
break
|
|
q.put(run_seq)
|
|
p = subprocess.Popen(["streamlit", "run", "webui.py",
|
|
"--server.address", host,
|
|
"--server.port", str(port)])
|
|
p.wait()
|
|
|
|
|
|
def parse_args() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--all-webui",
|
|
action="store_true",
|
|
help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py",
|
|
dest="all",
|
|
)
|
|
parser.add_argument(
|
|
"--all-api",
|
|
action="store_true",
|
|
help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py",
|
|
dest="all",
|
|
)
|
|
parser.add_argument(
|
|
"--llm-api",
|
|
action="store_true",
|
|
help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py",
|
|
dest="all",
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--openai-api",
|
|
action="store_true",
|
|
help="run fastchat controller/openai_api servers",
|
|
dest="openai_api",
|
|
)
|
|
parser.add_argument(
|
|
"-m",
|
|
"--model-worker",
|
|
action="store_true",
|
|
help="run fastchat model_worker server with specified model name. specify --model-name if not using default LLM_MODEL",
|
|
dest="model_worker",
|
|
)
|
|
parser.add_argument(
|
|
"-n"
|
|
"--model-name",
|
|
type=str,
|
|
default=LLM_MODEL,
|
|
help="specify model name for model worker.",
|
|
dest="model_name",
|
|
)
|
|
parser.add_argument(
|
|
"-c"
|
|
"--controller",
|
|
type=str,
|
|
help="specify controller address the worker is registered to. default is server_config.FSCHAT_CONTROLLER",
|
|
dest="controller_address",
|
|
)
|
|
parser.add_argument(
|
|
"--api",
|
|
action="store_true",
|
|
help="run api.py server",
|
|
dest="api",
|
|
)
|
|
parser.add_argument(
|
|
"-w",
|
|
"--webui",
|
|
action="store_true",
|
|
help="run webui.py server",
|
|
dest="webui",
|
|
)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
if __name__ == "__main__":
|
|
mp.set_start_method("spawn")
|
|
queue = Queue()
|
|
args = parse_args()
|
|
if args.all_webui:
|
|
args.openai_api = True
|
|
args.model_worker = True
|
|
args.api = True
|
|
args.webui = True
|
|
|
|
elif args.all_api:
|
|
args.openai_api = True
|
|
args.model_worker = True
|
|
args.api = True
|
|
args.webui = False
|
|
|
|
elif args.llm_api:
|
|
args.openai_api = True
|
|
args.model_worker = True
|
|
args.api = False
|
|
args.webui = False
|
|
|
|
|
|
logger.info(f"正在启动服务:")
|
|
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
|
|
|
processes = {}
|
|
|
|
if args.openai_api:
|
|
process = Process(
|
|
target=run_controller,
|
|
name=f"controller({os.getpid()})",
|
|
args=(queue, len(processes) + 1),
|
|
daemon=True,
|
|
)
|
|
process.start()
|
|
processes["controller"] = process
|
|
|
|
process = Process(
|
|
target=run_openai_api,
|
|
name=f"openai_api({os.getpid()})",
|
|
args=(queue, len(processes) + 1),
|
|
daemon=True,
|
|
)
|
|
process.start()
|
|
processes["openai_api"] = process
|
|
|
|
if args.model_worker:
|
|
process = Process(
|
|
target=run_model_worker,
|
|
name=f"model_worker({os.getpid()})",
|
|
args=(args.model_name, args.controller_address, queue, len(processes) + 1),
|
|
daemon=True,
|
|
)
|
|
process.start()
|
|
processes["model_worker"] = process
|
|
|
|
if args.api:
|
|
process = Process(
|
|
target=run_api_server,
|
|
name=f"API Server{os.getpid()})",
|
|
args=(queue, len(processes) + 1),
|
|
daemon=True,
|
|
)
|
|
process.start()
|
|
processes["api"] = process
|
|
|
|
if args.webui:
|
|
process = Process(
|
|
target=run_webui,
|
|
name=f"WEBUI Server{os.getpid()})",
|
|
args=(queue,),
|
|
daemon=True,
|
|
)
|
|
process.start()
|
|
processes["webui"] = process
|
|
|
|
try:
|
|
if model_worker_process := processes.get("model_worker"):
|
|
model_worker_process.join()
|
|
for name, process in processes.items():
|
|
if name != "model_worker":
|
|
process.join()
|
|
except:
|
|
if model_worker_process := processes.get("model_worker"):
|
|
model_worker_process.terminate()
|
|
for name, process in processes.items():
|
|
if name != "model_worker":
|
|
process.terminate()
|
|
|
|
# 服务启动后接口调用示例:
|
|
# import openai
|
|
# openai.api_key = "EMPTY" # Not support yet
|
|
# openai.api_base = "http://localhost:8888/v1"
|
|
|
|
# model = "chatglm2-6b"
|
|
|
|
# # create a chat completion
|
|
# completion = openai.ChatCompletion.create(
|
|
# model=model,
|
|
# messages=[{"role": "user", "content": "Hello! What is your name?"}]
|
|
# )
|
|
# # print the completion
|
|
# print(completion.choices[0].message.content)
|