271 lines
9.0 KiB
Python
271 lines
9.0 KiB
Python
import json
|
||
import os
|
||
|
||
import pandas as pd
|
||
import streamlit as st
|
||
from xinference.client import Client
|
||
import xinference.model.llm as xf_llm
|
||
from xinference.model import embedding as xf_embedding
|
||
from xinference.model import image as xf_image
|
||
from xinference.model import rerank as xf_rerank
|
||
from xinference.model import audio as xf_audio
|
||
from xinference.constants import XINFERENCE_CACHE_DIR
|
||
|
||
|
||
model_types = ["LLM", "embedding", "image", "rerank", "audio"]
|
||
model_name_suffix = "-custom"
|
||
cache_methods = {
|
||
"LLM": xf_llm.llm_family.cache,
|
||
"embedding": xf_embedding.core.cache,
|
||
"image": xf_image.core.cache,
|
||
"rerank": xf_rerank.core.cache,
|
||
"audio": xf_audio.core.cache,
|
||
}
|
||
|
||
|
||
@st.cache_resource
|
||
def get_client(url: str):
|
||
return Client(url)
|
||
|
||
|
||
def get_cache_dir(
|
||
model_type: str,
|
||
model_name: str,
|
||
model_format: str = "",
|
||
model_size: str = "",
|
||
):
|
||
if model_type == "LLM":
|
||
dir_name = f"{model_name}-{model_format}-{model_size}b"
|
||
else:
|
||
dir_name = f"{model_name}"
|
||
return os.path.join(XINFERENCE_CACHE_DIR, dir_name)
|
||
|
||
|
||
def get_meta_path(
|
||
model_type: str,
|
||
cache_dir: str,
|
||
model_format: str,
|
||
model_hub: str = "huggingface",
|
||
model_quant: str = "none",
|
||
):
|
||
if model_type == "LLM":
|
||
return xf_llm.llm_family._get_meta_path(
|
||
cache_dir=cache_dir,
|
||
model_format=model_format,
|
||
model_hub=model_hub,
|
||
quantization=model_quant,
|
||
)
|
||
else:
|
||
return os.path.join(cache_dir, "__valid_download")
|
||
|
||
|
||
def list_running_models():
|
||
models = client.list_models()
|
||
columns = [
|
||
"UID",
|
||
"type",
|
||
"name",
|
||
"ability",
|
||
"size",
|
||
"quant",
|
||
"max_tokens",
|
||
]
|
||
data = []
|
||
for k, v in models.items():
|
||
item = dict(
|
||
UID=k,
|
||
type=v["model_type"],
|
||
name=v["model_name"],
|
||
)
|
||
if v["model_type"] == "LLM":
|
||
item.update(
|
||
ability=v["model_ability"],
|
||
size=str(v["model_size_in_billions"]) + "B",
|
||
quant=v["quantization"],
|
||
max_tokens=v["context_length"],
|
||
)
|
||
elif v["model_type"] == "embedding":
|
||
item.update(
|
||
max_tokens=v["max_tokens"],
|
||
)
|
||
data.append(item)
|
||
df = pd.DataFrame(data, columns=columns)
|
||
df.index += 1
|
||
return df
|
||
|
||
|
||
def get_model_registrations():
|
||
data = {}
|
||
for model_type in model_types:
|
||
data[model_type] = {}
|
||
for m in client.list_model_registrations(model_type):
|
||
data[model_type][m["model_name"]] = {"is_builtin": m["is_builtin"]}
|
||
reg = client.get_model_registration(model_type, m["model_name"])
|
||
data[model_type][m["model_name"]]["reg"] = reg
|
||
return data
|
||
|
||
|
||
with st.sidebar:
|
||
st.subheader("请先执行 xinference 或 xinference-local 命令启动 XF 服务。然后将 XF 服务地址配置在下方。")
|
||
xf_url = st.text_input("Xinference endpoint", "http://127.0.0.1:9997")
|
||
st.divider()
|
||
st.markdown(
|
||
"### 使用方法\n\n"
|
||
"- 场景1:我已经下载过模型,不想 XF 内置模型重复下载\n\n"
|
||
"- 操作:选择对应的模型后,填写好本地模型路径,点‘设置模型缓存’即可\n\n"
|
||
"- 场景2:我想对 XF 内置模型做一些修改,又不想从头写模型注册文件\n\n"
|
||
"- 操作:选择对应的模型后,填写好本地模型路径,点‘注册为自定义模型’即可\n\n"
|
||
"- 场景3:我不小心设置了错误的模型路径\n\n"
|
||
"- 操作:直接‘删除模型缓存’或更换正确的路径后‘设置模型缓存’即可\n\n"
|
||
)
|
||
client = get_client(xf_url)
|
||
|
||
|
||
st.subheader("当前运行的模型:")
|
||
st.dataframe(list_running_models())
|
||
|
||
st.subheader("配置模型路径:")
|
||
regs = get_model_registrations()
|
||
cols = st.columns([3, 4, 3, 2, 2])
|
||
|
||
model_type = cols[0].selectbox("模型类别:", model_types)
|
||
|
||
model_names = list(regs[model_type].keys())
|
||
model_name = cols[1].selectbox("模型名称:", model_names)
|
||
|
||
cur_reg = regs[model_type][model_name]["reg"]
|
||
model_format = None
|
||
model_quant = None
|
||
|
||
if model_type == "LLM":
|
||
cur_family = xf_llm.LLMFamilyV1.parse_obj(cur_reg)
|
||
cur_spec = None
|
||
model_formats = []
|
||
for spec in cur_reg["model_specs"]:
|
||
if spec["model_format"] not in model_formats:
|
||
model_formats.append(spec["model_format"])
|
||
index = 0
|
||
if "pytorch" in model_formats:
|
||
index = model_formats.index("pytorch")
|
||
model_format = cols[2].selectbox("模型格式:", model_formats, index)
|
||
|
||
model_sizes = []
|
||
for spec in cur_reg["model_specs"]:
|
||
if (spec["model_format"] == model_format
|
||
and spec["model_size_in_billions"] not in model_sizes):
|
||
model_sizes.append(spec["model_size_in_billions"])
|
||
model_size = cols[3].selectbox("模型大小:", model_sizes, format_func=lambda x: f"{x}B")
|
||
|
||
model_quants = []
|
||
for spec in cur_reg["model_specs"]:
|
||
if (spec["model_format"] == model_format
|
||
and spec["model_size_in_billions"] == model_size):
|
||
model_quants = spec["quantizations"]
|
||
model_quant = cols[4].selectbox("模型量化:", model_quants)
|
||
if model_quant == "none":
|
||
model_quant = None
|
||
|
||
for i, spec in enumerate(cur_reg["model_specs"]):
|
||
if (spec["model_format"] == model_format
|
||
and spec["model_size_in_billions"] == model_size):
|
||
cur_spec = cur_family.model_specs[i]
|
||
break
|
||
cache_dir = get_cache_dir(model_type, model_name, model_format, model_size)
|
||
elif model_type == "embedding":
|
||
cur_spec = xf_embedding.core.EmbeddingModelSpec.parse_obj(cur_reg)
|
||
cache_dir = get_cache_dir(model_type, model_name)
|
||
elif model_type == "image":
|
||
cur_spec = xf_image.core.ImageModelFamilyV1.parse_obj(cur_reg)
|
||
cache_dir = get_cache_dir(model_type, model_name)
|
||
elif model_type == "rerank":
|
||
cur_spec = xf_rerank.core.RerankModelSpec.parse_obj(cur_reg)
|
||
cache_dir = get_cache_dir(model_type, model_name)
|
||
elif model_type == "audio":
|
||
cur_spec = xf_audio.core.AudioModelFamilyV1.parse_obj(cur_reg)
|
||
cache_dir = get_cache_dir(model_type, model_name)
|
||
|
||
meta_file = get_meta_path(
|
||
model_type=model_type,
|
||
model_format=model_format,
|
||
cache_dir=cache_dir,
|
||
model_quant=model_quant)
|
||
|
||
if os.path.isdir(cache_dir):
|
||
try:
|
||
with open(meta_file, encoding="utf-8") as fp:
|
||
meta = json.load(fp)
|
||
revision = meta.get("revision", meta.get("model_revision"))
|
||
except:
|
||
revision = None
|
||
|
||
if revision is None:
|
||
revision = "None"
|
||
|
||
if cur_spec.model_revision and cur_spec.model_revision != revision:
|
||
revision += " (与 XF 内置版本号不一致)"
|
||
else:
|
||
revision += " (符合要求)"
|
||
|
||
text = (f"模型已缓存。\n\n"
|
||
f"缓存路径:{cache_dir}\n\n"
|
||
f"原始路径:{os.readlink(cache_dir)[4:]}\n\n"
|
||
f"版本号 :{revision}"
|
||
)
|
||
else:
|
||
text = "模型尚未缓存"
|
||
|
||
st.divider()
|
||
st.markdown(text)
|
||
st.divider()
|
||
|
||
local_path = st.text_input("本地模型绝对路径:")
|
||
cols = st.columns(3)
|
||
|
||
if cols[0].button("设置模型缓存"):
|
||
if os.path.isabs(local_path) and os.path.isdir(local_path):
|
||
cur_spec.__dict__["model_uri"] = local_path # embedding spec has no attribute model_uri
|
||
if os.path.isdir(cache_dir):
|
||
os.rmdir(cache_dir)
|
||
if model_type == "LLM":
|
||
cache_methods[model_type](cur_family, cur_spec, model_quant)
|
||
xf_llm.llm_family._generate_meta_file(meta_file, cur_family, cur_spec, model_quant)
|
||
else:
|
||
cache_methods[model_type](cur_spec)
|
||
if cur_spec.model_revision:
|
||
for hub in ["huggingface", "modelscope"]:
|
||
meta_file = get_meta_path(
|
||
model_type=model_type,
|
||
model_format=model_format,
|
||
model_hub=hub,
|
||
cache_dir=cache_dir,
|
||
model_quant=model_quant)
|
||
with open(meta_file, "w", encoding="utf-8") as fp:
|
||
json.dump({"revision": cur_spec.model_revision}, fp)
|
||
st.rerun()
|
||
else:
|
||
st.error("必须输入存在的绝对路径")
|
||
|
||
if cols[1].button("删除模型缓存"):
|
||
if os.path.isdir(cache_dir):
|
||
os.rmdir(cache_dir)
|
||
|
||
if cols[2].button("注册为自定义模型"):
|
||
if os.path.isabs(local_path) and os.path.isdir(local_path):
|
||
cur_spec.model_uri = local_path
|
||
cur_spec.model_revision = None
|
||
if model_type == "LLM":
|
||
cur_family.model_name = f"{cur_family.model_name}{model_name_suffix}"
|
||
cur_family.model_family = "other"
|
||
model_definition = cur_family.model_dump_json(indent=2, ensure_ascii=False)
|
||
else:
|
||
cur_spec.model_name = f"{cur_spec.model_name}{model_name_suffix}"
|
||
model_definition = cur_spec.model_dump_json(indent=2, ensure_ascii=False)
|
||
client.register_model(
|
||
model_type=model_type,
|
||
model=model_definition,
|
||
persist=True,
|
||
)
|
||
st.rerun()
|
||
else:
|
||
st.error("必须输入存在的绝对路径")
|