parent
955b0bc211
commit
3dde02be28
|
|
@ -40,6 +40,7 @@ MODEL_PATH = {
|
|||
"chatglm2-6b": "THUDM/chatglm2-6b",
|
||||
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
|
||||
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
||||
"baichuan-7b": "baichuan-inc/Baichuan-7B",
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
|||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||
update_docs, download_doc, recreate_vector_store,
|
||||
search_docs, DocumentWithScore)
|
||||
from server.llm_api import list_llm_models, change_llm_model, stop_llm_model
|
||||
from server.llm_api import list_running_models,list_config_models, change_llm_model, stop_llm_model
|
||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
||||
from typing import List
|
||||
|
||||
|
|
@ -125,10 +125,15 @@ def create_app():
|
|||
)(recreate_vector_store)
|
||||
|
||||
# LLM模型相关接口
|
||||
app.post("/llm_model/list_models",
|
||||
app.post("/llm_model/list_running_models",
|
||||
tags=["LLM Model Management"],
|
||||
summary="列出当前已加载的模型",
|
||||
)(list_llm_models)
|
||||
)(list_running_models)
|
||||
|
||||
app.post("/llm_model/list_config_models",
|
||||
tags=["LLM Model Management"],
|
||||
summary="列出configs已配置的模型",
|
||||
)(list_config_models)
|
||||
|
||||
app.post("/llm_model/stop",
|
||||
tags=["LLM Model Management"],
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from fastapi import Body
|
||||
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
|
||||
from server.utils import BaseResponse, fschat_controller_address
|
||||
from server.utils import BaseResponse, fschat_controller_address, list_llm_models
|
||||
import httpx
|
||||
|
||||
|
||||
def list_llm_models(
|
||||
def list_running_models(
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]),
|
||||
placeholder: str = Body(None, description="该参数未使用,占位用"),
|
||||
) -> BaseResponse:
|
||||
|
|
@ -24,6 +24,13 @@ def list_llm_models(
|
|||
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
|
||||
|
||||
|
||||
def list_config_models() -> BaseResponse:
|
||||
'''
|
||||
从本地获取configs中配置的模型列表
|
||||
'''
|
||||
return BaseResponse(data=list_llm_models())
|
||||
|
||||
|
||||
def stop_llm_model(
|
||||
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
|
||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||
|
|
|
|||
|
|
@ -72,7 +72,10 @@ def request_qianfan_api(
|
|||
version_url = config.get("version_url")
|
||||
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
|
||||
if not access_token:
|
||||
raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?")
|
||||
yield {
|
||||
"error_code": 403,
|
||||
"error_msg": f"failed to get access token. have you set the correct api_key and secret key?",
|
||||
}
|
||||
|
||||
url = BASE_URL.format(
|
||||
model_version=version_url or MODEL_VERSIONS[version],
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@ from fastapi import FastAPI
|
|||
from pathlib import Path
|
||||
import asyncio
|
||||
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||
MODEL_PATH, MODEL_ROOT_PATH,
|
||||
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
|
||||
logger, log_verbose,
|
||||
FSCHAT_MODEL_WORKERS)
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Literal, Optional, Callable, Generator, Dict, Any
|
||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Tuple
|
||||
|
||||
|
||||
thread_pool = ThreadPoolExecutor(os.cpu_count())
|
||||
|
|
@ -201,10 +201,25 @@ def MakeFastAPIOffline(
|
|||
|
||||
# 从model_config中获取模型信息
|
||||
def list_embed_models() -> List[str]:
|
||||
'''
|
||||
get names of configured embedding models
|
||||
'''
|
||||
return list(MODEL_PATH["embed_model"])
|
||||
|
||||
def list_llm_models() -> List[str]:
|
||||
return list(MODEL_PATH["llm_model"])
|
||||
def list_llm_models() -> Dict[str, List[str]]:
|
||||
'''
|
||||
get names of configured llm models with different types.
|
||||
return [(model_name, config_type), ...]
|
||||
'''
|
||||
workers = list(FSCHAT_MODEL_WORKERS)
|
||||
if "default" in workers:
|
||||
workers.remove("default")
|
||||
return {
|
||||
"local": list(MODEL_PATH["llm_model"]),
|
||||
"online": list(ONLINE_LLM_MODEL),
|
||||
"worker": workers,
|
||||
}
|
||||
|
||||
|
||||
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
||||
if type in MODEL_PATH:
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ def dialogue_page(api: ApiRequest):
|
|||
config = get_model_worker_config(llm_model)
|
||||
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
st.session_state["cur_llm_model"] = st.session_state.llm_model
|
||||
|
||||
def llm_model_format_func(x):
|
||||
if x in running_models:
|
||||
|
|
@ -75,25 +76,32 @@ def dialogue_page(api: ApiRequest):
|
|||
return x
|
||||
|
||||
running_models = api.list_running_models()
|
||||
available_models = []
|
||||
config_models = api.list_config_models()
|
||||
for x in running_models:
|
||||
if x in config_models:
|
||||
config_models.remove(x)
|
||||
llm_models = running_models + config_models
|
||||
cur_model = st.session_state.get("cur_llm_model", LLM_MODEL)
|
||||
index = llm_models.index(cur_model)
|
||||
for models in config_models.values():
|
||||
for m in models:
|
||||
if m not in running_models:
|
||||
available_models.append(m)
|
||||
llm_models = running_models + available_models
|
||||
index = llm_models.index(st.session_state.get("cur_llm_model", LLM_MODEL))
|
||||
llm_model = st.selectbox("选择LLM模型:",
|
||||
llm_models,
|
||||
index,
|
||||
format_func=llm_model_format_func,
|
||||
on_change=on_llm_change,
|
||||
# key="llm_model",
|
||||
key="llm_model",
|
||||
)
|
||||
if (st.session_state.get("prev_llm_model") != llm_model
|
||||
and not get_model_worker_config(llm_model).get("online_api")):
|
||||
and not get_model_worker_config(llm_model).get("online_api")
|
||||
and llm_model not in running_models):
|
||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
|
||||
st.session_state["cur_llm_model"] = llm_model
|
||||
prev_model = st.session_state.get("prev_llm_model")
|
||||
r = api.change_llm_model(prev_model, llm_model)
|
||||
if msg := check_error_msg(r):
|
||||
st.error(msg)
|
||||
elif msg := check_success_msg(r):
|
||||
st.success(msg)
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
|
||||
temperature = st.slider("Temperature:", 0.0, 1.0, TEMPERATURE, 0.05)
|
||||
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
|
||||
|
|
|
|||
|
|
@ -766,23 +766,31 @@ class ApiRequest:
|
|||
"controller_address": controller_address,
|
||||
}
|
||||
if no_remote_api:
|
||||
from server.llm_api import list_llm_models
|
||||
return list_llm_models(**data).data
|
||||
from server.llm_api import list_running_models
|
||||
return list_running_models(**data).data
|
||||
else:
|
||||
r = self.post(
|
||||
"/llm_model/list_models",
|
||||
"/llm_model/list_running_models",
|
||||
json=data,
|
||||
)
|
||||
return r.json().get("data", [])
|
||||
|
||||
def list_config_models(self):
|
||||
def list_config_models(self, no_remote_api: bool = None) -> Dict[str, List[str]]:
|
||||
'''
|
||||
获取configs中配置的模型列表
|
||||
获取configs中配置的模型列表,返回形式为{"type": [model_name1, model_name2, ...], ...}。
|
||||
如果no_remote_api=True, 从运行ApiRequest的机器上获取;否则从运行api.py的机器上获取。
|
||||
'''
|
||||
models = list(FSCHAT_MODEL_WORKERS.keys())
|
||||
if "default" in models:
|
||||
models.remove("default")
|
||||
return models
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
if no_remote_api:
|
||||
from server.llm_api import list_config_models
|
||||
return list_config_models().data
|
||||
else:
|
||||
r = self.post(
|
||||
"/llm_model/list_config_models",
|
||||
)
|
||||
return r.json().get("data", {})
|
||||
|
||||
def stop_llm_model(
|
||||
self,
|
||||
|
|
@ -828,13 +836,13 @@ class ApiRequest:
|
|||
if not model_name or not new_model_name:
|
||||
return
|
||||
|
||||
if new_model_name == model_name:
|
||||
running_models = self.list_running_models()
|
||||
if new_model_name == model_name or new_model_name in running_models:
|
||||
return {
|
||||
"code": 200,
|
||||
"msg": "什么都不用做"
|
||||
"msg": "无需切换"
|
||||
}
|
||||
|
||||
running_models = self.list_running_models()
|
||||
if model_name not in running_models:
|
||||
return {
|
||||
"code": 500,
|
||||
|
|
@ -842,7 +850,7 @@ class ApiRequest:
|
|||
}
|
||||
|
||||
config_models = self.list_config_models()
|
||||
if new_model_name not in config_models:
|
||||
if new_model_name not in config_models.get("local", []):
|
||||
return {
|
||||
"code": 500,
|
||||
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
|
||||
|
|
|
|||
Loading…
Reference in New Issue