diff --git a/configs/model_config.py.example b/configs/model_config.py.example index f46dad6..b9dd1de 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -1,6 +1,5 @@ import os import logging -import torch # 日志格式 LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s" logger = logging.getLogger() @@ -8,6 +7,19 @@ logger.setLevel(logging.INFO) logging.basicConfig(format=LOG_FORMAT) +# 分布式部署时,不运行LLM的机器上可以不装torch +def default_device(): + try: + import torch + if torch.cuda.is_available(): + return "cuda" + if torch.backends.mps.is_available(): + return "mps" + except: + pass + return "cpu" + + # 在以下字典中修改属性值,以指定本地embedding模型存储位置 # 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese" # 此处请写绝对路径 @@ -33,7 +45,7 @@ embedding_model_dict = { EMBEDDING_MODEL = "m3e-base" # Embedding 模型运行设备 -EMBEDDING_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +EMBEDDING_DEVICE = default_device() llm_model_dict = { @@ -76,7 +88,6 @@ llm_model_dict = { }, } - # LLM 名称 LLM_MODEL = "chatglm2-6b" @@ -84,7 +95,7 @@ LLM_MODEL = "chatglm2-6b" HISTORY_LEN = 3 # LLM 运行设备 -LLM_DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" +LLM_DEVICE = default_device() # 日志存储路径 LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") @@ -166,4 +177,4 @@ BING_SUBSCRIPTION_KEY = "" # 是否开启中文标题加强,以及标题增强的相关配置 # 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记; # 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 -ZH_TITLE_ENHANCE = False \ No newline at end of file +ZH_TITLE_ENHANCE = False diff --git a/webui.py b/webui.py index 58fc0e3..0cda9eb 100644 --- a/webui.py +++ b/webui.py @@ -10,8 +10,10 @@ from streamlit_option_menu import option_menu from webui_pages import * import os from configs import VERSION +from server.utils import api_address -api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False) + +api = ApiRequest(base_url=api_address()) if __name__ == "__main__": st.set_page_config(