避免configs对torch的依赖;webui自动从configs获取api地址(close #1319)

This commit is contained in:
liunux4odoo 2023-08-31 16:08:16 +08:00
parent 215bc25f5c
commit 4e73e561fd
2 changed files with 19 additions and 6 deletions

View File

@ -1,6 +1,5 @@
import os
import logging
import torch
# 日志格式
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
logger = logging.getLogger()
@ -8,6 +7,19 @@ logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
# 分布式部署时不运行LLM的机器上可以不装torch
def default_device():
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
except:
pass
return "cpu"
# 在以下字典中修改属性值以指定本地embedding模型存储位置
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
# 此处请写绝对路径
@ -33,7 +45,7 @@ embedding_model_dict = {
EMBEDDING_MODEL = "m3e-base"
# Embedding 模型运行设备
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
EMBEDDING_DEVICE = default_device()
llm_model_dict = {
@ -76,7 +88,6 @@ llm_model_dict = {
},
}
# LLM 名称
LLM_MODEL = "chatglm2-6b"
@ -84,7 +95,7 @@ LLM_MODEL = "chatglm2-6b"
HISTORY_LEN = 3
# LLM 运行设备
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
LLM_DEVICE = default_device()
# 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")

View File

@ -10,8 +10,10 @@ from streamlit_option_menu import option_menu
from webui_pages import *
import os
from configs import VERSION
from server.utils import api_address
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
api = ApiRequest(base_url=api_address())
if __name__ == "__main__":
st.set_page_config(