避免configs对torch的依赖;webui自动从configs获取api地址(close #1319)
This commit is contained in:
parent
215bc25f5c
commit
4e73e561fd
|
|
@ -1,6 +1,5 @@
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
import torch
|
|
||||||
# 日志格式
|
# 日志格式
|
||||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
|
|
@ -8,6 +7,19 @@ logger.setLevel(logging.INFO)
|
||||||
logging.basicConfig(format=LOG_FORMAT)
|
logging.basicConfig(format=LOG_FORMAT)
|
||||||
|
|
||||||
|
|
||||||
|
# 分布式部署时,不运行LLM的机器上可以不装torch
|
||||||
|
def default_device():
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return "mps"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
# 在以下字典中修改属性值,以指定本地embedding模型存储位置
|
# 在以下字典中修改属性值,以指定本地embedding模型存储位置
|
||||||
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
|
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
|
||||||
# 此处请写绝对路径
|
# 此处请写绝对路径
|
||||||
|
|
@ -33,7 +45,7 @@ embedding_model_dict = {
|
||||||
EMBEDDING_MODEL = "m3e-base"
|
EMBEDDING_MODEL = "m3e-base"
|
||||||
|
|
||||||
# Embedding 模型运行设备
|
# Embedding 模型运行设备
|
||||||
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
EMBEDDING_DEVICE = default_device()
|
||||||
|
|
||||||
|
|
||||||
llm_model_dict = {
|
llm_model_dict = {
|
||||||
|
|
@ -76,7 +88,6 @@ llm_model_dict = {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# LLM 名称
|
# LLM 名称
|
||||||
LLM_MODEL = "chatglm2-6b"
|
LLM_MODEL = "chatglm2-6b"
|
||||||
|
|
||||||
|
|
@ -84,7 +95,7 @@ LLM_MODEL = "chatglm2-6b"
|
||||||
HISTORY_LEN = 3
|
HISTORY_LEN = 3
|
||||||
|
|
||||||
# LLM 运行设备
|
# LLM 运行设备
|
||||||
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
LLM_DEVICE = default_device()
|
||||||
|
|
||||||
# 日志存储路径
|
# 日志存储路径
|
||||||
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||||
|
|
@ -166,4 +177,4 @@ BING_SUBSCRIPTION_KEY = ""
|
||||||
# 是否开启中文标题加强,以及标题增强的相关配置
|
# 是否开启中文标题加强,以及标题增强的相关配置
|
||||||
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
||||||
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
||||||
ZH_TITLE_ENHANCE = False
|
ZH_TITLE_ENHANCE = False
|
||||||
|
|
|
||||||
4
webui.py
4
webui.py
|
|
@ -10,8 +10,10 @@ from streamlit_option_menu import option_menu
|
||||||
from webui_pages import *
|
from webui_pages import *
|
||||||
import os
|
import os
|
||||||
from configs import VERSION
|
from configs import VERSION
|
||||||
|
from server.utils import api_address
|
||||||
|
|
||||||
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
|
|
||||||
|
api = ApiRequest(base_url=api_address())
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
st.set_page_config(
|
st.set_page_config(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue