避免configs对torch的依赖;webui自动从configs获取api地址(close #1319)
This commit is contained in:
parent
215bc25f5c
commit
4e73e561fd
|
|
@ -1,6 +1,5 @@
|
|||
import os
|
||||
import logging
|
||||
import torch
|
||||
# 日志格式
|
||||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||||
logger = logging.getLogger()
|
||||
|
|
@ -8,6 +7,19 @@ logger.setLevel(logging.INFO)
|
|||
logging.basicConfig(format=LOG_FORMAT)
|
||||
|
||||
|
||||
# 分布式部署时,不运行LLM的机器上可以不装torch
|
||||
def default_device():
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
if torch.backends.mps.is_available():
|
||||
return "mps"
|
||||
except:
|
||||
pass
|
||||
return "cpu"
|
||||
|
||||
|
||||
# 在以下字典中修改属性值,以指定本地embedding模型存储位置
|
||||
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
|
||||
# 此处请写绝对路径
|
||||
|
|
@ -33,7 +45,7 @@ embedding_model_dict = {
|
|||
EMBEDDING_MODEL = "m3e-base"
|
||||
|
||||
# Embedding 模型运行设备
|
||||
EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
EMBEDDING_DEVICE = default_device()
|
||||
|
||||
|
||||
llm_model_dict = {
|
||||
|
|
@ -76,7 +88,6 @@ llm_model_dict = {
|
|||
},
|
||||
}
|
||||
|
||||
|
||||
# LLM 名称
|
||||
LLM_MODEL = "chatglm2-6b"
|
||||
|
||||
|
|
@ -84,7 +95,7 @@ LLM_MODEL = "chatglm2-6b"
|
|||
HISTORY_LEN = 3
|
||||
|
||||
# LLM 运行设备
|
||||
LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
|
||||
LLM_DEVICE = default_device()
|
||||
|
||||
# 日志存储路径
|
||||
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||
|
|
@ -166,4 +177,4 @@ BING_SUBSCRIPTION_KEY = ""
|
|||
# 是否开启中文标题加强,以及标题增强的相关配置
|
||||
# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记;
|
||||
# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。
|
||||
ZH_TITLE_ENHANCE = False
|
||||
ZH_TITLE_ENHANCE = False
|
||||
|
|
|
|||
4
webui.py
4
webui.py
|
|
@ -10,8 +10,10 @@ from streamlit_option_menu import option_menu
|
|||
from webui_pages import *
|
||||
import os
|
||||
from configs import VERSION
|
||||
from server.utils import api_address
|
||||
|
||||
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False)
|
||||
|
||||
api = ApiRequest(base_url=api_address())
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.set_page_config(
|
||||
|
|
|
|||
Loading…
Reference in New Issue