优化LLM和Embedding模型运行设备配置,可设为auto自动检测;修复:重建知识库时FAISS未保存 (#1330)

* 避免configs对torch的依赖;
* webui自动从configs获取api地址(close #1319)
* bug fix: 重建知识库时FAISS未保存
* 优化LLM和Embedding模型运行设备配置,可设为auto自动检测
This commit is contained in:
liunux4odoo 2023-08-31 17:44:48 +08:00 committed by GitHub
parent 26a9237237
commit b1201a5f23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 66 additions and 41 deletions

2
.gitignore vendored
View File

@ -3,7 +3,7 @@
logs
.idea/
__pycache__/
knowledge_base/
/knowledge_base/
configs/*.py
.vscode/
.pytest_cache/

View File

@ -7,19 +7,6 @@ logger.setLevel(logging.INFO)
logging.basicConfig(format=LOG_FORMAT)
# 分布式部署时不运行LLM的机器上可以不装torch
def default_device():
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
except:
pass
return "cpu"
# 在以下字典中修改属性值以指定本地embedding模型存储位置
# 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese"
# 此处请写绝对路径
@ -44,8 +31,8 @@ embedding_model_dict = {
# 选用的 Embedding 名称
EMBEDDING_MODEL = "m3e-base"
# Embedding 模型运行设备
EMBEDDING_DEVICE = default_device()
# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
EMBEDDING_DEVICE = "auto"
llm_model_dict = {
@ -94,8 +81,8 @@ LLM_MODEL = "chatglm2-6b"
# 历史对话轮数
HISTORY_LEN = 3
# LLM 运行设备
LLM_DEVICE = default_device()
# LLM 运行设备。可选项同Embedding 运行设备。
LLM_DEVICE = "auto"
# 日志存储路径
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")

View File

@ -18,11 +18,12 @@ from server.db.repository.knowledge_file_repository import (
)
from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
EMBEDDING_DEVICE, EMBEDDING_MODEL)
EMBEDDING_MODEL)
from server.knowledge_base.utils import (
get_kb_path, get_doc_path, load_embeddings, KnowledgeFile,
list_kbs_from_folder, list_files_from_folder,
)
from server.utils import embedding_device
from typing import List, Union, Dict
@ -45,7 +46,7 @@ class KBService(ABC):
self.doc_path = get_doc_path(self.kb_name)
self.do_init()
def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings:
def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings:
return load_embeddings(self.embed_model, embed_device)
def create_kb(self):

View File

@ -5,7 +5,6 @@ from configs.model_config import (
KB_ROOT_PATH,
CACHED_VS_NUM,
EMBEDDING_MODEL,
EMBEDDING_DEVICE,
SCORE_THRESHOLD
)
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
@ -15,7 +14,7 @@ from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from typing import List
from langchain.docstore.document import Document
from server.utils import torch_gc
from server.utils import torch_gc, embedding_device
_VECTOR_STORE_TICKS = {}
@ -25,10 +24,10 @@ _VECTOR_STORE_TICKS = {}
def load_faiss_vector_store(
knowledge_base_name: str,
embed_model: str = EMBEDDING_MODEL,
embed_device: str = EMBEDDING_DEVICE,
embed_device: str = embedding_device(),
embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
):
) -> FAISS:
print(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
@ -74,13 +73,18 @@ class FaissKBService(KBService):
def get_kb_path(self):
return os.path.join(KB_ROOT_PATH, self.kb_name)
def load_vector_store(self):
def load_vector_store(self) -> FAISS:
return load_faiss_vector_store(
knowledge_base_name=self.kb_name,
embed_model=self.embed_model,
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
)
def save_vector_store(self, vector_store: FAISS = None):
vector_store = vector_store or self.load_vector_store()
vector_store.save_local(self.vs_path)
return vector_store
def refresh_vs_cache(self):
refresh_vs_cache(self.kb_name)
@ -117,11 +121,11 @@ class FaissKBService(KBService):
if not kwargs.get("not_refresh_vs_cache"):
vector_store.save_local(self.vs_path)
self.refresh_vs_cache()
return vector_store
def do_delete_doc(self,
kb_file: KnowledgeFile,
**kwargs):
embeddings = self._load_embeddings()
vector_store = self.load_vector_store()
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
@ -133,7 +137,7 @@ class FaissKBService(KBService):
vector_store.save_local(self.vs_path)
self.refresh_vs_cache()
return True
return vector_store
def do_clear_vs(self):
shutil.rmtree(self.vs_path)

View File

@ -6,16 +6,16 @@ from langchain.vectorstores import PGVector
from langchain.vectorstores.pgvector import DistanceStrategy
from sqlalchemy import text
from configs.model_config import EMBEDDING_DEVICE, kbs_config
from configs.model_config import kbs_config
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \
score_threshold_process
from server.knowledge_base.utils import load_embeddings, KnowledgeFile
from server.utils import embedding_device as get_embedding_device
class PGKBService(KBService):
pg_vector: PGVector
def _load_pg_vector(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None):
def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None):
_embeddings = embeddings
if _embeddings is None:
_embeddings = load_embeddings(self.embed_model, embedding_device)

View File

@ -69,6 +69,7 @@ def folder2db(
print(result)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
elif mode == "fill_info_only":
files = list_files_from_folder(kb_name)
@ -85,6 +86,7 @@ def folder2db(
kb.update_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
elif mode == "increament":
db_files = kb.list_files()
@ -102,6 +104,7 @@ def folder2db(
print(result)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
else:
print(f"unspported migrate mode: {mode}")
@ -131,7 +134,10 @@ def prune_db_files(kb_name: str):
files = list(set(files_in_db) - set(files_in_folder))
kb_files = file_to_kbfile(kb_name, files)
for kb_file in kb_files:
kb.delete_doc(kb_file)
kb.delete_doc(kb_file, not_refresh_vs_cache=True)
if kb.vs_type() == SupportedVSType.FAISS:
kb.save_vector_store()
kb.refresh_vs_cache()
return kb_files
def prune_folder_files(kb_name: str):

View File

@ -4,8 +4,8 @@ import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger
from server.utils import MakeFastAPIOffline, set_httpx_timeout
from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger
from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device
host_ip = "0.0.0.0"
@ -34,7 +34,7 @@ def create_model_worker_app(
worker_address=base_url.format(model_worker_port),
controller_address=base_url.format(controller_port),
model_path=llm_model_dict[LLM_MODEL].get("local_model_path"),
device=LLM_DEVICE,
device=llm_device(),
gpus=None,
max_gpu_memory="20GiB",
load_8bit=False,

View File

@ -5,8 +5,8 @@ import torch
from fastapi import FastAPI
from pathlib import Path
import asyncio
from configs.model_config import LLM_MODEL
from typing import Any, Optional
from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE
from typing import Literal, Optional
class BaseResponse(BaseModel):
@ -201,6 +201,7 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict:
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
config.update(llm_model_dict.get(model_name, {}))
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}))
config["device"] = llm_device(config.get("device"))
return config
@ -256,3 +257,28 @@ def set_httpx_timeout(timeout: float = None):
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
# 自动检查torch可用的设备。分布式部署时不运行LLM的机器上可以不装torch
def detect_device() -> Literal["cuda", "mps", "cpu"]:
try:
import torch
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
except:
pass
return "cpu"
def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]:
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device
def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]:
if device not in ["cuda", "mps", "cpu"]:
device = detect_device()
return device

View File

@ -14,12 +14,13 @@ except:
pass
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
logger
from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS,
FSCHAT_OPENAI_API, )
from server.utils import (fschat_controller_address, fschat_model_worker_address,
fschat_openai_api_address, set_httpx_timeout)
fschat_openai_api_address, set_httpx_timeout,
llm_device, embedding_device, get_model_worker_config)
from server.utils import MakeFastAPIOffline, FastAPI
import argparse
from typing import Tuple, List
@ -195,7 +196,7 @@ def run_model_worker(
):
import uvicorn
kwargs = FSCHAT_MODEL_WORKERS[model_name].copy()
kwargs = get_model_worker_config(model_name)
host = kwargs.pop("host")
port = kwargs.pop("port")
model_path = llm_model_dict[model_name].get("local_model_path", "")
@ -331,9 +332,9 @@ def dump_server_info(after_start=False):
print(f"项目版本:{VERSION}")
print(f"langchain版本{langchain.__version__}. fastchat版本{fastchat.__version__}")
print("\n")
print(f"当前LLM模型{LLM_MODEL} @ {LLM_DEVICE}")
print(f"当前LLM模型{LLM_MODEL} @ {llm_device()}")
pprint(llm_model_dict[LLM_MODEL])
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}")
print(f"当前Embbedings模型 {EMBEDDING_MODEL} @ {embedding_device()}")
if after_start:
print("\n")
print(f"服务端运行信息:")