优化LLM和Embedding模型运行设备配置,可设为auto自动检测;修复:重建知识库时FAISS未保存 (#1330)

* 避免configs对torch的依赖;
* webui自动从configs获取api地址(close #1319)
* bug fix: 重建知识库时FAISS未保存
* 优化LLM和Embedding模型运行设备配置,可设为auto自动检测
This commit is contained in:
liunux4odoo 2023-08-31 17:44:48 +08:00 committed by GitHub
parent 26a9237237
commit b1201a5f23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 66 additions and 41 deletions

2
.gitignore vendored
View File

@ -3,7 +3,7 @@
logs logs
.idea/ .idea/
__pycache__/ __pycache__/
knowledge_base/ /knowledge_base/
configs/*.py configs/*.py
.vscode/ .vscode/
.pytest_cache/ .pytest_cache/

View File

@ -7,19 +7,6 @@ logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT) logging.basicConfig(format=LOG_FORMAT)
# 分布式部署时不运行LLM的机器上可以不装torch
def default_device():
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
except:
pass
return "cpu"
# 在以下字典中修改属性值以指定本地embedding模型存储位置 # 在以下字典中修改属性值以指定本地embedding模型存储位置
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese" # 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
# 此处请写绝对路径 # 此处请写绝对路径
@ -44,8 +31,8 @@ embedding_model_dict = {
# 选用的 Embedding 名称 # 选用的 Embedding 名称
EMBEDDING_MODEL = "m3e-base" EMBEDDING_MODEL = "m3e-base"
# Embedding 模型运行设备 # Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = default_device() EMBEDDING_DEVICE = "auto"
llm_model_dict = { llm_model_dict = {
@ -94,8 +81,8 @@ LLM_MODEL = "chatglm2-6b"
# 历史对话轮数 # 历史对话轮数
HISTORY_LEN = 3 HISTORY_LEN = 3
# LLM 运行设备 # LLM 运行设备。可选项同Embedding 运行设备。
LLM_DEVICE = default_device() LLM_DEVICE = "auto"
# 日志存储路径 # 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")

View File

@ -18,11 +18,12 @@ from server.db.repository.knowledge_file_repository import (
) )
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_DEVICE, EMBEDDING_MODEL) EMBEDDING_MODEL)
from server.knowledge_base.utils import ( from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder, list_kbs_from_folder, list_files_from_folder,
) )
from server.utils import embedding_device
from typing import List, Union, Dict from typing import List, Union, Dict
@ -45,7 +46,7 @@ class KBService(ABC):
self.doc_path = get_doc_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name)
self.do_init() self.do_init()
def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings: def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
return load_embeddings(self.embed_model, embed_device) return load_embeddings(self.embed_model, embed_device)
def create_kb(self): def create_kb(self):

View File

@ -5,7 +5,6 @@ from configs.model_config import (
KB_ROOT_PATH, KB_ROOT_PATH,
CACHED_VS_NUM, CACHED_VS_NUM,
EMBEDDING_MODEL, EMBEDDING_MODEL,
EMBEDDING_DEVICE,
SCORE_THRESHOLD SCORE_THRESHOLD
) )
from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.kb_service.base import KBService, SupportedVSType
@ -15,7 +14,7 @@ from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from typing import List from typing import List
from langchain.docstore.document import Document from langchain.docstore.document import Document
from server.utils import torch_gc from server.utils import torch_gc, embedding_device
_VECTOR_STORE_TICKS = {} _VECTOR_STORE_TICKS = {}
@ -25,10 +24,10 @@ _VECTOR_STORE_TICKS = {}
def load_faiss_vector_store( def load_faiss_vector_store(
knowledge_base_name: str, knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL, embed_model: str = EMBEDDING_MODEL,
embed_device: str = EMBEDDING_DEVICE, embed_device: str = embedding_device(),
embeddings: Embeddings = None, embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
): ) -> FAISS:
print(f"loading vector store in '{knowledge_base_name}'.") print(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name) vs_path = get_vs_path(knowledge_base_name)
if embeddings is None: if embeddings is None:
@ -74,13 +73,18 @@ class FaissKBService(KBService):
def get_kb_path(self): def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name) return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self): def load_vector_store(self) -> FAISS:
return load_faiss_vector_store( return load_faiss_vector_store(
knowledge_base_name=self.kb_name, knowledge_base_name=self.kb_name,
embed_model=self.embed_model, embed_model=self.embed_model,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0), tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
) )
def save_vector_store(self, vector_store: FAISS = None):
vector_store = vector_store or self.load_vector_store()
vector_store.save_local(self.vs_path)
return vector_store
def refresh_vs_cache(self): def refresh_vs_cache(self):
refresh_vs_cache(self.kb_name) refresh_vs_cache(self.kb_name)
@ -117,11 +121,11 @@ class FaissKBService(KBService):
if not kwargs.get("not_refresh_vs_cache"): if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path) vector_store.save_local(self.vs_path)
self.refresh_vs_cache() self.refresh_vs_cache()
return vector_store
def do_delete_doc(self, def do_delete_doc(self,
kb_file: KnowledgeFile, kb_file: KnowledgeFile,
**kwargs): **kwargs):
embeddings = self._load_embeddings()
vector_store = self.load_vector_store() vector_store = self.load_vector_store()
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
@ -133,7 +137,7 @@ class FaissKBService(KBService):
vector_store.save_local(self.vs_path) vector_store.save_local(self.vs_path)
self.refresh_vs_cache() self.refresh_vs_cache()
return True return vector_store
def do_clear_vs(self): def do_clear_vs(self):
shutil.rmtree(self.vs_path) shutil.rmtree(self.vs_path)

View File

@ -6,16 +6,16 @@ from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text from sqlalchemy import text
from configs.model_config import EMBEDDING_DEVICE, kbs_config from configs.model_config import kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile from server.knowledge_base.utils import load_embeddings, KnowledgeFile
from server.utils import embedding_device as get_embedding_device
class PGKBService(KBService): class PGKBService(KBService):
pg_vector: PGVector pg_vector: PGVector
def _load_pg_vector(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None): def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None):
_embeddings = embeddings _embeddings = embeddings
if _embeddings is None: if _embeddings is None:
_embeddings = load_embeddings(self.embed_model, embedding_device) _embeddings = load_embeddings(self.embed_model, embedding_device)

View File

@ -69,6 +69,7 @@ def folder2db(
print(result) print(result)
if kb.vs_type() == SupportedVSType.FAISS: if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache() kb.refresh_vs_cache()
elif mode == "fill_info_only": elif mode == "fill_info_only":
files = list_files_from_folder(kb_name) files = list_files_from_folder(kb_name)
@ -85,6 +86,7 @@ def folder2db(
kb.update_doc(kb_file, not_refresh_vs_cache=True) kb.update_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS: if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache() kb.refresh_vs_cache()
elif mode == "increament": elif mode == "increament":
db_files = kb.list_files() db_files = kb.list_files()
@ -102,6 +104,7 @@ def folder2db(
print(result) print(result)
if kb.vs_type() == SupportedVSType.FAISS: if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache() kb.refresh_vs_cache()
else: else:
print(f"unspported migrate mode: {mode}") print(f"unspported migrate mode: {mode}")
@ -131,7 +134,10 @@ def prune_db_files(kb_name: str):
files = list(set(files_in_db) - set(files_in_folder)) files = list(set(files_in_db) - set(files_in_folder))
kb_files = file_to_kbfile(kb_name, files) kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files: for kb_file in kb_files:
kb.delete_doc(kb_file) kb.delete_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
return kb_files return kb_files
def prune_folder_files(kb_name: str): def prune_folder_files(kb_name: str):

View File

@ -4,8 +4,8 @@ import sys
import os import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger
from server.utils import MakeFastAPIOffline, set_httpx_timeout from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device
host_ip = "0.0.0.0" host_ip = "0.0.0.0"
@ -34,7 +34,7 @@ def create_model_worker_app(
worker_address=base_url.format(model_worker_port), worker_address=base_url.format(model_worker_port),
controller_address=base_url.format(controller_port), controller_address=base_url.format(controller_port),
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"), model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
device=LLM_DEVICE, device=llm_device(),
gpus=None, gpus=None,
max_gpu_memory="20GiB", max_gpu_memory="20GiB",
load_8bit=False, load_8bit=False,

View File

@ -5,8 +5,8 @@ import torch
from fastapi import FastAPI from fastapi import FastAPI
from pathlib import Path from pathlib import Path
import asyncio import asyncio
from configs.model_config import LLM_MODEL from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE
from typing import Any, Optional from typing import Literal, Optional
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
@ -201,6 +201,7 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {})) config.update(llm_model_dict.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
config["device"] = llm_device(config.get("device"))
return config return config
@ -256,3 +257,28 @@ def set_httpx_timeout(timeout: float = None):
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]:
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
except:
pass
return "cpu"
def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]:
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device
def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]:
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device

View File

@ -14,12 +14,13 @@ except:
pass pass
sys.path.append(os.path.dirname(os.path.dirname(__file__))) sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \ from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
FSCHAT_OPENAI_API, ) FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address, from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout) fschat_openai_api_address, set_httpx_timeout,
llm_device, embedding_device, get_model_worker_config)
from server.utils import MakeFastAPIOffline, FastAPI from server.utils import MakeFastAPIOffline, FastAPI
import argparse import argparse
from typing import Tuple, List from typing import Tuple, List
@ -195,7 +196,7 @@ def run_model_worker(
): ):
import uvicorn import uvicorn
kwargs = FSCHAT_MODEL_WORKERS[model_name].copy() kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host") host = kwargs.pop("host")
port = kwargs.pop("port") port = kwargs.pop("port")
model_path = llm_model_dict[model_name].get("local_model_path", "") model_path = llm_model_dict[model_name].get("local_model_path", "")
@ -331,9 +332,9 @@ def dump_server_info(after_start=False):
print(f"项目版本:{VERSION}") print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}") print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n") print("\n")
print(f"当前LLM模型{LLM_MODEL} @ {LLM_DEVICE}") print(f"当前LLM模型{LLM_MODEL} @ {llm_device()}")
pprint(llm_model_dict[LLM_MODEL]) pprint(llm_model_dict[LLM_MODEL])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}") print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start: if after_start:
print("\n") print("\n")
print(f"服务端运行信息:") print(f"服务端运行信息:")