diff --git a/.gitignore b/.gitignore index c4178a9..a7ef90f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ logs .idea/ __pycache__/ -knowledge_base/ +/knowledge_base/ configs/*.py .vscode/ .pytest_cache/ diff --git a/configs/model_config.py.example b/configs/model_config.py.example index b9dd1de..5466f7f 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -7,19 +7,6 @@ logger.setLevel(logging.INFO) logging.basicConfig(format=LOG_FORMAT) -# 分布式部署时,不运行LLM的机器上可以不装torch -def default_device(): - try: - import torch - if torch.cuda.is_available(): - return "cuda" - if torch.backends.mps.is_available(): - return "mps" - except: - pass - return "cpu" - - # 在以下字典中修改属性值,以指定本地embedding模型存储位置 # 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese" # 此处请写绝对路径 @@ -44,8 +31,8 @@ embedding_model_dict = { # 选用的 Embedding 名称 EMBEDDING_MODEL = "m3e-base" -# Embedding 模型运行设备 -EMBEDDING_DEVICE = default_device() +# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 +EMBEDDING_DEVICE = "auto" llm_model_dict = { @@ -94,8 +81,8 @@ LLM_MODEL = "chatglm2-6b" # 历史对话轮数 HISTORY_LEN = 3 -# LLM 运行设备 -LLM_DEVICE = default_device() +# LLM 运行设备。可选项同Embedding 运行设备。 +LLM_DEVICE = "auto" # 日志存储路径 LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index ca8d1ae..79b1518 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -18,11 +18,12 @@ from server.db.repository.knowledge_file_repository import ( ) from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - EMBEDDING_DEVICE, EMBEDDING_MODEL) + EMBEDDING_MODEL) from server.knowledge_base.utils import ( get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, list_kbs_from_folder, list_files_from_folder, ) +from server.utils import embedding_device from typing import List, Union, Dict @@ -45,7 +46,7 @@ class KBService(ABC): self.doc_path = get_doc_path(self.kb_name) self.do_init() - def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings: + def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings: return load_embeddings(self.embed_model, embed_device) def create_kb(self): diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index d2e8525..f17b2da 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -5,7 +5,6 @@ from configs.model_config import ( KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, - EMBEDDING_DEVICE, SCORE_THRESHOLD ) from server.knowledge_base.kb_service.base import KBService, SupportedVSType @@ -15,7 +14,7 @@ from langchain.vectorstores import FAISS from langchain.embeddings.base import Embeddings from typing import List from langchain.docstore.document import Document -from server.utils import torch_gc +from server.utils import torch_gc, embedding_device _VECTOR_STORE_TICKS = {} @@ -25,7 +24,7 @@ _VECTOR_STORE_TICKS = {} def load_faiss_vector_store( knowledge_base_name: str, embed_model: str = EMBEDDING_MODEL, - embed_device: str = EMBEDDING_DEVICE, + embed_device: str = embedding_device(), embeddings: Embeddings = None, tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. ) -> FAISS: diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 4b52625..3e3dd52 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -6,16 +6,16 @@ from langchain.vectorstores import PGVector from langchain.vectorstores.pgvector import DistanceStrategy from sqlalchemy import text -from configs.model_config import EMBEDDING_DEVICE, kbs_config +from configs.model_config import kbs_config from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ score_threshold_process from server.knowledge_base.utils import load_embeddings, KnowledgeFile - +from server.utils import embedding_device as get_embedding_device class PGKBService(KBService): pg_vector: PGVector - def _load_pg_vector(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None): + def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None): _embeddings = embeddings if _embeddings is None: _embeddings = load_embeddings(self.embed_model, embedding_device) diff --git a/server/llm_api.py b/server/llm_api.py index 7ef5891..d9667e4 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -4,8 +4,8 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger -from server.utils import MakeFastAPIOffline, set_httpx_timeout +from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger +from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device host_ip = "0.0.0.0" @@ -34,7 +34,7 @@ def create_model_worker_app( worker_address=base_url.format(model_worker_port), controller_address=base_url.format(controller_port), model_path=llm_model_dict[LLM_MODEL].get("local_model_path"), - device=LLM_DEVICE, + device=llm_device(), gpus=None, max_gpu_memory="20GiB", load_8bit=False, diff --git a/server/utils.py b/server/utils.py index 167b672..d716582 100644 --- a/server/utils.py +++ b/server/utils.py @@ -5,8 +5,8 @@ import torch from fastapi import FastAPI from pathlib import Path import asyncio -from configs.model_config import LLM_MODEL -from typing import Any, Optional +from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE +from typing import Literal, Optional class BaseResponse(BaseModel): @@ -201,6 +201,7 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict: config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config.update(llm_model_dict.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) + config["device"] = llm_device(config.get("device")) return config @@ -256,3 +257,28 @@ def set_httpx_timeout(timeout: float = None): httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout + + +# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch +def detect_device() -> Literal["cuda", "mps", "cpu"]: + try: + import torch + if torch.cuda.is_available(): + return "cuda" + if torch.backends.mps.is_available(): + return "mps" + except: + pass + return "cpu" + + +def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]: + if device not in ["cuda", "mps", "cpu"]: + device = detect_device() + return device + + +def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]: + if device not in ["cuda", "mps", "cpu"]: + device = detect_device() + return device diff --git a/startup.py b/startup.py index 64a3bcc..07630d9 100644 --- a/startup.py +++ b/startup.py @@ -14,12 +14,13 @@ except: pass sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \ +from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \ logger from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, FSCHAT_OPENAI_API, ) from server.utils import (fschat_controller_address, fschat_model_worker_address, - fschat_openai_api_address, set_httpx_timeout) + fschat_openai_api_address, set_httpx_timeout, + llm_device, embedding_device, get_model_worker_config) from server.utils import MakeFastAPIOffline, FastAPI import argparse from typing import Tuple, List @@ -195,7 +196,7 @@ def run_model_worker( ): import uvicorn - kwargs = FSCHAT_MODEL_WORKERS[model_name].copy() + kwargs = get_model_worker_config(model_name) host = kwargs.pop("host") port = kwargs.pop("port") model_path = llm_model_dict[model_name].get("local_model_path", "") @@ -331,9 +332,9 @@ def dump_server_info(after_start=False): print(f"项目版本:{VERSION}") print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") print("\n") - print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}") + print(f"当前LLM模型:{LLM_MODEL} @ {llm_device()}") pprint(llm_model_dict[LLM_MODEL]) - print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}") + print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}") if after_start: print("\n") print(f"服务端运行信息:")