From a0cb14de230cb2339c3bace385a31eb18bfb27bc Mon Sep 17 00:00:00 2001 From: zqt996 <67185303+zqt996@users.noreply.github.com> Date: Mon, 15 May 2023 19:13:12 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=91=BD=E4=BB=A4=E8=A1=8C?= =?UTF-8?q?=E7=AE=A1=E7=90=86=E8=84=9A=E6=9C=AC=20(#355)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 添加加命令行工具 * 添加加命令行工具 --------- Co-authored-by: zqt <1178747941@qq.com> Co-authored-by: imClumsyPanda --- api.py | 17 +++++++------- cli.bat | 2 ++ cli.py | 59 ++++++++++++++++++++++++++++++++++++++++++++++++ cli.sh | 2 ++ cli_demo.py | 7 +++++- docs/cli.md | 49 ++++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + 7 files changed, 128 insertions(+), 9 deletions(-) create mode 100644 cli.bat create mode 100644 cli.py create mode 100644 cli.sh create mode 100644 docs/cli.md diff --git a/api.py b/api.py index e4917b8..407b301 100644 --- a/api.py +++ b/api.py @@ -85,6 +85,7 @@ def get_vs_path(local_doc_id: str): def get_file_path(local_doc_id: str, doc_name: str): return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name) + async def upload_file( file: UploadFile = File(description="A single binary file"), knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"), @@ -112,6 +113,7 @@ async def upload_file( file_status = "文件上传失败,请重新上传" return BaseResponse(code=500, msg=file_status) + async def upload_files( files: Annotated[ List[UploadFile], File(description="Multiple files as UploadFile") @@ -308,13 +310,9 @@ async def document(): return RedirectResponse(url="/docs") -def main(): +def api_start(host, port): global app global local_doc_qa - parser = argparse.ArgumentParser() - parser.add_argument("--host", type=str, default="0.0.0.0") - parser.add_argument("--port", type=int, default=7861) - args = parser.parse_args() app = FastAPI() # Add CORS middleware to allow all origins @@ -340,7 +338,6 @@ def main(): app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs) app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs) - local_doc_qa = LocalDocQA() local_doc_qa.init_cfg( llm_model=LLM_MODEL, @@ -349,8 +346,12 @@ def main(): llm_history_len=LLM_HISTORY_LEN, top_k=VECTOR_SEARCH_TOP_K, ) - uvicorn.run(app, host=args.host, port=args.port) + uvicorn.run(app, host=host, port=port) if __name__ == "__main__": - main() + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="0.0.0.0") + parser.add_argument("--port", type=int, default=7861) + args = parser.parse_args() + api_start(args.host, args.port) diff --git a/cli.bat b/cli.bat new file mode 100644 index 0000000..7696591 --- /dev/null +++ b/cli.bat @@ -0,0 +1,2 @@ +@echo off +python cli.py %* diff --git a/cli.py b/cli.py new file mode 100644 index 0000000..6cce3ca --- /dev/null +++ b/cli.py @@ -0,0 +1,59 @@ +import click + +from api import api_start as api_start +from configs.model_config import llm_model_dict, embedding_model_dict + + +@click.group() +@click.version_option(version='1.0.0') +@click.pass_context +def cli(ctx): + pass + + +@cli.group() +def llm(): + pass + + +@llm.command(name="ls") +def llm_ls(): + for k in llm_model_dict.keys(): + print(k) + + +@cli.group() +def embedding(): + pass + + +@embedding.command(name="ls") +def embedding_ls(): + for k in embedding_model_dict.keys(): + print(k) + + +@cli.group() +def start(): + pass + + +@start.command(name="api", context_settings=dict(help_option_names=['-h', '--help'])) +@click.option('-i', '--ip', default='0.0.0.0', show_default=True, type=str, help='api_server listen address.') +@click.option('-p', '--port', default=7861, show_default=True, type=int, help='api_server listen port.') +def start_api(ip, port): + api_start(host=ip, port=port) + + +@start.command(name="cli", context_settings=dict(help_option_names=['-h', '--help'])) +def start_cli(): + import cli_demo + cli_demo.main() + + +@start.command(name="webui", context_settings=dict(help_option_names=['-h', '--help'])) +def start_webui(): + import webui + + +cli() diff --git a/cli.sh b/cli.sh new file mode 100644 index 0000000..b01f805 --- /dev/null +++ b/cli.sh @@ -0,0 +1,2 @@ +#!/bin/bash +python cli.py "$@" diff --git a/cli_demo.py b/cli_demo.py index ea5a895..9961faf 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -8,7 +8,8 @@ nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path # Show reply with source text from input document REPLY_WITH_SOURCE = True -if __name__ == "__main__": + +def main(): local_doc_qa = LocalDocQA() local_doc_qa.init_cfg(llm_model=LLM_MODEL, embedding_model=EMBEDDING_MODEL, @@ -41,3 +42,7 @@ if __name__ == "__main__": for inum, doc in enumerate(resp["source_documents"])] print("\n\n" + "\n\n".join(source_text)) + + +if __name__ == "__main__": + main() diff --git a/docs/cli.md b/docs/cli.md new file mode 100644 index 0000000..ff3eb9f --- /dev/null +++ b/docs/cli.md @@ -0,0 +1,49 @@ +## 命令行工具 + +windows cli.bat +linux cli.sh + +## 命令列表 + +### llm 管理 + +llm 支持列表 + +```shell +cli.bat llm ls +``` + +### embedding 管理 + +embedding 支持列表 + +```shell +cli.bat embedding ls +``` + +### start 启动管理 + +查看启动选择 + +```shell +cli.bat start +``` + +启动命令行交互 + +```shell +cli.bat start cli +``` + +启动Web 交互 + +```shell +cli.bat start webui +``` + +启动api服务 + +```shell +cli.bat start api +``` + diff --git a/requirements.txt b/requirements.txt index 34d8df6..601d92b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,4 +18,5 @@ uvicorn peft pypinyin bitsandbytes +click~=8.1.3 tabulate