将 MiniMax 和 千帆 在线 Embedding 改为 10 个文本一批,防止接口数量限制 (#2161)

This commit is contained in:
liunux4odoo 2023-11-24 16:42:20 +08:00 committed by GitHub
parent 76151e884a
commit 824c29a6d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 40 additions and 30 deletions

View File

@ -100,7 +100,7 @@ kbs_config = {
"index_name": "test_index", "index_name": "test_index",
"user": "", "user": "",
"password": "" "password": ""
} }
} }
# TextSplitter配置项如果你不明白其中的含义就不要修改。 # TextSplitter配置项如果你不明白其中的含义就不要修改。

View File

@ -106,7 +106,7 @@ class MiniMaxWorker(ApiModelWorker):
data = { data = {
"model": params.embed_model or self.DEFAULT_EMBED_MODEL, "model": params.embed_model or self.DEFAULT_EMBED_MODEL,
"texts": params.texts, "texts": [],
"type": "query" if params.to_query else "db", "type": "query" if params.to_query else "db",
} }
if log_verbose: if log_verbose:
@ -115,21 +115,26 @@ class MiniMaxWorker(ApiModelWorker):
logger.info(f'{self.__class__.__name__}:headers: {headers}') logger.info(f'{self.__class__.__name__}:headers: {headers}')
with get_httpx_client() as client: with get_httpx_client() as client:
r = client.post(url, headers=headers, json=data).json() result = []
if embeddings := r.get("vectors"): i = 0
return {"code": 200, "data": embeddings} for texts in params.texts[i:i+10]:
elif error := r.get("base_resp"): data["texts"] = texts
return { r = client.post(url, headers=headers, json=data).json()
"code": error["status_code"], if embeddings := r.get("vectors"):
"msg": error["status_msg"], result += embeddings
elif error := r.get("base_resp"):
"error": { return {
"message": error["status_msg"], "code": error["status_code"],
"type": "invalid_request_error", "msg": error["status_msg"],
"param": None, "error": {
"code": None, "message": error["status_msg"],
"type": "invalid_request_error",
"param": None,
"code": None,
}
} }
} i += 10
return {"code": 200, "data": embeddings}
def get_embeddings(self, params): def get_embeddings(self, params):
# TODO: 支持embeddings # TODO: 支持embeddings

View File

@ -184,21 +184,26 @@ class QianFanWorker(ApiModelWorker):
logger.info(f'{self.__class__.__name__}:url: {url}') logger.info(f'{self.__class__.__name__}:url: {url}')
with get_httpx_client() as client: with get_httpx_client() as client:
resp = client.post(url, json={"input": params.texts}).json() result = []
if "error_cdoe" not in resp: i = 0
embeddings = [x["embedding"] for x in resp.get("data", [])] for texts in params.texts[i:i+10]:
return {"code": 200, "data": embeddings} resp = client.post(url, json={"input": texts}).json()
else: if "error_cdoe" in resp:
return { return {
"code": resp["error_code"], "code": resp["error_code"],
"msg": resp["error_msg"], "msg": resp["error_msg"],
"error": { "error": {
"message": resp["error_msg"], "message": resp["error_msg"],
"type": "invalid_request_error", "type": "invalid_request_error",
"param": None, "param": None,
"code": None, "code": None,
}
} }
} else:
embeddings = [x["embedding"] for x in resp.get("data", [])]
result += embeddings
i += 10
return {"code": 200, "data": result}
# TODO: qianfan支持续写模型 # TODO: qianfan支持续写模型
def get_embeddings(self, params): def get_embeddings(self, params):