diff --git a/.gitignore b/.gitignore index af50500..b5918ee 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,4 @@ logs .idea/ __pycache__/ knowledge_base/ -configs/model_config.py \ No newline at end of file +configs/*.py diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 8771cfc..be0bebc 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -52,13 +52,13 @@ llm_model_dict = { "chatglm2-6b": { "local_model_path": "THUDM/chatglm2-6b", - "api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url" + "api_base_url": "http://localhost:8888/v1", # URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致 "api_key": "EMPTY" }, "chatglm2-6b-32k": { "local_model_path": "THUDM/chatglm2-6b-32k", # "THUDM/chatglm2-6b-32k", - "api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url" + "api_base_url": "http://localhost:8888/v1", # "URL需要与运行fastchat服务端的server_config.FSCHAT_OPENAI_API一致 "api_key": "EMPTY" }, diff --git a/configs/server_config.py.example b/configs/server_config.py.example new file mode 100644 index 0000000..24ce6b4 --- /dev/null +++ b/configs/server_config.py.example @@ -0,0 +1,88 @@ +from .model_config import LLM_MODEL, llm_model_dict, LLM_DEVICE + + +# API 是否开启跨域,默认为False,如果需要开启,请设置为True +# is open cross domain +OPEN_CROSS_DOMAIN = False + +# 各服务器默认绑定host +DEFAULT_BIND_HOST = "127.0.0.1" + +# webui.py server +WEBUI_SERVER = { + "host": DEFAULT_BIND_HOST, + "port": 8501, +} + +# api.py server +API_SERVER = { + "host": DEFAULT_BIND_HOST, + "port": 7861, +} + +# fastchat openai_api server +FSCHAT_OPENAI_API = { + "host": DEFAULT_BIND_HOST, + "port": 8888, # model_config.llm_model_dict中模型配置的api_base_url需要与这里一致。 +} + +# fastchat model_worker server +# 这些模型必须是在model_config.llm_model_dict中正确配置的。 +# 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL +FSCHAT_MODEL_WORKERS = { + LLM_MODEL: { + "host": DEFAULT_BIND_HOST, + "port": 20002, + "device": LLM_DEVICE, + # todo: 多卡加载需要配置的参数 + "gpus": None, + "numgpus": 1, + # 以下为非常用参数,可根据需要配置 + # "max_gpu_memory": "20GiB", + # "load_8bit": False, + # "cpu_offloading": None, + # "gptq_ckpt": None, + # "gptq_wbits": 16, + # "gptq_groupsize": -1, + # "gptq_act_order": False, + # "awq_ckpt": None, + # "awq_wbits": 16, + # "awq_groupsize": -1, + # "model_names": [LLM_MODEL], + # "conv_template": None, + # "limit_worker_concurrency": 5, + # "stream_interval": 2, + # "no_register": False, + }, +} + + +# fastchat multi model worker server +FSCHAT_MULTI_MODEL_WORKERS = { + # todo +} + +# fastchat controller server +FSCHAT_CONTROLLER = { + "host": DEFAULT_BIND_HOST, + "port": 20001, + "dispatch_method": "shortest_queue", +} + + +# 以下不要更改 +def fschat_controller_address() -> str: + host = FSCHAT_CONTROLLER["host"] + port = FSCHAT_CONTROLLER["port"] + return f"http://{host}:{port}" + +def fschat_model_worker_address(model_name: str = LLM_MODEL) -> str: + if model := FSCHAT_MODEL_WORKERS.get(model_name): + host = model["host"] + port = model["port"] + return f"http://{host}:{port}" + +def fschat_openai_api_address() -> str: + host = FSCHAT_OPENAI_API["host"] + port = FSCHAT_OPENAI_API["port"] + return f"http://{host}:{port}" diff --git a/server/api.py b/server/api.py index 800680c..c398f15 100644 --- a/server/api.py +++ b/server/api.py @@ -4,7 +4,8 @@ import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN +from configs.model_config import NLTK_DATA_PATH +from configs.server_config import OPEN_CROSS_DOMAIN import argparse import uvicorn from fastapi.middleware.cors import CORSMiddleware diff --git a/startup.py b/startup.py new file mode 100644 index 0000000..975bc92 --- /dev/null +++ b/startup.py @@ -0,0 +1,364 @@ +from multiprocessing import Process, Queue +import multiprocessing as mp +import subprocess +import sys +import os +from xml.etree.ElementPath import prepare_child + +sys.path.append(os.path.dirname(os.path.dirname(__file__))) +from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger +from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, + FSCHAT_OPENAI_API, fschat_controller_address, fschat_model_worker_address,) +from server.utils import MakeFastAPIOffline, FastAPI +import argparse +from typing import Tuple, List + + +def set_httpx_timeout(timeout=60.0): + import httpx + httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout + httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout + httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout + + +def create_controller_app( + dispatch_method: str, +) -> FastAPI: + import fastchat.constants + fastchat.constants.LOGDIR = LOG_PATH + from fastchat.serve.controller import app, Controller + + controller = Controller(dispatch_method) + sys.modules["fastchat.serve.controller"].controller = controller + + MakeFastAPIOffline(app) + app.title = "FastChat Controller" + return app + + +def create_model_worker_app(**kwargs) -> Tuple[argparse.ArgumentParser, FastAPI]: + import fastchat.constants + fastchat.constants.LOGDIR = LOG_PATH + from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id + import argparse + import threading + import fastchat.serve.model_worker + + # workaround to make program exit with Ctrl+c + # it should be deleted after pr is merged by fastchat + def _new_init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, + ) + self.heart_beat_thread.start() + ModelWorker.init_heart_beat = _new_init_heart_beat + + parser = argparse.ArgumentParser() + args = parser.parse_args([]) + # default args. should be deleted after pr is merged by fastchat + args.gpus = None + args.max_gpu_memory = "20GiB" + args.load_8bit = False + args.cpu_offloading = None + args.gptq_ckpt = None + args.gptq_wbits = 16 + args.gptq_groupsize = -1 + args.gptq_act_order = False + args.awq_ckpt = None + args.awq_wbits = 16 + args.awq_groupsize = -1 + args.num_gpus = 1 + args.model_names = [] + args.conv_template = None + args.limit_worker_concurrency = 5 + args.stream_interval = 2 + args.no_register = False + + for k, v in kwargs.items(): + setattr(args, k, v) + + if args.gpus: + if args.num_gpus is None: + args.num_gpus = len(args.gpus.split(',')) + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + + + gptq_config = GptqConfig( + ckpt=args.gptq_ckpt or args.model_path, + wbits=args.gptq_wbits, + groupsize=args.gptq_groupsize, + act_order=args.gptq_act_order, + ) + awq_config = AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ) + + worker = ModelWorker( + controller_addr=args.controller_address, + worker_addr=args.worker_address, + worker_id=worker_id, + model_path=args.model_path, + model_names=args.model_names, + limit_worker_concurrency=args.limit_worker_concurrency, + no_register=args.no_register, + device=args.device, + num_gpus=args.num_gpus, + max_gpu_memory=args.max_gpu_memory, + load_8bit=args.load_8bit, + cpu_offloading=args.cpu_offloading, + gptq_config=gptq_config, + awq_config=awq_config, + stream_interval=args.stream_interval, + conv_template=args.conv_template, + ) + + sys.modules["fastchat.serve.model_worker"].worker = worker + sys.modules["fastchat.serve.model_worker"].args = args + sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config + + MakeFastAPIOffline(app) + app.title = f"FastChat LLM Server ({LLM_MODEL})" + return app + + +def create_openai_api_app( + controller_address: str, + api_keys: List = [], +) -> FastAPI: + import fastchat.constants + fastchat.constants.LOGDIR = LOG_PATH + from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings + + app.add_middleware( + CORSMiddleware, + allow_credentials=True, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], + ) + + app_settings.controller_address = controller_address + app_settings.api_keys = api_keys + + MakeFastAPIOffline(app) + app.title = "FastChat OpeanAI API Server" + return app + + +def _set_app_seq(app: FastAPI, q: Queue, run_seq: int): + if run_seq == 1: + @app.on_event("startup") + async def on_startup(): + set_httpx_timeout() + q.put(run_seq) + elif run_seq > 1: + @app.on_event("startup") + async def on_startup(): + set_httpx_timeout() + while True: + no = q.get() + if no != run_seq - 1: + q.put(no) + else: + break + q.put(run_seq) + + +def run_controller(q: Queue, run_seq: int = 1): + import uvicorn + + app = create_controller_app(FSCHAT_CONTROLLER.get("dispatch_method")) + _set_app_seq(app, q, run_seq) + + host = FSCHAT_CONTROLLER["host"] + port = FSCHAT_CONTROLLER["port"] + uvicorn.run(app, host=host, port=port) + + +def run_model_worker(model_name: str = LLM_MODEL, q: Queue = None, run_seq: int = 2): + import uvicorn + + kwargs = FSCHAT_MODEL_WORKERS[LLM_MODEL].copy() + host = kwargs.pop("host") + port = kwargs.pop("port") + model_path = llm_model_dict[model_name].get("local_model_path", "") + kwargs["model_path"] = model_path + kwargs["model_names"] = [model_name] + kwargs["controller_address"] = fschat_controller_address() + kwargs["worker_address"] = fschat_model_worker_address() + + app = create_model_worker_app(**kwargs) + _set_app_seq(app, q, run_seq) + + uvicorn.run(app, host=host, port=port) + + +def run_openai_api(q: Queue, run_seq: int = 3): + import uvicorn + + controller_addr = fschat_controller_address() + app = create_openai_api_app(controller_addr) # todo: not support keys yet. + _set_app_seq(app, q, run_seq) + + host = FSCHAT_OPENAI_API["host"] + port = FSCHAT_OPENAI_API["port"] + uvicorn.run(app, host=host, port=port) + + +def run_api_server(q: Queue, run_seq: int = 4): + from server.api import create_app + import uvicorn + + app = create_app() + _set_app_seq(app, q, run_seq) + + host = API_SERVER["host"] + port = API_SERVER["port"] + + uvicorn.run(app, host=host, port=port) + + +def run_webui(): + from configs.model_config import logger + host = WEBUI_SERVER["host"] + port = WEBUI_SERVER["port"] + p = subprocess.Popen(["streamlit", "run", "webui.py", + "--server.address", host, + "--server.port", str(port)]) + p.wait() + + +def parse_args() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--all", + action="store_true", + help="run fastchat's controller/model_worker/openai_api servers, run api.py and webui.py", + ) + parser.add_argument( + "--openai-api", + action="store_true", + help="run fastchat controller/openai_api servers", + ) + parser.add_argument( + "--model-worker", + action="store_true", + help="run fastchat model_worker server with specified model name. specify --model-name if not using default LLM_MODEL", + ) + parser.add_argument( + "--model-name", + type=str, + default=LLM_MODEL, + help="specify model name for model worker.", + ) + parser.add_argument( + "--api", + action="store_true", + help="run api.py server", + ) + parser.add_argument( + "--webui", + action="store_true", + help="run webui.py server", + ) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + mp.set_start_method("spawn") + queue = Queue() + args = parse_args() + if args.all: + args.openai_api = True + args.model_worker = True + args.api = True + args.webui = True + + logger.info(f"正在启动服务:") + logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") + + processes = {} + + if args.openai_api: + process = Process( + target=run_controller, + name=f"controller({os.getpid()})", + args=(queue, len(processes) + 1), + daemon=True, + ) + process.start() + processes["controller"] = process + + process = Process( + target=run_openai_api, + name=f"openai_api({os.getpid()})", + args=(queue, len(processes) + 1), + daemon=True, + ) + process.start() + processes["openai_api"] = process + + if args.model_worker: + process = Process( + target=run_model_worker, + name=f"model_worker({os.getpid()})", + args=(args.model_name, queue, len(processes) + 1), + daemon=True, + ) + process.start() + processes["model_worker"] = process + + if args.api: + process = Process( + target=run_api_server, + name=f"API Server{os.getpid()})", + args=(queue, len(processes) + 1), + daemon=True, + ) + process.start() + processes["api"] = process + + if args.webui: + process = Process( + target=run_webui, + name=f"WEBUI Server{os.getpid()})", + daemon=True, + ) + process.start() + processes["webui"] = process + + try: + if model_worker_process := processes.get("model_worker"): + model_worker_process.join() + for name, process in processes.items(): + if name != "model_worker": + process.join() + except: + if model_worker_process := processes.get("model_worker"): + model_worker_process.terminate() + for name, process in processes.items(): + if name != "model_worker": + process.terminate() + +# 服务启动后接口调用示例: +# import openai +# openai.api_key = "EMPTY" # Not support yet +# openai.api_base = "http://localhost:8888/v1" + +# model = "chatglm2-6b" + +# # create a chat completion +# completion = openai.ChatCompletion.create( +# model=model, +# messages=[{"role": "user", "content": "Hello! What is your name?"}] +# ) +# # print the completion +# print(completion.choices[0].message.content) diff --git a/tests/api/stream_api_test.py b/tests/api/stream_api_test.py index 06a9654..2902c8a 100644 --- a/tests/api/stream_api_test.py +++ b/tests/api/stream_api_test.py @@ -28,4 +28,14 @@ if __name__ == "__main__": for line in response.iter_content(decode_unicode=True): print(line, flush=True) else: - print("Error:", response.status_code) \ No newline at end of file + print("Error:", response.status_code) + + + r = requests.post( + openai_url + "/chat/completions", + json={"model": LLM_MODEL, "messages": "你好", "max_tokens": 1000}) + data = r.json() + print(f"/chat/completions\n") + print(data) + assert "choices" in data +