优化LLM和Embedding模型运行设备配置,可设为auto自动检测

This commit is contained in:
liunux4odoo 2023-08-31 17:33:43 +08:00
parent 80590ef5dc
commit 16fb19d2c3
8 changed files with 50 additions and 36 deletions

2
.gitignore vendored
View File

@ -3,7 +3,7 @@
logs logs
.idea/ .idea/
__pycache__/ __pycache__/
knowledge_base/ /knowledge_base/
configs/*.py configs/*.py
.vscode/ .vscode/
.pytest_cache/ .pytest_cache/

View File

@ -7,19 +7,6 @@ logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT) logging.basicConfig(format=LOG_FORMAT)
# 分布式部署时不运行LLM的机器上可以不装torch
def default_device():
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
except:
pass
return "cpu"
# 在以下字典中修改属性值以指定本地embedding模型存储位置 # 在以下字典中修改属性值以指定本地embedding模型存储位置
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese" # 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
# 此处请写绝对路径 # 此处请写绝对路径
@ -44,8 +31,8 @@ embedding_model_dict = {
# 选用的 Embedding 名称 # 选用的 Embedding 名称
EMBEDDING_MODEL = "m3e-base" EMBEDDING_MODEL = "m3e-base"
# Embedding 模型运行设备 # Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = default_device() EMBEDDING_DEVICE = "auto"
llm_model_dict = { llm_model_dict = {
@ -94,8 +81,8 @@ LLM_MODEL = "chatglm2-6b"
# 历史对话轮数 # 历史对话轮数
HISTORY_LEN = 3 HISTORY_LEN = 3
# LLM 运行设备 # LLM 运行设备。可选项同Embedding 运行设备。
LLM_DEVICE = default_device() LLM_DEVICE = "auto"
# 日志存储路径 # 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")

View File

@ -18,11 +18,12 @@ from server.db.repository.knowledge_file_repository import (
) )
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_DEVICE, EMBEDDING_MODEL) EMBEDDING_MODEL)
from server.knowledge_base.utils import ( from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder, list_kbs_from_folder, list_files_from_folder,
) )
from server.utils import embedding_device
from typing import List, Union, Dict from typing import List, Union, Dict
@ -45,7 +46,7 @@ class KBService(ABC):
self.doc_path = get_doc_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name)
self.do_init() self.do_init()
def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings: def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
return load_embeddings(self.embed_model, embed_device) return load_embeddings(self.embed_model, embed_device)
def create_kb(self): def create_kb(self):

View File

@ -5,7 +5,6 @@ from configs.model_config import (
KB_ROOT_PATH, KB_ROOT_PATH,
CACHED_VS_NUM, CACHED_VS_NUM,
EMBEDDING_MODEL, EMBEDDING_MODEL,
EMBEDDING_DEVICE,
SCORE_THRESHOLD SCORE_THRESHOLD
) )
from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.kb_service.base import KBService, SupportedVSType
@ -15,7 +14,7 @@ from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from typing import List from typing import List
from langchain.docstore.document import Document from langchain.docstore.document import Document
from server.utils import torch_gc from server.utils import torch_gc, embedding_device
_VECTOR_STORE_TICKS = {} _VECTOR_STORE_TICKS = {}
@ -25,7 +24,7 @@ _VECTOR_STORE_TICKS = {}
def load_faiss_vector_store( def load_faiss_vector_store(
knowledge_base_name: str, knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
embed_device: str = EMBEDDING_DEVICE, embed_device: str = embedding_device(),
embeddings: Embeddings = None, embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
) -> FAISS: ) -> FAISS:

View File

@ -6,16 +6,16 @@ from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text from sqlalchemy import text
from configs.model_config import EMBEDDING_DEVICE, kbs_config from configs.model_config import kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile from server.knowledge_base.utils import load_embeddings, KnowledgeFile
from server.utils import embedding_device as get_embedding_device
class PGKBService(KBService): class PGKBService(KBService):
pg_vector: PGVector pg_vector: PGVector
def _load_pg_vector(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None): def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None):
_embeddings = embeddings _embeddings = embeddings
if _embeddings is None: if _embeddings is None:
_embeddings = load_embeddings(self.embed_model, embedding_device) _embeddings = load_embeddings(self.embed_model, embedding_device)

View File

@ -4,8 +4,8 @@ import sys
import os import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger
from server.utils import MakeFastAPIOffline, set_httpx_timeout from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device
host_ip = "0.0.0.0" host_ip = "0.0.0.0"
@ -34,7 +34,7 @@ def create_model_worker_app(
worker_address=base_url.format(model_worker_port), worker_address=base_url.format(model_worker_port),
controller_address=base_url.format(controller_port), controller_address=base_url.format(controller_port),
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"), model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
device=LLM_DEVICE, device=llm_device(),
gpus=None, gpus=None,
max_gpu_memory="20GiB", max_gpu_memory="20GiB",
load_8bit=False, load_8bit=False,

View File

@ -5,8 +5,8 @@ import torch
from fastapi import FastAPI from fastapi import FastAPI
from pathlib import Path from pathlib import Path
import asyncio import asyncio
from configs.model_config import LLM_MODEL from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE
from typing import Any, Optional from typing import Literal, Optional
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
@ -201,6 +201,7 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {})) config.update(llm_model_dict.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
config["device"] = llm_device(config.get("device"))
return config return config
@ -256,3 +257,28 @@ def set_httpx_timeout(timeout: float = None):
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]:
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
except:
pass
return "cpu"
def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]:
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device
def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]:
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device

View File

@ -14,12 +14,13 @@ except:
pass pass
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \ from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
FSCHAT_OPENAI_API, ) FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address, from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout) fschat_openai_api_address, set_httpx_timeout,
llm_device, embedding_device, get_model_worker_config)
from server.utils import MakeFastAPIOffline, FastAPI from server.utils import MakeFastAPIOffline, FastAPI
import argparse import argparse
from typing import Tuple, List from typing import Tuple, List
@ -195,7 +196,7 @@ def run_model_worker(
): ):
import uvicorn import uvicorn
kwargs = FSCHAT_MODEL_WORKERS[model_name].copy() kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host") host = kwargs.pop("host")
port = kwargs.pop("port") port = kwargs.pop("port")
model_path = llm_model_dict[model_name].get("local_model_path", "") model_path = llm_model_dict[model_name].get("local_model_path", "")
@ -331,9 +332,9 @@ def dump_server_info(after_start=False):
print(f"项目版本:{VERSION}") print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}") print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n") print("\n")
print(f"当前LLM模型{LLM_MODEL} @ {LLM_DEVICE}") print(f"当前LLM模型{LLM_MODEL} @ {llm_device()}")
pprint(llm_model_dict[LLM_MODEL]) pprint(llm_model_dict[LLM_MODEL])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}") print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start: if after_start:
print("\n") print("\n")
print(f"服务端运行信息:") print(f"服务端运行信息:")