diff --git a/README.md b/README.md index 3918daa..e0a7354 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ * [3. 设置配置项](README.md#3.-设置配置项) * [4. 知识库初始化与迁移](README.md#4.-知识库初始化与迁移) * [5. 启动 API 服务或 Web UI](README.md#5.-启动-API-服务或-Web-UI) + * [6. 一键启动](README.md#6.-一键启动) * [常见问题](README.md#常见问题) * [路线图](README.md#路线图) * [项目交流群](README.md#项目交流群) @@ -248,14 +249,23 @@ gpus=None, num_gpus=1, max_gpu_memory="20GiB" ``` -其中,`gpus` 控制使用的显卡的ID,如果 "0,1"; -`num_gpus` 控制使用的卡数; +其中,`gpus` 控制使用的显卡的ID,如果 "0,1"; + +`num_gpus` 控制使用的卡数; `max_gpu_memory` 控制每个卡使用的显存容量。 ##### 5.1.2 基于命令行脚本 llm_api_launch.py 启动 LLM 服务 +**!!!注意:** + +**1.llm_api_launch.py脚本仅适用于linux和mac设备,win平台请使用wls;** + +**2.加载非默认模型需要用命令行参数--model-path-address指定指定模型,不会读取model_config.py配置;** + +**!!!** + 在项目根目录下,执行 [server/llm_api_launch.py](server/llm_api.py) 脚本启动 **LLM 模型**服务: ```shell @@ -274,7 +284,7 @@ $ python server/llm_api_launch.py --model-path-addresss model1@host1@port1 model $ python server/llm_api_launch.py --gpus 0,1 --num-gpus 2 --max-gpu-memory 10GiB ``` -注:以如上方式启动LLM服务会以nohup命令在后台运行 fastchat 服务,如需停止服务,可以运行如下命令,但该脚本**仅适用于linux和mac平台**: +注:以如上方式启动LLM服务会以nohup命令在后台运行 FastChat 服务,如需停止服务,可以运行如下命令: ```shell $ python server/llm_api_shutdown.py --serve all @@ -349,11 +359,17 @@ $ streamlit run webui.py --server.port 666 --- -### 6 一键启动: +### 6. 一键启动 -#### 6.1 api服务一键启动脚本 +⚠️ **注意:** -新增api一键启动脚本,可一键开启fastchat后台服务及本项目提供的langchain api服务,调用示例: +**1. 一键启动脚本仅适用于 Linux 和 Mac 设备, Winodws 平台请使用 WLS;** + +**2. 加载非默认模型需要用命令行参数 `--model-path-address` 指定指定模型,不会读取 `model_config.py` 配置。** + +#### 6.1 API 服务一键启动脚本 + +新增 API 一键启动脚本,可一键开启 FastChat 后台服务及本项目提供的 API 服务,调用示例: 调用默认模型: @@ -372,32 +388,49 @@ $ python server/api_allinone.py --model-path-address model1@host1@port1 model2@h ```shell python server/api_allinone.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB ``` -其他参数详见各脚本及fastchat服务说明。 + +其他参数详见各脚本及 FastChat 服务说明。 #### 6.2 webui一键启动脚本 + 加载本地模型: + ```shell $ python webui_allinone.py ``` -调用远程api服务: +调用远程 API 服务: + ```shell $ python webui_allinone.py --use-remote-api ``` + 后台运行webui服务: + ```shell $ python webui_allinone.py --nohup ``` + 加载多个非默认模型: + ```shell $ python webui_allinone.py --model-path-address model1@host1@port1 model2@host2@port2 ``` + 多卡启动: + +```shell +$ python webui_alline.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB ``` -python webui_alline.py --model-path-address model@host@port --num-gpus 2 --gpus 0,1 --max-gpu-memory 10GiB -``` + 其他参数详见各脚本及fastchat服务说明。 +上述两个一键启动脚本会后台运行多个服务,如要停止所有服务,可使用 `shutdown_all.sh` 脚本: + +```shell +bash shutdown_all.sh +``` + ## 常见问题 参见 [常见问题](docs/FAQ.md)。 diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 1090f85..8771cfc 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -46,7 +46,7 @@ llm_model_dict = { "chatglm-6b-int4": { "local_model_path": "THUDM/chatglm-6b-int4", - "api_base_url": "http://localhost:8001/v1", # "name"修改为fastchat服务中的"api_base_url" + "api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url" "api_key": "EMPTY" }, @@ -64,7 +64,7 @@ llm_model_dict = { "vicuna-13b-hf": { "local_model_path": "", - "api_base_url": "http://localhost:8000/v1", # "name"修改为fastchat服务中的"api_base_url" + "api_base_url": "http://localhost:8888/v1", # "name"修改为fastchat服务中的"api_base_url" "api_key": "EMPTY" }, @@ -85,6 +85,7 @@ llm_model_dict = { }, } + # LLM 名称 LLM_MODEL = "chatglm2-6b" diff --git a/requirements.txt b/requirements.txt index f2e1d65..646a5c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,11 +1,10 @@ langchain==0.0.257 openai sentence_transformers -fschat==0.2.20 -transformers +fschat==0.2.24 +transformers>=4.31.0 torch~=2.0.0 fastapi~=0.99.1 -fastapi-offline nltk~=3.8.1 uvicorn~=0.23.1 starlette~=0.27.0 diff --git a/requirements_api.txt b/requirements_api.txt index f077c94..1e13587 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -1,11 +1,10 @@ langchain==0.0.257 openai sentence_transformers -fschat==0.2.20 -transformers +fschat==0.2.24 +transformers>=4.31.0 torch~=2.0.0 fastapi~=0.99.1 -fastapi-offline nltk~=3.8.1 uvicorn~=0.23.1 starlette~=0.27.0 diff --git a/server/api.py b/server/api.py index 458b1d7..800680c 100644 --- a/server/api.py +++ b/server/api.py @@ -7,15 +7,17 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN import argparse import uvicorn -from server.utils import FastAPIOffline as FastAPI from fastapi.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse from server.chat import (chat, knowledge_base_chat, openai_chat, search_engine_chat) from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc, - update_doc, download_doc, recreate_vector_store) -from server.utils import BaseResponse, ListResponse + update_doc, download_doc, recreate_vector_store, + search_docs, DocumentWithScore) +from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline +from typing import List + nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -25,7 +27,8 @@ async def document(): def create_app(): - app = FastAPI() + app = FastAPI(title="Langchain-Chatchat API Server") + MakeFastAPIOffline(app) # Add CORS middleware to allow all origins # 在config.py中设置OPEN_DOMAIN=True,允许跨域 # set OPEN_DOMAIN=True in config.py to allow cross-domain @@ -83,6 +86,12 @@ def create_app(): summary="获取知识库内的文件列表" )(list_docs) + app.post("/knowledge_base/search_docs", + tags=["Knowledge Base Management"], + response_model=List[DocumentWithScore], + summary="搜索知识库" + )(search_docs) + app.post("/knowledge_base/upload_doc", tags=["Knowledge Base Management"], response_model=BaseResponse, diff --git a/server/api_allinone.py b/server/api_allinone.py index b19b6cc..4379d7d 100644 --- a/server/api_allinone.py +++ b/server/api_allinone.py @@ -22,9 +22,7 @@ parser.add_argument("--api-host", type=str, default="0.0.0.0") parser.add_argument("--api-port", type=int, default=7861) parser.add_argument("--ssl_keyfile", type=str) parser.add_argument("--ssl_certfile", type=str) -# 初始化消息 -args = parser.parse_args() -args_dict = vars(args) + api_args = ["api-host","api-port","ssl_keyfile","ssl_certfile"] @@ -41,6 +39,11 @@ def run_api(host, port, **kwargs): uvicorn.run(app, host=host, port=port) if __name__ == "__main__": + print("Luanching api_allinone,it would take a while, please be patient...") + print("正在启动api_allinone,LLM服务启动约3-10分钟,请耐心等待...") + # 初始化消息 + args = parser.parse_args() + args_dict = vars(args) launch_all(args=args,controller_args=controller_args,worker_args=worker_args,server_args=server_args) run_api( host=args.api_host, @@ -48,3 +51,5 @@ if __name__ == "__main__": ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, ) + print("Luanching api_allinone done.") + print("api_allinone启动完毕.") diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 0ecabf0..84c62f0 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,26 +1,27 @@ from fastapi import Body, Request from fastapi.responses import StreamingResponse from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, - VECTOR_SEARCH_TOP_K) + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) from server.chat.utils import wrap_done from server.utils import BaseResponse from langchain.chat_models import ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler -from typing import AsyncIterable +from typing import AsyncIterable, List, Optional import asyncio from langchain.prompts.chat import ChatPromptTemplate -from typing import List, Optional from server.chat.utils import History from server.knowledge_base.kb_service.base import KBService, KBServiceFactory import json import os from urllib.parse import urlencode +from server.knowledge_base.kb_doc_api import search_docs def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), history: List[History] = Body([], description="历史对话", examples=[[ @@ -53,7 +54,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], model_name=LLM_MODEL ) - docs = kb.search_docs(query, top_k) + docs = search_docs(query, knowledge_base_name, top_k, score_threshold) context = "\n".join([doc.page_content for doc in docs]) chat_prompt = ChatPromptTemplate.from_messages( diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py index 7b69265..a7ad807 100644 --- a/server/chat/openai_chat.py +++ b/server/chat/openai_chat.py @@ -1,7 +1,7 @@ from fastapi.responses import StreamingResponse from typing import List import openai -from configs.model_config import llm_model_dict, LLM_MODEL +from configs.model_config import llm_model_dict, LLM_MODEL, logger from pydantic import BaseModel @@ -33,19 +33,23 @@ async def openai_chat(msg: OpenAiChatMsgIn): data = msg.dict() data["streaming"] = True data.pop("stream") - response = openai.ChatCompletion.create(**data) - if msg.stream: - for chunk in response.choices[0].message.content: - print(chunk) - yield chunk - else: - answer = "" - for chunk in response.choices[0].message.content: - answer += chunk - print(answer) - yield(answer) - + try: + response = openai.ChatCompletion.create(**data) + if msg.stream: + for chunk in response.choices[0].message.content: + print(chunk) + yield chunk + else: + answer = "" + for chunk in response.choices[0].message.content: + answer += chunk + print(answer) + yield(answer) + except Exception as e: + print(type(e)) + logger.error(e) + return StreamingResponse( get_response(msg), media_type='text/event-stream', diff --git a/server/db/repository/knowledge_file_repository.py b/server/db/repository/knowledge_file_repository.py index f5e912f..404910f 100644 --- a/server/db/repository/knowledge_file_repository.py +++ b/server/db/repository/knowledge_file_repository.py @@ -27,7 +27,7 @@ def add_doc_to_db(session, kb_file: KnowledgeFile): file_ext=kb_file.ext, kb_name=kb_file.kb_name, document_loader_name=kb_file.document_loader_name, - text_splitter_name=kb_file.text_splitter_name, + text_splitter_name=kb_file.text_splitter_name or "SpacyTextSplitter", ) kb.file_count += 1 session.add(new_file) diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 3f27fb1..0bf2cb7 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -1,13 +1,32 @@ import os import urllib from fastapi import File, Form, Body, Query, UploadFile -from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL +from configs.model_config import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) from server.utils import BaseResponse, ListResponse from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile from fastapi.responses import StreamingResponse, FileResponse import json from server.knowledge_base.kb_service.base import KBServiceFactory -from typing import List +from typing import List, Dict +from langchain.docstore.document import Document + + +class DocumentWithScore(Document): + score: float = None + + +def search_docs(query: str = Body(..., description="用户输入", examples=["你好"]), + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), + ) -> List[DocumentWithScore]: + kb = KBServiceFactory.get_service_by_name(knowledge_base_name) + if kb is None: + return {"code": 404, "msg": f"未找到知识库 {knowledge_base_name}", "docs": []} + docs = kb.search_docs(query, top_k, score_threshold) + data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] + + return data async def list_docs( diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index ec1c692..d506f63 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -13,7 +13,7 @@ from server.db.repository.knowledge_file_repository import ( list_docs_from_db, get_file_detail, delete_file_from_db ) -from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, +from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, EMBEDDING_DEVICE, EMBEDDING_MODEL) from server.knowledge_base.utils import ( get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, @@ -112,9 +112,10 @@ class KBService(ABC): def search_docs(self, query: str, top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: float = SCORE_THRESHOLD, ): embeddings = self._load_embeddings() - docs = self.do_search(query, top_k, embeddings) + docs = self.do_search(query, top_k, score_threshold, embeddings) return docs @abstractmethod diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 0ef820a..5c8376f 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -81,12 +81,13 @@ class FaissKBService(KBService): def do_search(self, query: str, top_k: int, - embeddings: Embeddings, + score_threshold: float = SCORE_THRESHOLD, + embeddings: Embeddings = None, ) -> List[Document]: search_index = load_vector_store(self.kb_name, embeddings=embeddings, tick=_VECTOR_STORE_TICKS.get(self.kb_name)) - docs = search_index.similarity_search(query, k=top_k, score_threshold=SCORE_THRESHOLD) + docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) return docs def do_add_doc(self, diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 6f1c392..f9c40c0 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -45,7 +45,8 @@ class MilvusKBService(KBService): def do_drop_kb(self): self.milvus.col.drop() - def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]: + def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: + # todo: support score threshold self._load_milvus(embeddings=embeddings) return self.milvus.similarity_search(query, top_k, score_threshold=SCORE_THRESHOLD) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 82511bb..a3126ec 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -43,7 +43,8 @@ class PGKBService(KBService): ''')) connect.commit() - def do_search(self, query: str, top_k: int, embeddings: Embeddings) -> List[Document]: + def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings) -> List[Document]: + # todo: support score threshold self._load_pg_vector(embeddings=embeddings) return self.pg_vector.similarity_search(query, top_k) diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 3e8be26..3ab6560 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -102,6 +102,7 @@ class KnowledgeFile: chunk_size=CHUNK_SIZE, chunk_overlap=OVERLAP_SIZE, ) + self.text_splitter_name = "SpacyTextSplitter" else: text_splitter_module = importlib.import_module('langchain.text_splitter') TextSplitter = getattr(text_splitter_module, self.text_splitter_name) diff --git a/server/llm_api.py b/server/llm_api.py index 0a7d3b0..ab71b3d 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -1,16 +1,18 @@ from multiprocessing import Process, Queue +import multiprocessing as mp import sys import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger +from server.utils import MakeFastAPIOffline + host_ip = "0.0.0.0" controller_port = 20001 model_worker_port = 20002 openai_api_port = 8888 base_url = "http://127.0.0.1:{}" -queue = Queue() def set_httpx_timeout(timeout=60.0): @@ -30,33 +32,50 @@ def create_controller_app( controller = Controller(dispatch_method) sys.modules["fastchat.serve.controller"].controller = controller + MakeFastAPIOffline(app) + app.title = "FastChat Controller" return app def create_model_worker_app( + worker_address=base_url.format(model_worker_port), + controller_address=base_url.format(controller_port), model_path=llm_model_dict[LLM_MODEL].get("local_model_path"), - model_names=[LLM_MODEL], device=LLM_DEVICE, + gpus=None, + max_gpu_memory="20GiB", load_8bit=False, + cpu_offloading=None, gptq_ckpt=None, gptq_wbits=16, gptq_groupsize=-1, - gptq_act_order=None, - gpus=None, - num_gpus=1, - max_gpu_memory="20GiB", - cpu_offloading=None, - worker_address=base_url.format(model_worker_port), - controller_address=base_url.format(controller_port), + gptq_act_order=False, + awq_ckpt=None, + awq_wbits=16, + awq_groupsize=-1, + model_names=[LLM_MODEL], + num_gpus=1, # not in fastchat + conv_template=None, limit_worker_concurrency=5, stream_interval=2, no_register=False, ): import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id - from fastchat.serve import model_worker + from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id import argparse + import threading + import fastchat.serve.model_worker + + # workaround to make program exit with Ctrl+c + # it should be deleted after pr is merged by fastchat + def _new_init_heart_beat(self): + self.register_to_controller() + self.heart_beat_thread = threading.Thread( + target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, + ) + self.heart_beat_thread.start() + ModelWorker.init_heart_beat = _new_init_heart_beat parser = argparse.ArgumentParser() args = parser.parse_args() @@ -68,12 +87,16 @@ def create_model_worker_app( args.gptq_wbits = gptq_wbits args.gptq_groupsize = gptq_groupsize args.gptq_act_order = gptq_act_order + args.awq_ckpt = awq_ckpt + args.awq_wbits = awq_wbits + args.awq_groupsize = awq_groupsize args.gpus = gpus args.num_gpus = num_gpus args.max_gpu_memory = max_gpu_memory args.cpu_offloading = cpu_offloading args.worker_address = worker_address args.controller_address = controller_address + args.conv_template = conv_template args.limit_worker_concurrency = limit_worker_concurrency args.stream_interval = stream_interval args.no_register = no_register @@ -95,6 +118,12 @@ def create_model_worker_app( groupsize=args.gptq_groupsize, act_order=args.gptq_act_order, ) + awq_config = AWQConfig( + ckpt=args.awq_ckpt or args.model_path, + wbits=args.awq_wbits, + groupsize=args.awq_groupsize, + ) + # torch.multiprocessing.set_start_method('spawn') worker = ModelWorker( controller_addr=args.controller_address, @@ -110,19 +139,21 @@ def create_model_worker_app( load_8bit=args.load_8bit, cpu_offloading=args.cpu_offloading, gptq_config=gptq_config, + awq_config=awq_config, stream_interval=args.stream_interval, + conv_template=args.conv_template, ) sys.modules["fastchat.serve.model_worker"].worker = worker sys.modules["fastchat.serve.model_worker"].args = args sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config + MakeFastAPIOffline(app) + app.title = f"FastChat LLM Server ({LLM_MODEL})" return app def create_openai_api_app( - host=host_ip, - port=openai_api_port, controller_address=base_url.format(controller_port), api_keys=[], ): @@ -141,6 +172,8 @@ def create_openai_api_app( app_settings.controller_address = controller_address app_settings.api_keys = api_keys + MakeFastAPIOffline(app) + app.title = "FastChat OpeanAI API Server" return app @@ -193,6 +226,8 @@ def run_openai_api(q): if __name__ == "__main__": + mp.set_start_method("spawn") + queue = Queue() logger.info(llm_model_dict[LLM_MODEL]) model_path = llm_model_dict[LLM_MODEL]["local_model_path"] @@ -209,15 +244,14 @@ if __name__ == "__main__": ) controller_process.start() - # cuda 没办法用在fork的多进程中 - # model_worker_process = Process( - # target=run_model_worker, - # name=f"model_worker({os.getpid()})", - # args=(queue,), - # # kwargs={"load_8bit": True}, - # daemon=True, - # ) - # model_worker_process.start() + model_worker_process = Process( + target=run_model_worker, + name=f"model_worker({os.getpid()})", + args=(queue,), + # kwargs={"load_8bit": True}, + daemon=True, + ) + model_worker_process.start() openai_api_process = Process( target=run_openai_api, @@ -227,11 +261,14 @@ if __name__ == "__main__": ) openai_api_process.start() - run_model_worker(queue) - - controller_process.join() - # model_worker_process.join() - openai_api_process.join() + try: + model_worker_process.join() + controller_process.join() + openai_api_process.join() + except KeyboardInterrupt: + model_worker_process.terminate() + controller_process.terminate() + openai_api_process.terminate() # 服务启动后接口调用示例: # import openai diff --git a/server/llm_api_launch.py b/server/llm_api_launch.py index 0fdccfc..044bab6 100644 --- a/server/llm_api_launch.py +++ b/server/llm_api_launch.py @@ -159,17 +159,7 @@ server_args = ["server-host", "server-port", "allow-credentials", "api-keys", "controller-address" ] -args = parser.parse_args() -# 必须要加http//:,否则InvalidSchema: No connection adapters were found -args = argparse.Namespace(**vars(args), - **{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"}) -if args.gpus: - if len(args.gpus.split(",")) < args.num_gpus: - raise ValueError( - f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" - ) - os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus # 0,controller, model_worker, openai_api_server # 1, 命令行选项 @@ -211,12 +201,13 @@ def string_args(args, args_list): return args_str -def launch_worker(item,args=args,worker_args=worker_args): +def launch_worker(item,args,worker_args=worker_args): log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_") # 先分割model-path-address,在传到string_args中分析参数 args.model_path, args.worker_host, args.worker_port = item.split("@") args.worker_address = f"http://{args.worker_host}:{args.worker_port}" print("*" * 80) + print(f"如长时间未启动,请到{LOG_PATH}{log_name}.log下查看日志") worker_str_args = string_args(args, worker_args) print(worker_str_args) worker_sh = base_launch_sh.format("model_worker", worker_str_args, LOG_PATH, f"worker_{log_name}") @@ -225,30 +216,44 @@ def launch_worker(item,args=args,worker_args=worker_args): subprocess.run(worker_check_sh, shell=True, check=True) -def launch_all(args=args, +def launch_all(args, controller_args=controller_args, worker_args=worker_args, server_args=server_args ): + print(f"Launching llm service,logs are located in {LOG_PATH}...") + print(f"开始启动LLM服务,请到{LOG_PATH}下监控各模块日志...") controller_str_args = string_args(args, controller_args) controller_sh = base_launch_sh.format("controller", controller_str_args, LOG_PATH, "controller") controller_check_sh = base_check_sh.format(LOG_PATH, "controller", "controller") subprocess.run(controller_sh, shell=True, check=True) subprocess.run(controller_check_sh, shell=True, check=True) - + print(f"worker启动时间视设备不同而不同,约需3-10分钟,请耐心等待...") if isinstance(args.model_path_address, str): launch_worker(args.model_path_address,args=args,worker_args=worker_args) else: for idx, item in enumerate(args.model_path_address): print(f"开始加载第{idx}个模型:{item}") - launch_worker(item) + launch_worker(item,args=args,worker_args=worker_args) server_str_args = string_args(args, server_args) server_sh = base_launch_sh.format("openai_api_server", server_str_args, LOG_PATH, "openai_api_server") server_check_sh = base_check_sh.format(LOG_PATH, "openai_api_server", "openai_api_server") subprocess.run(server_sh, shell=True, check=True) subprocess.run(server_check_sh, shell=True, check=True) - + print("Launching LLM service done!") + print("LLM服务启动完毕。") if __name__ == "__main__": - launch_all() + args = parser.parse_args() + # 必须要加http//:,否则InvalidSchema: No connection adapters were found + args = argparse.Namespace(**vars(args), + **{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"}) + + if args.gpus: + if len(args.gpus.split(",")) < args.num_gpus: + raise ValueError( + f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" + ) + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus + launch_all(args=args) diff --git a/server/static/favicon.png b/server/static/favicon.png new file mode 100644 index 0000000..5de8ee8 Binary files /dev/null and b/server/static/favicon.png differ diff --git a/server/utils.py b/server/utils.py index e1a23d1..c0f11a5 100644 --- a/server/utils.py +++ b/server/utils.py @@ -2,14 +2,10 @@ import pydantic from pydantic import BaseModel from typing import List import torch -from fastapi_offline import FastAPIOffline -import fastapi_offline +from fastapi import FastAPI from pathlib import Path import asyncio - - -# patch fastapi_offline to use local static assests -fastapi_offline.core._STATIC_PATH = Path(__file__).parent / "static" +from typing import Any, Optional class BaseResponse(BaseModel): @@ -112,3 +108,81 @@ def iter_over_async(ait, loop): if done: break yield obj + + +def MakeFastAPIOffline( + app: FastAPI, + static_dir = Path(__file__).parent / "static", + static_url = "/static-offline-docs", + docs_url: Optional[str] = "/docs", + redoc_url: Optional[str] = "/redoc", +) -> None: + """patch the FastAPI obj that doesn't rely on CDN for the documentation page""" + from fastapi import Request + from fastapi.openapi.docs import ( + get_redoc_html, + get_swagger_ui_html, + get_swagger_ui_oauth2_redirect_html, + ) + from fastapi.staticfiles import StaticFiles + from starlette.responses import HTMLResponse + + openapi_url = app.openapi_url + swagger_ui_oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url + + def remove_route(url: str) -> None: + ''' + remove original route from app + ''' + index = None + for i, r in enumerate(app.routes): + if r.path.lower() == url.lower(): + index = i + break + if isinstance(index, int): + app.routes.pop(i) + + # Set up static file mount + app.mount( + static_url, + StaticFiles(directory=Path(static_dir).as_posix()), + name="static-offline-docs", + ) + + if docs_url is not None: + remove_route(docs_url) + remove_route(swagger_ui_oauth2_redirect_url) + + # Define the doc and redoc pages, pointing at the right files + @app.get(docs_url, include_in_schema=False) + async def custom_swagger_ui_html(request: Request) -> HTMLResponse: + root = request.scope.get("root_path") + favicon = f"{root}{static_url}/favicon.png" + return get_swagger_ui_html( + openapi_url=f"{root}{openapi_url}", + title=app.title + " - Swagger UI", + oauth2_redirect_url=swagger_ui_oauth2_redirect_url, + swagger_js_url=f"{root}{static_url}/swagger-ui-bundle.js", + swagger_css_url=f"{root}{static_url}/swagger-ui.css", + swagger_favicon_url=favicon, + ) + + @app.get(swagger_ui_oauth2_redirect_url, include_in_schema=False) + async def swagger_ui_redirect() -> HTMLResponse: + return get_swagger_ui_oauth2_redirect_html() + + if redoc_url is not None: + remove_route(redoc_url) + + @app.get(redoc_url, include_in_schema=False) + async def redoc_html(request: Request) -> HTMLResponse: + root = request.scope.get("root_path") + favicon = f"{root}{static_url}/favicon.png" + + return get_redoc_html( + openapi_url=f"{root}{openapi_url}", + title=app.title + " - ReDoc", + redoc_js_url=f"{root}{static_url}/redoc.standalone.js", + with_google_fonts=False, + redoc_favicon_url=favicon, + ) diff --git a/shutdown_all.sh b/shutdown_all.sh new file mode 100644 index 0000000..961260d --- /dev/null +++ b/shutdown_all.sh @@ -0,0 +1 @@ +ps -eo pid,user,cmd|grep -P 'server/api.py|webui.py|fastchat.serve'|grep -v grep|awk '{print $1}'|xargs kill -9 \ No newline at end of file diff --git a/webui.py b/webui.py index d84da42..99db3f6 100644 --- a/webui.py +++ b/webui.py @@ -13,7 +13,11 @@ import os api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False) if __name__ == "__main__": - st.set_page_config("Langchain-Chatchat WebUI", initial_sidebar_state="expanded") + st.set_page_config( + "Langchain-Chatchat WebUI", + os.path.join("img", "chatchat_icon_blue_square_v2.png"), + initial_sidebar_state="expanded", + ) if not chat_box.chat_inited: st.toast( diff --git a/webui_allinone.py b/webui_allinone.py index b347387..2992ae5 100644 --- a/webui_allinone.py +++ b/webui_allinone.py @@ -34,27 +34,31 @@ parser.add_argument("--theme.secondaryBackgroundColor",type=str,default='"#f5f5f parser.add_argument("--theme.textColor",type=str,default='"#000000"') web_args = ["server.port","theme.base","theme.primaryColor","theme.secondaryBackgroundColor","theme.textColor"] -args = parser.parse_args() -def launch_api(args=args,args_list=api_args,log_name=None): - print("launch api ...") +def launch_api(args,args_list=api_args,log_name=None): + print("Launching api ...") + print("启动API服务...") if not log_name: log_name = f"{LOG_PATH}api_{args.api_host}_{args.api_port}" print(f"logs on api are written in {log_name}") + print(f"API日志位于{log_name}下,如启动异常请查看日志") args_str = string_args(args,args_list) api_sh = "python server/{script} {args_str} >{log_name}.log 2>&1 &".format( script="api.py",args_str=args_str,log_name=log_name) subprocess.run(api_sh, shell=True, check=True) print("launch api done!") + print("启动API服务完毕.") - -def launch_webui(args=args,args_list=web_args,log_name=None): +def launch_webui(args,args_list=web_args,log_name=None): print("Launching webui...") + print("启动webui服务...") if not log_name: log_name = f"{LOG_PATH}webui" - print(f"logs on api are written in {log_name}") + args_str = string_args(args,args_list) if args.nohup: + print(f"logs on api are written in {log_name}") + print(f"webui服务日志位于{log_name}下,如启动异常请查看日志") webui_sh = "streamlit run webui.py {args_str} >{log_name}.log 2>&1 &".format( args_str=args_str,log_name=log_name) else: @@ -62,12 +66,18 @@ def launch_webui(args=args,args_list=web_args,log_name=None): args_str=args_str) subprocess.run(webui_sh, shell=True, check=True) print("launch webui done!") + print("启动webui服务完毕.") if __name__ == "__main__": print("Starting webui_allineone.py, it would take a while, please be patient....") + print(f"开始启动webui_allinone,启动LLM服务需要约3-10分钟,请耐心等待,如长时间未启动,请到{LOG_PATH}下查看日志...") + args = parser.parse_args() + + print("*"*80) if not args.use_remote_api: launch_all(args=args,controller_args=controller_args,worker_args=worker_args,server_args=server_args) launch_api(args=args,args_list=api_args) launch_webui(args=args,args_list=web_args) print("Start webui_allinone.py done!") + print("感谢耐心等待,启动webui_allinone完毕。") \ No newline at end of file diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index a39a196..a317aba 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -76,7 +76,7 @@ def dialogue_page(api: ApiRequest): key="selected_kb", ) kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3) - # score_threshold = st.slider("知识匹配分数阈值:", 0, 1, 0, disabled=True) + score_threshold = st.number_input("知识匹配分数阈值:", 0.0, 1.0, float(SCORE_THRESHOLD), 0.01) # chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) elif dialogue_mode == "搜索引擎问答": @@ -98,6 +98,9 @@ def dialogue_page(api: ApiRequest): text = "" r = api.chat_chat(prompt, history) for t in r: + if error_msg := check_error_msg(t): # check whether error occured + st.error(error_msg) + break text += t chat_box.update_msg(text) chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 @@ -108,7 +111,9 @@ def dialogue_page(api: ApiRequest): Markdown("...", in_expander=True, title="知识库匹配结果"), ]) text = "" - for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history): + for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history): + if error_msg := check_error_msg(d): # check whether error occured + st.error(error_msg) text += d["answer"] chat_box.update_msg(text, 0) chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) @@ -120,6 +125,8 @@ def dialogue_page(api: ApiRequest): ]) text = "" for d in api.search_engine_chat(prompt, search_engine, se_top_k): + if error_msg := check_error_msg(d): # check whether error occured + st.error(error_msg) text += d["answer"] chat_box.update_msg(text, 0) chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index f963a67..f059475 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -49,21 +49,27 @@ def file_exists(kb: str, selected_rows: List) -> Tuple[str, str]: def knowledge_base_page(api: ApiRequest): try: - kb_list = get_kb_details() + kb_list = {x["kb_name"]: x for x in get_kb_details()} except Exception as e: st.error("获取知识库信息错误,请检查是否已按照 `README.md` 中 `4 知识库初始化与迁移` 步骤完成初始化或迁移,或是否为数据库连接错误。") st.stop() - kb_names = [x["kb_name"] for x in kb_list] + kb_names = list(kb_list.keys()) if "selected_kb_name" in st.session_state and st.session_state["selected_kb_name"] in kb_names: selected_kb_index = kb_names.index(st.session_state["selected_kb_name"]) else: selected_kb_index = 0 + def format_selected_kb(kb_name: str) -> str: + if kb := kb_list.get(kb_name): + return f"{kb_name} ({kb['vs_type']} @ {kb['embed_model']})" + else: + return kb_name + selected_kb = st.selectbox( "请选择或新建知识库:", - kb_list + ["新建知识库"], - format_func=lambda s: f"{s['kb_name']} ({s['vs_type']} @ {s['embed_model']})" if type(s) != str else s, + kb_names + ["新建知识库"], + format_func=format_selected_kb, index=selected_kb_index ) @@ -117,9 +123,7 @@ def knowledge_base_page(api: ApiRequest): st.experimental_rerun() elif selected_kb: - kb = selected_kb["kb_name"] - - + kb = selected_kb # 上传文件 # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) diff --git a/webui_pages/utils.py b/webui_pages/utils.py index d2772d5..3e67ed7 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -6,8 +6,10 @@ from configs.model_config import ( DEFAULT_VS_TYPE, KB_ROOT_PATH, LLM_MODEL, + SCORE_THRESHOLD, VECTOR_SEARCH_TOP_K, SEARCH_ENGINE_TOP_K, + logger, ) import httpx import asyncio @@ -24,6 +26,7 @@ from configs.model_config import NLTK_DATA_PATH import nltk nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path + def set_httpx_timeout(timeout=60.0): ''' 设置httpx默认timeout到60秒。 @@ -80,7 +83,8 @@ class ApiRequest: return httpx.stream("GET", url, params=params, **kwargs) else: return httpx.get(url, params=params, **kwargs) - except: + except Exception as e: + logger.error(e) retry -= 1 async def aget( @@ -100,7 +104,8 @@ class ApiRequest: return await client.stream("GET", url, params=params, **kwargs) else: return await client.get(url, params=params, **kwargs) - except: + except Exception as e: + logger.error(e) retry -= 1 def post( @@ -121,7 +126,8 @@ class ApiRequest: return httpx.stream("POST", url, data=data, json=json, **kwargs) else: return httpx.post(url, data=data, json=json, **kwargs) - except: + except Exception as e: + logger.error(e) retry -= 1 async def apost( @@ -142,7 +148,8 @@ class ApiRequest: return await client.stream("POST", url, data=data, json=json, **kwargs) else: return await client.post(url, data=data, json=json, **kwargs) - except: + except Exception as e: + logger.error(e) retry -= 1 def delete( @@ -162,7 +169,8 @@ class ApiRequest: return httpx.stream("DELETE", url, data=data, json=json, **kwargs) else: return httpx.delete(url, data=data, json=json, **kwargs) - except: + except Exception as e: + logger.error(e) retry -= 1 async def adelete( @@ -183,7 +191,8 @@ class ApiRequest: return await client.stream("DELETE", url, data=data, json=json, **kwargs) else: return await client.delete(url, data=data, json=json, **kwargs) - except: + except Exception as e: + logger.error(e) retry -= 1 def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False): @@ -195,11 +204,14 @@ class ApiRequest: except: loop = asyncio.new_event_loop() - for chunk in iter_over_async(response.body_iterator, loop): - if as_json and chunk: - yield json.loads(chunk) - elif chunk.strip(): - yield chunk + try: + for chunk in iter_over_async(response.body_iterator, loop): + if as_json and chunk: + yield json.loads(chunk) + elif chunk.strip(): + yield chunk + except Exception as e: + logger.error(e) def _httpx_stream2generator( self, @@ -209,12 +221,26 @@ class ApiRequest: ''' 将httpx.stream返回的GeneratorContextManager转化为普通生成器 ''' - with response as r: - for chunk in r.iter_text(None): - if as_json and chunk: - yield json.loads(chunk) - elif chunk.strip(): - yield chunk + try: + with response as r: + for chunk in r.iter_text(None): + if as_json and chunk: + yield json.loads(chunk) + elif chunk.strip(): + yield chunk + except httpx.ConnectError as e: + msg = f"无法连接API服务器,请确认已执行python server\\api.py" + logger.error(msg) + logger.error(e) + yield {"code": 500, "errorMsg": msg} + except httpx.ReadTimeout as e: + msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')" + logger.error(msg) + logger.error(e) + yield {"code": 500, "errorMsg": msg} + except Exception as e: + logger.error(e) + yield {"code": 500, "errorMsg": str(e)} # 对话相关操作 @@ -287,6 +313,7 @@ class ApiRequest: query: str, knowledge_base_name: str, top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: float = SCORE_THRESHOLD, history: List[Dict] = [], stream: bool = True, no_remote_api: bool = None, @@ -301,6 +328,7 @@ class ApiRequest: "query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k, + "score_threshold": score_threshold, "history": history, "stream": stream, "local_doc_url": no_remote_api, @@ -353,6 +381,21 @@ class ApiRequest: # 知识库相关操作 + def _check_httpx_json_response( + self, + response: httpx.Response, + errorMsg: str = f"无法连接API服务器,请确认已执行python server\\api.py", + ) -> Dict: + ''' + check whether httpx returns correct data with normal Response. + error in api with streaming support was checked in _httpx_stream2enerator + ''' + try: + return response.json() + except Exception as e: + logger.error(e) + return {"code": 500, "errorMsg": errorMsg or str(e)} + def list_knowledge_bases( self, no_remote_api: bool = None, @@ -369,7 +412,8 @@ class ApiRequest: return response.data else: response = self.get("/knowledge_base/list_knowledge_bases") - return response.json().get("data") + data = self._check_httpx_json_response(response) + return data.get("data", []) def create_knowledge_base( self, @@ -399,7 +443,7 @@ class ApiRequest: "/knowledge_base/create_knowledge_base", json=data, ) - return response.json() + return self._check_httpx_json_response(response) def delete_knowledge_base( self, @@ -421,7 +465,7 @@ class ApiRequest: "/knowledge_base/delete_knowledge_base", json=f"{knowledge_base_name}", ) - return response.json() + return self._check_httpx_json_response(response) def list_kb_docs( self, @@ -443,7 +487,8 @@ class ApiRequest: "/knowledge_base/list_docs", params={"knowledge_base_name": knowledge_base_name} ) - return response.json().get("data") + data = self._check_httpx_json_response(response) + return data.get("data", []) def upload_kb_doc( self, @@ -487,7 +532,7 @@ class ApiRequest: data={"knowledge_base_name": knowledge_base_name, "override": override}, files={"file": (filename, file)}, ) - return response.json() + return self._check_httpx_json_response(response) def delete_kb_doc( self, @@ -517,7 +562,7 @@ class ApiRequest: "/knowledge_base/delete_doc", json=data, ) - return response.json() + return self._check_httpx_json_response(response) def update_kb_doc( self, @@ -540,7 +585,7 @@ class ApiRequest: "/knowledge_base/update_doc", json={"knowledge_base_name": knowledge_base_name, "file_name": file_name}, ) - return response.json() + return self._check_httpx_json_response(response) def recreate_vector_store( self, @@ -572,10 +617,20 @@ class ApiRequest: "/knowledge_base/recreate_vector_store", json=data, stream=True, + timeout=False, ) return self._httpx_stream2generator(response, as_json=True) +def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: + ''' + return error message if error occured when requests API + ''' + if isinstance(data, dict) and key in data: + return data[key] + return "" + + if __name__ == "__main__": api = ApiRequest(no_remote_api=True)