diff --git a/server/model_workers/azure.py b/server/model_workers/azure.py index 3feff44..2735c53 100644 --- a/server/model_workers/azure.py +++ b/server/model_workers/azure.py @@ -18,7 +18,7 @@ class AzureWorker(ApiModelWorker): **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 8192) + kwargs.setdefault("context_len", 8000) #TODO 16K模型需要改成16384 super().__init__(**kwargs) self.version = version diff --git a/server/model_workers/qianfan.py b/server/model_workers/qianfan.py index 238a5bf..657e90d 100644 --- a/server/model_workers/qianfan.py +++ b/server/model_workers/qianfan.py @@ -9,7 +9,6 @@ import sys from server.model_workers.base import ApiEmbeddingsParams from typing import List, Literal, Dict - MODEL_VERSIONS = { "ernie-bot-4": "completions_pro", "ernie-bot": "completions", @@ -47,7 +46,7 @@ MODEL_VERSIONS = { } -@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次 +@cached(TTLCache(1, 1800)) # 经过测试,缓存的token可以使用,目前每30分钟刷新一次 def get_baidu_access_token(api_key: str, secret_key: str) -> str: """ 使用 AK,SK 生成鉴权签名(Access Token) @@ -69,13 +68,13 @@ class QianFanWorker(ApiModelWorker): DEFAULT_EMBED_MODEL = "embedding-v1" def __init__( - self, - *, - version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot", - model_names: List[str] = ["qianfan-api"], - controller_addr: str = None, - worker_addr: str = None, - **kwargs, + self, + *, + version: Literal["ernie-bot", "ernie-bot-turbo"] = "ernie-bot", + model_names: List[str] = ["qianfan-api"], + controller_addr: str = None, + worker_addr: str = None, + **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) kwargs.setdefault("context_len", 16384) @@ -108,8 +107,8 @@ class QianFanWorker(ApiModelWorker): # "text": str(resp.body), # } - BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat'\ - '/{model_version}?access_token={access_token}' + BASE_URL = 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat' \ + '/{model_version}?access_token={access_token}' access_token = get_baidu_access_token(params.api_key, params.secret_key) if not access_token: @@ -131,7 +130,7 @@ class QianFanWorker(ApiModelWorker): 'Content-Type': 'application/json', 'Accept': 'application/json', } - + text = "" with get_httpx_client() as client: with client.stream("POST", url, headers=headers, json=payload) as response: @@ -177,7 +176,6 @@ class QianFanWorker(ApiModelWorker): else: return {"code": resp["error_code"], "msg": resp["error_msg"]} - # TODO: qianfan支持续写模型 def get_embeddings(self, params): # TODO: 支持embeddings @@ -207,4 +205,4 @@ if __name__ == "__main__": ) sys.modules["fastchat.serve.model_worker"].worker = worker MakeFastAPIOffline(app) - uvicorn.run(app, port=21004) \ No newline at end of file + uvicorn.run(app, port=21004) diff --git a/server/model_workers/xinghuo.py b/server/model_workers/xinghuo.py index 6f2fbfb..e866609 100644 --- a/server/model_workers/xinghuo.py +++ b/server/model_workers/xinghuo.py @@ -13,6 +13,7 @@ async def request(appid, api_key, api_secret, Spark_url, domain, question, tempe wsParam = SparkApi.Ws_Param(appid, api_key, api_secret, Spark_url) wsUrl = wsParam.create_url() data = SparkApi.gen_params(appid, domain, question, temperature, max_token) + print(data) async with websockets.connect(wsUrl) as ws: await ws.send(json.dumps(data, ensure_ascii=False)) finish = False @@ -36,7 +37,7 @@ class XingHuoWorker(ApiModelWorker): **kwargs, ): kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 8192) + kwargs.setdefault("context_len", 8000) # TODO: V1模型的最大长度为4000,需要自行修改 super().__init__(**kwargs) self.version = version @@ -45,15 +46,14 @@ class XingHuoWorker(ApiModelWorker): params.load_config(self.model_names[0]) version_mapping = { - "v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat","max_tokens": 2048}, - "v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat","max_tokens": 4096}, - "v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat","max_tokens": 8192}, + "v1.5": {"domain": "general", "url": "ws://spark-api.xf-yun.com/v1.1/chat","max_tokens": 4000}, + "v2.0": {"domain": "generalv2", "url": "ws://spark-api.xf-yun.com/v2.1/chat","max_tokens": 8000}, + "v3.0": {"domain": "generalv3", "url": "ws://spark-api.xf-yun.com/v3.1/chat","max_tokens": 8000}, } def get_version_details(version_key): return version_mapping.get(version_key, {"domain": None, "url": None}) - # 使用方法: details = get_version_details(params.version) domain = details["domain"] Spark_url = details["url"]