parent
955b0bc211
commit
3dde02be28
|
|
@ -40,6 +40,7 @@ MODEL_PATH = {
|
||||||
"chatglm2-6b": "THUDM/chatglm2-6b",
|
"chatglm2-6b": "THUDM/chatglm2-6b",
|
||||||
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
|
"chatglm2-6b-int4": "THUDM/chatglm2-6b-int4",
|
||||||
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
"chatglm2-6b-32k": "THUDM/chatglm2-6b-32k",
|
||||||
|
"baichuan-7b": "baichuan-inc/Baichuan-7B",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||||
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
||||||
update_docs, download_doc, recreate_vector_store,
|
update_docs, download_doc, recreate_vector_store,
|
||||||
search_docs, DocumentWithScore)
|
search_docs, DocumentWithScore)
|
||||||
from server.llm_api import list_llm_models, change_llm_model, stop_llm_model
|
from server.llm_api import list_running_models,list_config_models, change_llm_model, stop_llm_model
|
||||||
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
@ -125,10 +125,15 @@ def create_app():
|
||||||
)(recreate_vector_store)
|
)(recreate_vector_store)
|
||||||
|
|
||||||
# LLM模型相关接口
|
# LLM模型相关接口
|
||||||
app.post("/llm_model/list_models",
|
app.post("/llm_model/list_running_models",
|
||||||
tags=["LLM Model Management"],
|
tags=["LLM Model Management"],
|
||||||
summary="列出当前已加载的模型",
|
summary="列出当前已加载的模型",
|
||||||
)(list_llm_models)
|
)(list_running_models)
|
||||||
|
|
||||||
|
app.post("/llm_model/list_config_models",
|
||||||
|
tags=["LLM Model Management"],
|
||||||
|
summary="列出configs已配置的模型",
|
||||||
|
)(list_config_models)
|
||||||
|
|
||||||
app.post("/llm_model/stop",
|
app.post("/llm_model/stop",
|
||||||
tags=["LLM Model Management"],
|
tags=["LLM Model Management"],
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
|
from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT
|
||||||
from server.utils import BaseResponse, fschat_controller_address
|
from server.utils import BaseResponse, fschat_controller_address, list_llm_models
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
def list_llm_models(
|
def list_running_models(
|
||||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]),
|
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]),
|
||||||
placeholder: str = Body(None, description="该参数未使用,占位用"),
|
placeholder: str = Body(None, description="该参数未使用,占位用"),
|
||||||
) -> BaseResponse:
|
) -> BaseResponse:
|
||||||
|
|
@ -24,6 +24,13 @@ def list_llm_models(
|
||||||
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
|
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def list_config_models() -> BaseResponse:
|
||||||
|
'''
|
||||||
|
从本地获取configs中配置的模型列表
|
||||||
|
'''
|
||||||
|
return BaseResponse(data=list_llm_models())
|
||||||
|
|
||||||
|
|
||||||
def stop_llm_model(
|
def stop_llm_model(
|
||||||
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
|
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
|
||||||
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,10 @@ def request_qianfan_api(
|
||||||
version_url = config.get("version_url")
|
version_url = config.get("version_url")
|
||||||
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
|
access_token = get_baidu_access_token(config.get("api_key"), config.get("secret_key"))
|
||||||
if not access_token:
|
if not access_token:
|
||||||
raise RuntimeError(f"failed to get access token. have you set the correct api_key and secret key?")
|
yield {
|
||||||
|
"error_code": 403,
|
||||||
|
"error_msg": f"failed to get access token. have you set the correct api_key and secret key?",
|
||||||
|
}
|
||||||
|
|
||||||
url = BASE_URL.format(
|
url = BASE_URL.format(
|
||||||
model_version=version_url or MODEL_VERSIONS[version],
|
model_version=version_url or MODEL_VERSIONS[version],
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,12 @@ from fastapi import FastAPI
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
||||||
MODEL_PATH, MODEL_ROOT_PATH,
|
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL,
|
||||||
logger, log_verbose,
|
logger, log_verbose,
|
||||||
FSCHAT_MODEL_WORKERS)
|
FSCHAT_MODEL_WORKERS)
|
||||||
import os
|
import os
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import Literal, Optional, Callable, Generator, Dict, Any
|
from typing import Literal, Optional, Callable, Generator, Dict, Any, Tuple
|
||||||
|
|
||||||
|
|
||||||
thread_pool = ThreadPoolExecutor(os.cpu_count())
|
thread_pool = ThreadPoolExecutor(os.cpu_count())
|
||||||
|
|
@ -201,10 +201,25 @@ def MakeFastAPIOffline(
|
||||||
|
|
||||||
# 从model_config中获取模型信息
|
# 从model_config中获取模型信息
|
||||||
def list_embed_models() -> List[str]:
|
def list_embed_models() -> List[str]:
|
||||||
|
'''
|
||||||
|
get names of configured embedding models
|
||||||
|
'''
|
||||||
return list(MODEL_PATH["embed_model"])
|
return list(MODEL_PATH["embed_model"])
|
||||||
|
|
||||||
def list_llm_models() -> List[str]:
|
def list_llm_models() -> Dict[str, List[str]]:
|
||||||
return list(MODEL_PATH["llm_model"])
|
'''
|
||||||
|
get names of configured llm models with different types.
|
||||||
|
return [(model_name, config_type), ...]
|
||||||
|
'''
|
||||||
|
workers = list(FSCHAT_MODEL_WORKERS)
|
||||||
|
if "default" in workers:
|
||||||
|
workers.remove("default")
|
||||||
|
return {
|
||||||
|
"local": list(MODEL_PATH["llm_model"]),
|
||||||
|
"online": list(ONLINE_LLM_MODEL),
|
||||||
|
"worker": workers,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
||||||
if type in MODEL_PATH:
|
if type in MODEL_PATH:
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
config = get_model_worker_config(llm_model)
|
config = get_model_worker_config(llm_model)
|
||||||
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
||||||
st.session_state["prev_llm_model"] = llm_model
|
st.session_state["prev_llm_model"] = llm_model
|
||||||
|
st.session_state["cur_llm_model"] = st.session_state.llm_model
|
||||||
|
|
||||||
def llm_model_format_func(x):
|
def llm_model_format_func(x):
|
||||||
if x in running_models:
|
if x in running_models:
|
||||||
|
|
@ -75,25 +76,32 @@ def dialogue_page(api: ApiRequest):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
running_models = api.list_running_models()
|
running_models = api.list_running_models()
|
||||||
|
available_models = []
|
||||||
config_models = api.list_config_models()
|
config_models = api.list_config_models()
|
||||||
for x in running_models:
|
for models in config_models.values():
|
||||||
if x in config_models:
|
for m in models:
|
||||||
config_models.remove(x)
|
if m not in running_models:
|
||||||
llm_models = running_models + config_models
|
available_models.append(m)
|
||||||
cur_model = st.session_state.get("cur_llm_model", LLM_MODEL)
|
llm_models = running_models + available_models
|
||||||
index = llm_models.index(cur_model)
|
index = llm_models.index(st.session_state.get("cur_llm_model", LLM_MODEL))
|
||||||
llm_model = st.selectbox("选择LLM模型:",
|
llm_model = st.selectbox("选择LLM模型:",
|
||||||
llm_models,
|
llm_models,
|
||||||
index,
|
index,
|
||||||
format_func=llm_model_format_func,
|
format_func=llm_model_format_func,
|
||||||
on_change=on_llm_change,
|
on_change=on_llm_change,
|
||||||
# key="llm_model",
|
key="llm_model",
|
||||||
)
|
)
|
||||||
if (st.session_state.get("prev_llm_model") != llm_model
|
if (st.session_state.get("prev_llm_model") != llm_model
|
||||||
and not get_model_worker_config(llm_model).get("online_api")):
|
and not get_model_worker_config(llm_model).get("online_api")
|
||||||
|
and llm_model not in running_models):
|
||||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||||
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
|
prev_model = st.session_state.get("prev_llm_model")
|
||||||
st.session_state["cur_llm_model"] = llm_model
|
r = api.change_llm_model(prev_model, llm_model)
|
||||||
|
if msg := check_error_msg(r):
|
||||||
|
st.error(msg)
|
||||||
|
elif msg := check_success_msg(r):
|
||||||
|
st.success(msg)
|
||||||
|
st.session_state["prev_llm_model"] = llm_model
|
||||||
|
|
||||||
temperature = st.slider("Temperature:", 0.0, 1.0, TEMPERATURE, 0.05)
|
temperature = st.slider("Temperature:", 0.0, 1.0, TEMPERATURE, 0.05)
|
||||||
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
|
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
|
||||||
|
|
|
||||||
|
|
@ -766,23 +766,31 @@ class ApiRequest:
|
||||||
"controller_address": controller_address,
|
"controller_address": controller_address,
|
||||||
}
|
}
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.llm_api import list_llm_models
|
from server.llm_api import list_running_models
|
||||||
return list_llm_models(**data).data
|
return list_running_models(**data).data
|
||||||
else:
|
else:
|
||||||
r = self.post(
|
r = self.post(
|
||||||
"/llm_model/list_models",
|
"/llm_model/list_running_models",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
return r.json().get("data", [])
|
return r.json().get("data", [])
|
||||||
|
|
||||||
def list_config_models(self):
|
def list_config_models(self, no_remote_api: bool = None) -> Dict[str, List[str]]:
|
||||||
'''
|
'''
|
||||||
获取configs中配置的模型列表
|
获取configs中配置的模型列表,返回形式为{"type": [model_name1, model_name2, ...], ...}。
|
||||||
|
如果no_remote_api=True, 从运行ApiRequest的机器上获取;否则从运行api.py的机器上获取。
|
||||||
'''
|
'''
|
||||||
models = list(FSCHAT_MODEL_WORKERS.keys())
|
if no_remote_api is None:
|
||||||
if "default" in models:
|
no_remote_api = self.no_remote_api
|
||||||
models.remove("default")
|
|
||||||
return models
|
if no_remote_api:
|
||||||
|
from server.llm_api import list_config_models
|
||||||
|
return list_config_models().data
|
||||||
|
else:
|
||||||
|
r = self.post(
|
||||||
|
"/llm_model/list_config_models",
|
||||||
|
)
|
||||||
|
return r.json().get("data", {})
|
||||||
|
|
||||||
def stop_llm_model(
|
def stop_llm_model(
|
||||||
self,
|
self,
|
||||||
|
|
@ -828,13 +836,13 @@ class ApiRequest:
|
||||||
if not model_name or not new_model_name:
|
if not model_name or not new_model_name:
|
||||||
return
|
return
|
||||||
|
|
||||||
if new_model_name == model_name:
|
running_models = self.list_running_models()
|
||||||
|
if new_model_name == model_name or new_model_name in running_models:
|
||||||
return {
|
return {
|
||||||
"code": 200,
|
"code": 200,
|
||||||
"msg": "什么都不用做"
|
"msg": "无需切换"
|
||||||
}
|
}
|
||||||
|
|
||||||
running_models = self.list_running_models()
|
|
||||||
if model_name not in running_models:
|
if model_name not in running_models:
|
||||||
return {
|
return {
|
||||||
"code": 500,
|
"code": 500,
|
||||||
|
|
@ -842,7 +850,7 @@ class ApiRequest:
|
||||||
}
|
}
|
||||||
|
|
||||||
config_models = self.list_config_models()
|
config_models = self.list_config_models()
|
||||||
if new_model_name not in config_models:
|
if new_model_name not in config_models.get("local", []):
|
||||||
return {
|
return {
|
||||||
"code": 500,
|
"code": 500,
|
||||||
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
|
"msg": f"要切换的模型'{new_model_name}'在configs中没有配置。"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue