修复: MiniMax和千帆在线embedding模型分批请求的bug (#2208)

* 修复: MiniMax和千帆在线embedding模型分批请求的bug

* 修改了一处typo
This commit is contained in:
zty 2023-11-30 17:28:22 +08:00 committed by GitHub
parent 8b70b1db7e
commit 5ac77e5089
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 5 deletions

View File

@ -119,7 +119,9 @@ class MiniMaxWorker(ApiModelWorker):
with get_httpx_client() as client:
result = []
i = 0
for texts in params.texts[i:i+10]:
batch_size = 10
while i < len(params.texts):
texts = params.texts[i:i+batch_size]
data["texts"] = texts
r = client.post(url, headers=headers, json=data).json()
if embeddings := r.get("vectors"):
@ -137,7 +139,7 @@ class MiniMaxWorker(ApiModelWorker):
}
self.logger.error(f"请求 MiniMax API 时发生错误:{data}")
return data
i += 10
i += batch_size
return {"code": 200, "data": embeddings}
def get_embeddings(self, params):

View File

@ -188,9 +188,11 @@ class QianFanWorker(ApiModelWorker):
with get_httpx_client() as client:
result = []
i = 0
for texts in params.texts[i:i+10]:
batch_size = 10
while i < len(params.texts):
texts = params.texts[i:i+batch_size]
resp = client.post(url, json={"input": texts}).json()
if "error_cdoe" in resp:
if "error_code" in resp:
data = {
"code": resp["error_code"],
"msg": resp["error_msg"],
@ -206,7 +208,7 @@ class QianFanWorker(ApiModelWorker):
else:
embeddings = [x["embedding"] for x in resp.get("data", [])]
result += embeddings
i += 10
i += batch_size
return {"code": 200, "data": result}
# TODO: qianfan支持续写模型