add startup.py: start specified servers with one command. see python startup.py --help
This commit is contained in:
parent
f29a877bd0
commit
f92b002342
|
|
@ -4,4 +4,4 @@ logs
|
|||
.idea/
|
||||
__pycache__/
|
||||
knowledge_base/
|
||||
configs/model_config.py
|
||||
configs/*.py
|
||||
|
|
|
|||
|
|
@ -52,13 +52,13 @@ llm_model_dict = {
|
|||
|
||||
"chatglm2-6b": {
|
||||
"local_model_path": "THUDM/chatglm2-6b",
|
||||
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
"chatglm2-6b-32k": {
|
||||
"local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k",
|
||||
"api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url"
|
||||
"api_base_url": "http://localhost:8888/v1", # "URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致
|
||||
"api_key": "EMPTY"
|
||||
},
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,88 @@
|
|||
from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE
|
||||
|
||||
|
||||
# API 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
OPEN_CROSS_DOMAIN = False
|
||||
|
||||
# 各服务器默认绑定host
|
||||
DEFAULT_BIND_HOST = "127.0.0.1"
|
||||
|
||||
# webui.py server
|
||||
WEBUI_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 8501,
|
||||
}
|
||||
|
||||
# api.py server
|
||||
API_SERVER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 7861,
|
||||
}
|
||||
|
||||
# fastchat openai_api server
|
||||
FSCHAT_OPENAI_API = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。
|
||||
}
|
||||
|
||||
# fastchat model_worker server
|
||||
# 这些模型必须是在model_config.llm_model_dict中正确配置的。
|
||||
# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL
|
||||
FSCHAT_MODEL_WORKERS = {
|
||||
LLM_MODEL: {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 20002,
|
||||
"device": LLM_DEVICE,
|
||||
# todo: 多卡加载需要配置的参数
|
||||
"gpus": None,
|
||||
"numgpus": 1,
|
||||
# 以下为非常用参数,可根据需要配置
|
||||
# "max_gpu_memory": "20GiB",
|
||||
# "load_8bit": False,
|
||||
# "cpu_offloading": None,
|
||||
# "gptq_ckpt": None,
|
||||
# "gptq_wbits": 16,
|
||||
# "gptq_groupsize": -1,
|
||||
# "gptq_act_order": False,
|
||||
# "awq_ckpt": None,
|
||||
# "awq_wbits": 16,
|
||||
# "awq_groupsize": -1,
|
||||
# "model_names": [LLM_MODEL],
|
||||
# "conv_template": None,
|
||||
# "limit_worker_concurrency": 5,
|
||||
# "stream_interval": 2,
|
||||
# "no_register": False,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# fastchat multi model worker server
|
||||
FSCHAT_MULTI_MODEL_WORKERS = {
|
||||
# todo
|
||||
}
|
||||
|
||||
# fastchat controller server
|
||||
FSCHAT_CONTROLLER = {
|
||||
"host": DEFAULT_BIND_HOST,
|
||||
"port": 20001,
|
||||
"dispatch_method": "shortest_queue",
|
||||
}
|
||||
|
||||
|
||||
# 以下不要更改
|
||||
def fschat_controller_address() -> str:
|
||||
host = FSCHAT_CONTROLLER["host"]
|
||||
port = FSCHAT_CONTROLLER["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str:
|
||||
if model := FSCHAT_MODEL_WORKERS.get(model_name):
|
||||
host = model["host"]
|
||||
port = model["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
||||
def fschat_openai_api_address() -> str:
|
||||
host = FSCHAT_OPENAI_API["host"]
|
||||
port = FSCHAT_OPENAI_API["port"]
|
||||
return f"http://{host}:{port}"
|
||||
|
|
@ -4,7 +4,8 @@ import os
|
|||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
|
||||
from configs.model_config import NLTK_DATA_PATH
|
||||
from configs.server_config import OPEN_CROSS_DOMAIN
|
||||
import argparse
|
||||
import uvicorn
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
|
|
|||
|
|
@ -0,0 +1,364 @@
|
|||
from multiprocessing import Process, Queue
|
||||
import multiprocessing as mp
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
from xml.etree.ElementPath import prepare_child
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
|
||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
|
||||
FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address,)
|
||||
from server.utils import MakeFastAPIOffline, FastAPI
|
||||
import argparse
|
||||
from typing import Tuple, List
|
||||
|
||||
|
||||
def set_httpx_timeout(timeout=60.0):
|
||||
import httpx
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
||||
|
||||
|
||||
def create_controller_app(
|
||||
dispatch_method: str,
|
||||
) -> FastAPI:
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.controller import app, Controller
|
||||
|
||||
controller = Controller(dispatch_method)
|
||||
sys.modules["fastchat.serve.controller"].controller = controller
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = "FastChat Controller"
|
||||
return app
|
||||
|
||||
|
||||
def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]:
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id
|
||||
import argparse
|
||||
import threading
|
||||
import fastchat.serve.model_worker
|
||||
|
||||
# workaround to make program exit with Ctrl+c
|
||||
# it should be deleted after pr is merged by fastchat
|
||||
def _new_init_heart_beat(self):
|
||||
self.register_to_controller()
|
||||
self.heart_beat_thread = threading.Thread(
|
||||
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
|
||||
)
|
||||
self.heart_beat_thread.start()
|
||||
ModelWorker.init_heart_beat = _new_init_heart_beat
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
args = parser.parse_args([])
|
||||
# default args. should be deleted after pr is merged by fastchat
|
||||
args.gpus = None
|
||||
args.max_gpu_memory = "20GiB"
|
||||
args.load_8bit = False
|
||||
args.cpu_offloading = None
|
||||
args.gptq_ckpt = None
|
||||
args.gptq_wbits = 16
|
||||
args.gptq_groupsize = -1
|
||||
args.gptq_act_order = False
|
||||
args.awq_ckpt = None
|
||||
args.awq_wbits = 16
|
||||
args.awq_groupsize = -1
|
||||
args.num_gpus = 1
|
||||
args.model_names = []
|
||||
args.conv_template = None
|
||||
args.limit_worker_concurrency = 5
|
||||
args.stream_interval = 2
|
||||
args.no_register = False
|
||||
|
||||
for k, v in kwargs.items():
|
||||
setattr(args, k, v)
|
||||
|
||||
if args.gpus:
|
||||
if args.num_gpus is None:
|
||||
args.num_gpus = len(args.gpus.split(','))
|
||||
if len(args.gpus.split(",")) < args.num_gpus:
|
||||
raise ValueError(
|
||||
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
||||
)
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||||
|
||||
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=args.gptq_ckpt or args.model_path,
|
||||
wbits=args.gptq_wbits,
|
||||
groupsize=args.gptq_groupsize,
|
||||
act_order=args.gptq_act_order,
|
||||
)
|
||||
awq_config = AWQConfig(
|
||||
ckpt=args.awq_ckpt or args.model_path,
|
||||
wbits=args.awq_wbits,
|
||||
groupsize=args.awq_groupsize,
|
||||
)
|
||||
|
||||
worker = ModelWorker(
|
||||
controller_addr=args.controller_address,
|
||||
worker_addr=args.worker_address,
|
||||
worker_id=worker_id,
|
||||
model_path=args.model_path,
|
||||
model_names=args.model_names,
|
||||
limit_worker_concurrency=args.limit_worker_concurrency,
|
||||
no_register=args.no_register,
|
||||
device=args.device,
|
||||
num_gpus=args.num_gpus,
|
||||
max_gpu_memory=args.max_gpu_memory,
|
||||
load_8bit=args.load_8bit,
|
||||
cpu_offloading=args.cpu_offloading,
|
||||
gptq_config=gptq_config,
|
||||
awq_config=awq_config,
|
||||
stream_interval=args.stream_interval,
|
||||
conv_template=args.conv_template,
|
||||
)
|
||||
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
sys.modules["fastchat.serve.model_worker"].args = args
|
||||
sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = f"FastChat LLM Server ({LLM_MODEL})"
|
||||
return app
|
||||
|
||||
|
||||
def create_openai_api_app(
|
||||
controller_address: str,
|
||||
api_keys: List = [],
|
||||
) -> FastAPI:
|
||||
import fastchat.constants
|
||||
fastchat.constants.LOGDIR = LOG_PATH
|
||||
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_credentials=True,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app_settings.controller_address = controller_address
|
||||
app_settings.api_keys = api_keys
|
||||
|
||||
MakeFastAPIOffline(app)
|
||||
app.title = "FastChat OpeanAI API Server"
|
||||
return app
|
||||
|
||||
|
||||
def _set_app_seq(app: FastAPI, q: Queue, run_seq: int):
|
||||
if run_seq == 1:
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
set_httpx_timeout()
|
||||
q.put(run_seq)
|
||||
elif run_seq > 1:
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
set_httpx_timeout()
|
||||
while True:
|
||||
no = q.get()
|
||||
if no != run_seq - 1:
|
||||
q.put(no)
|
||||
else:
|
||||
break
|
||||
q.put(run_seq)
|
||||
|
||||
|
||||
def run_controller(q: Queue, run_seq: int = 1):
|
||||
import uvicorn
|
||||
|
||||
app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method"))
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
||||
host = FSCHAT_CONTROLLER["host"]
|
||||
port = FSCHAT_CONTROLLER["port"]
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_model_worker(model_name: str = LLM_MODEL, q: Queue = None, run_seq: int = 2):
|
||||
import uvicorn
|
||||
|
||||
kwargs = FSCHAT_MODEL_WORKERS[LLM_MODEL].copy()
|
||||
host = kwargs.pop("host")
|
||||
port = kwargs.pop("port")
|
||||
model_path = llm_model_dict[model_name].get("local_model_path", "")
|
||||
kwargs["model_path"] = model_path
|
||||
kwargs["model_names"] = [model_name]
|
||||
kwargs["controller_address"] = fschat_controller_address()
|
||||
kwargs["worker_address"] = fschat_model_worker_address()
|
||||
|
||||
app = create_model_worker_app(**kwargs)
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_openai_api(q: Queue, run_seq: int = 3):
|
||||
import uvicorn
|
||||
|
||||
controller_addr = fschat_controller_address()
|
||||
app = create_openai_api_app(controller_addr) # todo: not support keys yet.
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
||||
host = FSCHAT_OPENAI_API["host"]
|
||||
port = FSCHAT_OPENAI_API["port"]
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_api_server(q: Queue, run_seq: int = 4):
|
||||
from server.api import create_app
|
||||
import uvicorn
|
||||
|
||||
app = create_app()
|
||||
_set_app_seq(app, q, run_seq)
|
||||
|
||||
host = API_SERVER["host"]
|
||||
port = API_SERVER["port"]
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
def run_webui():
|
||||
from configs.model_config import logger
|
||||
host = WEBUI_SERVER["host"]
|
||||
port = WEBUI_SERVER["port"]
|
||||
p = subprocess.Popen(["streamlit", "run", "webui.py",
|
||||
"--server.address", host,
|
||||
"--server.port", str(port)])
|
||||
p.wait()
|
||||
|
||||
|
||||
def parse_args() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--all",
|
||||
action="store_true",
|
||||
help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--openai-api",
|
||||
action="store_true",
|
||||
help="run fastchat controller/openai_api servers",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-worker",
|
||||
action="store_true",
|
||||
help="run fastchat model_worker server with specified model name. specify --model-name if not using default LLM_MODEL",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default=LLM_MODEL,
|
||||
help="specify model name for model worker.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api",
|
||||
action="store_true",
|
||||
help="run api.py server",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--webui",
|
||||
action="store_true",
|
||||
help="run webui.py server",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mp.set_start_method("spawn")
|
||||
queue = Queue()
|
||||
args = parse_args()
|
||||
if args.all:
|
||||
args.openai_api = True
|
||||
args.model_worker = True
|
||||
args.api = True
|
||||
args.webui = True
|
||||
|
||||
logger.info(f"正在启动服务:")
|
||||
logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}")
|
||||
|
||||
processes = {}
|
||||
|
||||
if args.openai_api:
|
||||
process = Process(
|
||||
target=run_controller,
|
||||
name=f"controller({os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["controller"] = process
|
||||
|
||||
process = Process(
|
||||
target=run_openai_api,
|
||||
name=f"openai_api({os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["openai_api"] = process
|
||||
|
||||
if args.model_worker:
|
||||
process = Process(
|
||||
target=run_model_worker,
|
||||
name=f"model_worker({os.getpid()})",
|
||||
args=(args.model_name, queue, len(processes) + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["model_worker"] = process
|
||||
|
||||
if args.api:
|
||||
process = Process(
|
||||
target=run_api_server,
|
||||
name=f"API Server{os.getpid()})",
|
||||
args=(queue, len(processes) + 1),
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["api"] = process
|
||||
|
||||
if args.webui:
|
||||
process = Process(
|
||||
target=run_webui,
|
||||
name=f"WEBUI Server{os.getpid()})",
|
||||
daemon=True,
|
||||
)
|
||||
process.start()
|
||||
processes["webui"] = process
|
||||
|
||||
try:
|
||||
if model_worker_process := processes.get("model_worker"):
|
||||
model_worker_process.join()
|
||||
for name, process in processes.items():
|
||||
if name != "model_worker":
|
||||
process.join()
|
||||
except:
|
||||
if model_worker_process := processes.get("model_worker"):
|
||||
model_worker_process.terminate()
|
||||
for name, process in processes.items():
|
||||
if name != "model_worker":
|
||||
process.terminate()
|
||||
|
||||
# 服务启动后接口调用示例:
|
||||
# import openai
|
||||
# openai.api_key = "EMPTY" # Not support yet
|
||||
# openai.api_base = "http://localhost:8888/v1"
|
||||
|
||||
# model = "chatglm2-6b"
|
||||
|
||||
# # create a chat completion
|
||||
# completion = openai.ChatCompletion.create(
|
||||
# model=model,
|
||||
# messages=[{"role": "user", "content": "Hello! What is your name?"}]
|
||||
# )
|
||||
# # print the completion
|
||||
# print(completion.choices[0].message.content)
|
||||
|
|
@ -28,4 +28,14 @@ if __name__ == "__main__":
|
|||
for line in response.iter_content(decode_unicode=True):
|
||||
print(line, flush=True)
|
||||
else:
|
||||
print("Error:", response.status_code)
|
||||
print("Error:", response.status_code)
|
||||
|
||||
|
||||
r = requests.post(
|
||||
openai_url + "/chat/completions",
|
||||
json={"model": LLM_MODEL, "messages": "你好", "max_tokens": 1000})
|
||||
data = r.json()
|
||||
print(f"/chat/completions\n")
|
||||
print(data)
|
||||
assert "choices" in data
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue