commit
569209289b
|
|
@ -5,6 +5,7 @@ from server.utils import get_httpx_client
|
|||
from fastchat import conversation as conv
|
||||
import json
|
||||
from typing import List, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class AzureWorker(ApiModelWorker):
|
||||
|
|
@ -39,6 +40,11 @@ class AzureWorker(ApiModelWorker):
|
|||
}
|
||||
|
||||
text = ""
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||
for line in response.iter_lines():
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from fastchat import conversation as conv
|
|||
import sys
|
||||
import json
|
||||
from typing import List, Literal, Dict
|
||||
|
||||
from configs import logger, log_verbose
|
||||
|
||||
def calculate_md5(input_string):
|
||||
md5 = hashlib.md5()
|
||||
|
|
@ -56,6 +56,11 @@ class BaiChuanWorker(ApiModelWorker):
|
|||
}
|
||||
|
||||
text = ""
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:json_data: {json_data}')
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=data) as response:
|
||||
for line in response.iter_lines():
|
||||
|
|
@ -71,8 +76,14 @@ class BaiChuanWorker(ApiModelWorker):
|
|||
else:
|
||||
yield {
|
||||
"error_code": resp["code"],
|
||||
"text": resp["msg"]
|
||||
"text": resp["msg"],
|
||||
"error": {
|
||||
"message": resp["msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
@ -103,4 +114,4 @@ if __name__ == "__main__":
|
|||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
MakeFastAPIOffline(app)
|
||||
uvicorn.run(app, port=21007)
|
||||
# do_request()
|
||||
# do_request()
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from server.model_workers.base import *
|
|||
from fastchat import conversation as conv
|
||||
import sys
|
||||
from typing import List, Literal, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class FangZhouWorker(ApiModelWorker):
|
||||
|
|
@ -46,10 +47,21 @@ class FangZhouWorker(ApiModelWorker):
|
|||
}
|
||||
|
||||
text = ""
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:maas: {maas}')
|
||||
for resp in maas.stream_chat(req):
|
||||
error = resp.error
|
||||
if error.code_n > 0:
|
||||
yield {"error_code": error.code_n, "text": error.message}
|
||||
yield {
|
||||
"error_code": error.code_n,
|
||||
"text": error.message,
|
||||
"error": {
|
||||
"message": error.message,
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
elif chunk := resp.choice.message.content:
|
||||
text += chunk
|
||||
yield {"error_code": 0, "text": text}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import json
|
|||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
from server.utils import get_httpx_client
|
||||
from typing import List, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class MiniMaxWorker(ApiModelWorker):
|
||||
|
|
@ -59,6 +60,10 @@ class MiniMaxWorker(ApiModelWorker):
|
|||
# "bot_setting": [],
|
||||
# "role_meta": params.role_meta,
|
||||
}
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
||||
logger.info(f'{self.__class__.__name__}:url: {url.format(pro=pro, group_id=params.group_id)}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
response = client.stream("POST",
|
||||
|
|
@ -69,7 +74,16 @@ class MiniMaxWorker(ApiModelWorker):
|
|||
text = ""
|
||||
for e in r.iter_text():
|
||||
if not e.startswith("data: "): # 真是优秀的返回
|
||||
yield {"error_code": 500, "text": f"minimax返回错误的结果:{e}"}
|
||||
yield {
|
||||
"error_code": 500,
|
||||
"text": f"minimax返回错误的结果:{e}",
|
||||
"error": {
|
||||
"message": f"minimax返回错误的结果:{e}",
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
continue
|
||||
|
||||
data = json.loads(e[6:])
|
||||
|
|
@ -95,13 +109,27 @@ class MiniMaxWorker(ApiModelWorker):
|
|||
"texts": params.texts,
|
||||
"type": "query" if params.to_query else "db",
|
||||
}
|
||||
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
r = client.post(url, headers=headers, json=data).json()
|
||||
if embeddings := r.get("vectors"):
|
||||
return {"code": 200, "data": embeddings}
|
||||
elif error := r.get("base_resp"):
|
||||
return {"code": error["status_code"], "msg": error["status_msg"]}
|
||||
return {
|
||||
"code": error["status_code"],
|
||||
"msg": error["status_msg"],
|
||||
|
||||
"error": {
|
||||
"message": error["status_msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
|
||||
def get_embeddings(self, params):
|
||||
# TODO: 支持embeddings
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from fastchat import conversation as conv
|
|||
import sys
|
||||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
from typing import List, Literal, Dict
|
||||
from configs import logger, log_verbose
|
||||
|
||||
MODEL_VERSIONS = {
|
||||
"ernie-bot-4": "completions_pro",
|
||||
|
|
@ -132,6 +133,11 @@ class QianFanWorker(ApiModelWorker):
|
|||
}
|
||||
|
||||
text = ""
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:data: {payload}')
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
logger.info(f'{self.__class__.__name__}:headers: {headers}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
for line in response.iter_lines():
|
||||
|
|
@ -150,7 +156,13 @@ class QianFanWorker(ApiModelWorker):
|
|||
else:
|
||||
yield {
|
||||
"error_code": resp["error_code"],
|
||||
"text": resp["error_msg"]
|
||||
"text": resp["error_msg"],
|
||||
"error": {
|
||||
"message": resp["error_msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
|
|
@ -168,13 +180,25 @@ class QianFanWorker(ApiModelWorker):
|
|||
embed_model = params.embed_model or self.DEFAULT_EMBED_MODEL
|
||||
access_token = get_baidu_access_token(params.api_key, params.secret_key)
|
||||
url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/embeddings/{embed_model}?access_token={access_token}"
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:url: {url}')
|
||||
|
||||
with get_httpx_client() as client:
|
||||
resp = client.post(url, json={"input": params.texts}).json()
|
||||
if "error_cdoe" not in resp:
|
||||
embeddings = [x["embedding"] for x in resp.get("data", [])]
|
||||
return {"code": 200, "data": embeddings}
|
||||
else:
|
||||
return {"code": resp["error_code"], "msg": resp["error_msg"]}
|
||||
return {
|
||||
"code": resp["error_code"],
|
||||
"msg": resp["error_msg"],
|
||||
"error": {
|
||||
"message": resp["error_msg"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
|
||||
# TODO: qianfan支持续写模型
|
||||
def get_embeddings(self, params):
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from typing import List, Literal, Dict
|
|||
from fastchat import conversation as conv
|
||||
from server.model_workers.base import *
|
||||
from server.model_workers.base import ApiEmbeddingsParams
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class QwenWorker(ApiModelWorker):
|
||||
|
|
@ -31,6 +32,8 @@ class QwenWorker(ApiModelWorker):
|
|||
def do_chat(self, params: ApiChatParams) -> Dict:
|
||||
import dashscope
|
||||
params.load_config(self.model_names[0])
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:params: {params}')
|
||||
|
||||
gen = dashscope.Generation()
|
||||
responses = gen.call(
|
||||
|
|
@ -53,12 +56,20 @@ class QwenWorker(ApiModelWorker):
|
|||
yield {
|
||||
"error_code": resp["status_code"],
|
||||
"text": resp["message"],
|
||||
"error": {
|
||||
"message": resp["message"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
import dashscope
|
||||
params.load_config(self.model_names[0])
|
||||
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:params: {params}')
|
||||
result = []
|
||||
i = 0
|
||||
while i < len(params.texts):
|
||||
|
|
@ -69,7 +80,16 @@ class QwenWorker(ApiModelWorker):
|
|||
api_key=params.api_key,
|
||||
)
|
||||
if resp["status_code"] != 200:
|
||||
return {"code": resp["status_code"], "msg": resp.message}
|
||||
return {
|
||||
"code": resp["status_code"],
|
||||
"msg": resp.message,
|
||||
"error": {
|
||||
"message": resp["message"],
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
else:
|
||||
embeddings = [x["embedding"] for x in resp["output"]["embeddings"]]
|
||||
result += embeddings
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from server.model_workers.base import *
|
|||
from fastchat import conversation as conv
|
||||
import sys
|
||||
from typing import List, Dict, Iterator, Literal
|
||||
from configs import logger, log_verbose
|
||||
|
||||
|
||||
class ChatGLMWorker(ApiModelWorker):
|
||||
|
|
@ -29,6 +30,9 @@ class ChatGLMWorker(ApiModelWorker):
|
|||
params.load_config(self.model_names[0])
|
||||
zhipuai.api_key = params.api_key
|
||||
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:params: {params}')
|
||||
|
||||
response = zhipuai.model_api.sse_invoke(
|
||||
model=params.version,
|
||||
prompt=params.messages,
|
||||
|
|
@ -40,7 +44,16 @@ class ChatGLMWorker(ApiModelWorker):
|
|||
if e.event == "add":
|
||||
yield {"error_code": 0, "text": e.data}
|
||||
elif e.event in ["error", "interrupted"]:
|
||||
yield {"error_code": 500, "text": str(e)}
|
||||
yield {
|
||||
"error_code": 500,
|
||||
"text": str(e),
|
||||
"error": {
|
||||
"message": str(e),
|
||||
"type": "invalid_request_error",
|
||||
"param": None,
|
||||
"code": None,
|
||||
}
|
||||
}
|
||||
|
||||
def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict:
|
||||
import zhipuai
|
||||
|
|
@ -55,7 +68,7 @@ class ChatGLMWorker(ApiModelWorker):
|
|||
if response["code"] == 200:
|
||||
embeddings.append(response["data"]["embedding"])
|
||||
else:
|
||||
return response # dict with code & msg
|
||||
return response # dict with code & msg
|
||||
except Exception as e:
|
||||
return {"code": 500, "msg": f"对文本向量化时出错:{e}"}
|
||||
|
||||
|
|
|
|||
|
|
@ -639,7 +639,10 @@ def get_httpx_client(
|
|||
|
||||
# construct Client
|
||||
kwargs.update(timeout=timeout, proxies=default_proxies)
|
||||
print(kwargs)
|
||||
|
||||
if log_verbose:
|
||||
logger.info(f'{get_httpx_client.__class__.__name__}:kwargs: {kwargs}')
|
||||
|
||||
if use_async:
|
||||
return httpx.AsyncClient(**kwargs)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -85,6 +85,7 @@ class ApiRequest:
|
|||
) -> Union[httpx.Response, Iterator[httpx.Response], None]:
|
||||
while retry > 0:
|
||||
try:
|
||||
print(kwargs)
|
||||
if stream:
|
||||
return self.client.stream("POST", url, data=data, json=json, **kwargs)
|
||||
else:
|
||||
|
|
@ -745,6 +746,9 @@ class ApiRequest:
|
|||
"controller_address": controller_address,
|
||||
}
|
||||
|
||||
if log_verbose:
|
||||
logger.info(f'{self.__class__.__name__}:data: {data}')
|
||||
|
||||
response = self.post(
|
||||
"/llm_model/list_running_models",
|
||||
json=data,
|
||||
|
|
|
|||
Loading…
Reference in New Issue