From 824c29a6d2de7e43a0f641b2f8f1361102d21cac Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Fri, 24 Nov 2023 16:42:20 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=86=20MiniMax=20=E5=92=8C=20=E5=8D=83?= =?UTF-8?q?=E5=B8=86=20=E5=9C=A8=E7=BA=BF=20Embedding=20=E6=94=B9=E4=B8=BA?= =?UTF-8?q?=2010=20=E4=B8=AA=E6=96=87=E6=9C=AC=E4=B8=80=E6=89=B9=EF=BC=8C?= =?UTF-8?q?=E9=98=B2=E6=AD=A2=E6=8E=A5=E5=8F=A3=E6=95=B0=E9=87=8F=E9=99=90?= =?UTF-8?q?=E5=88=B6=20(#2161)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/kb_config.py.example | 2 +- server/model_workers/minimax.py | 35 +++++++++++++++++++-------------- server/model_workers/qianfan.py | 33 ++++++++++++++++++------------- 3 files changed, 40 insertions(+), 30 deletions(-) diff --git a/configs/kb_config.py.example b/configs/kb_config.py.example index 04e09ec..9c727b7 100644 --- a/configs/kb_config.py.example +++ b/configs/kb_config.py.example @@ -100,7 +100,7 @@ kbs_config = { "index_name": "test_index", "user": "", "password": "" - } + } } # TextSplitter配置项,如果你不明白其中的含义,就不要修改。 diff --git a/server/model_workers/minimax.py b/server/model_workers/minimax.py index 47d6099..220ed58 100644 --- a/server/model_workers/minimax.py +++ b/server/model_workers/minimax.py @@ -106,7 +106,7 @@ class MiniMaxWorker(ApiModelWorker): data = { "model": params.embed_model or self.DEFAULT_EMBED_MODEL, - "texts": params.texts, + "texts": [], "type": "query" if params.to_query else "db", } if log_verbose: @@ -115,21 +115,26 @@ class MiniMaxWorker(ApiModelWorker): logger.info(f'{self.__class__.__name__}:headers: {headers}') with get_httpx_client() as client: - r = client.post(url, headers=headers, json=data).json() - if embeddings := r.get("vectors"): - return {"code": 200, "data": embeddings} - elif error := r.get("base_resp"): - return { - "code": error["status_code"], - "msg": error["status_msg"], - - "error": { - "message": error["status_msg"], - "type": "invalid_request_error", - "param": None, - "code": None, + result = [] + i = 0 + for texts in params.texts[i:i+10]: + data["texts"] = texts + r = client.post(url, headers=headers, json=data).json() + if embeddings := r.get("vectors"): + result += embeddings + elif error := r.get("base_resp"): + return { + "code": error["status_code"], + "msg": error["status_msg"], + "error": { + "message": error["status_msg"], + "type": "invalid_request_error", + "param": None, + "code": None, + } } - } + i += 10 + return {"code": 200, "data": embeddings} def get_embeddings(self, params): # TODO: 支持embeddings diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index dacb8d1..2ab39d1 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -184,21 +184,26 @@ class QianFanWorker(ApiModelWorker): logger.info(f'{self.__class__.__name__}:url: {url}') with get_httpx_client() as client: - resp = client.post(url, json={"input": params.texts}).json() - if "error_cdoe" not in resp: - embeddings = [x["embedding"] for x in resp.get("data", [])] - return {"code": 200, "data": embeddings} - else: - return { - "code": resp["error_code"], - "msg": resp["error_msg"], - "error": { - "message": resp["error_msg"], - "type": "invalid_request_error", - "param": None, - "code": None, + result = [] + i = 0 + for texts in params.texts[i:i+10]: + resp = client.post(url, json={"input": texts}).json() + if "error_cdoe" in resp: + return { + "code": resp["error_code"], + "msg": resp["error_msg"], + "error": { + "message": resp["error_msg"], + "type": "invalid_request_error", + "param": None, + "code": None, + } } - } + else: + embeddings = [x["embedding"] for x in resp.get("data", [])] + result += embeddings + i += 10 + return {"code": 200, "data": result} # TODO: qianfan支持续写模型 def get_embeddings(self, params):