添加命令行管理脚本 (#355)
* 添加加命令行工具 * 添加加命令行工具 --------- Co-authored-by: zqt <1178747941@qq.com> Co-authored-by: imClumsyPanda <littlepanda0716@gmail.com>
This commit is contained in:
parent
1678392cee
commit
a0cb14de23
17
api.py
17
api.py
|
|
@ -85,6 +85,7 @@ def get_vs_path(local_doc_id: str):
|
|||
def get_file_path(local_doc_id: str, doc_name: str):
|
||||
return os.path.join(UPLOAD_ROOT_PATH, local_doc_id, doc_name)
|
||||
|
||||
|
||||
async def upload_file(
|
||||
file: UploadFile = File(description="A single binary file"),
|
||||
knowledge_base_id: str = Form(..., description="Knowledge Base Name", example="kb1"),
|
||||
|
|
@ -112,6 +113,7 @@ async def upload_file(
|
|||
file_status = "文件上传失败,请重新上传"
|
||||
return BaseResponse(code=500, msg=file_status)
|
||||
|
||||
|
||||
async def upload_files(
|
||||
files: Annotated[
|
||||
List[UploadFile], File(description="Multiple files as UploadFile")
|
||||
|
|
@ -308,13 +310,9 @@ async def document():
|
|||
return RedirectResponse(url="/docs")
|
||||
|
||||
|
||||
def main():
|
||||
def api_start(host, port):
|
||||
global app
|
||||
global local_doc_qa
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=7861)
|
||||
args = parser.parse_args()
|
||||
|
||||
app = FastAPI()
|
||||
# Add CORS middleware to allow all origins
|
||||
|
|
@ -340,7 +338,6 @@ def main():
|
|||
app.get("/local_doc_qa/list_files", response_model=ListDocsResponse)(list_docs)
|
||||
app.delete("/local_doc_qa/delete_file", response_model=BaseResponse)(delete_docs)
|
||||
|
||||
|
||||
local_doc_qa = LocalDocQA()
|
||||
local_doc_qa.init_cfg(
|
||||
llm_model=LLM_MODEL,
|
||||
|
|
@ -349,8 +346,12 @@ def main():
|
|||
llm_history_len=LLM_HISTORY_LEN,
|
||||
top_k=VECTOR_SEARCH_TOP_K,
|
||||
)
|
||||
uvicorn.run(app, host=args.host, port=args.port)
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||||
parser.add_argument("--port", type=int, default=7861)
|
||||
args = parser.parse_args()
|
||||
api_start(args.host, args.port)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
import click
|
||||
|
||||
from api import api_start as api_start
|
||||
from configs.model_config import llm_model_dict, embedding_model_dict
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(version='1.0.0')
|
||||
@click.pass_context
|
||||
def cli(ctx):
|
||||
pass
|
||||
|
||||
|
||||
@cli.group()
|
||||
def llm():
|
||||
pass
|
||||
|
||||
|
||||
@llm.command(name="ls")
|
||||
def llm_ls():
|
||||
for k in llm_model_dict.keys():
|
||||
print(k)
|
||||
|
||||
|
||||
@cli.group()
|
||||
def embedding():
|
||||
pass
|
||||
|
||||
|
||||
@embedding.command(name="ls")
|
||||
def embedding_ls():
|
||||
for k in embedding_model_dict.keys():
|
||||
print(k)
|
||||
|
||||
|
||||
@cli.group()
|
||||
def start():
|
||||
pass
|
||||
|
||||
|
||||
@start.command(name="api", context_settings=dict(help_option_names=['-h', '--help']))
|
||||
@click.option('-i', '--ip', default='0.0.0.0', show_default=True, type=str, help='api_server listen address.')
|
||||
@click.option('-p', '--port', default=7861, show_default=True, type=int, help='api_server listen port.')
|
||||
def start_api(ip, port):
|
||||
api_start(host=ip, port=port)
|
||||
|
||||
|
||||
@start.command(name="cli", context_settings=dict(help_option_names=['-h', '--help']))
|
||||
def start_cli():
|
||||
import cli_demo
|
||||
cli_demo.main()
|
||||
|
||||
|
||||
@start.command(name="webui", context_settings=dict(help_option_names=['-h', '--help']))
|
||||
def start_webui():
|
||||
import webui
|
||||
|
||||
|
||||
cli()
|
||||
|
|
@ -8,7 +8,8 @@ nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|||
# Show reply with source text from input document
|
||||
REPLY_WITH_SOURCE = True
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main():
|
||||
local_doc_qa = LocalDocQA()
|
||||
local_doc_qa.init_cfg(llm_model=LLM_MODEL,
|
||||
embedding_model=EMBEDDING_MODEL,
|
||||
|
|
@ -41,3 +42,7 @@ if __name__ == "__main__":
|
|||
for inum, doc in
|
||||
enumerate(resp["source_documents"])]
|
||||
print("\n\n" + "\n\n".join(source_text))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
## 命令行工具
|
||||
|
||||
windows cli.bat
|
||||
linux cli.sh
|
||||
|
||||
## 命令列表
|
||||
|
||||
### llm 管理
|
||||
|
||||
llm 支持列表
|
||||
|
||||
```shell
|
||||
cli.bat llm ls
|
||||
```
|
||||
|
||||
### embedding 管理
|
||||
|
||||
embedding 支持列表
|
||||
|
||||
```shell
|
||||
cli.bat embedding ls
|
||||
```
|
||||
|
||||
### start 启动管理
|
||||
|
||||
查看启动选择
|
||||
|
||||
```shell
|
||||
cli.bat start
|
||||
```
|
||||
|
||||
启动命令行交互
|
||||
|
||||
```shell
|
||||
cli.bat start cli
|
||||
```
|
||||
|
||||
启动Web 交互
|
||||
|
||||
```shell
|
||||
cli.bat start webui
|
||||
```
|
||||
|
||||
启动api服务
|
||||
|
||||
```shell
|
||||
cli.bat start api
|
||||
```
|
||||
|
||||
|
|
@ -18,4 +18,5 @@ uvicorn
|
|||
peft
|
||||
pypinyin
|
||||
bitsandbytes
|
||||
click~=8.1.3
|
||||
tabulate
|
||||
|
|
|
|||
Loading…
Reference in New Issue