From b68f7fcdea8584a59f8070504158fbab8135dc1b Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Fri, 27 Oct 2023 13:42:16 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9qianfan-api=E4=BD=BF=E7=94=A8?= =?UTF-8?q?=E5=8E=9F=E5=A7=8Bpost=E8=AE=BF=E9=97=AE=EF=BC=8Cqianfan=20sdk?= =?UTF-8?q?=E6=97=A0=E6=B3=95=E8=AE=BF=E9=97=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/model_workers/qianfan.py | 257 +++++++++++++++++--------------- webui_pages/utils.py | 2 +- 2 files changed, 137 insertions(+), 122 deletions(-) diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 2160db3..238a5bf 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -1,68 +1,72 @@ import sys from fastchat.conversation import Conversation from server.model_workers.base import * +from server.utils import get_httpx_client +from cachetools import cached, TTLCache +import json from fastchat import conversation as conv import sys from server.model_workers.base import ApiEmbeddingsParams from typing import List, Literal, Dict -# MODEL_VERSIONS = { -# "ernie-bot": "completions", -# "ernie-bot-turbo": "eb-instant", -# "bloomz-7b": "bloomz_7b1", -# "qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed", -# "llama2-7b-chat": "llama_2_7b", -# "llama2-13b-chat": "llama_2_13b", -# "llama2-70b-chat": "llama_2_70b", -# "qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b", -# "chatglm2-6b-32k": "chatglm2_6b_32k", -# "aquilachat-7b": "aquilachat_7b", -# # "linly-llama2-ch-7b": "", # 暂未发布 -# # "linly-llama2-ch-13b": "", # 暂未发布 -# # "chatglm2-6b": "", # 暂未发布 -# # "chatglm2-6b-int4": "", # 暂未发布 -# # "falcon-7b": "", # 暂未发布 -# # "falcon-180b-chat": "", # 暂未发布 -# # "falcon-40b": "", # 暂未发布 -# # "rwkv4-world": "", # 暂未发布 -# # "rwkv5-world": "", # 暂未发布 -# # "rwkv4-pile-14b": "", # 暂未发布 -# # "rwkv4-raven-14b": "", # 暂未发布 -# # "open-llama-7b": "", # 暂未发布 -# # "dolly-12b": "", # 暂未发布 -# # "mpt-7b-instruct": "", # 暂未发布 -# # "mpt-30b-instruct": "", # 暂未发布 -# # "OA-Pythia-12B-SFT-4": "", # 暂未发布 -# # "xverse-13b": "", # 暂未发布 +MODEL_VERSIONS = { + "ernie-bot-4": "completions_pro", + "ernie-bot": "completions", + "ernie-bot-turbo": "eb-instant", + "bloomz-7b": "bloomz_7b1", + "qianfan-bloomz-7b-c": "qianfan_bloomz_7b_compressed", + "llama2-7b-chat": "llama_2_7b", + "llama2-13b-chat": "llama_2_13b", + "llama2-70b-chat": "llama_2_70b", + "qianfan-llama2-ch-7b": "qianfan_chinese_llama_2_7b", + "chatglm2-6b-32k": "chatglm2_6b_32k", + "aquilachat-7b": "aquilachat_7b", + # "linly-llama2-ch-7b": "", # 暂未发布 + # "linly-llama2-ch-13b": "", # 暂未发布 + # "chatglm2-6b": "", # 暂未发布 + # "chatglm2-6b-int4": "", # 暂未发布 + # "falcon-7b": "", # 暂未发布 + # "falcon-180b-chat": "", # 暂未发布 + # "falcon-40b": "", # 暂未发布 + # "rwkv4-world": "", # 暂未发布 + # "rwkv5-world": "", # 暂未发布 + # "rwkv4-pile-14b": "", # 暂未发布 + # "rwkv4-raven-14b": "", # 暂未发布 + # "open-llama-7b": "", # 暂未发布 + # "dolly-12b": "", # 暂未发布 + # "mpt-7b-instruct": "", # 暂未发布 + # "mpt-30b-instruct": "", # 暂未发布 + # "OA-Pythia-12B-SFT-4": "", # 暂未发布 + # "xverse-13b": "", # 暂未发布 -# # # 以下为企业测试,需要单独申请 -# # "flan-ul2": "", -# # "Cerebras-GPT-6.7B": "" -# # "Pythia-6.9B": "" -# } + # # 以下为企业测试,需要单独申请 + # "flan-ul2": "", + # "Cerebras-GPT-6.7B": "" + # "Pythia-6.9B": "" +} -# @cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次 -# def get_baidu_access_token(api_key: str, secret_key: str) -> str: -# """ -# 使用 AK,SK 生成鉴权签名(Access Token) -# :return: access_token,或是None(如果错误) -# """ -# url = "https://aip.baidubce.com/oauth/2.0/token" -# params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} -# try: -# with get_httpx_client() as client: -# return client.get(url, params=params).json().get("access_token") -# except Exception as e: -# print(f"failed to get token from baidu: {e}") +@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次 +def get_baidu_access_token(api_key: str, secret_key: str) -> str: + """ + 使用 AK,SK 生成鉴权签名(Access Token) + :return: access_token,或是None(如果错误) + """ + url = "https://aip.baidubce.com/oauth/2.0/token" + params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} + try: + with get_httpx_client() as client: + return client.get(url, params=params).json().get("access_token") + except Exception as e: + print(f"failed to get token from baidu: {e}") class QianFanWorker(ApiModelWorker): """ 百度千帆 """ - DEFAULT_EMBED_MODEL = "bge-large-zh" + DEFAULT_EMBED_MODEL = "embedding-v1" def __init__( self, @@ -80,87 +84,98 @@ class QianFanWorker(ApiModelWorker): def do_chat(self, params: ApiChatParams) -> Dict: params.load_config(self.model_names[0]) - import qianfan + # import qianfan - comp = qianfan.ChatCompletion(model=params.version, - endpoint=params.version_url, - ak=params.api_key, - sk=params.secret_key,) - text = "" - for resp in comp.do(messages=params.messages, - temperature=params.temperature, - top_p=params.top_p, - stream=True): - if resp.code == 200: - if chunk := resp.body.get("result"): - text += chunk - yield { - "error_code": 0, - "text": text - } - else: - yield { - "error_code": resp.code, - "text": str(resp.body), - } - - # BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\ - # '/{model_version}?access_token={access_token}' - - # access_token = get_baidu_access_token(params.api_key, params.secret_key) - # if not access_token: - # yield { - # "error_code": 403, - # "text": f"failed to get access token. have you set the correct api_key and secret key?", - # } - - # url = BASE_URL.format( - # model_version=params.version_url or MODEL_VERSIONS[params.version], - # access_token=access_token, - # ) - # payload = { - # "messages": params.messages, - # "temperature": params.temperature, - # "stream": True - # } - # headers = { - # 'Content-Type': 'application/json', - # 'Accept': 'application/json', - # } - + # comp = qianfan.ChatCompletion(model=params.version, + # endpoint=params.version_url, + # ak=params.api_key, + # sk=params.secret_key,) # text = "" - # with get_httpx_client() as client: - # with client.stream("POST", url, headers=headers, json=payload) as response: - # for line in response.iter_lines(): - # if not line.strip(): - # continue - # if line.startswith("data: "): - # line = line[6:] - # resp = json.loads(line) + # for resp in comp.do(messages=params.messages, + # temperature=params.temperature, + # top_p=params.top_p, + # stream=True): + # if resp.code == 200: + # if chunk := resp.body.get("result"): + # text += chunk + # yield { + # "error_code": 0, + # "text": text + # } + # else: + # yield { + # "error_code": resp.code, + # "text": str(resp.body), + # } - # if "result" in resp.keys(): - # text += resp["result"] - # yield { - # "error_code": 0, - # "text": text - # } - # else: - # yield { - # "error_code": resp["error_code"], - # "text": resp["error_msg"] - # } + BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\ + '/{model_version}?access_token={access_token}' + + access_token = get_baidu_access_token(params.api_key, params.secret_key) + if not access_token: + yield { + "error_code": 403, + "text": f"failed to get access token. have you set the correct api_key and secret key?", + } + + url = BASE_URL.format( + model_version=params.version_url or MODEL_VERSIONS[params.version.lower()], + access_token=access_token, + ) + payload = { + "messages": params.messages, + "temperature": params.temperature, + "stream": True + } + headers = { + 'Content-Type': 'application/json', + 'Accept': 'application/json', + } + + text = "" + with get_httpx_client() as client: + with client.stream("POST", url, headers=headers, json=payload) as response: + for line in response.iter_lines(): + if not line.strip(): + continue + if line.startswith("data: "): + line = line[6:] + resp = json.loads(line) + + if "result" in resp.keys(): + text += resp["result"] + yield { + "error_code": 0, + "text": text + } + else: + yield { + "error_code": resp["error_code"], + "text": resp["error_msg"] + } def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: - import qianfan params.load_config(self.model_names[0]) + # import qianfan - embed = qianfan.Embedding(ak=params.api_key, sk=params.secret_key) - resp = embed.do(texts = params.texts, model=params.embed_model or self.DEFAULT_EMBED_MODEL) - if resp.code == 200: - embeddings = [x.embedding for x in resp.body.get("data", [])] - return {"code": 200, "embeddings": embeddings} - else: - return {"code": resp.code, "msg": str(resp.body)} + # embed = qianfan.Embedding(ak=params.api_key, sk=params.secret_key) + # resp = embed.do(texts = params.texts, model=params.embed_model or self.DEFAULT_EMBED_MODEL) + # if resp.code == 200: + # embeddings = [x.embedding for x in resp.body.get("data", [])] + # return {"code": 200, "embeddings": embeddings} + # else: + # return {"code": resp.code, "msg": str(resp.body)} + + embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL + access_token = get_baidu_access_token(params.api_key, params.secret_key) + url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/{embed_model}?access_token={access_token}" + with get_httpx_client() as client: + resp = client.post(url, json={"input": params.texts}).json() + if "error_cdoe" not in resp: + embeddings = [x["embedding"] for x in resp.get("data", [])] + return {"code": 200, "embeddings": embeddings} + else: + return {"code": resp["error_code"], "msg": resp["error_msg"]} # TODO: qianfan支持续写模型 diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 7a34f56..791edba 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -677,7 +677,7 @@ class ApiRequest: return self._get_response_value(response, as_json=True, value_func=lambda r:r.get("data", [])) - def get_default_llm_model(self) -> (str, bool): + def get_default_llm_model(self) -> Tuple[str, bool]: ''' 从服务器上获取当前运行的LLM模型,如果本机配置的LLM_MODEL属于本地模型且在其中,则优先返回 返回类型为(model_name, is_local_model)