From 22ff073309dd1b33367957a527f3d35ade6a5af0 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Mon, 11 Sep 2023 20:41:41 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9Embeddings=E5=92=8CFAISS?= =?UTF-8?q?=E7=BC=93=E5=AD=98=E5=8A=A0=E8=BD=BD=E6=96=B9=E5=BC=8F=EF=BC=8C?= =?UTF-8?q?=E7=9F=A5=E8=AF=86=E5=BA=93=E7=9B=B8=E5=85=B3API=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E6=94=AF=E6=8C=81=E5=A4=9A=E7=BA=BF=E7=A8=8B=E5=B9=B6?= =?UTF-8?q?=E5=8F=91=20(#1434)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修改Embeddings和FAISS缓存加载方式,支持多线程,支持内存FAISS * 知识库相关API接口支持多线程并发 * 根据新的API接口调整ApiRequest和测试用例 * 删除webui.py失效的启动说明 --- server/api.py | 84 +---- server/knowledge_base/kb_api.py | 12 +- server/knowledge_base/kb_cache/base.py | 137 +++++++ server/knowledge_base/kb_cache/faiss_cache.py | 157 ++++++++ server/knowledge_base/kb_doc_api.py | 48 +-- server/knowledge_base/kb_service/base.py | 1 - .../kb_service/faiss_kb_service.py | 108 ++---- server/knowledge_base/migrate.py | 7 +- server/knowledge_base/utils.py | 37 +- server/llm_api.py | 337 ++++-------------- server/utils.py | 2 +- tests/api/test_kb_api_request.py | 4 +- tests/api/test_llm_api.py | 3 +- tests/api/test_stream_chat_api.py | 1 - webui.py | 6 - webui_pages/utils.py | 83 +++-- 16 files changed, 497 insertions(+), 530 deletions(-) create mode 100644 server/knowledge_base/kb_cache/base.py create mode 100644 server/knowledge_base/kb_cache/faiss_cache.py diff --git a/server/api.py b/server/api.py index a6d0827..357a067 100644 --- a/server/api.py +++ b/server/api.py @@ -4,12 +4,11 @@ import os sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import LLM_MODEL, NLTK_DATA_PATH -from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT -from configs import VERSION, logger, log_verbose +from configs import VERSION +from configs.model_config import NLTK_DATA_PATH +from configs.server_config import OPEN_CROSS_DOMAIN import argparse import uvicorn -from fastapi import Body from fastapi.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse from server.chat import (chat, knowledge_base_chat, openai_chat, @@ -18,8 +17,8 @@ from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs, update_docs, download_doc, recreate_vector_store, search_docs, DocumentWithScore) -from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address -import httpx +from server.llm_api import list_llm_models, change_llm_model, stop_llm_model +from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline from typing import List nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -126,79 +125,20 @@ def create_app(): )(recreate_vector_store) # LLM模型相关接口 - @app.post("/llm_model/list_models", + app.post("/llm_model/list_models", tags=["LLM Model Management"], - summary="列出当前已加载的模型") - def list_models( - controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) - ) -> BaseResponse: - ''' - 从fastchat controller获取已加载模型列表 - ''' - try: - controller_address = controller_address or fschat_controller_address() - r = httpx.post(controller_address + "/list_models") - return BaseResponse(data=r.json()["models"]) - except Exception as e: - logger.error(f'{e.__class__.__name__}: {e}', - exc_info=e if log_verbose else None) - return BaseResponse( - code=500, - data=[], - msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}") + summary="列出当前已加载的模型", + )(list_llm_models) - @app.post("/llm_model/stop", + app.post("/llm_model/stop", tags=["LLM Model Management"], summary="停止指定的LLM模型(Model Worker)", - ) - def stop_llm_model( - model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]), - controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) - ) -> BaseResponse: - ''' - 向fastchat controller请求停止某个LLM模型。 - 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 - ''' - try: - controller_address = controller_address or fschat_controller_address() - r = httpx.post( - controller_address + "/release_worker", - json={"model_name": model_name}, - ) - return r.json() - except Exception as e: - logger.error(f'{e.__class__.__name__}: {e}', - exc_info=e if log_verbose else None) - return BaseResponse( - code=500, - msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}") + )(stop_llm_model) - @app.post("/llm_model/change", + app.post("/llm_model/change", tags=["LLM Model Management"], summary="切换指定的LLM模型(Model Worker)", - ) - def change_llm_model( - model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]), - new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]), - controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) - ): - ''' - 向fastchat controller请求切换LLM模型。 - ''' - try: - controller_address = controller_address or fschat_controller_address() - r = httpx.post( - controller_address + "/release_worker", - json={"model_name": model_name, "new_model_name": new_model_name}, - timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model - ) - return r.json() - except Exception as e: - logger.error(f'{e.__class__.__name__}: {e}', - exc_info=e if log_verbose else None) - return BaseResponse( - code=500, - msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}") + )(change_llm_model) return app diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index a6efa68..c7b703e 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -12,10 +12,10 @@ def list_kbs(): return ListResponse(data=list_kbs_from_db()) -async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), - vector_store_type: str = Body("faiss"), - embed_model: str = Body(EMBEDDING_MODEL), - ) -> BaseResponse: +def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), + vector_store_type: str = Body("faiss"), + embed_model: str = Body(EMBEDDING_MODEL), + ) -> BaseResponse: # Create selected knowledge base if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -38,8 +38,8 @@ async def create_kb(knowledge_base_name: str = Body(..., examples=["samples"]), return BaseResponse(code=200, msg=f"已新增知识库 {knowledge_base_name}") -async def delete_kb( - knowledge_base_name: str = Body(..., examples=["samples"]) +def delete_kb( + knowledge_base_name: str = Body(..., examples=["samples"]) ) -> BaseResponse: # Delete selected knowledge base if not validate_kb_name(knowledge_base_name): diff --git a/server/knowledge_base/kb_cache/base.py b/server/knowledge_base/kb_cache/base.py new file mode 100644 index 0000000..f3e6d65 --- /dev/null +++ b/server/knowledge_base/kb_cache/base.py @@ -0,0 +1,137 @@ +from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from langchain.embeddings.openai import OpenAIEmbeddings +from langchain.embeddings import HuggingFaceBgeEmbeddings +from langchain.embeddings.base import Embeddings +from langchain.schema import Document +import threading +from configs.model_config import (CACHED_VS_NUM, EMBEDDING_MODEL, CHUNK_SIZE, + embedding_model_dict, logger, log_verbose) +from server.utils import embedding_device +from contextlib import contextmanager +from collections import OrderedDict +from typing import List, Any, Union, Tuple + + +class ThreadSafeObject: + def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None): + self._obj = obj + self._key = key + self._pool = pool + self._lock = threading.RLock() + self._loaded = threading.Event() + + def __repr__(self) -> str: + cls = type(self).__name__ + return f"<{cls}: key: {self._key}, obj: {self._obj}>" + + @contextmanager + def acquire(self, owner: str = "", msg: str = ""): + owner = owner or f"thread {threading.get_native_id()}" + try: + self._lock.acquire() + if self._pool is not None: + self._pool._cache.move_to_end(self._key) + if log_verbose: + logger.info(f"{owner} 开始操作:{self._key}。{msg}") + yield self._obj + finally: + if log_verbose: + logger.info(f"{owner} 结束操作:{self._key}。{msg}") + self._lock.release() + + def start_loading(self): + self._loaded.clear() + + def finish_loading(self): + self._loaded.set() + + def wait_for_loading(self): + self._loaded.wait() + + @property + def obj(self): + return self._obj + + @obj.setter + def obj(self, val: Any): + self._obj = val + + +class CachePool: + def __init__(self, cache_num: int = -1): + self._cache_num = cache_num + self._cache = OrderedDict() + self.atomic = threading.RLock() + + def keys(self) -> List[str]: + return list(self._cache.keys()) + + def _check_count(self): + if isinstance(self._cache_num, int) and self._cache_num > 0: + while len(self._cache) > self._cache_num: + self._cache.popitem(last=False) + + def get(self, key: str) -> ThreadSafeObject: + if cache := self._cache.get(key): + cache.wait_for_loading() + return cache + + def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject: + self._cache[key] = obj + self._check_count() + return obj + + def pop(self, key: str = None) -> ThreadSafeObject: + if key is None: + return self._cache.popitem(last=False) + else: + return self._cache.pop(key, None) + + def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""): + cache = self.get(key) + if cache is None: + raise RuntimeError(f"请求的资源 {key} 不存在") + elif isinstance(cache, ThreadSafeObject): + self._cache.move_to_end(key) + return cache.acquire(owner=owner, msg=msg) + else: + return cache + + def load_kb_embeddings(self, kb_name: str=None, embed_device: str = embedding_device()) -> Embeddings: + from server.db.repository.knowledge_base_repository import get_kb_detail + + kb_detail = get_kb_detail(kb_name=kb_name) + print(kb_detail) + embed_model = kb_detail.get("embed_model", EMBEDDING_MODEL) + return embeddings_pool.load_embeddings(model=embed_model, device=embed_device) + + +class EmbeddingsPool(CachePool): + def load_embeddings(self, model: str, device: str) -> Embeddings: + self.atomic.acquire() + model = model or EMBEDDING_MODEL + device = device or embedding_device() + key = (model, device) + if not self.get(key): + item = ThreadSafeObject(key, pool=self) + self.set(key, item) + with item.acquire(msg="初始化"): + self.atomic.release() + if model == "text-embedding-ada-002": # openai text-embedding-ada-002 + embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE) + elif 'bge-' in model: + embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model], + model_kwargs={'device': device}, + query_instruction="为这个句子生成表示以用于检索相关文章:") + if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding + embeddings.query_instruction = "" + else: + embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device}) + item.obj = embeddings + item.finish_loading() + else: + self.atomic.release() + return self.get(key).obj + + +embeddings_pool = EmbeddingsPool(cache_num=1) diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py new file mode 100644 index 0000000..325c7bb --- /dev/null +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -0,0 +1,157 @@ +from server.knowledge_base.kb_cache.base import * +from server.knowledge_base.utils import get_vs_path +from langchain.vectorstores import FAISS +import os + + +class ThreadSafeFaiss(ThreadSafeObject): + def __repr__(self) -> str: + cls = type(self).__name__ + return f"<{cls}: key: {self._key}, obj: {self._obj}, docs_count: {self.docs_count()}>" + + def docs_count(self) -> int: + return len(self._obj.docstore._dict) + + def save(self, path: str, create_path: bool = True): + with self.acquire(): + if not os.path.isdir(path) and create_path: + os.makedirs(path) + ret = self._obj.save_local(path) + logger.info(f"已将向量库 {self._key} 保存到磁盘") + return ret + + def clear(self): + ret = [] + with self.acquire(): + ids = list(self._obj.docstore._dict.keys()) + if ids: + ret = self._obj.delete(ids) + assert len(self._obj.docstore._dict) == 0 + logger.info(f"已将向量库 {self._key} 清空") + return ret + + +class _FaissPool(CachePool): + def new_vector_store( + self, + embed_model: str = EMBEDDING_MODEL, + embed_device: str = embedding_device(), + ) -> FAISS: + embeddings = embeddings_pool.load_embeddings(embed_model, embed_device) + + # create an empty vector store + doc = Document(page_content="init", metadata={}) + vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True) + ids = list(vector_store.docstore._dict.keys()) + vector_store.delete(ids) + return vector_store + + def save_vector_store(self, kb_name: str, path: str=None): + if cache := self.get(kb_name): + return cache.save(path) + + def unload_vector_store(self, kb_name: str): + if cache := self.get(kb_name): + self.pop(kb_name) + logger.info(f"成功释放向量库:{kb_name}") + + +class KBFaissPool(_FaissPool): + def load_vector_store( + self, + kb_name: str, + create: bool = True, + embed_model: str = EMBEDDING_MODEL, + embed_device: str = embedding_device(), + ) -> ThreadSafeFaiss: + self.atomic.acquire() + cache = self.get(kb_name) + if cache is None: + item = ThreadSafeFaiss(kb_name, pool=self) + self.set(kb_name, item) + with item.acquire(msg="初始化"): + self.atomic.release() + logger.info(f"loading vector store in '{kb_name}' from disk.") + vs_path = get_vs_path(kb_name) + + if os.path.isfile(os.path.join(vs_path, "index.faiss")): + embeddings = self.load_kb_embeddings(kb_name=kb_name, embed_device=embed_device) + vector_store = FAISS.load_local(vs_path, embeddings, normalize_L2=True) + elif create: + # create an empty vector store + if not os.path.exists(vs_path): + os.makedirs(vs_path) + vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device) + vector_store.save_local(vs_path) + else: + raise RuntimeError(f"knowledge base {kb_name} not exist.") + item.obj = vector_store + item.finish_loading() + else: + self.atomic.release() + return self.get(kb_name) + + +class MemoFaissPool(_FaissPool): + def load_vector_store( + self, + kb_name: str, + embed_model: str = EMBEDDING_MODEL, + embed_device: str = embedding_device(), + ) -> ThreadSafeFaiss: + self.atomic.acquire() + cache = self.get(kb_name) + if cache is None: + item = ThreadSafeFaiss(kb_name, pool=self) + self.set(kb_name, item) + with item.acquire(msg="初始化"): + self.atomic.release() + logger.info(f"loading vector store in '{kb_name}' to memory.") + # create an empty vector store + vector_store = self.new_vector_store(embed_model=embed_model, embed_device=embed_device) + item.obj = vector_store + item.finish_loading() + else: + self.atomic.release() + return self.get(kb_name) + + +kb_faiss_pool = KBFaissPool(cache_num=CACHED_VS_NUM) +memo_faiss_pool = MemoFaissPool() + + +if __name__ == "__main__": + import time, random + from pprint import pprint + + kb_names = ["vs1", "vs2", "vs3"] + # for name in kb_names: + # memo_faiss_pool.load_vector_store(name) + + def worker(vs_name: str, name: str): + vs_name = "samples" + time.sleep(random.randint(1, 5)) + embeddings = embeddings_pool.load_embeddings() + r = random.randint(1, 3) + + with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs: + if r == 1: # add docs + ids = vs.add_texts([f"text added by {name}"], embeddings=embeddings) + pprint(ids) + elif r == 2: # search docs + docs = vs.similarity_search_with_score(f"{name}", top_k=3, score_threshold=1.0) + pprint(docs) + if r == 3: # delete docs + logger.warning(f"清除 {vs_name} by {name}") + kb_faiss_pool.get(vs_name).clear() + + threads = [] + for n in range(1, 30): + t = threading.Thread(target=worker, + kwargs={"vs_name": random.choice(kb_names), "name": f"worker {n}"}, + daemon=True) + t.start() + threads.append(t) + + for t in threads: + t.join() diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index b80ce2d..9074415 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -117,13 +117,13 @@ def _save_files_in_thread(files: List[UploadFile], # yield json.dumps(result, ensure_ascii=False) -async def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), - knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), - override: bool = Form(False, description="覆盖已有文件"), - to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"), - docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]), - not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), - ) -> BaseResponse: +def upload_docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), + knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), + override: bool = Form(False, description="覆盖已有文件"), + to_vector_store: bool = Form(True, description="上传文件后是否进行向量化"), + docs: Json = Form({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]), + not_refresh_vs_cache: bool = Form(False, description="暂不保存向量库(用于FAISS)"), + ) -> BaseResponse: ''' API接口:上传文件,并/或向量化 ''' @@ -148,7 +148,7 @@ async def upload_docs(files: List[UploadFile] = File(..., description="上传文 # 对保存的文件进行向量化 if to_vector_store: - result = await update_docs( + result = update_docs( knowledge_base_name=knowledge_base_name, file_names=file_names, override_custom_docs=True, @@ -162,11 +162,11 @@ async def upload_docs(files: List[UploadFile] = File(..., description="上传文 return BaseResponse(code=200, msg="文件上传与向量化完成", data={"failed_files": failed_files}) -async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]), - file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]), - delete_content: bool = Body(False), - not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), - ) -> BaseResponse: +def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]), + file_names: List[str] = Body(..., examples=[["file_name.md", "test.txt"]]), + delete_content: bool = Body(False), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), + ) -> BaseResponse: if not validate_kb_name(knowledge_base_name): return BaseResponse(code=403, msg="Don't attack me") @@ -196,12 +196,12 @@ async def delete_docs(knowledge_base_name: str = Body(..., examples=["samples"]) return BaseResponse(code=200, msg=f"文件删除完成", data={"failed_files": failed_files}) -async def update_docs( - knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), - file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]), - override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), - docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]), - not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), +def update_docs( + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + file_names: List[str] = Body(..., description="文件名称,支持多文件", examples=["file_name"]), + override_custom_docs: bool = Body(False, description="是否覆盖之前自定义的docs"), + docs: Json = Body({}, description="自定义的docs", examples=[{"test.txt": [Document(page_content="custom doc")]}]), + not_refresh_vs_cache: bool = Body(False, description="暂不保存向量库(用于FAISS)"), ) -> BaseResponse: ''' 更新知识库文档 @@ -302,11 +302,11 @@ def download_doc( return BaseResponse(code=500, msg=f"{kb_file.filename} 读取文件失败") -async def recreate_vector_store( - knowledge_base_name: str = Body(..., examples=["samples"]), - allow_empty_kb: bool = Body(True), - vs_type: str = Body(DEFAULT_VS_TYPE), - embed_model: str = Body(EMBEDDING_MODEL), +def recreate_vector_store( + knowledge_base_name: str = Body(..., examples=["samples"]), + allow_empty_kb: bool = Body(True), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(EMBEDDING_MODEL), ): ''' recreate vector store from the content. diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 5cdebfb..d0860f7 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -146,7 +146,6 @@ class KBService(ABC): docs = self.do_search(query, top_k, score_threshold, embeddings) return docs - # TODO: milvus/pg需要实现该方法 def get_doc_by_id(self, id: str) -> Optional[Document]: return None diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index c415b82..6e20acf 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -3,62 +3,16 @@ import shutil from configs.model_config import ( KB_ROOT_PATH, - CACHED_VS_NUM, - EMBEDDING_MODEL, SCORE_THRESHOLD, logger, log_verbose, ) from server.knowledge_base.kb_service.base import KBService, SupportedVSType -from functools import lru_cache -from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile -from langchain.vectorstores import FAISS +from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss +from server.knowledge_base.utils import KnowledgeFile from langchain.embeddings.base import Embeddings from typing import List, Dict, Optional from langchain.docstore.document import Document -from server.utils import torch_gc, embedding_device - - -_VECTOR_STORE_TICKS = {} - - -@lru_cache(CACHED_VS_NUM) -def load_faiss_vector_store( - knowledge_base_name: str, - embed_model: str = EMBEDDING_MODEL, - embed_device: str = embedding_device(), - embeddings: Embeddings = None, - tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. -) -> FAISS: - logger.info(f"loading vector store in '{knowledge_base_name}'.") - vs_path = get_vs_path(knowledge_base_name) - if embeddings is None: - embeddings = load_embeddings(embed_model, embed_device) - - if not os.path.exists(vs_path): - os.makedirs(vs_path) - - if "index.faiss" in os.listdir(vs_path): - search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True) - else: - # create an empty vector store - doc = Document(page_content="init", metadata={}) - search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True) - ids = [k for k, v in search_index.docstore._dict.items()] - search_index.delete(ids) - search_index.save_local(vs_path) - - if tick == 0: # vector store is loaded first time - _VECTOR_STORE_TICKS[knowledge_base_name] = 0 - - return search_index - - -def refresh_vs_cache(kb_name: str): - """ - make vector store cache refreshed when next loading - """ - _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 - logger.info(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}") +from server.utils import torch_gc class FaissKBService(KBService): @@ -74,24 +28,15 @@ class FaissKBService(KBService): def get_kb_path(self): return os.path.join(KB_ROOT_PATH, self.kb_name) - def load_vector_store(self) -> FAISS: - return load_faiss_vector_store( - knowledge_base_name=self.kb_name, - embed_model=self.embed_model, - tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0), - ) + def load_vector_store(self) -> ThreadSafeFaiss: + return kb_faiss_pool.load_vector_store(kb_name=self.kb_name, embed_model=self.embed_model) - def save_vector_store(self, vector_store: FAISS = None): - vector_store = vector_store or self.load_vector_store() - vector_store.save_local(self.vs_path) - return vector_store - - def refresh_vs_cache(self): - refresh_vs_cache(self.kb_name) + def save_vector_store(self): + self.load_vector_store().save(self.vs_path) def get_doc_by_id(self, id: str) -> Optional[Document]: - vector_store = self.load_vector_store() - return vector_store.docstore._dict.get(id) + with self.load_vector_store().acquire() as vs: + return vs.docstore._dict.get(id) def do_init(self): self.kb_path = self.get_kb_path() @@ -112,43 +57,38 @@ class FaissKBService(KBService): score_threshold: float = SCORE_THRESHOLD, embeddings: Embeddings = None, ) -> List[Document]: - search_index = self.load_vector_store() - docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) + with self.load_vector_store().acquire() as vs: + docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) return docs def do_add_doc(self, docs: List[Document], **kwargs, ) -> List[Dict]: - vector_store = self.load_vector_store() - ids = vector_store.add_documents(docs) + with self.load_vector_store().acquire() as vs: + ids = vs.add_documents(docs) + if not kwargs.get("not_refresh_vs_cache"): + vs.save_local(self.vs_path) doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] torch_gc() - if not kwargs.get("not_refresh_vs_cache"): - vector_store.save_local(self.vs_path) - self.refresh_vs_cache() return doc_infos def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): - vector_store = self.load_vector_store() - - ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath] - if len(ids) == 0: - return None - - vector_store.delete(ids) - if not kwargs.get("not_refresh_vs_cache"): - vector_store.save_local(self.vs_path) - self.refresh_vs_cache() - - return vector_store + with self.load_vector_store().acquire() as vs: + ids = [k for k, v in vs.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath] + if len(ids) > 0: + vs.delete(ids) + if not kwargs.get("not_refresh_vs_cache"): + vs.save_local(self.vs_path) + return ids def do_clear_vs(self): + with kb_faiss_pool.atomic: + kb_faiss_pool.pop(self.kb_name) shutil.rmtree(self.vs_path) os.makedirs(self.vs_path) - self.refresh_vs_cache() def exist_doc(self, file_name: str): if super().exist_doc(file_name): diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 100a249..118ae37 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -1,7 +1,6 @@ from configs.model_config import EMBEDDING_MODEL, DEFAULT_VS_TYPE, logger, log_verbose from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder, - list_files_from_folder, run_in_thread_pool, - files2docs_in_thread, + list_files_from_folder,files2docs_in_thread, KnowledgeFile,) from server.knowledge_base.kb_service.base import KBServiceFactory, SupportedVSType from server.db.repository.knowledge_file_repository import add_file_to_db @@ -72,7 +71,6 @@ def folder2db( if kb.vs_type() == SupportedVSType.FAISS: kb.save_vector_store() - kb.refresh_vs_cache() elif mode == "fill_info_only": files = list_files_from_folder(kb_name) kb_files = file_to_kbfile(kb_name, files) @@ -89,7 +87,6 @@ def folder2db( if kb.vs_type() == SupportedVSType.FAISS: kb.save_vector_store() - kb.refresh_vs_cache() elif mode == "increament": db_files = kb.list_files() folder_files = list_files_from_folder(kb_name) @@ -107,7 +104,6 @@ def folder2db( if kb.vs_type() == SupportedVSType.FAISS: kb.save_vector_store() - kb.refresh_vs_cache() else: print(f"unspported migrate mode: {mode}") @@ -139,7 +135,6 @@ def prune_db_files(kb_name: str): kb.delete_doc(kb_file, not_refresh_vs_cache=True) if kb.vs_type() == SupportedVSType.FAISS: kb.save_vector_store() - kb.refresh_vs_cache() return kb_files def prune_folder_files(kb_name: str): diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index a5b107a..2297d15 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -4,13 +4,13 @@ from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings import HuggingFaceBgeEmbeddings from configs.model_config import ( embedding_model_dict, + EMBEDDING_MODEL, KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, logger, log_verbose, ) -from functools import lru_cache import importlib from text_splitter import zh_title_enhance import langchain.document_loaders @@ -19,25 +19,11 @@ from langchain.text_splitter import TextSplitter from pathlib import Path import json from concurrent.futures import ThreadPoolExecutor -from server.utils import run_in_thread_pool +from server.utils import run_in_thread_pool, embedding_device import io from typing import List, Union, Callable, Dict, Optional, Tuple, Generator -# make HuggingFaceEmbeddings hashable -def _embeddings_hash(self): - if isinstance(self, HuggingFaceEmbeddings): - return hash(self.model_name) - elif isinstance(self, HuggingFaceBgeEmbeddings): - return hash(self.model_name) - elif isinstance(self, OpenAIEmbeddings): - return hash(self.model) - -HuggingFaceEmbeddings.__hash__ = _embeddings_hash -OpenAIEmbeddings.__hash__ = _embeddings_hash -HuggingFaceBgeEmbeddings.__hash__ = _embeddings_hash - - def validate_kb_name(knowledge_base_id: str) -> bool: # 检查是否包含预期外的字符或路径攻击关键字 if "../" in knowledge_base_id: @@ -72,19 +58,12 @@ def list_files_from_folder(kb_name: str): if os.path.isfile(os.path.join(doc_path, file))] -@lru_cache(1) -def load_embeddings(model: str, device: str): - if model == "text-embedding-ada-002": # openai text-embedding-ada-002 - embeddings = OpenAIEmbeddings(openai_api_key=embedding_model_dict[model], chunk_size=CHUNK_SIZE) - elif 'bge-' in model: - embeddings = HuggingFaceBgeEmbeddings(model_name=embedding_model_dict[model], - model_kwargs={'device': device}, - query_instruction="为这个句子生成表示以用于检索相关文章:") - if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding - embeddings.query_instruction = "" - else: - embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], model_kwargs={'device': device}) - return embeddings +def load_embeddings(model: str = EMBEDDING_MODEL, device: str = embedding_device()): + ''' + 从缓存中加载embeddings,可以避免多线程时竞争加载。 + ''' + from server.knowledge_base.kb_cache.base import embeddings_pool + return embeddings_pool.load_embeddings(model=model, device=device) LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], diff --git a/server/llm_api.py b/server/llm_api.py index e4f012b..72e6b8d 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -1,279 +1,70 @@ -from multiprocessing import Process, Queue -import multiprocessing as mp -import sys -import os - -sys.path.append(os.path.dirname(os.path.dirname(__file__))) -from configs.model_config import llm_model_dict, LLM_MODEL, LOG_PATH, logger, log_verbose -from server.utils import MakeFastAPIOffline, set_httpx_timeout, llm_device +from fastapi import Body +from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT +from server.utils import BaseResponse, fschat_controller_address +import httpx -host_ip = "0.0.0.0" -controller_port = 20001 -model_worker_port = 20002 -openai_api_port = 8888 -base_url = "http://127.0.0.1:{}" +def list_llm_models( + controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) +) -> BaseResponse: + ''' + 从fastchat controller获取已加载模型列表 + ''' + try: + controller_address = controller_address or fschat_controller_address() + r = httpx.post(controller_address + "/list_models") + return BaseResponse(data=r.json()["models"]) + except Exception as e: + logger.error(f'{e.__class__.__name__}: {e}', + exc_info=e if log_verbose else None) + return BaseResponse( + code=500, + data=[], + msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}") -def create_controller_app( - dispatch_method="shortest_queue", +def stop_llm_model( + model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]), + controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) +) -> BaseResponse: + ''' + 向fastchat controller请求停止某个LLM模型。 + 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 + ''' + try: + controller_address = controller_address or fschat_controller_address() + r = httpx.post( + controller_address + "/release_worker", + json={"model_name": model_name}, + ) + return r.json() + except Exception as e: + logger.error(f'{e.__class__.__name__}: {e}', + exc_info=e if log_verbose else None) + return BaseResponse( + code=500, + msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}") + + +def change_llm_model( + model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]), + new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]), + controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()]) ): - import fastchat.constants - fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.controller import app, Controller - - controller = Controller(dispatch_method) - sys.modules["fastchat.serve.controller"].controller = controller - - MakeFastAPIOffline(app) - app.title = "FastChat Controller" - return app - - -def create_model_worker_app( - worker_address=base_url.format(model_worker_port), - controller_address=base_url.format(controller_port), - model_path=llm_model_dict[LLM_MODEL].get("local_model_path"), - device=llm_device(), - gpus=None, - max_gpu_memory="20GiB", - load_8bit=False, - cpu_offloading=None, - gptq_ckpt=None, - gptq_wbits=16, - gptq_groupsize=-1, - gptq_act_order=False, - awq_ckpt=None, - awq_wbits=16, - awq_groupsize=-1, - model_names=[LLM_MODEL], - num_gpus=1, # not in fastchat - conv_template=None, - limit_worker_concurrency=5, - stream_interval=2, - no_register=False, -): - import fastchat.constants - fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.model_worker import app, GptqConfig, AWQConfig, ModelWorker, worker_id - import argparse - import threading - import fastchat.serve.model_worker - - # workaround to make program exit with Ctrl+c - # it should be deleted after pr is merged by fastchat - def _new_init_heart_beat(self): - self.register_to_controller() - self.heart_beat_thread = threading.Thread( - target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, + ''' + 向fastchat controller请求切换LLM模型。 + ''' + try: + controller_address = controller_address or fschat_controller_address() + r = httpx.post( + controller_address + "/release_worker", + json={"model_name": model_name, "new_model_name": new_model_name}, + timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model ) - self.heart_beat_thread.start() - ModelWorker.init_heart_beat = _new_init_heart_beat - - parser = argparse.ArgumentParser() - args = parser.parse_args() - args.model_path = model_path - args.model_names = model_names - args.device = device - args.load_8bit = load_8bit - args.gptq_ckpt = gptq_ckpt - args.gptq_wbits = gptq_wbits - args.gptq_groupsize = gptq_groupsize - args.gptq_act_order = gptq_act_order - args.awq_ckpt = awq_ckpt - args.awq_wbits = awq_wbits - args.awq_groupsize = awq_groupsize - args.gpus = gpus - args.num_gpus = num_gpus - args.max_gpu_memory = max_gpu_memory - args.cpu_offloading = cpu_offloading - args.worker_address = worker_address - args.controller_address = controller_address - args.conv_template = conv_template - args.limit_worker_concurrency = limit_worker_concurrency - args.stream_interval = stream_interval - args.no_register = no_register - - if args.gpus: - if len(args.gpus.split(",")) < args.num_gpus: - raise ValueError( - f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" - ) - os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus - - if gpus and num_gpus is None: - num_gpus = len(gpus.split(',')) - args.num_gpus = num_gpus - - gptq_config = GptqConfig( - ckpt=gptq_ckpt or model_path, - wbits=args.gptq_wbits, - groupsize=args.gptq_groupsize, - act_order=args.gptq_act_order, - ) - awq_config = AWQConfig( - ckpt=args.awq_ckpt or args.model_path, - wbits=args.awq_wbits, - groupsize=args.awq_groupsize, - ) - - # torch.multiprocessing.set_start_method('spawn') - worker = ModelWorker( - controller_addr=args.controller_address, - worker_addr=args.worker_address, - worker_id=worker_id, - model_path=args.model_path, - model_names=args.model_names, - limit_worker_concurrency=args.limit_worker_concurrency, - no_register=args.no_register, - device=args.device, - num_gpus=args.num_gpus, - max_gpu_memory=args.max_gpu_memory, - load_8bit=args.load_8bit, - cpu_offloading=args.cpu_offloading, - gptq_config=gptq_config, - awq_config=awq_config, - stream_interval=args.stream_interval, - conv_template=args.conv_template, - ) - - sys.modules["fastchat.serve.model_worker"].worker = worker - sys.modules["fastchat.serve.model_worker"].args = args - sys.modules["fastchat.serve.model_worker"].gptq_config = gptq_config - - MakeFastAPIOffline(app) - app.title = f"FastChat LLM Server ({LLM_MODEL})" - return app - - -def create_openai_api_app( - controller_address=base_url.format(controller_port), - api_keys=[], -): - import fastchat.constants - fastchat.constants.LOGDIR = LOG_PATH - from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings - - app.add_middleware( - CORSMiddleware, - allow_credentials=True, - allow_origins=["*"], - allow_methods=["*"], - allow_headers=["*"], - ) - - app_settings.controller_address = controller_address - app_settings.api_keys = api_keys - - MakeFastAPIOffline(app) - app.title = "FastChat OpeanAI API Server" - return app - - -def run_controller(q): - import uvicorn - app = create_controller_app() - - @app.on_event("startup") - async def on_startup(): - set_httpx_timeout() - q.put(1) - - uvicorn.run(app, host=host_ip, port=controller_port) - - -def run_model_worker(q, *args, **kwargs): - import uvicorn - app = create_model_worker_app(*args, **kwargs) - - @app.on_event("startup") - async def on_startup(): - set_httpx_timeout() - while True: - no = q.get() - if no != 1: - q.put(no) - else: - break - q.put(2) - - uvicorn.run(app, host=host_ip, port=model_worker_port) - - -def run_openai_api(q): - import uvicorn - app = create_openai_api_app() - - @app.on_event("startup") - async def on_startup(): - set_httpx_timeout() - while True: - no = q.get() - if no != 2: - q.put(no) - else: - break - q.put(3) - - uvicorn.run(app, host=host_ip, port=openai_api_port) - - -if __name__ == "__main__": - mp.set_start_method("spawn") - queue = Queue() - logger.info(llm_model_dict[LLM_MODEL]) - model_path = llm_model_dict[LLM_MODEL]["local_model_path"] - - logger.info(f"如需查看 llm_api 日志,请前往 {LOG_PATH}") - - if not model_path: - logger.error("local_model_path 不能为空") - else: - controller_process = Process( - target=run_controller, - name=f"controller({os.getpid()})", - args=(queue,), - daemon=True, - ) - controller_process.start() - - model_worker_process = Process( - target=run_model_worker, - name=f"model_worker({os.getpid()})", - args=(queue,), - # kwargs={"load_8bit": True}, - daemon=True, - ) - model_worker_process.start() - - openai_api_process = Process( - target=run_openai_api, - name=f"openai_api({os.getpid()})", - args=(queue,), - daemon=True, - ) - openai_api_process.start() - - try: - model_worker_process.join() - controller_process.join() - openai_api_process.join() - except KeyboardInterrupt: - model_worker_process.terminate() - controller_process.terminate() - openai_api_process.terminate() - -# 服务启动后接口调用示例: -# import openai -# openai.api_key = "EMPTY" # Not support yet -# openai.api_base = "http://localhost:8888/v1" - -# model = "chatglm2-6b" - -# # create a chat completion -# completion = openai.ChatCompletion.create( -# model=model, -# messages=[{"role": "user", "content": "Hello! What is your name?"}] -# ) -# # print the completion -# print(completion.choices[0].message.content) + return r.json() + except Exception as e: + logger.error(f'{e.__class__.__name__}: {e}', + exc_info=e if log_verbose else None) + return BaseResponse( + code=500, + msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}") diff --git a/server/utils.py b/server/utils.py index 163250d..0184309 100644 --- a/server/utils.py +++ b/server/utils.py @@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Literal, Optional, Callable, Generator, Dict, Any -thread_pool = ThreadPoolExecutor() +thread_pool = ThreadPoolExecutor(os.cpu_count()) class BaseResponse(BaseModel): diff --git a/tests/api/test_kb_api_request.py b/tests/api/test_kb_api_request.py index a0d15ce..3b067b3 100644 --- a/tests/api/test_kb_api_request.py +++ b/tests/api/test_kb_api_request.py @@ -14,7 +14,7 @@ from pprint import pprint api_base_url = api_address() -api: ApiRequest = ApiRequest(api_base_url) +api: ApiRequest = ApiRequest(api_base_url, no_remote_api=True) kb = "kb_for_api_test" @@ -84,7 +84,7 @@ def test_upload_docs(): print(f"\n尝试重新上传知识文件, 覆盖,自定义docs") docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]} - data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)} + data = {"knowledge_base_name": kb, "override": True, "docs": docs} data = api.upload_kb_docs(files, **data) pprint(data) assert data["code"] == 200 diff --git a/tests/api/test_llm_api.py b/tests/api/test_llm_api.py index f348fe7..2c5d039 100644 --- a/tests/api/test_llm_api.py +++ b/tests/api/test_llm_api.py @@ -5,8 +5,9 @@ from pathlib import Path root_path = Path(__file__).parent.parent.parent sys.path.append(str(root_path)) -from configs.server_config import api_address, FSCHAT_MODEL_WORKERS +from configs.server_config import FSCHAT_MODEL_WORKERS from configs.model_config import LLM_MODEL, llm_model_dict +from server.utils import api_address from pprint import pprint import random diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py index 4c2d5fa..75c995e 100644 --- a/tests/api/test_stream_chat_api.py +++ b/tests/api/test_stream_chat_api.py @@ -47,7 +47,6 @@ data = { } - def test_chat_fastchat(api="/chat/fastchat"): url = f"{api_base_url}{api}" data2 = { diff --git a/webui.py b/webui.py index 0cda9eb..2750c47 100644 --- a/webui.py +++ b/webui.py @@ -1,9 +1,3 @@ -# 运行方式: -# 1. 安装必要的包:pip install streamlit-option-menu streamlit-chatbox>=1.1.6 -# 2. 运行本机fastchat服务:python server\llm_api.py 或者 运行对应的sh文件 -# 3. 运行API服务器:python server/api.py。如果使用api = ApiRequest(no_remote_api=True),该步可以跳过。 -# 4. 运行WEB UI:streamlit run webui.py --server.port 7860 - import streamlit as st from webui_pages.utils import * from streamlit_option_menu import option_menu diff --git a/webui_pages/utils.py b/webui_pages/utils.py index cd9ebc0..3145b6f 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -20,6 +20,7 @@ from server.chat.openai_chat import OpenAiChatMsgIn from fastapi.responses import StreamingResponse import contextlib import json +import os from io import BytesIO from server.utils import run_async, iter_over_async, set_httpx_timeout, api_address @@ -475,7 +476,7 @@ class ApiRequest: if no_remote_api: from server.knowledge_base.kb_api import create_kb - response = run_async(create_kb(**data)) + response = create_kb(**data) return response.dict() else: response = self.post( @@ -497,7 +498,7 @@ class ApiRequest: if no_remote_api: from server.knowledge_base.kb_api import delete_kb - response = run_async(delete_kb(knowledge_base_name)) + response = delete_kb(knowledge_base_name) return response.dict() else: response = self.post( @@ -584,7 +585,7 @@ class ApiRequest: filename = filename or file.name else: # a local path file = Path(file).absolute().open("rb") - filename = filename or file.name + filename = filename or os.path.split(file.name)[-1] return filename, file files = [convert_file(file) for file in files] @@ -602,13 +603,13 @@ class ApiRequest: from tempfile import SpooledTemporaryFile upload_files = [] - for file, filename in files: + for filename, file in files: temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) temp_file.write(file.read()) temp_file.seek(0) upload_files.append(UploadFile(file=temp_file, filename=filename)) - response = run_async(upload_docs(upload_files, **data)) + response = upload_docs(upload_files, **data) return response.dict() else: if isinstance(data["docs"], dict): @@ -643,7 +644,7 @@ class ApiRequest: if no_remote_api: from server.knowledge_base.kb_doc_api import delete_docs - response = run_async(delete_docs(**data)) + response = delete_docs(**data) return response.dict() else: response = self.post( @@ -676,7 +677,7 @@ class ApiRequest: } if no_remote_api: from server.knowledge_base.kb_doc_api import update_docs - response = run_async(update_docs(**data)) + response = update_docs(**data) return response.dict() else: if isinstance(data["docs"], dict): @@ -710,7 +711,7 @@ class ApiRequest: if no_remote_api: from server.knowledge_base.kb_doc_api import recreate_vector_store - response = run_async(recreate_vector_store(**data)) + response = recreate_vector_store(**data) return self._fastapi_stream2generator(response, as_json=True) else: response = self.post( @@ -721,14 +722,30 @@ class ApiRequest: ) return self._httpx_stream2generator(response, as_json=True) - def list_running_models(self, controller_address: str = None): + # LLM模型相关操作 + def list_running_models( + self, + controller_address: str = None, + no_remote_api: bool = None, + ): ''' 获取Fastchat中正运行的模型列表 ''' - r = self.post( - "/llm_model/list_models", - ) - return r.json().get("data", []) + if no_remote_api is None: + no_remote_api = self.no_remote_api + + data = { + "controller_address": controller_address, + } + if no_remote_api: + from server.llm_api import list_llm_models + return list_llm_models(**data).data + else: + r = self.post( + "/llm_model/list_models", + json=data, + ) + return r.json().get("data", []) def list_config_models(self): ''' @@ -740,30 +757,43 @@ class ApiRequest: self, model_name: str, controller_address: str = None, + no_remote_api: bool = None, ): ''' 停止某个LLM模型。 注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。 ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + data = { "model_name": model_name, "controller_address": controller_address, } - r = self.post( - "/llm_model/stop", - json=data, - ) - return r.json() + + if no_remote_api: + from server.llm_api import stop_llm_model + return stop_llm_model(**data).dict() + else: + r = self.post( + "/llm_model/stop", + json=data, + ) + return r.json() def change_llm_model( self, model_name: str, new_model_name: str, controller_address: str = None, + no_remote_api: bool = None, ): ''' 向fastchat controller请求切换LLM模型。 ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + if not model_name or not new_model_name: return @@ -792,12 +822,17 @@ class ApiRequest: "new_model_name": new_model_name, "controller_address": controller_address, } - r = self.post( - "/llm_model/change", - json=data, - timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model - ) - return r.json() + + if no_remote_api: + from server.llm_api import change_llm_model + return change_llm_model(**data).dict() + else: + r = self.post( + "/llm_model/change", + json=data, + timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model + ) + return r.json() def check_error_msg(data: Union[str, dict, list], key: str = "errorMsg") -> str: