From 5ac77e50893aeaf3c1ccae2ca92e372260e44686 Mon Sep 17 00:00:00 2001 From: zty <93087391+alanlaye617@users.noreply.github.com> Date: Thu, 30 Nov 2023 17:28:22 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D:=20MiniMax=E5=92=8C=E5=8D=83?= =?UTF-8?q?=E5=B8=86=E5=9C=A8=E7=BA=BFembedding=E6=A8=A1=E5=9E=8B=E5=88=86?= =?UTF-8?q?=E6=89=B9=E8=AF=B7=E6=B1=82=E7=9A=84bug=20(#2208)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 修复: MiniMax和千帆在线embedding模型分批请求的bug * 修改了一处typo --- server/model_workers/minimax.py | 6 ++++-- server/model_workers/qianfan.py | 8 +++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index 45e9a5f..947782a 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -119,7 +119,9 @@ class MiniMaxWorker(ApiModelWorker): with get_httpx_client() as client: result = [] i = 0 - for texts in params.texts[i:i+10]: + batch_size = 10 + while i < len(params.texts): + texts = params.texts[i:i+batch_size] data["texts"] = texts r = client.post(url, headers=headers, json=data).json() if embeddings := r.get("vectors"): @@ -137,7 +139,7 @@ class MiniMaxWorker(ApiModelWorker): } self.logger.error(f"请求 MiniMax API 时发生错误:{data}") return data - i += 10 + i += batch_size return {"code": 200, "data": embeddings} def get_embeddings(self, params): diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 95e200a..2bcce94 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -188,9 +188,11 @@ class QianFanWorker(ApiModelWorker): with get_httpx_client() as client: result = [] i = 0 - for texts in params.texts[i:i+10]: + batch_size = 10 + while i < len(params.texts): + texts = params.texts[i:i+batch_size] resp = client.post(url, json={"input": texts}).json() - if "error_cdoe" in resp: + if "error_code" in resp: data = { "code": resp["error_code"], "msg": resp["error_msg"], @@ -206,7 +208,7 @@ class QianFanWorker(ApiModelWorker): else: embeddings = [x["embedding"] for x in resp.get("data", [])] result += embeddings - i += 10 + i += batch_size return {"code": 200, "data": result} # TODO: qianfan支持续写模型