diff --git a/configs/kb_config.py.exmaple b/configs/kb_config.py.exmaple deleted file mode 100644 index 3ceee3c..0000000 --- a/configs/kb_config.py.exmaple +++ /dev/null @@ -1,99 +0,0 @@ -import os - - -# 默认向量库类型。可选:faiss, milvus, pg. -DEFAULT_VS_TYPE = "faiss" - -# 缓存向量库数量(针对FAISS) -CACHED_VS_NUM = 1 - -# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) -CHUNK_SIZE = 250 - -# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter) -OVERLAP_SIZE = 50 - -# 知识库匹配向量数量 -VECTOR_SEARCH_TOP_K = 3 - -# 知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右 -SCORE_THRESHOLD = 1 - -# 搜索引擎匹配结题数量 -SEARCH_ENGINE_TOP_K = 3 - - -# Bing 搜索必备变量 -# 使用 Bing 搜索需要使用 Bing Subscription Key,需要在azure port中申请试用bing search -# 具体申请方式请见 -# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/create-bing-search-service-resource -# 使用python创建bing api 搜索实例详见: -# https://learn.microsoft.com/en-us/bing/search-apis/bing-web-search/quickstarts/rest/python -BING_SEARCH_URL = "https://api.bing.microsoft.com/v7.0/search" -# 注意不是bing Webmaster Tools的api key, - -# 此外,如果是在服务器上,报Failed to establish a new connection: [Errno 110] Connection timed out -# 是因为服务器加了防火墙,需要联系管理员加白名单,如果公司的服务器的话,就别想了GG -BING_SUBSCRIPTION_KEY = "" - -# 是否开启中文标题加强,以及标题增强的相关配置 -# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记; -# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 -ZH_TITLE_ENHANCE = False - - -# 通常情况下不需要更改以下内容 - -# 知识库默认存储路径 -KB_ROOT_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "knowledge_base") -if not os.path.exists(KB_ROOT_PATH): - os.mkdir(KB_ROOT_PATH) - -# 数据库默认存储路径。 -# 如果使用sqlite,可以直接修改DB_ROOT_PATH;如果使用其它数据库,请直接修改SQLALCHEMY_DATABASE_URI。 -DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") -SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}" - -# 可选向量库类型及对应配置 -kbs_config = { - "faiss": { - }, - "milvus": { - "host": "127.0.0.1", - "port": "19530", - "user": "", - "password": "", - "secure": False, - }, - "pg": { - "connection_uri": "postgresql://postgres:postgres@127.0.0.1:5432/langchain_chatchat", - } -} - -# TextSplitter配置项,如果你不明白其中的含义,就不要修改。 -text_splitter_dict = { - "ChineseRecursiveTextSplitter": { - "source": "huggingface", ## 选择tiktoken则使用openai的方法 - "tokenizer_name_or_path": "gpt2", - }, - "SpacyTextSplitter": { - "source": "huggingface", - "tokenizer_name_or_path": "", - }, - "RecursiveCharacterTextSplitter": { - "source": "tiktoken", - "tokenizer_name_or_path": "cl100k_base", - }, - "MarkdownHeaderTextSplitter": { - "headers_to_split_on": - [ - ("#", "head1"), - ("##", "head2"), - ("###", "head3"), - ("####", "head4"), - ] - }, -} - -# TEXT_SPLITTER 名称 -TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter" diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 801e6d2..71b1229 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -32,6 +32,7 @@ FSCHAT_OPENAI_API = { # fastchat model_worker server # 这些模型必须是在model_config.MODEL_PATH或ONLINE_MODEL中正确配置的。 # 在启动startup.py时,可用通过`--model-worker --model-name xxxx`指定模型,不指定则为LLM_MODEL +# 必须在这里添加的模型才会出现在WEBUI中可选模型列表里(LLM_MODEL会自动添加) FSCHAT_MODEL_WORKERS = { # 所有模型共用的默认配置,可在模型专项配置中进行覆盖。 "default": { @@ -39,7 +40,8 @@ FSCHAT_MODEL_WORKERS = { "port": 20002, "device": LLM_DEVICE, # False,'vllm',使用的推理加速框架,使用vllm如果出现HuggingFace通信问题,参见doc/FAQ - "infer_turbo": "vllm" if sys.platform.startswith("linux") else False, + # vllm对一些模型支持还不成熟,暂时默认关闭 + "infer_turbo": False, # model_worker多卡加载需要配置的参数 # "gpus": None, # 使用的GPU,以str的格式指定,如"0,1",如失效请使用CUDA_VISIBLE_DEVICES="0,1"等形式指定 @@ -97,24 +99,24 @@ FSCHAT_MODEL_WORKERS = { "zhipu-api": { # 请为每个要运行的在线API设置不同的端口 "port": 21001, }, - "minimax-api": { - "port": 21002, - }, - "xinghuo-api": { - "port": 21003, - }, - "qianfan-api": { - "port": 21004, - }, - "fangzhou-api": { - "port": 21005, - }, - "qwen-api": { - "port": 21006, - }, - "baichuan-api": { - "port": 21007, - }, + # "minimax-api": { + # "port": 21002, + # }, + # "xinghuo-api": { + # "port": 21003, + # }, + # "qianfan-api": { + # "port": 21004, + # }, + # "fangzhou-api": { + # "port": 21005, + # }, + # "qwen-api": { + # "port": 21006, + # }, + # "baichuan-api": { + # "port": 21007, + # }, } # fastchat multi model worker server diff --git a/requirements.txt b/requirements.txt index 68385b7..02bd05d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ -langchain>=0.0.310 -fschat[model_worker]>=0.2.30 +langchain==0.0.313 +langchain-experimental==0.0.30 +fschat[model_worker]==0.2.30 openai sentence_transformers transformers>=4.34 diff --git a/requirements_api.txt b/requirements_api.txt index b5428dd..af4e7e0 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -1,5 +1,6 @@ -langchain>=0.0.310 -fschat[model_worker]>=0.2.30 +langchain==0.0.313 +langchain-experimental==0.0.30 +fschat[model_worker]==0.2.30 openai sentence_transformers>=2.2.2 transformers>=4.34 diff --git a/server/api.py b/server/api.py index ea098d6..0b0692a 100644 --- a/server/api.py +++ b/server/api.py @@ -17,8 +17,10 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, update_docs, download_doc, recreate_vector_store, search_docs, DocumentWithScore) -from server.llm_api import list_running_models, list_config_models, change_llm_model, stop_llm_model -from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline +from server.llm_api import (list_running_models, list_config_models, + change_llm_model, stop_llm_model, + get_model_config, list_search_engines) +from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, get_server_configs from typing import List nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -139,6 +141,11 @@ def create_app(): summary="列出configs已配置的模型", )(list_config_models) + app.post("/llm_model/get_model_config", + tags=["LLM Model Management"], + summary="获取模型配置(合并后)", + )(get_model_config) + app.post("/llm_model/stop", tags=["LLM Model Management"], summary="停止指定的LLM模型(Model Worker)", @@ -149,6 +156,17 @@ def create_app(): summary="切换指定的LLM模型(Model Worker)", )(change_llm_model) + # 服务器相关接口 + app.post("/server/configs", + tags=["Server State"], + summary="获取服务器原始配置信息", + )(get_server_configs) + + app.post("/server/list_search_engines", + tags=["Server State"], + summary="获取服务器支持的搜索引擎", + )(list_search_engines) + return app diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 87e7098..6ae90e2 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -33,7 +33,6 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), - local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), request: Request = None, ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) @@ -74,11 +73,8 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", source_documents = [] for inum, doc in enumerate(docs): filename = os.path.split(doc.metadata["source"])[-1] - if local_doc_url: - url = "file://" + doc.metadata["source"] - else: - parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename}) - url = f"{request.base_url}knowledge_base/download_doc?" + parameters + parameters = urlencode({"knowledge_base_name": knowledge_base_name, "file_name":filename}) + url = f"{request.base_url}knowledge_base/download_doc?" + parameters text = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{doc.page_content}\n\n""" source_documents.append(text) diff --git a/server/llm_api.py b/server/llm_api.py index b028747..2b1ce45 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -1,7 +1,7 @@ from fastapi import Body from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT -from server.utils import BaseResponse, fschat_controller_address, list_llm_models, get_httpx_client - +from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models, + get_httpx_client, get_model_worker_config) def list_running_models( @@ -9,19 +9,21 @@ def list_running_models( placeholder: str = Body(None, description="该参数未使用,占位用"), ) -> BaseResponse: ''' - 从fastchat controller获取已加载模型列表 + 从fastchat controller获取已加载模型列表及其配置项 ''' try: controller_address = controller_address or fschat_controller_address() with get_httpx_client() as client: r = client.post(controller_address + "/list_models") - return BaseResponse(data=r.json()["models"]) + models = r.json()["models"] + data = {m: get_model_worker_config(m) for m in models} + return BaseResponse(data=data) except Exception as e: logger.error(f'{e.__class__.__name__}: {e}', exc_info=e if log_verbose else None) return BaseResponse( code=500, - data=[], + data={}, msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}") @@ -29,7 +31,38 @@ def list_config_models() -> BaseResponse: ''' 从本地获取configs中配置的模型列表 ''' - return BaseResponse(data=list_llm_models()) + configs = list_config_llm_models() + + # 删除ONLINE_MODEL配置中的敏感信息 + for config in configs["online"].values(): + del_keys = set(["worker_class"]) + for k in config: + if "key" in k.lower() or "secret" in k.lower(): + del_keys.add(k) + for k in del_keys: + config.pop(k, None) + + return BaseResponse(data=configs) + + +def get_model_config( + model_name: str = Body(description="配置中LLM模型的名称"), + placeholder: str = Body(description="占位用,无实际效果") +) -> BaseResponse: + ''' + 获取LLM模型配置项(合并后的) + ''' + config = get_model_worker_config(model_name=model_name) + + # 删除ONLINE_MODEL配置中的敏感信息 + del_keys = set(["worker_class"]) + for k in config: + if "key" in k.lower() or "secret" in k.lower(): + del_keys.add(k) + for k in del_keys: + config.pop(k, None) + + return BaseResponse(data=config) def stop_llm_model( @@ -79,3 +112,9 @@ def change_llm_model( return BaseResponse( code=500, msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}") + + +def list_search_engines() -> BaseResponse: + from server.chat.search_engine_chat import SEARCH_ENGINES + + return BaseResponse(data=list(SEARCH_ENGINES)) diff --git a/server/utils.py b/server/utils.py index bae9fe4..f62cccf 100644 --- a/server/utils.py +++ b/server/utils.py @@ -258,17 +258,18 @@ def list_embed_models() -> List[str]: return list(MODEL_PATH["embed_model"]) -def list_llm_models() -> Dict[str, List[str]]: +def list_config_llm_models() -> Dict[str, Dict]: ''' - get names of configured llm models with different types. + get configured llm models with different types. return [(model_name, config_type), ...] ''' workers = list(FSCHAT_MODEL_WORKERS) - if "default" in workers: - workers.remove("default") + if LLM_MODEL not in workers: + workers.insert(0, LLM_MODEL) + return { - "local": list(MODEL_PATH["llm_model"]), - "online": list(ONLINE_LLM_MODEL), + "local": MODEL_PATH["llm_model"], + "online": ONLINE_LLM_MODEL, "worker": workers, } @@ -306,7 +307,7 @@ def get_model_worker_config(model_name: str = None) -> dict: 加载model worker的配置项。 优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"] ''' - from configs.model_config import ONLINE_LLM_MODEL + from configs.model_config import ONLINE_LLM_MODEL, MODEL_PATH from configs.server_config import FSCHAT_MODEL_WORKERS from server import model_workers @@ -324,9 +325,10 @@ def get_model_worker_config(model_name: str = None) -> dict: msg = f"在线模型 ‘{model_name}’ 的provider没有正确配置" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) - - config["model_path"] = get_model_path(model_name) - config["device"] = llm_device(config.get("device")) + # 本地模型 + if model_name in MODEL_PATH["llm_model"]: + config["model_path"] = get_model_path(model_name) + config["device"] = llm_device(config.get("device")) return config @@ -449,11 +451,11 @@ def set_httpx_config( # TODO: 简单的清除系统代理不是个好的选择,影响太多。似乎修改代理服务器的bypass列表更好。 # patch requests to use custom proxies instead of system settings - # def _get_proxies(): - # return {} + def _get_proxies(): + return proxies - # import urllib.request - # urllib.request.getproxies = _get_proxies + import urllib.request + urllib.request.getproxies = _get_proxies # 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch @@ -557,3 +559,35 @@ def get_httpx_client( return httpx.AsyncClient(**kwargs) else: return httpx.Client(**kwargs) + + +def get_server_configs() -> Dict: + ''' + 获取configs中的原始配置项,供前端使用 + ''' + from configs.kb_config import ( + DEFAULT_VS_TYPE, + CHUNK_SIZE, + OVERLAP_SIZE, + SCORE_THRESHOLD, + VECTOR_SEARCH_TOP_K, + SEARCH_ENGINE_TOP_K, + ZH_TITLE_ENHANCE, + text_splitter_dict, + TEXT_SPLITTER_NAME, + ) + from configs.model_config import ( + LLM_MODEL, + EMBEDDING_MODEL, + HISTORY_LEN, + TEMPERATURE, + ) + from configs.prompt_config import PROMPT_TEMPLATES + + _custom = { + "controller_address": fschat_controller_address(), + "openai_api_address": fschat_openai_api_address(), + "api_address": api_address(), + } + + return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom} diff --git a/tests/api/test_kb_api_request.py b/tests/api/test_kb_api_request.py index 3c115f1..400b1c6 100644 --- a/tests/api/test_kb_api_request.py +++ b/tests/api/test_kb_api_request.py @@ -14,7 +14,7 @@ from pprint import pprint api_base_url = api_address() -api: ApiRequest = ApiRequest(api_base_url, no_remote_api=False) +api: ApiRequest = ApiRequest(api_base_url) kb = "kb_for_api_test" diff --git a/tests/api/test_llm_api.py b/tests/api/test_llm_api.py index 8957981..9b9a4a6 100644 --- a/tests/api/test_llm_api.py +++ b/tests/api/test_llm_api.py @@ -32,7 +32,7 @@ def get_running_models(api="/llm_model/list_models"): return [] -def test_running_models(api="/llm_model/list_models"): +def test_running_models(api="/llm_model/list_running_models"): url = api_base_url + api r = requests.post(url) assert r.status_code == 200 @@ -48,7 +48,7 @@ def test_running_models(api="/llm_model/list_models"): # r = requests.post(url, json={""}) -def test_change_model(api="/llm_model/change"): +def test_change_model(api="/llm_model/change_model"): url = api_base_url + api running_models = get_running_models() diff --git a/webui.py b/webui.py index 2750c47..776d5e6 100644 --- a/webui.py +++ b/webui.py @@ -22,9 +22,10 @@ if __name__ == "__main__": ) if not chat_box.chat_inited: + running_models = api.list_running_models() st.toast( f"欢迎使用 [`Langchain-Chatchat`](https://github.com/chatchat-space/Langchain-Chatchat) ! \n\n" - f"当前使用模型`{LLM_MODEL}`, 您可以开始提问了." + f"当前运行中的模型`{running_models}`, 您可以开始提问了." ) pages = { diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 2664b30..deadc32 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -2,11 +2,11 @@ import streamlit as st from webui_pages.utils import * from streamlit_chatbox import * from datetime import datetime -from server.chat.search_engine_chat import SEARCH_ENGINES import os -from configs import LLM_MODEL, TEMPERATURE -from server.utils import get_model_worker_config +from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN from typing import List, Dict + + chat_box = ChatBox( assistant_avatar=os.path.join( "img", @@ -15,9 +15,6 @@ chat_box = ChatBox( ) - - - def get_messages_history(history_len: int, content_in_expander: bool = False) -> List[Dict]: ''' 返回消息历史。 @@ -38,6 +35,26 @@ def get_messages_history(history_len: int, content_in_expander: bool = False) -> return chat_box.filter_history(history_len=history_len, filter=filter) +def get_default_llm_model(api: ApiRequest) -> (str, bool): + ''' + 从服务器上获取当前运行的LLM模型,如果本机配置的LLM_MODEL属于本地模型且在其中,则优先返回 + 返回类型为(model_name, is_local_model) + ''' + running_models = api.list_running_models() + + if not running_models: + return "", False + + if LLM_MODEL in running_models: + return LLM_MODEL, True + + local_models = [k for k, v in running_models.items() if not v.get("online_api")] + if local_models: + return local_models[0], True + + return running_models[0], False + + def dialogue_page(api: ApiRequest): chat_box.init_session() @@ -51,7 +68,6 @@ def dialogue_page(api: ApiRequest): if cur_kb: text = f"{text} 当前知识库: `{cur_kb}`。" st.toast(text) - # sac.alert(text, description="descp", type="success", closable=True, banner=True) dialogue_mode = st.selectbox("请选择对话模式:", ["LLM 对话", @@ -65,7 +81,7 @@ def dialogue_page(api: ApiRequest): ) def on_llm_change(): - config = get_model_worker_config(llm_model) + config = api.get_model_config(llm_model) if not config.get("online_api"): # 只有本地model_worker可以切换模型 st.session_state["prev_llm_model"] = llm_model st.session_state["cur_llm_model"] = st.session_state.llm_model @@ -75,15 +91,20 @@ def dialogue_page(api: ApiRequest): return f"{x} (Running)" return x - running_models = api.list_running_models() + running_models = list(api.list_running_models()) available_models = [] config_models = api.list_config_models() - for models in config_models.values(): - for m in models: - if m not in running_models: - available_models.append(m) + worker_models = list(config_models.get("worker", {})) # 仅列出在FSCHAT_MODEL_WORKERS中配置的模型 + for m in worker_models: + if m not in running_models and m != "default": + available_models.append(m) + for k, v in config_models.get("online", {}).items(): # 列出ONLINE_MODELS中直接访问的模型(如GPT) + if not v.get("provider") and k not in running_models: + print(k, v) + available_models.append(k) + llm_models = running_models + available_models - index = llm_models.index(st.session_state.get("cur_llm_model", LLM_MODEL)) + index = llm_models.index(st.session_state.get("cur_llm_model", get_default_llm_model(api)[0])) llm_model = st.selectbox("选择LLM模型:", llm_models, index, @@ -92,7 +113,7 @@ def dialogue_page(api: ApiRequest): key="llm_model", ) if (st.session_state.get("prev_llm_model") != llm_model - and not get_model_worker_config(llm_model).get("online_api") + and not api.get_model_config(llm_model).get("online_api") and llm_model not in running_models): with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"): prev_model = st.session_state.get("prev_llm_model") @@ -114,7 +135,7 @@ def dialogue_page(api: ApiRequest): if dialogue_mode == "知识库问答": with st.expander("知识库配置", True): - kb_list = api.list_knowledge_bases(no_remote_api=True) + kb_list = api.list_knowledge_bases() selected_kb = st.selectbox( "请选择知识库:", kb_list, @@ -126,7 +147,7 @@ def dialogue_page(api: ApiRequest): # chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) elif dialogue_mode == "搜索引擎问答": - search_engine_list = list(SEARCH_ENGINES.keys()) + search_engine_list = api.list_search_engines() with st.expander("搜索引擎配置", True): search_engine = st.selectbox( label="请选择搜索引擎", diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 6ba4661..3e07766 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -1,12 +1,14 @@ -# 该文件包含webui通用工具,可以被不同的webui使用 +# 该文件封装了对api.py的请求,可以被不同的webui使用 +# 通过ApiRequest和AsyncApiRequest支持同步/异步调用 + + from typing import * from pathlib import Path +# 此处导入的配置为发起请求(如WEBUI)机器上的配置,主要用于为前端设置默认值。分布式部署时可以与服务器上的不同 from configs import ( EMBEDDING_MODEL, DEFAULT_VS_TYPE, - KB_ROOT_PATH, LLM_MODEL, - HISTORY_LEN, TEMPERATURE, SCORE_THRESHOLD, CHUNK_SIZE, @@ -14,59 +16,44 @@ from configs import ( ZH_TITLE_ENHANCE, VECTOR_SEARCH_TOP_K, SEARCH_ENGINE_TOP_K, - FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT, logger, log_verbose, ) import httpx -import asyncio from server.chat.openai_chat import OpenAiChatMsgIn -from fastapi.responses import StreamingResponse import contextlib import json import os from io import BytesIO -from server.utils import run_async, iter_over_async, set_httpx_config, api_address, get_httpx_client +from server.utils import run_async, set_httpx_config, api_address, get_httpx_client -from configs.model_config import NLTK_DATA_PATH -import nltk -nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path from pprint import pprint -KB_ROOT_PATH = Path(KB_ROOT_PATH) set_httpx_config() class ApiRequest: ''' - api.py调用的封装,主要实现: - 1. 简化api调用方式 - 2. 实现无api调用(直接运行server.chat.*中的视图函数获取结果),无需启动api.py + api.py调用的封装(同步模式),简化api调用方式 ''' def __init__( self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT, - no_remote_api: bool = False, # call api view function directly ): self.base_url = base_url self.timeout = timeout - self.no_remote_api = no_remote_api - self._client = get_httpx_client() - self._aclient = get_httpx_client(use_async=True) - if no_remote_api: - logger.warn("将来可能取消对no_remote_api的支持,更新版本时请注意。") + self._use_async = False + self._client = None - def _parse_url(self, url: str) -> str: - if (not url.startswith("http") - and self.base_url - ): - part1 = self.base_url.strip(" /") - part2 = url.strip(" /") - return f"{part1}/{part2}" - else: - return url + @property + def client(self): + if self._client is None or self._client.is_closed: + self._client = get_httpx_client(base_url=self.base_url, + use_async=self._use_async, + timeout=self.timeout) + return self._client def get( self, @@ -75,44 +62,19 @@ class ApiRequest: retry: int = 3, stream: bool = False, **kwargs: Any, - ) -> Union[httpx.Response, None]: - url = self._parse_url(url) - kwargs.setdefault("timeout", self.timeout) + ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: if stream: - return self._client.stream("GET", url, params=params, **kwargs) + return self.client.stream("GET", url, params=params, **kwargs) else: - return self._client.get(url, params=params, **kwargs) + return self.client.get(url, params=params, **kwargs) except Exception as e: msg = f"error when get {url}: {e}" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) retry -= 1 - async def aget( - self, - url: str, - params: Union[Dict, List[Tuple], bytes] = None, - retry: int = 3, - stream: bool = False, - **kwargs: Any, - ) -> Union[httpx.Response, None]: - url = self._parse_url(url) - kwargs.setdefault("timeout", self.timeout) - - while retry > 0: - try: - if stream: - return await self._aclient.stream("GET", url, params=params, **kwargs) - else: - return await self._aclient.get(url, params=params, **kwargs) - except Exception as e: - msg = f"error when aget {url}: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - retry -= 1 - def post( self, url: str, @@ -121,45 +83,19 @@ class ApiRequest: retry: int = 3, stream: bool = False, **kwargs: Any - ) -> Union[httpx.Response, None]: - url = self._parse_url(url) - kwargs.setdefault("timeout", self.timeout) + ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: if stream: - return self._client.stream("POST", url, data=data, json=json, **kwargs) + return self.client.stream("POST", url, data=data, json=json, **kwargs) else: - return self._client.post(url, data=data, json=json, **kwargs) + return self.client.post(url, data=data, json=json, **kwargs) except Exception as e: msg = f"error when post {url}: {e}" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) retry -= 1 - async def apost( - self, - url: str, - data: Dict = None, - json: Dict = None, - retry: int = 3, - stream: bool = False, - **kwargs: Any - ) -> Union[httpx.Response, None]: - url = self._parse_url(url) - kwargs.setdefault("timeout", self.timeout) - - while retry > 0: - try: - if stream: - return await self._client.stream("POST", url, data=data, json=json, **kwargs) - else: - return await self._client.post(url, data=data, json=json, **kwargs) - except Exception as e: - msg = f"error when apost {url}: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - retry -= 1 - def delete( self, url: str, @@ -168,65 +104,19 @@ class ApiRequest: retry: int = 3, stream: bool = False, **kwargs: Any - ) -> Union[httpx.Response, None]: - url = self._parse_url(url) - kwargs.setdefault("timeout", self.timeout) + ) -> Union[httpx.Response, Iterator[httpx.Response], None]: while retry > 0: try: if stream: - return self._client.stream("DELETE", url, data=data, json=json, **kwargs) + return self.client.stream("DELETE", url, data=data, json=json, **kwargs) else: - return self._client.delete(url, data=data, json=json, **kwargs) + return self.client.delete(url, data=data, json=json, **kwargs) except Exception as e: msg = f"error when delete {url}: {e}" logger.error(f'{e.__class__.__name__}: {msg}', exc_info=e if log_verbose else None) retry -= 1 - async def adelete( - self, - url: str, - data: Dict = None, - json: Dict = None, - retry: int = 3, - stream: bool = False, - **kwargs: Any - ) -> Union[httpx.Response, None]: - url = self._parse_url(url) - kwargs.setdefault("timeout", self.timeout) - - while retry > 0: - try: - if stream: - return await self._aclient.stream("DELETE", url, data=data, json=json, **kwargs) - else: - return await self._aclient.delete(url, data=data, json=json, **kwargs) - except Exception as e: - msg = f"error when adelete {url}: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - retry -= 1 - - def _fastapi_stream2generator(self, response: StreamingResponse, as_json: bool =False): - ''' - 将api.py中视图函数返回的StreamingResponse转化为同步生成器 - ''' - try: - loop = asyncio.get_event_loop() - except: - loop = asyncio.new_event_loop() - - try: - for chunk in iter_over_async(response.body_iterator, loop): - if as_json and chunk: - yield json.loads(chunk) - elif chunk.strip(): - yield chunk - except Exception as e: - msg = f"error when run fastapi router: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - def _httpx_stream2generator( self, response: contextlib._GeneratorContextManager, @@ -235,37 +125,117 @@ class ApiRequest: ''' 将httpx.stream返回的GeneratorContextManager转化为普通生成器 ''' - try: - with response as r: - for chunk in r.iter_text(None): - if not chunk: # fastchat api yield empty bytes on start and end - continue - if as_json: - try: - data = json.loads(chunk) - pprint(data, depth=1) - yield data - except Exception as e: - msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - else: - # print(chunk, end="", flush=True) - yield chunk - except httpx.ConnectError as e: - msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" - logger.error(msg) - logger.error(msg) - yield {"code": 500, "msg": msg} - except httpx.ReadTimeout as e: - msg = f"API通信超时,请确认已启动FastChat与API服务(详见RADME '5. 启动 API 服务或 Web UI')。({e})" - logger.error(msg) - yield {"code": 500, "msg": msg} - except Exception as e: - msg = f"API通信遇到错误:{e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - yield {"code": 500, "msg": msg} + async def ret_async(response, as_json): + try: + async with response as r: + async for chunk in r.aiter_text(None): + if not chunk: # fastchat api yield empty bytes on start and end + continue + if as_json: + try: + data = json.loads(chunk) + pprint(data, depth=1) + yield data + except Exception as e: + msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + else: + # print(chunk, end="", flush=True) + yield chunk + except httpx.ConnectError as e: + msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" + logger.error(msg) + yield {"code": 500, "msg": msg} + except httpx.ReadTimeout as e: + msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})" + logger.error(msg) + yield {"code": 500, "msg": msg} + except Exception as e: + msg = f"API通信遇到错误:{e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + yield {"code": 500, "msg": msg} + + def ret_sync(response, as_json): + try: + with response as r: + for chunk in r.iter_text(None): + if not chunk: # fastchat api yield empty bytes on start and end + continue + if as_json: + try: + data = json.loads(chunk) + pprint(data, depth=1) + yield data + except Exception as e: + msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + else: + # print(chunk, end="", flush=True) + yield chunk + except httpx.ConnectError as e: + msg = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。({e})" + logger.error(msg) + yield {"code": 500, "msg": msg} + except httpx.ReadTimeout as e: + msg = f"API通信超时,请确认已启动FastChat与API服务(详见Wiki '5. 启动 API 服务或 Web UI')。({e})" + logger.error(msg) + yield {"code": 500, "msg": msg} + except Exception as e: + msg = f"API通信遇到错误:{e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + yield {"code": 500, "msg": msg} + + if self._use_async: + return ret_async(response, as_json) + else: + return ret_sync(response, as_json) + + def _get_response_value( + self, + response: httpx.Response, + as_json: bool = False, + value_func: Callable = None, + ): + ''' + 转换同步或异步请求返回的响应 + `as_json`: 返回json + `value_func`: 用户可以自定义返回值,该函数接受response或json + ''' + def to_json(r): + try: + return r.json() + except Exception as e: + msg = "API未能返回正确的JSON。" + str(e) + if log_verbose: + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + return {"code": 500, "msg": msg} + + if value_func is None: + value_func = (lambda r: r) + + async def ret_async(response): + if as_json: + return value_func(to_json(await response)) + else: + return value_func(await response) + + if self._use_async: + return ret_async(response) + else: + if as_json: + return value_func(to_json(response)) + else: + return value_func(response) + + # 服务器信息 + def get_server_configs(self, **kwargs): + response = self.post("/server/configs", **kwargs) + return self._get_response_value(response, lambda r: r.json()) # 对话相关操作 @@ -275,15 +245,12 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, - max_tokens: int = 1024, # TODO:根据message内容自动计算max_tokens - no_remote_api: bool = None, + max_tokens: int = 1024, **kwargs: Any, ): ''' 对应api.py/chat/fastchat接口 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api msg = OpenAiChatMsgIn(**{ "messages": messages, "stream": stream, @@ -293,21 +260,16 @@ class ApiRequest: **kwargs, }) - if no_remote_api: - from server.chat.openai_chat import openai_chat - response = run_async(openai_chat(msg)) - return self._fastapi_stream2generator(response) - else: - data = msg.dict(exclude_unset=True, exclude_none=True) - print(f"received input message:") - pprint(data) + data = msg.dict(exclude_unset=True, exclude_none=True) + print(f"received input message:") + pprint(data) - response = self.post( - "/chat/fastchat", - json=data, - stream=True, - ) - return self._httpx_stream2generator(response) + response = self.post( + "/chat/fastchat", + json=data, + stream=True, + ) + return self._httpx_stream2generator(response) def chat_chat( self, @@ -318,14 +280,11 @@ class ApiRequest: temperature: float = TEMPERATURE, max_tokens: int = 1024, prompt_name: str = "llm_chat", - no_remote_api: bool = None, + **kwargs, ): ''' - 对应api.py/chat/chat接口 + 对应api.py/chat/chat接口 #TODO: 考虑是否返回json ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - data = { "query": query, "history": history, @@ -339,13 +298,8 @@ class ApiRequest: print(f"received input message:") pprint(data) - if no_remote_api: - from server.chat.chat import chat - response = run_async(chat(**data)) - return self._fastapi_stream2generator(response) - else: - response = self.post("/chat/chat", json=data, stream=True) - return self._httpx_stream2generator(response) + response = self.post("/chat/chat", json=data, stream=True, **kwargs) + return self._httpx_stream2generator(response) def agent_chat( self, @@ -355,14 +309,10 @@ class ApiRequest: model: str = LLM_MODEL, temperature: float = TEMPERATURE, max_tokens: int = 1024, - no_remote_api: bool = None, ): ''' 对应api.py/chat/agent_chat 接口 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - data = { "query": query, "history": history, @@ -375,13 +325,8 @@ class ApiRequest: print(f"received input message:") pprint(data) - if no_remote_api: - from server.chat.agent_chat import agent_chat - response = run_async(agent_chat(**data)) - return self._fastapi_stream2generator(response) - else: - response = self.post("/chat/agent_chat", json=data, stream=True) - return self._httpx_stream2generator(response) + response = self.post("/chat/agent_chat", json=data, stream=True) + return self._httpx_stream2generator(response) def knowledge_base_chat( self, @@ -395,14 +340,10 @@ class ApiRequest: temperature: float = TEMPERATURE, max_tokens: int = 1024, prompt_name: str = "knowledge_base_chat", - no_remote_api: bool = None, ): ''' 对应api.py/chat/knowledge_base_chat接口 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - data = { "query": query, "knowledge_base_name": knowledge_base_name, @@ -413,24 +354,18 @@ class ApiRequest: "model_name": model, "temperature": temperature, "max_tokens": max_tokens, - "local_doc_url": no_remote_api, "prompt_name": prompt_name, } print(f"received input message:") pprint(data) - if no_remote_api: - from server.chat.knowledge_base_chat import knowledge_base_chat - response = run_async(knowledge_base_chat(**data)) - return self._fastapi_stream2generator(response, as_json=True) - else: - response = self.post( - "/chat/knowledge_base_chat", - json=data, - stream=True, - ) - return self._httpx_stream2generator(response, as_json=True) + response = self.post( + "/chat/knowledge_base_chat", + json=data, + stream=True, + ) + return self._httpx_stream2generator(response, as_json=True) def search_engine_chat( self, @@ -443,14 +378,10 @@ class ApiRequest: temperature: float = TEMPERATURE, max_tokens: int = 1024, prompt_name: str = "knowledge_base_chat", - no_remote_api: bool = None, ): ''' 对应api.py/chat/search_engine_chat接口 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - data = { "query": query, "search_engine_name": search_engine_name, @@ -466,130 +397,74 @@ class ApiRequest: print(f"received input message:") pprint(data) - if no_remote_api: - from server.chat.search_engine_chat import search_engine_chat - response = run_async(search_engine_chat(**data)) - return self._fastapi_stream2generator(response, as_json=True) - else: - response = self.post( - "/chat/search_engine_chat", - json=data, - stream=True, - ) - return self._httpx_stream2generator(response, as_json=True) + response = self.post( + "/chat/search_engine_chat", + json=data, + stream=True, + ) + return self._httpx_stream2generator(response, as_json=True) # 知识库相关操作 - def _check_httpx_json_response( - self, - response: httpx.Response, - errorMsg: str = f"无法连接API服务器,请确认 ‘api.py’ 已正常启动。", - ) -> Dict: - ''' - check whether httpx returns correct data with normal Response. - error in api with streaming support was checked in _httpx_stream2enerator - ''' - try: - return response.json() - except Exception as e: - msg = "API未能返回正确的JSON。" + (errorMsg or str(e)) - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - return {"code": 500, "msg": msg} - def list_knowledge_bases( self, - no_remote_api: bool = None, ): ''' 对应api.py/knowledge_base/list_knowledge_bases接口 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - - if no_remote_api: - from server.knowledge_base.kb_api import list_kbs - response = list_kbs() - return response.data - else: - response = self.get("/knowledge_base/list_knowledge_bases") - data = self._check_httpx_json_response(response) - return data.get("data", []) + response = self.get("/knowledge_base/list_knowledge_bases") + return self._get_response_value(response, + as_json=True, + value_func=lambda r: r.get("data", [])) def create_knowledge_base( self, knowledge_base_name: str, - vector_store_type: str = "faiss", + vector_store_type: str = DEFAULT_VS_TYPE, embed_model: str = EMBEDDING_MODEL, - no_remote_api: bool = None, ): ''' 对应api.py/knowledge_base/create_knowledge_base接口 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - data = { "knowledge_base_name": knowledge_base_name, "vector_store_type": vector_store_type, "embed_model": embed_model, } - if no_remote_api: - from server.knowledge_base.kb_api import create_kb - response = create_kb(**data) - return response.dict() - else: - response = self.post( - "/knowledge_base/create_knowledge_base", - json=data, - ) - return self._check_httpx_json_response(response) + response = self.post( + "/knowledge_base/create_knowledge_base", + json=data, + ) + return self._get_response_value(response, as_json=True) def delete_knowledge_base( self, knowledge_base_name: str, - no_remote_api: bool = None, ): ''' 对应api.py/knowledge_base/delete_knowledge_base接口 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - - if no_remote_api: - from server.knowledge_base.kb_api import delete_kb - response = delete_kb(knowledge_base_name) - return response.dict() - else: - response = self.post( - "/knowledge_base/delete_knowledge_base", - json=f"{knowledge_base_name}", - ) - return self._check_httpx_json_response(response) + response = self.post( + "/knowledge_base/delete_knowledge_base", + json=f"{knowledge_base_name}", + ) + return self._get_response_value(response, as_json=True) def list_kb_docs( self, knowledge_base_name: str, - no_remote_api: bool = None, ): ''' 对应api.py/knowledge_base/list_files接口 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - - if no_remote_api: - from server.knowledge_base.kb_doc_api import list_files - response = list_files(knowledge_base_name) - return response.data - else: - response = self.get( - "/knowledge_base/list_files", - params={"knowledge_base_name": knowledge_base_name} - ) - data = self._check_httpx_json_response(response) - return data.get("data", []) + response = self.get( + "/knowledge_base/list_files", + params={"knowledge_base_name": knowledge_base_name} + ) + return self._get_response_value(response, + as_json=True, + value_func=lambda r: r.get("data", [])) def search_kb_docs( self, @@ -597,14 +472,10 @@ class ApiRequest: knowledge_base_name: str, top_k: int = VECTOR_SEARCH_TOP_K, score_threshold: int = SCORE_THRESHOLD, - no_remote_api: bool = None, ) -> List: ''' 对应api.py/knowledge_base/search_docs接口 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - data = { "query": query, "knowledge_base_name": knowledge_base_name, @@ -612,16 +483,11 @@ class ApiRequest: "score_threshold": score_threshold, } - if no_remote_api: - from server.knowledge_base.kb_doc_api import search_docs - return search_docs(**data) - else: - response = self.post( - "/knowledge_base/search_docs", - json=data, - ) - data = self._check_httpx_json_response(response) - return data + response = self.post( + "/knowledge_base/search_docs", + json=data, + ) + return self._get_response_value(response, as_json=True) def upload_kb_docs( self, @@ -634,14 +500,10 @@ class ApiRequest: zh_title_enhance=ZH_TITLE_ENHANCE, docs: Dict = {}, not_refresh_vs_cache: bool = False, - no_remote_api: bool = None, ): ''' 对应api.py/knowledge_base/upload_docs接口 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - def convert_file(file, filename=None): if isinstance(file, bytes): # raw bytes file = BytesIO(file) @@ -664,29 +526,14 @@ class ApiRequest: "not_refresh_vs_cache": not_refresh_vs_cache, } - if no_remote_api: - from server.knowledge_base.kb_doc_api import upload_docs - from fastapi import UploadFile - from tempfile import SpooledTemporaryFile - - upload_files = [] - for filename, file in files: - temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) - temp_file.write(file.read()) - temp_file.seek(0) - upload_files.append(UploadFile(file=temp_file, filename=filename)) - - response = upload_docs(upload_files, **data) - return response.dict() - else: - if isinstance(data["docs"], dict): - data["docs"] = json.dumps(data["docs"], ensure_ascii=False) - response = self.post( - "/knowledge_base/upload_docs", - data=data, - files=[("files", (filename, file)) for filename, file in files], - ) - return self._check_httpx_json_response(response) + if isinstance(data["docs"], dict): + data["docs"] = json.dumps(data["docs"], ensure_ascii=False) + response = self.post( + "/knowledge_base/upload_docs", + data=data, + files=[("files", (filename, file)) for filename, file in files], + ) + return self._get_response_value(response, as_json=True) def delete_kb_docs( self, @@ -694,14 +541,10 @@ class ApiRequest: file_names: List[str], delete_content: bool = False, not_refresh_vs_cache: bool = False, - no_remote_api: bool = None, ): ''' 对应api.py/knowledge_base/delete_docs接口 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - data = { "knowledge_base_name": knowledge_base_name, "file_names": file_names, @@ -709,16 +552,11 @@ class ApiRequest: "not_refresh_vs_cache": not_refresh_vs_cache, } - if no_remote_api: - from server.knowledge_base.kb_doc_api import delete_docs - response = delete_docs(**data) - return response.dict() - else: - response = self.post( - "/knowledge_base/delete_docs", - json=data, - ) - return self._check_httpx_json_response(response) + response = self.post( + "/knowledge_base/delete_docs", + json=data, + ) + return self._get_response_value(response, as_json=True) def update_kb_docs( self, @@ -730,14 +568,10 @@ class ApiRequest: zh_title_enhance=ZH_TITLE_ENHANCE, docs: Dict = {}, not_refresh_vs_cache: bool = False, - no_remote_api: bool = None, ): ''' 对应api.py/knowledge_base/update_docs接口 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - data = { "knowledge_base_name": knowledge_base_name, "file_names": file_names, @@ -748,18 +582,15 @@ class ApiRequest: "docs": docs, "not_refresh_vs_cache": not_refresh_vs_cache, } - if no_remote_api: - from server.knowledge_base.kb_doc_api import update_docs - response = update_docs(**data) - return response.dict() - else: - if isinstance(data["docs"], dict): - data["docs"] = json.dumps(data["docs"], ensure_ascii=False) - response = self.post( - "/knowledge_base/update_docs", - json=data, - ) - return self._check_httpx_json_response(response) + + if isinstance(data["docs"], dict): + data["docs"] = json.dumps(data["docs"], ensure_ascii=False) + + response = self.post( + "/knowledge_base/update_docs", + json=data, + ) + return self._get_response_value(response, as_json=True) def recreate_vector_store( self, @@ -770,14 +601,10 @@ class ApiRequest: chunk_size=CHUNK_SIZE, chunk_overlap=OVERLAP_SIZE, zh_title_enhance=ZH_TITLE_ENHANCE, - no_remote_api: bool = None, ): ''' 对应api.py/knowledge_base/recreate_vector_store接口 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - data = { "knowledge_base_name": knowledge_base_name, "allow_empty_kb": allow_empty_kb, @@ -788,141 +615,176 @@ class ApiRequest: "zh_title_enhance": zh_title_enhance, } - if no_remote_api: - from server.knowledge_base.kb_doc_api import recreate_vector_store - response = recreate_vector_store(**data) - return self._fastapi_stream2generator(response, as_json=True) - else: - response = self.post( - "/knowledge_base/recreate_vector_store", - json=data, - stream=True, - timeout=None, - ) - return self._httpx_stream2generator(response, as_json=True) + response = self.post( + "/knowledge_base/recreate_vector_store", + json=data, + stream=True, + timeout=None, + ) + return self._httpx_stream2generator(response, as_json=True) # LLM模型相关操作 def list_running_models( self, controller_address: str = None, - no_remote_api: bool = None, ): ''' 获取Fastchat中正运行的模型列表 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - data = { "controller_address": controller_address, } - if no_remote_api: - from server.llm_api import list_running_models - return list_running_models(**data).data - else: - r = self.post( - "/llm_model/list_running_models", - json=data, - ) - return r.json().get("data", []) - def list_config_models(self, no_remote_api: bool = None) -> Dict[str, List[str]]: - ''' - 获取configs中配置的模型列表,返回形式为{"type": [model_name1, model_name2, ...], ...}。 - 如果no_remote_api=True, 从运行ApiRequest的机器上获取;否则从运行api.py的机器上获取。 - ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api + response = self.post( + "/llm_model/list_running_models", + json=data, + ) + return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", [])) - if no_remote_api: - from server.llm_api import list_config_models - return list_config_models().data - else: - r = self.post( - "/llm_model/list_config_models", - ) - return r.json().get("data", {}) + def list_config_models(self) -> Dict[str, List[str]]: + ''' + 获取服务器configs中配置的模型列表,返回形式为{"type": [model_name1, model_name2, ...], ...}。 + ''' + response = self.post( + "/llm_model/list_config_models", + ) + return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {})) + + def get_model_config( + self, + model_name: str, + ) -> Dict: + ''' + 获取服务器上模型配置 + ''' + data={ + "model_name": model_name, + } + response = self.post( + "/llm_model/get_model_config", + ) + return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {})) + + def list_search_engines(self) -> List[str]: + ''' + 获取服务器支持的搜索引擎 + ''' + response = self.post( + "/server/list_search_engines", + ) + return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", {})) def stop_llm_model( self, model_name: str, controller_address: str = None, - no_remote_api: bool = None, ): ''' 停止某个LLM模型。 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - data = { "model_name": model_name, "controller_address": controller_address, } - if no_remote_api: - from server.llm_api import stop_llm_model - return stop_llm_model(**data).dict() - else: - r = self.post( - "/llm_model/stop", - json=data, - ) - return r.json() + response = self.post( + "/llm_model/stop", + json=data, + ) + return self._get_response_value(response, as_json=True) def change_llm_model( self, model_name: str, new_model_name: str, controller_address: str = None, - no_remote_api: bool = None, ): ''' 向fastchat controller请求切换LLM模型。 ''' - if no_remote_api is None: - no_remote_api = self.no_remote_api - if not model_name or not new_model_name: - return - - running_models = self.list_running_models() - if new_model_name == model_name or new_model_name in running_models: - return { - "code": 200, - "msg": "无需切换" - } - - if model_name not in running_models: return { "code": 500, - "msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}" + "msg": f"未指定模型名称" } - config_models = self.list_config_models() - if new_model_name not in config_models.get("local", []): - return { - "code": 500, - "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。" + def ret_sync(): + running_models = self.list_running_models() + if new_model_name == model_name or new_model_name in running_models: + return { + "code": 200, + "msg": "无需切换" + } + + if model_name not in running_models: + return { + "code": 500, + "msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}" + } + + config_models = self.list_config_models() + if new_model_name not in config_models.get("local", {}): + return { + "code": 500, + "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。" + } + + data = { + "model_name": model_name, + "new_model_name": new_model_name, + "controller_address": controller_address, } - data = { - "model_name": model_name, - "new_model_name": new_model_name, - "controller_address": controller_address, - } - - if no_remote_api: - from server.llm_api import change_llm_model - return change_llm_model(**data).dict() - else: - r = self.post( + response = self.post( "/llm_model/change", json=data, - timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model ) - return r.json() + return self._get_response_value(response, as_json=True) + + async def ret_async(): + running_models = await self.list_running_models() + if new_model_name == model_name or new_model_name in running_models: + return { + "code": 200, + "msg": "无需切换" + } + + if model_name not in running_models: + return { + "code": 500, + "msg": f"指定的模型'{model_name}'没有运行。当前运行模型:{running_models}" + } + + config_models = await self.list_config_models() + if new_model_name not in config_models.get("local", {}): + return { + "code": 500, + "msg": f"要切换的模型'{new_model_name}'在configs中没有配置。" + } + + data = { + "model_name": model_name, + "new_model_name": new_model_name, + "controller_address": controller_address, + } + + response = self.post( + "/llm_model/change", + json=data, + ) + return self._get_response_value(response, as_json=True) + + if self._use_async: + return ret_async() + else: + return ret_sync() + + +class AsyncApiRequest(ApiRequest): + def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT): + super().__init__(base_url, timeout) + self._use_async = True def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: @@ -950,7 +812,8 @@ def check_success_msg(data: Union[str, dict, list], key: str = "msg") -> str: if __name__ == "__main__": - api = ApiRequest(no_remote_api=True) + api = ApiRequest() + aapi = AsyncApiRequest() # print(api.chat_fastchat( # messages=[{"role": "user", "content": "hello"}]