diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 1908414..9100ad4 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -96,6 +96,7 @@ ONLINE_LLM_MODEL = { "version": "qwen-turbo", # 可选包括 "qwen-turbo", "qwen-plus" "api_key": "", # 请在阿里云控制台模型服务灵积API-KEY管理页面创建 "provider": "QwenWorker", + "embed_model": "text-embedding-v1" # embedding 模型名称 }, # 百川 API,申请方式请参考 https://www.baichuan-ai.com/home#api-enter diff --git a/server/embeddings_api.py b/server/embeddings_api.py index 80cd289..661d97b 100644 --- a/server/embeddings_api.py +++ b/server/embeddings_api.py @@ -5,32 +5,32 @@ from server.utils import BaseResponse, get_model_worker_config, list_embed_model from fastapi import Body from typing import Dict, List - online_embed_models = list_online_embed_models() def embed_texts( - texts: List[str], - embed_model: str = EMBEDDING_MODEL, - to_query: bool = False, + texts: List[str], + embed_model: str = EMBEDDING_MODEL, + to_query: bool = False, ) -> BaseResponse: ''' 对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]]) TODO: 也许需要加入缓存机制,减少 token 消耗 ''' try: - if embed_model in list_embed_models(): # 使用本地Embeddings模型 + if embed_model in list_embed_models(): # 使用本地Embeddings模型 from server.utils import load_local_embeddings embeddings = load_local_embeddings(model=embed_model) return BaseResponse(data=embeddings.embed_documents(texts)) - if embed_model in list_online_embed_models(): # 使用在线API + if embed_model in list_online_embed_models(): # 使用在线API config = get_model_worker_config(embed_model) worker_class = config.get("worker_class") + embed_model = config.get("embed_model") worker = worker_class() if worker_class.can_embedding(): - params = ApiEmbeddingsParams(texts=texts, to_query=to_query) + params = ApiEmbeddingsParams(texts=texts, to_query=to_query, embed_model=embed_model) resp = worker.do_embeddings(params) return BaseResponse(**resp) @@ -39,10 +39,12 @@ def embed_texts( logger.error(e) return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") + def embed_texts_endpoint( - texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]), - embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"), - to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"), + texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]), + embed_model: str = Body(EMBEDDING_MODEL, + description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"), + to_query: bool = Body(False, description="向量是否用于查询。有些模型如Minimax对存储/查询的向量进行了区分优化。"), ) -> BaseResponse: ''' 对文本进行向量化,返回 BaseResponse(data=List[List[float]]) @@ -51,9 +53,9 @@ def embed_texts_endpoint( def embed_documents( - docs: List[Document], - embed_model: str = EMBEDDING_MODEL, - to_query: bool = False, + docs: List[Document], + embed_model: str = EMBEDDING_MODEL, + to_query: bool = False, ) -> Dict: """ 将 List[Document] 向量化,转化为 VectorStore.add_embeddings 可以接受的参数 diff --git a/server/model_workers/base.py b/server/model_workers/base.py index 7b456a9..88affb4 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -113,6 +113,9 @@ class ApiModelWorker(BaseModelWorker): sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + self.context_len = context_len self.semaphore = asyncio.Semaphore(self.limit_worker_concurrency) self.version = None