优化LLM和Embedding模型运行设备配置,可设为auto自动检测
This commit is contained in:
parent
80590ef5dc
commit
16fb19d2c3
|
|
@ -3,7 +3,7 @@
|
||||||
logs
|
logs
|
||||||
.idea/
|
.idea/
|
||||||
__pycache__/
|
__pycache__/
|
||||||
knowledge_base/
|
/knowledge_base/
|
||||||
configs/*.py
|
configs/*.py
|
||||||
.vscode/
|
.vscode/
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
|
|
|
||||||
|
|
@ -7,19 +7,6 @@ logger.setLevel(logging.INFO)
|
||||||
logging.basicConfig(format=LOG_FORMAT)
|
logging.basicConfig(format=LOG_FORMAT)
|
||||||
|
|
||||||
|
|
||||||
# 分布式部署时,不运行LLM的机器上可以不装torch
|
|
||||||
def default_device():
|
|
||||||
try:
|
|
||||||
import torch
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
return "cuda"
|
|
||||||
if torch.backends.mps.is_available():
|
|
||||||
return "mps"
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return "cpu"
|
|
||||||
|
|
||||||
|
|
||||||
# 在以下字典中修改属性值,以指定本地embedding模型存储位置
|
# 在以下字典中修改属性值,以指定本地embedding模型存储位置
|
||||||
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
|
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
|
||||||
# 此处请写绝对路径
|
# 此处请写绝对路径
|
||||||
|
|
@ -44,8 +31,8 @@ embedding_model_dict = {
|
||||||
# 选用的 Embedding 名称
|
# 选用的 Embedding 名称
|
||||||
EMBEDDING_MODEL = "m3e-base"
|
EMBEDDING_MODEL = "m3e-base"
|
||||||
|
|
||||||
# Embedding 模型运行设备
|
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||||
EMBEDDING_DEVICE = default_device()
|
EMBEDDING_DEVICE = "auto"
|
||||||
|
|
||||||
|
|
||||||
llm_model_dict = {
|
llm_model_dict = {
|
||||||
|
|
@ -94,8 +81,8 @@ LLM_MODEL = "chatglm2-6b"
|
||||||
# 历史对话轮数
|
# 历史对话轮数
|
||||||
HISTORY_LEN = 3
|
HISTORY_LEN = 3
|
||||||
|
|
||||||
# LLM 运行设备
|
# LLM 运行设备。可选项同Embedding 运行设备。
|
||||||
LLM_DEVICE = default_device()
|
LLM_DEVICE = "auto"
|
||||||
|
|
||||||
# 日志存储路径
|
# 日志存储路径
|
||||||
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||||
|
|
|
||||||
|
|
@ -18,11 +18,12 @@ from server.db.repository.knowledge_file_repository import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||||
EMBEDDING_DEVICE, EMBEDDING_MODEL)
|
EMBEDDING_MODEL)
|
||||||
from server.knowledge_base.utils import (
|
from server.knowledge_base.utils import (
|
||||||
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
|
||||||
list_kbs_from_folder, list_files_from_folder,
|
list_kbs_from_folder, list_files_from_folder,
|
||||||
)
|
)
|
||||||
|
from server.utils import embedding_device
|
||||||
from typing import List, Union, Dict
|
from typing import List, Union, Dict
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -45,7 +46,7 @@ class KBService(ABC):
|
||||||
self.doc_path = get_doc_path(self.kb_name)
|
self.doc_path = get_doc_path(self.kb_name)
|
||||||
self.do_init()
|
self.do_init()
|
||||||
|
|
||||||
def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings:
|
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
|
||||||
return load_embeddings(self.embed_model, embed_device)
|
return load_embeddings(self.embed_model, embed_device)
|
||||||
|
|
||||||
def create_kb(self):
|
def create_kb(self):
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from configs.model_config import (
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
CACHED_VS_NUM,
|
CACHED_VS_NUM,
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
EMBEDDING_DEVICE,
|
|
||||||
SCORE_THRESHOLD
|
SCORE_THRESHOLD
|
||||||
)
|
)
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
|
|
@ -15,7 +14,7 @@ from langchain.vectorstores import FAISS
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from typing import List
|
from typing import List
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from server.utils import torch_gc
|
from server.utils import torch_gc, embedding_device
|
||||||
|
|
||||||
|
|
||||||
_VECTOR_STORE_TICKS = {}
|
_VECTOR_STORE_TICKS = {}
|
||||||
|
|
@ -25,7 +24,7 @@ _VECTOR_STORE_TICKS = {}
|
||||||
def load_faiss_vector_store(
|
def load_faiss_vector_store(
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
embed_device: str = EMBEDDING_DEVICE,
|
embed_device: str = embedding_device(),
|
||||||
embeddings: Embeddings = None,
|
embeddings: Embeddings = None,
|
||||||
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||||
) -> FAISS:
|
) -> FAISS:
|
||||||
|
|
|
||||||
|
|
@ -6,16 +6,16 @@ from langchain.vectorstores import PGVector
|
||||||
from langchain.vectorstores.pgvector import DistanceStrategy
|
from langchain.vectorstores.pgvector import DistanceStrategy
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
from configs.model_config import kbs_config
|
||||||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
|
||||||
score_threshold_process
|
score_threshold_process
|
||||||
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
|
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
|
||||||
|
from server.utils import embedding_device as get_embedding_device
|
||||||
|
|
||||||
class PGKBService(KBService):
|
class PGKBService(KBService):
|
||||||
pg_vector: PGVector
|
pg_vector: PGVector
|
||||||
|
|
||||||
def _load_pg_vector(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None):
|
def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None):
|
||||||
_embeddings = embeddings
|
_embeddings = embeddings
|
||||||
if _embeddings is None:
|
if _embeddings is None:
|
||||||
_embeddings = load_embeddings(self.embed_model, embedding_device)
|
_embeddings = load_embeddings(self.embed_model, embedding_device)
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ import sys
|
||||||
import os
|
import os
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
|
from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger
|
||||||
from server.utils import MakeFastAPIOffline, set_httpx_timeout
|
from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device
|
||||||
|
|
||||||
|
|
||||||
host_ip = "0.0.0.0"
|
host_ip = "0.0.0.0"
|
||||||
|
|
@ -34,7 +34,7 @@ def create_model_worker_app(
|
||||||
worker_address=base_url.format(model_worker_port),
|
worker_address=base_url.format(model_worker_port),
|
||||||
controller_address=base_url.format(controller_port),
|
controller_address=base_url.format(controller_port),
|
||||||
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
|
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
|
||||||
device=LLM_DEVICE,
|
device=llm_device(),
|
||||||
gpus=None,
|
gpus=None,
|
||||||
max_gpu_memory="20GiB",
|
max_gpu_memory="20GiB",
|
||||||
load_8bit=False,
|
load_8bit=False,
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,8 @@ import torch
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import asyncio
|
import asyncio
|
||||||
from configs.model_config import LLM_MODEL
|
from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE
|
||||||
from typing import Any, Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
|
||||||
class BaseResponse(BaseModel):
|
class BaseResponse(BaseModel):
|
||||||
|
|
@ -201,6 +201,7 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
|
||||||
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
||||||
config.update(llm_model_dict.get(model_name, {}))
|
config.update(llm_model_dict.get(model_name, {}))
|
||||||
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
|
||||||
|
config["device"] = llm_device(config.get("device"))
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -256,3 +257,28 @@ def set_httpx_timeout(timeout: float = None):
|
||||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
||||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
||||||
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
||||||
|
|
||||||
|
|
||||||
|
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
||||||
|
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
return "mps"
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]:
|
||||||
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
|
device = detect_device()
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]:
|
||||||
|
if device not in ["cuda", "mps", "cpu"]:
|
||||||
|
device = detect_device()
|
||||||
|
return device
|
||||||
|
|
|
||||||
11
startup.py
11
startup.py
|
|
@ -14,12 +14,13 @@ except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \
|
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
|
||||||
logger
|
logger
|
||||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
|
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
|
||||||
FSCHAT_OPENAI_API, )
|
FSCHAT_OPENAI_API, )
|
||||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||||
fschat_openai_api_address, set_httpx_timeout)
|
fschat_openai_api_address, set_httpx_timeout,
|
||||||
|
llm_device, embedding_device, get_model_worker_config)
|
||||||
from server.utils import MakeFastAPIOffline, FastAPI
|
from server.utils import MakeFastAPIOffline, FastAPI
|
||||||
import argparse
|
import argparse
|
||||||
from typing import Tuple, List
|
from typing import Tuple, List
|
||||||
|
|
@ -195,7 +196,7 @@ def run_model_worker(
|
||||||
):
|
):
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
kwargs = FSCHAT_MODEL_WORKERS[model_name].copy()
|
kwargs = get_model_worker_config(model_name)
|
||||||
host = kwargs.pop("host")
|
host = kwargs.pop("host")
|
||||||
port = kwargs.pop("port")
|
port = kwargs.pop("port")
|
||||||
model_path = llm_model_dict[model_name].get("local_model_path", "")
|
model_path = llm_model_dict[model_name].get("local_model_path", "")
|
||||||
|
|
@ -331,9 +332,9 @@ def dump_server_info(after_start=False):
|
||||||
print(f"项目版本:{VERSION}")
|
print(f"项目版本:{VERSION}")
|
||||||
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}")
|
||||||
print("\n")
|
print("\n")
|
||||||
print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}")
|
print(f"当前LLM模型:{LLM_MODEL} @ {llm_device()}")
|
||||||
pprint(llm_model_dict[LLM_MODEL])
|
pprint(llm_model_dict[LLM_MODEL])
|
||||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
|
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
||||||
if after_start:
|
if after_start:
|
||||||
print("\n")
|
print("\n")
|
||||||
print(f"服务端运行信息:")
|
print(f"服务端运行信息:")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue