diff --git a/init_database.py b/init_database.py index c2cd1d4..f7f0805 100644 --- a/init_database.py +++ b/init_database.py @@ -1,7 +1,7 @@ import sys sys.path.append(".") from server.knowledge_base.migrate import create_tables, reset_tables, folder2db, prune_db_docs, prune_folder_files -from configs.model_config import NLTK_DATA_PATH +from configs.model_config import NLTK_DATA_PATH, EMBEDDING_MODEL import nltk nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path from datetime import datetime @@ -62,12 +62,20 @@ if __name__ == "__main__": ) ) parser.add_argument( + "-n", "--kb-name", type=str, nargs="+", default=[], help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.") ) + parser.add_argument( + "-e", + "--embed-model", + type=str, + default=EMBEDDING_MODEL, + help=("specify knowledge base names to operate on. default is all folders exist in KB_ROOT_PATH.") + ) if len(sys.argv) <= 1: parser.print_help() @@ -80,11 +88,11 @@ if __name__ == "__main__": reset_tables() print("database talbes reseted") print("recreating all vector stores") - folder2db(kb_names=args.kb_name, mode="recreate_vs") + folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model) elif args.update_in_db: - folder2db(kb_names=args.kb_name, mode="update_in_db") + folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model) elif args.increament: - folder2db(kb_names=args.kb_name, mode="increament") + folder2db(kb_names=args.kb_name, mode="increament", embed_model=args.embed_model) elif args.prune_db: prune_db_docs(args.kb_name) elif args.prune_folder: diff --git a/requirements_lite.txt b/requirements_lite.txt index 4ff659e..a22d012 100644 --- a/requirements_lite.txt +++ b/requirements_lite.txt @@ -30,7 +30,7 @@ pytest # online api libs zhipuai dashscope>=1.10.0 # qwen -qianfan +# qianfan # volcengine>=1.0.106 # fangzhou # uncomment libs if you want to use corresponding vector store diff --git a/server/api.py b/server/api.py index d47e786..be795ba 100644 --- a/server/api.py +++ b/server/api.py @@ -16,6 +16,7 @@ from server.chat.chat import chat from server.chat.openai_chat import openai_chat from server.chat.search_engine_chat import search_engine_chat from server.chat.completion import completion +from server.embeddings_api import embed_texts_endpoint from server.llm_api import (list_running_models, list_config_models, change_llm_model, stop_llm_model, get_model_config, list_search_engines) @@ -47,33 +48,34 @@ def create_app(run_mode: str = None): allow_methods=["*"], allow_headers=["*"], ) - mount_basic_routes(app) - if run_mode != "lite": - mount_knowledge_routes(app) + mount_app_routes(app, run_mode=run_mode) return app -def mount_basic_routes(app: FastAPI): +def mount_app_routes(app: FastAPI, run_mode: str = None): app.get("/", response_model=BaseResponse, summary="swagger 文档")(document) - app.post("/completion", - tags=["Completion"], - summary="要求llm模型补全(通过LLMChain)")(completion) - # Tag: Chat app.post("/chat/fastchat", tags=["Chat"], - summary="与llm模型对话(直接与fastchat api对话)")(openai_chat) + summary="与llm模型对话(直接与fastchat api对话)", + )(openai_chat) app.post("/chat/chat", tags=["Chat"], - summary="与llm模型对话(通过LLMChain)")(chat) + summary="与llm模型对话(通过LLMChain)", + )(chat) app.post("/chat/search_engine_chat", tags=["Chat"], - summary="与搜索引擎对话")(search_engine_chat) + summary="与搜索引擎对话", + )(search_engine_chat) + + # 知识库相关接口 + if run_mode != "lite": + mount_knowledge_routes(app) # LLM模型相关接口 app.post("/llm_model/list_running_models", @@ -121,6 +123,17 @@ def mount_basic_routes(app: FastAPI): ) -> str: return get_prompt_template(type=type, name=name) + # 其它接口 + app.post("/other/completion", + tags=["Other"], + summary="要求llm模型补全(通过LLMChain)", + )(completion) + + app.post("/other/embed_texts", + tags=["Other"], + summary="将文本向量化,支持本地模型和在线模型", + )(embed_texts_endpoint) + def mount_knowledge_routes(app: FastAPI): from server.chat.knowledge_base_chat import knowledge_base_chat diff --git a/server/embeddings_api.py b/server/embeddings_api.py new file mode 100644 index 0000000..9b08e21 --- /dev/null +++ b/server/embeddings_api.py @@ -0,0 +1,68 @@ +from langchain.docstore.document import Document +from configs import EMBEDDING_MODEL, logger +from server.model_workers.base import ApiEmbeddingsParams +from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models +from fastapi import Body +from typing import Dict, List + + +online_embed_models = list_online_embed_models() + + +def embed_texts( + texts: List[str], + embed_model: str = EMBEDDING_MODEL, + to_query: bool = False, +) -> BaseResponse: + ''' + 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) + ''' + try: + if embed_model in list_embed_models(): # 使用本地Embeddings模型 + from server.utils import load_local_embeddings + + embeddings = load_local_embeddings(model=embed_model) + return BaseResponse(data=embeddings.embed_documents(texts)) + + if embed_model in list_online_embed_models(): # 使用在线API + config = get_model_worker_config(embed_model) + worker_class = config.get("worker_class") + worker = worker_class() + if worker_class.can_embedding(): + params = ApiEmbeddingsParams(texts=texts, to_query=to_query) + resp = worker.do_embeddings(params) + return BaseResponse(**resp) + + return BaseResponse(code=500, msg=f"指定的模型 {embed_model} 不支持 Embeddings 功能。") + except Exception as e: + logger.error(e) + return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") + +def embed_texts_endpoint( + texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]), + embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"), + to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"), +) -> BaseResponse: + ''' + 对文本进行向量化,返回 BaseResponse(data=List[List[float]]) + ''' + return embed_texts(texts=texts, embed_model=embed_model, to_query=to_query) + + +def embed_documents( + docs: List[Document], + embed_model: str = EMBEDDING_MODEL, + to_query: bool = False, +) -> Dict: + """ + 将 List[Document] 向量化,转化为 VectorStore.add_embeddings 可以接受的参数 + """ + texts = [x.page_content for x in docs] + metadatas = [x.metadata for x in docs] + embeddings = embed_texts(texts=texts, embed_model=embed_model, to_query=to_query).data + if embeddings is not None: + return { + "texts": texts, + "embeddings": embeddings, + "metadatas": metadatas, + } diff --git a/server/knowledge_base/kb_cache/base.py b/server/knowledge_base/kb_cache/base.py index b99ca80..23f6e84 100644 --- a/server/knowledge_base/kb_cache/base.py +++ b/server/knowledge_base/kb_cache/base.py @@ -110,7 +110,7 @@ class CachePool: class EmbeddingsPool(CachePool): - def load_embeddings(self, model: str, device: str) -> Embeddings: + def load_embeddings(self, model: str = None, device: str = None) -> Embeddings: self.atomic.acquire() model = model or EMBEDDING_MODEL device = device or embedding_device() @@ -121,7 +121,7 @@ class EmbeddingsPool(CachePool): with item.acquire(msg="初始化"): self.atomic.release() if model == "text-embedding-ada-002": # openai text-embedding-ada-002 - embeddings = OpenAIEmbeddings(model_name=model, # TODO: 支持Azure + embeddings = OpenAIEmbeddings(model_name=model, openai_api_key=get_model_path(model), chunk_size=CHUNK_SIZE) elif 'bge-' in model: diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index 80e09c1..b7dd6d7 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -1,7 +1,10 @@ from configs import CACHED_VS_NUM from server.knowledge_base.kb_cache.base import * +from server.knowledge_base.kb_service.base import EmbeddingsFunAdapter +from server.utils import load_local_embeddings from server.knowledge_base.utils import get_vs_path -from langchain.vectorstores import FAISS +from langchain.vectorstores.faiss import FAISS +from langchain.schema import Document import os from langchain.schema import Document @@ -38,9 +41,9 @@ class _FaissPool(CachePool): embed_model: str = EMBEDDING_MODEL, embed_device: str = embedding_device(), ) -> FAISS: - embeddings = embeddings_pool.load_embeddings(embed_model, embed_device) - + # TODO: 整个Embeddings加载逻辑有些混乱,待清理 # create an empty vector store + embeddings = EmbeddingsFunAdapter(embed_model) doc = Document(page_content="init", metadata={}) vector_store = FAISS.from_documents([doc], embeddings, normalize_L2=True) ids = list(vector_store.docstore._dict.keys()) @@ -133,7 +136,7 @@ if __name__ == "__main__": def worker(vs_name: str, name: str): vs_name = "samples" time.sleep(random.randint(1, 5)) - embeddings = embeddings_pool.load_embeddings() + embeddings = load_local_embeddings() r = random.randint(1, 3) with kb_faiss_pool.load_vector_store(vs_name).acquire(name) as vs: diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index 9ef9fa4..1a3d551 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -21,12 +21,15 @@ from server.db.repository.knowledge_file_repository import ( from configs import (kbs_config, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, EMBEDDING_MODEL, KB_INFO) from server.knowledge_base.utils import ( - get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, + get_kb_path, get_doc_path, KnowledgeFile, list_kbs_from_folder, list_files_from_folder, ) -from server.utils import embedding_device + from typing import List, Union, Dict, Optional +from server.embeddings_api import embed_texts +from server.embeddings_api import embed_documents + class SupportedVSType: FAISS = 'faiss' @@ -48,8 +51,6 @@ class KBService(ABC): self.kb_path = get_kb_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name) self.do_init() - def _load_embeddings(self, embed_device: str = embedding_device()) -> Embeddings: - return load_embeddings(self.embed_model, embed_device) def save_vector_store(self): ''' @@ -83,6 +84,12 @@ class KBService(ABC): status = delete_kb_from_db(self.kb_name) return status + def _docs_to_embeddings(self, docs: List[Document]) -> Dict: + ''' + 将 List[Document] 转化为 VectorStore.add_embeddings 可以接受的参数 + ''' + return embed_documents(docs=docs, embed_model=self.embed_model, to_query=False) + def add_doc(self, kb_file: KnowledgeFile, docs: List[Document] = [], **kwargs): """ 向知识库添加文件 @@ -149,8 +156,7 @@ class KBService(ABC): top_k: int = VECTOR_SEARCH_TOP_K, score_threshold: float = SCORE_THRESHOLD, ): - embeddings = self._load_embeddings() - docs = self.do_search(query, top_k, score_threshold, embeddings) + docs = self.do_search(query, top_k, score_threshold) return docs def get_doc_by_id(self, id: str) -> Optional[Document]: @@ -346,24 +352,26 @@ def get_kb_file_details(kb_name: str) -> List[Dict]: class EmbeddingsFunAdapter(Embeddings): - - def __init__(self, embeddings: Embeddings): - self.embeddings = embeddings + def __init__(self, embed_model: str = EMBEDDING_MODEL): + self.embed_model = embed_model def embed_documents(self, texts: List[str]) -> List[List[float]]: - return normalize(self.embeddings.embed_documents(texts)) + embeddings = embed_texts(texts=texts, embed_model=self.embed_model, to_query=False).data + return normalize(embeddings) def embed_query(self, text: str) -> List[float]: - query_embed = self.embeddings.embed_query(text) + embeddings = embed_texts(texts=[text], embed_model=self.embed_model, to_query=True).data + query_embed = embeddings[0] query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组 normalized_query_embed = normalize(query_embed_2d) return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 - async def aembed_documents(self, texts: List[str]) -> List[List[float]]: - return await normalize(self.embeddings.aembed_documents(texts)) + # TODO: 暂不支持异步 + # async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + # return normalize(await self.embeddings.aembed_documents(texts)) - async def aembed_query(self, text: str) -> List[float]: - return await normalize(self.embeddings.aembed_query(text)) + # async def aembed_query(self, text: str) -> List[float]: + # return normalize(await self.embeddings.aembed_query(text)) def score_threshold_process(score_threshold, k, docs): diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index a72fcf7..10e26c1 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -9,10 +9,9 @@ from configs import ( from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.kb_cache.faiss_cache import kb_faiss_pool, ThreadSafeFaiss from server.knowledge_base.utils import KnowledgeFile -from langchain.embeddings.base import Embeddings -from typing import List, Dict, Optional -from langchain.docstore.document import Document from server.utils import torch_gc +from langchain.docstore.document import Document +from typing import List, Dict, Optional class FaissKBService(KBService): @@ -58,7 +57,6 @@ class FaissKBService(KBService): query: str, top_k: int, score_threshold: float = SCORE_THRESHOLD, - embeddings: Embeddings = None, ) -> List[Document]: with self.load_vector_store().acquire() as vs: docs = vs.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold) @@ -68,8 +66,11 @@ class FaissKBService(KBService): docs: List[Document], **kwargs, ) -> List[Dict]: + data = self._docs_to_embeddings(docs) # 将向量化单独出来可以减少向量库的锁定时间 + with self.load_vector_store().acquire() as vs: - ids = vs.add_documents(docs) + ids = vs.add_embeddings(text_embeddings=zip(data["texts"], data["embeddings"]), + metadatas=data["metadatas"]) if not kwargs.get("not_refresh_vs_cache"): vs.save_local(self.vs_path) doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index af87553..7257494 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -1,8 +1,7 @@ from typing import List, Dict, Optional -from langchain.embeddings.base import Embeddings from langchain.schema import Document -from langchain.vectorstores import Milvus +from langchain.vectorstores.milvus import Milvus from configs import kbs_config @@ -46,10 +45,8 @@ class MilvusKBService(KBService): def vs_type(self) -> str: return SupportedVSType.MILVUS - def _load_milvus(self, embeddings: Embeddings = None): - if embeddings is None: - embeddings = self._load_embeddings() - self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(embeddings), + def _load_milvus(self): + self.milvus = Milvus(embedding_function=EmbeddingsFunAdapter(self.embed_model), collection_name=self.kb_name, connection_args=kbs_config.get("milvus")) def do_init(self): @@ -60,8 +57,8 @@ class MilvusKBService(KBService): self.milvus.col.release() self.milvus.col.drop() - def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): - self._load_milvus(embeddings=EmbeddingsFunAdapter(embeddings)) + def do_search(self, query: str, top_k: int, score_threshold: float): + self._load_milvus() return score_threshold_process(score_threshold, top_k, self.milvus.similarity_search_with_score(query, top_k)) def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: diff --git a/server/knowledge_base/kb_service/pg_kb_service.py b/server/knowledge_base/kb_service/pg_kb_service.py index 9c17e80..66ad9c1 100644 --- a/server/knowledge_base/kb_service/pg_kb_service.py +++ b/server/knowledge_base/kb_service/pg_kb_service.py @@ -1,28 +1,22 @@ import json from typing import List, Dict, Optional -from langchain.embeddings.base import Embeddings from langchain.schema import Document -from langchain.vectorstores import PGVector -from langchain.vectorstores.pgvector import DistanceStrategy +from langchain.vectorstores.pgvector import PGVector, DistanceStrategy from sqlalchemy import text from configs import kbs_config from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ score_threshold_process -from server.knowledge_base.utils import load_embeddings, KnowledgeFile -from server.utils import embedding_device as get_embedding_device +from server.knowledge_base.utils import KnowledgeFile class PGKBService(KBService): pg_vector: PGVector - def _load_pg_vector(self, embedding_device: str = get_embedding_device(), embeddings: Embeddings = None): - _embeddings = embeddings - if _embeddings is None: - _embeddings = load_embeddings(self.embed_model, embedding_device) - self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(_embeddings), + def _load_pg_vector(self): + self.pg_vector = PGVector(embedding_function=EmbeddingsFunAdapter(self.embed_model), collection_name=self.kb_name, distance_strategy=DistanceStrategy.EUCLIDEAN, connection_string=kbs_config.get("pg").get("connection_uri")) @@ -57,8 +51,8 @@ class PGKBService(KBService): ''')) connect.commit() - def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): - self._load_pg_vector(embeddings=embeddings) + def do_search(self, query: str, top_k: int, score_threshold: float): + self._load_pg_vector() return score_threshold_process(score_threshold, top_k, self.pg_vector.similarity_search_with_score(query, top_k)) diff --git a/server/knowledge_base/kb_service/zilliz_kb_service.py b/server/knowledge_base/kb_service/zilliz_kb_service.py index 679e7d9..8bf10ef 100644 --- a/server/knowledge_base/kb_service/zilliz_kb_service.py +++ b/server/knowledge_base/kb_service/zilliz_kb_service.py @@ -43,11 +43,9 @@ class ZillizKBService(KBService): def vs_type(self) -> str: return SupportedVSType.ZILLIZ - def _load_zilliz(self, embeddings: Embeddings = None): - if embeddings is None: - embeddings = self._load_embeddings() + def _load_zilliz(self): zilliz_args = kbs_config.get("zilliz") - self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(embeddings), + self.zilliz = Zilliz(embedding_function=EmbeddingsFunAdapter(self.embed_model), collection_name=self.kb_name, connection_args=zilliz_args) @@ -59,8 +57,8 @@ class ZillizKBService(KBService): self.zilliz.col.release() self.zilliz.col.drop() - def do_search(self, query: str, top_k: int, score_threshold: float, embeddings: Embeddings): - self._load_zilliz(embeddings=EmbeddingsFunAdapter(embeddings)) + def do_search(self, query: str, top_k: int, score_threshold: float): + self._load_zilliz() return score_threshold_process(score_threshold, top_k, self.zilliz.similarity_search_with_score(query, top_k)) def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 3981e81..ef3f001 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,8 +1,4 @@ import os -import sys - -sys.path.append("/home/congyin/Code/Project_Langchain_0814/Langchain-Chatchat") -from transformers import AutoTokenizer from configs import ( EMBEDDING_MODEL, KB_ROOT_PATH, @@ -22,7 +18,6 @@ from langchain.docstore.document import Document from langchain.text_splitter import TextSplitter from pathlib import Path import json -from concurrent.futures import ThreadPoolExecutor from server.utils import run_in_thread_pool, embedding_device, get_model_worker_config import io from typing import List, Union, Callable, Dict, Optional, Tuple, Generator @@ -62,14 +57,6 @@ def list_files_from_folder(kb_name: str): if os.path.isfile(os.path.join(doc_path, file))] -def load_embeddings(model: str = EMBEDDING_MODEL, device: str = embedding_device()): - ''' - 从缓存中加载embeddings,可以避免多线程时竞争加载。 - ''' - from server.knowledge_base.kb_cache.base import embeddings_pool - return embeddings_pool.load_embeddings(model=model, device=device) - - LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], "UnstructuredMarkdownLoader": ['.md'], "CustomJSONLoader": [".json"], @@ -239,6 +226,7 @@ def make_text_splitter( from langchain.text_splitter import CharacterTextSplitter tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") else: ## 字符长度加载 + from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained( text_splitter_dict[splitter_name]["tokenizer_name_or_path"], trust_remote_code=True) @@ -358,7 +346,6 @@ def files2docs_in_thread( chunk_size: int = CHUNK_SIZE, chunk_overlap: int = OVERLAP_SIZE, zh_title_enhance: bool = ZH_TITLE_ENHANCE, - pool: ThreadPoolExecutor = None, ) -> Generator: ''' 利用多线程批量将磁盘文件转化成langchain Document. @@ -396,7 +383,7 @@ def files2docs_in_thread( except Exception as e: yield False, (kb_name, filename, str(e)) - for result in run_in_thread_pool(func=file2docs, params=kwargs_list, pool=pool): + for result in run_in_thread_pool(func=file2docs, params=kwargs_list): yield result diff --git a/server/llm_api.py b/server/llm_api.py index 11b5e37..0b78fc2 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -2,6 +2,7 @@ from fastapi import Body from configs import logger, log_verbose, LLM_MODEL, HTTPX_DEFAULT_TIMEOUT from server.utils import (BaseResponse, fschat_controller_address, list_config_llm_models, get_httpx_client, get_model_worker_config) +from copy import deepcopy def list_running_models( @@ -31,16 +32,16 @@ def list_config_models() -> BaseResponse: ''' 从本地获取configs中配置的模型列表 ''' - configs = list_config_llm_models() + configs = {} # 删除ONLINE_MODEL配置中的敏感信息 - for config in configs["online"].values(): - del_keys = set(["worker_class"]) - for k in config: - if "key" in k.lower() or "secret" in k.lower(): - del_keys.add(k) - for k in del_keys: - config.pop(k, None) - + for name, config in list_config_llm_models()["online"].items(): + configs[name] = {} + for k, v in config.items(): + if not (k == "worker_class" + or "key" in k.lower() + or "secret" in k.lower() + or k.lower().endswith("id")): + configs[name][k] = v return BaseResponse(data=configs) @@ -51,14 +52,14 @@ def get_model_config( ''' 获取LLM模型配置项(合并后的) ''' - config = get_model_worker_config(model_name=model_name) + config = {} # 删除ONLINE_MODEL配置中的敏感信息 - del_keys = set(["worker_class"]) - for k in config: - if "key" in k.lower() or "secret" in k.lower(): - del_keys.add(k) - for k in del_keys: - config.pop(k, None) + for k, v in get_model_worker_config(model_name=model_name).items(): + if not (k == "worker_class" + or "key" in k.lower() + or "secret" in k.lower() + or k.lower().endswith("id")): + config[k] = v return BaseResponse(data=config) diff --git a/server/model_workers/base.py b/server/model_workers/base.py index 95ba7c6..e015ce3 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -180,7 +180,7 @@ class ApiModelWorker(BaseModelWorker): def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: ''' 执行Embeddings的方法,默认使用模块里面的embed_documents函数。 - 要求返回形式:{"code": int, "embeddings": List[List[float]], "msg": str} + 要求返回形式:{"code": int, "data": List[List[float]], "msg": str} ''' return {"code": 500, "msg": f"{self.model_names[0]}未实现embeddings功能"} diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index caffcba..fa4bb85 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -99,7 +99,7 @@ class MiniMaxWorker(ApiModelWorker): with get_httpx_client() as client: r = client.post(url, headers=headers, json=data).json() if embeddings := r.get("vectors"): - return {"code": 200, "embeddings": embeddings} + return {"code": 200, "data": embeddings} elif error := r.get("base_resp"): return {"code": error["status_code"], "msg": error["status_msg"]} diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 657e90d..ad71c0c 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -172,7 +172,7 @@ class QianFanWorker(ApiModelWorker): resp = client.post(url, json={"input": params.texts}).json() if "error_cdoe" not in resp: embeddings = [x["embedding"] for x in resp.get("data", [])] - return {"code": 200, "embeddings": embeddings} + return {"code": 200, "data": embeddings} else: return {"code": resp["error_code"], "msg": resp["error_msg"]} diff --git a/server/model_workers/qwen.py b/server/model_workers/qwen.py index 2340ec2..2cb1ed0 100644 --- a/server/model_workers/qwen.py +++ b/server/model_workers/qwen.py @@ -68,7 +68,7 @@ class QwenWorker(ApiModelWorker): return {"code": resp["status_code"], "msg": resp.message} else: embeddings = [x["embedding"] for x in resp["output"]["embeddings"]] - return {"code": 200, "embeddings": embeddings} + return {"code": 200, "data": embeddings} def get_embeddings(self, params): # TODO: 支持embeddings diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 546dbce..9d48bed 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -59,7 +59,7 @@ class ChatGLMWorker(ApiModelWorker): except Exception as e: return {"code": 500, "msg": f"对文本向量化时出错:{e}"} - return {"code": 200, "embeddings": embeddings} + return {"code": 200, "data": embeddings} def get_embeddings(self, params): # TODO: 支持embeddings diff --git a/server/utils.py b/server/utils.py index 39c1cb3..23713c4 100644 --- a/server/utils.py +++ b/server/utils.py @@ -14,7 +14,6 @@ from langchain.llms import OpenAI, AzureOpenAI, Anthropic import httpx from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union -thread_pool = ThreadPoolExecutor(os.cpu_count()) async def wrap_done(fn: Awaitable, event: asyncio.Event): @@ -368,8 +367,8 @@ def MakeFastAPIOffline( redoc_favicon_url=favicon, ) - # 从model_config中获取模型信息 +# 从model_config中获取模型信息 def list_embed_models() -> List[str]: ''' @@ -432,8 +431,8 @@ def get_model_worker_config(model_name: str = None) -> dict: from server import model_workers config = FSCHAT_MODEL_WORKERS.get("default", {}).copy() - config.update(ONLINE_LLM_MODEL.get(model_name, {})) - config.update(FSCHAT_MODEL_WORKERS.get(model_name, {})) + config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy()) + config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy()) if model_name in ONLINE_LLM_MODEL: @@ -611,21 +610,19 @@ def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]: def run_in_thread_pool( func: Callable, params: List[Dict] = [], - pool: ThreadPoolExecutor = None, ) -> Generator: ''' 在线程池中批量运行任务,并将运行结果以生成器的形式返回。 请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。 ''' tasks = [] - pool = pool or thread_pool + with ThreadPoolExecutor() as pool: + for kwargs in params: + thread = pool.submit(func, **kwargs) + tasks.append(thread) - for kwargs in params: - thread = pool.submit(func, **kwargs) - tasks.append(thread) - - for obj in as_completed(tasks): - yield obj.result() + for obj in as_completed(tasks): # TODO: Ctrl+c无法停止 + yield obj.result() def get_httpx_client( @@ -703,7 +700,6 @@ def get_server_configs() -> Dict: ) from configs.model_config import ( LLM_MODEL, - EMBEDDING_MODEL, HISTORY_LEN, TEMPERATURE, ) @@ -728,3 +724,14 @@ def list_online_embed_models() -> List[str]: if worker_class is not None and worker_class.can_embedding(): ret.append(k) return ret + + +def load_local_embeddings(model: str = None, device: str = embedding_device()): + ''' + 从缓存中加载embeddings,可以避免多线程时竞争加载。 + ''' + from server.knowledge_base.kb_cache.base import embeddings_pool + from configs import EMBEDDING_MODEL + + model = model or EMBEDDING_MODEL + return embeddings_pool.load_embeddings(model=model, device=device) diff --git a/tests/test_online_api.py b/tests/test_online_api.py index 0ddeaee..7c72f3e 100644 --- a/tests/test_online_api.py +++ b/tests/test_online_api.py @@ -16,6 +16,8 @@ for x in list_config_llm_models()["online"]: workers.append(x) print(f"all workers to test: {workers}") +# workers = ["qianfan-api"] + @pytest.mark.parametrize("worker", workers) def test_chat(worker): @@ -49,8 +51,8 @@ def test_embeddings(worker): pprint(resp, depth=2) assert resp["code"] == 200 - assert "embeddings" in resp - embeddings = resp["embeddings"] + assert "data" in resp + embeddings = resp["data"] assert isinstance(embeddings, list) and len(embeddings) > 0 assert isinstance(embeddings[0], list) and len(embeddings[0]) > 0 assert isinstance(embeddings[0][0], float) diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 4df7c9c..2a097f1 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -9,7 +9,7 @@ from typing import Literal, Dict, Tuple from configs import (kbs_config, EMBEDDING_MODEL, DEFAULT_VS_TYPE, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) -from server.utils import list_embed_models +from server.utils import list_embed_models, list_online_embed_models import os import time @@ -103,7 +103,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): key="vs_type", ) - embed_models = list_embed_models() + embed_models = list_embed_models() + list_online_embed_models() embed_model = cols[1].selectbox( "Embedding 模型", diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 791edba..4a3963a 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -853,6 +853,26 @@ class ApiRequest: else: return ret_sync() + def embed_texts( + self, + texts: List[str], + embed_model: str = EMBEDDING_MODEL, + to_query: bool = False, + ) -> List[List[float]]: + ''' + 对文本进行向量化,可选模型包括本地 embed_models 和支持 embeddings 的在线模型 + ''' + data = { + "texts": texts, + "embed_model": embed_model, + "to_query": to_query, + } + resp = self.post( + "/other/embed_texts", + json=data, + ) + return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data")) + class AsyncApiRequest(ApiRequest): def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):