diff --git a/.gitignore b/.gitignore index c4178a9..a7ef90f 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ logs .idea/ __pycache__/ -knowledge_base/ +/knowledge_base/ configs/*.py .vscode/ .pytest_cache/ diff --git a/configs/model_config.py.example b/configs/model_config.py.example index b9dd1de..5466f7f 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -7,19 +7,6 @@ logger.setLevel(logging.INFO) logging.basicConfig(format=LOG_FORMAT) -# 分布式部署时,不运行LLM的机器上可以不装torch -def default_device(): - try: - import torch - if torch.cuda.is_available(): - return "cuda" - if torch.backends.mps.is_available(): - return "mps" - except: - pass - return "cpu" - - # 在以下字典中修改属性值,以指定本地embedding模型存储位置 # 如将 "text2vec": "GanymedeNil/text2vec-large-chinese" 修改为 "text2vec": "User/Downloads/text2vec-large-chinese" # 此处请写绝对路径 @@ -44,8 +31,8 @@ embedding_model_dict = { # 选用的 Embedding 名称 EMBEDDING_MODEL = "m3e-base" -# Embedding 模型运行设备 -EMBEDDING_DEVICE = default_device() +# Embedding 模型运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 +EMBEDDING_DEVICE = "auto" llm_model_dict = { @@ -94,8 +81,8 @@ LLM_MODEL = "chatglm2-6b" # 历史对话轮数 HISTORY_LEN = 3 -# LLM 运行设备 -LLM_DEVICE = default_device() +# LLM 运行设备。可选项同Embedding 运行设备。 +LLM_DEVICE = "auto" # 日志存储路径 LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index ca8d1ae..79b1518 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -18,11 +18,12 @@ from server.db.repository.knowledge_file_repository import ( ) from configs.model_config import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, - EMBEDDING_DEVICE, EMBEDDING_MODEL) + EMBEDDING_MODEL) from server.knowledge_base.utils import ( get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, list_kbs_from_folder, list_files_from_folder, ) +from server.utils import embedding_device from typing import List, Union, Dict @@ -45,7 +46,7 @@ class KBService(ABC): self.doc_path = get_doc_path(self.kb_name) self.do_init() - def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings: + def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings: return load_embeddings(self.embed_model, embed_device) def create_kb(self): diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 3601b57..f17b2da 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -5,7 +5,6 @@ from configs.model_config import ( KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, - EMBEDDING_DEVICE, SCORE_THRESHOLD ) from server.knowledge_base.kb_service.base import KBService, SupportedVSType @@ -15,7 +14,7 @@ from langchain.vectorstores import FAISS from langchain.embeddings.base import Embeddings from typing import List from langchain.docstore.document import Document -from server.utils import torch_gc +from server.utils import torch_gc, embedding_device _VECTOR_STORE_TICKS = {} @@ -25,10 +24,10 @@ _VECTOR_STORE_TICKS = {} def load_faiss_vector_store( knowledge_base_name: str, embed_model: str = EMBEDDING_MODEL, - embed_device: str = EMBEDDING_DEVICE, + embed_device: str = embedding_device(), embeddings: Embeddings = None, tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. -): +) -> FAISS: print(f"loading vector store in '{knowledge_base_name}'.") vs_path = get_vs_path(knowledge_base_name) if embeddings is None: @@ -74,13 +73,18 @@ class FaissKBService(KBService): def get_kb_path(self): return os.path.join(KB_ROOT_PATH, self.kb_name) - def load_vector_store(self): + def load_vector_store(self) -> FAISS: return load_faiss_vector_store( knowledge_base_name=self.kb_name, embed_model=self.embed_model, tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0), ) + def save_vector_store(self, vector_store: FAISS = None): + vector_store = vector_store or self.load_vector_store() + vector_store.save_local(self.vs_path) + return vector_store + def refresh_vs_cache(self): refresh_vs_cache(self.kb_name) @@ -117,11 +121,11 @@ class FaissKBService(KBService): if not kwargs.get("not_refresh_vs_cache"): vector_store.save_local(self.vs_path) self.refresh_vs_cache() + return vector_store def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): - embeddings = self._load_embeddings() vector_store = self.load_vector_store() ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] @@ -133,7 +137,7 @@ class FaissKBService(KBService): vector_store.save_local(self.vs_path) self.refresh_vs_cache() - return True + return vector_store def do_clear_vs(self): shutil.rmtree(self.vs_path) diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 4b52625..3e3dd52 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -6,16 +6,16 @@ from langchain.vectorstores import PGVector from langchain.vectorstores.pgvector import DistanceStrategy from sqlalchemy import text -from configs.model_config import EMBEDDING_DEVICE, kbs_config +from configs.model_config import kbs_config from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ score_threshold_process from server.knowledge_base.utils import load_embeddings, KnowledgeFile - +from server.utils import embedding_device as get_embedding_device class PGKBService(KBService): pg_vector: PGVector - def _load_pg_vector(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None): + def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None): _embeddings = embeddings if _embeddings is None: _embeddings = load_embeddings(self.embed_model, embedding_device) diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index af506e2..4285b79 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -69,6 +69,7 @@ def folder2db( print(result) if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() kb.refresh_vs_cache() elif mode == "fill_info_only": files = list_files_from_folder(kb_name) @@ -85,6 +86,7 @@ def folder2db( kb.update_doc(kb_file, not_refresh_vs_cache=True) if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() kb.refresh_vs_cache() elif mode == "increament": db_files = kb.list_files() @@ -102,6 +104,7 @@ def folder2db( print(result) if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() kb.refresh_vs_cache() else: print(f"unspported migrate mode: {mode}") @@ -131,7 +134,10 @@ def prune_db_files(kb_name: str): files = list(set(files_in_db) - set(files_in_folder)) kb_files = file_to_kbfile(kb_name, files) for kb_file in kb_files: - kb.delete_doc(kb_file) + kb.delete_doc(kb_file, not_refresh_vs_cache=True) + if kb.vs_type() == SupportedVSType.FAISS: + kb.save_vector_store() + kb.refresh_vs_cache() return kb_files def prune_folder_files(kb_name: str): diff --git a/server/llm_api.py b/server/llm_api.py index 7ef5891..d9667e4 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -4,8 +4,8 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, logger -from server.utils import MakeFastAPIOffline, set_httpx_timeout +from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger +from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device host_ip = "0.0.0.0" @@ -34,7 +34,7 @@ def create_model_worker_app( worker_address=base_url.format(model_worker_port), controller_address=base_url.format(controller_port), model_path=llm_model_dict[LLM_MODEL].get("local_model_path"), - device=LLM_DEVICE, + device=llm_device(), gpus=None, max_gpu_memory="20GiB", load_8bit=False, diff --git a/server/utils.py b/server/utils.py index 167b672..d716582 100644 --- a/server/utils.py +++ b/server/utils.py @@ -5,8 +5,8 @@ import torch from fastapi import FastAPI from pathlib import Path import asyncio -from configs.model_config import LLM_MODEL -from typing import Any, Optional +from configs.model_config import LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE +from typing import Literal, Optional class BaseResponse(BaseModel): @@ -201,6 +201,7 @@ def get_model_worker_config(model_name: str = LLM_MODEL) -> dict: config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() config.update(llm_model_dict.get(model_name, {})) config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) + config["device"] = llm_device(config.get("device")) return config @@ -256,3 +257,28 @@ def set_httpx_timeout(timeout: float = None): httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout + + +# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch +def detect_device() -> Literal["cuda", "mps", "cpu"]: + try: + import torch + if torch.cuda.is_available(): + return "cuda" + if torch.backends.mps.is_available(): + return "mps" + except: + pass + return "cpu" + + +def llm_device(device: str = LLM_DEVICE) -> Literal["cuda", "mps", "cpu"]: + if device not in ["cuda", "mps", "cpu"]: + device = detect_device() + return device + + +def embedding_device(device: str = EMBEDDING_DEVICE) -> Literal["cuda", "mps", "cpu"]: + if device not in ["cuda", "mps", "cpu"]: + device = detect_device() + return device diff --git a/startup.py b/startup.py index 64a3bcc..07630d9 100644 --- a/startup.py +++ b/startup.py @@ -14,12 +14,13 @@ except: pass sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import EMBEDDING_DEVICE, EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LLM_DEVICE, LOG_PATH, \ +from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \ logger from configs.server_config import (WEBUI_SERVER, API_SERVER, OPEN_CROSS_DOMAIN, FSCHAT_CONTROLLER, FSCHAT_MODEL_WORKERS, FSCHAT_OPENAI_API, ) from server.utils import (fschat_controller_address, fschat_model_worker_address, - fschat_openai_api_address, set_httpx_timeout) + fschat_openai_api_address, set_httpx_timeout, + llm_device, embedding_device, get_model_worker_config) from server.utils import MakeFastAPIOffline, FastAPI import argparse from typing import Tuple, List @@ -195,7 +196,7 @@ def run_model_worker( ): import uvicorn - kwargs = FSCHAT_MODEL_WORKERS[model_name].copy() + kwargs = get_model_worker_config(model_name) host = kwargs.pop("host") port = kwargs.pop("port") model_path = llm_model_dict[model_name].get("local_model_path", "") @@ -331,9 +332,9 @@ def dump_server_info(after_start=False): print(f"项目版本:{VERSION}") print(f"langchain版本:{langchain.__version__}. fastchat版本:{fastchat.__version__}") print("\n") - print(f"当前LLM模型:{LLM_MODEL} @ {LLM_DEVICE}") + print(f"当前LLM模型:{LLM_MODEL} @ {llm_device()}") pprint(llm_model_dict[LLM_MODEL]) - print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {EMBEDDING_DEVICE}") + print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}") if after_start: print("\n") print(f"服务端运行信息:")