将 MiniMax 和 千帆 在线 Embedding 改为 10 个文本一批,防止接口数量限制 (#2161)
This commit is contained in:
parent
76151e884a
commit
824c29a6d2
|
|
@ -100,7 +100,7 @@ kbs_config = {
|
||||||
"index_name": "test_index",
|
"index_name": "test_index",
|
||||||
"user": "",
|
"user": "",
|
||||||
"password": ""
|
"password": ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# TextSplitter配置项,如果你不明白其中的含义,就不要修改。
|
# TextSplitter配置项,如果你不明白其中的含义,就不要修改。
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,7 @@ class MiniMaxWorker(ApiModelWorker):
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": params.embed_model or self.DEFAULT_EMBED_MODEL,
|
"model": params.embed_model or self.DEFAULT_EMBED_MODEL,
|
||||||
"texts": params.texts,
|
"texts": [],
|
||||||
"type": "query" if params.to_query else "db",
|
"type": "query" if params.to_query else "db",
|
||||||
}
|
}
|
||||||
if log_verbose:
|
if log_verbose:
|
||||||
|
|
@ -115,21 +115,26 @@ class MiniMaxWorker(ApiModelWorker):
|
||||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||||
|
|
||||||
with get_httpx_client() as client:
|
with get_httpx_client() as client:
|
||||||
r = client.post(url, headers=headers, json=data).json()
|
result = []
|
||||||
if embeddings := r.get("vectors"):
|
i = 0
|
||||||
return {"code": 200, "data": embeddings}
|
for texts in params.texts[i:i+10]:
|
||||||
elif error := r.get("base_resp"):
|
data["texts"] = texts
|
||||||
return {
|
r = client.post(url, headers=headers, json=data).json()
|
||||||
"code": error["status_code"],
|
if embeddings := r.get("vectors"):
|
||||||
"msg": error["status_msg"],
|
result += embeddings
|
||||||
|
elif error := r.get("base_resp"):
|
||||||
"error": {
|
return {
|
||||||
"message": error["status_msg"],
|
"code": error["status_code"],
|
||||||
"type": "invalid_request_error",
|
"msg": error["status_msg"],
|
||||||
"param": None,
|
"error": {
|
||||||
"code": None,
|
"message": error["status_msg"],
|
||||||
|
"type": "invalid_request_error",
|
||||||
|
"param": None,
|
||||||
|
"code": None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
i += 10
|
||||||
|
return {"code": 200, "data": embeddings}
|
||||||
|
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
# TODO: 支持embeddings
|
# TODO: 支持embeddings
|
||||||
|
|
|
||||||
|
|
@ -184,21 +184,26 @@ class QianFanWorker(ApiModelWorker):
|
||||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||||
|
|
||||||
with get_httpx_client() as client:
|
with get_httpx_client() as client:
|
||||||
resp = client.post(url, json={"input": params.texts}).json()
|
result = []
|
||||||
if "error_cdoe" not in resp:
|
i = 0
|
||||||
embeddings = [x["embedding"] for x in resp.get("data", [])]
|
for texts in params.texts[i:i+10]:
|
||||||
return {"code": 200, "data": embeddings}
|
resp = client.post(url, json={"input": texts}).json()
|
||||||
else:
|
if "error_cdoe" in resp:
|
||||||
return {
|
return {
|
||||||
"code": resp["error_code"],
|
"code": resp["error_code"],
|
||||||
"msg": resp["error_msg"],
|
"msg": resp["error_msg"],
|
||||||
"error": {
|
"error": {
|
||||||
"message": resp["error_msg"],
|
"message": resp["error_msg"],
|
||||||
"type": "invalid_request_error",
|
"type": "invalid_request_error",
|
||||||
"param": None,
|
"param": None,
|
||||||
"code": None,
|
"code": None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
else:
|
||||||
|
embeddings = [x["embedding"] for x in resp.get("data", [])]
|
||||||
|
result += embeddings
|
||||||
|
i += 10
|
||||||
|
return {"code": 200, "data": result}
|
||||||
|
|
||||||
# TODO: qianfan支持续写模型
|
# TODO: qianfan支持续写模型
|
||||||
def get_embeddings(self, params):
|
def get_embeddings(self, params):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue