diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index 45e9a5f..947782a 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -119,7 +119,9 @@ class MiniMaxWorker(ApiModelWorker): with get_httpx_client() as client: result = [] i = 0 - for texts in params.texts[i:i+10]: + batch_size = 10 + while i < len(params.texts): + texts = params.texts[i:i+batch_size] data["texts"] = texts r = client.post(url, headers=headers, json=data).json() if embeddings := r.get("vectors"): @@ -137,7 +139,7 @@ class MiniMaxWorker(ApiModelWorker): } self.logger.error(f"请求 MiniMax API 时发生错误:{data}") return data - i += 10 + i += batch_size return {"code": 200, "data": embeddings} def get_embeddings(self, params): diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 95e200a..2bcce94 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -188,9 +188,11 @@ class QianFanWorker(ApiModelWorker): with get_httpx_client() as client: result = [] i = 0 - for texts in params.texts[i:i+10]: + batch_size = 10 + while i < len(params.texts): + texts = params.texts[i:i+batch_size] resp = client.post(url, json={"input": texts}).json() - if "error_cdoe" in resp: + if "error_code" in resp: data = { "code": resp["error_code"], "msg": resp["error_msg"], @@ -206,7 +208,7 @@ class QianFanWorker(ApiModelWorker): else: embeddings = [x["embedding"] for x in resp.get("data", [])] result += embeddings - i += 10 + i += batch_size return {"code": 200, "data": result} # TODO: qianfan支持续写模型