update server.model_worker.zhipu.py/base.py:增加流式响应

This commit is contained in:
hzg0601 2023-09-06 11:16:16 +08:00
parent 6f039cfdeb
commit 5e4bd5c3d3
2 changed files with 155 additions and 149 deletions

View File

@ -88,145 +88,3 @@ class ApiModelWorker(BaseModelWorker):
prompt.append(prompt_dict) prompt.append(prompt_dict)
return prompt return prompt
def create_oneshot(self,
message: List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}],
model:str = "chatglm_pro",
top_p:float=0.7,
temperature:float=0.9,
**kwargs
):
response = self.zhipuai.model_api.invoke(
model = model,
prompt = message,
top_p = top_p,
temperature = temperature
)
if response["code"] == 200:
result = response["data"]["choices"][-1]["content"]
return json.dumps({"error_code": 0, "text": result}, ensure_ascii=False).encode() + b"\0"
else:
#TODO 确认openai的error code
print(f"error occurred, error code:{response['code']},error msg:{response['msg']}")
return json.dumps({"error_code": response['code'],
"text": f"error occurred, error code:{response['code']},error msg:{response['msg']}"
},
ensure_ascii=False).encode() + b"\0"
def create_stream(self,
message: List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}],
model:str = "chatglm_pro",
top_p:float=0.7,
temperature:float=0.9,
**kwargs
):
response = self.zhipuai.model_api.sse_invoke(
model = model,
prompt = message,
top_p = top_p,
temperature = temperature,
incremental = True
)
for event in response.events():
if event.event == "add":
# yield event.data
yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0"
elif event.event == "error" or event.event == "interrupted":
# return event.data
yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0"
elif event.event == "finish":
# yield event.data
yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0"
print(event.meta)
else:
print("Something get wrong with ZhipuAPILoader.create_chat_completion_stream")
print(event.data)
yield json.dumps({"error_code": 1, "text": event.data}, ensure_ascii=False).encode() + b"\0"
def create_chat_completion(self,
model: str = "chatglm_pro",
prompt:List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}],
top_p:float=0.7,
temperature:float=0.9,
stream:bool=False):
if stream:
return self.create_stream(model=model,
message=prompt,
top_p=top_p,
temperature=temperature)
else:
return self.create_oneshot(model=model,
message=prompt,
top_p=top_p,
temperature=temperature)
async def acreate_chat_completion(self,
prompt: List[Dict[str,str]]=[{"role":"system","content":"你是一个人工智能助手"},
{"role":"user","content":"你好。"}],
model:str = "chatglm_pro",
top_p:float=0.7,
temperature:float=0.9,
**kwargs):
response = await self.zhipuai.model_api.async_invoke(
model = model,
prompt = prompt,
top_p = top_p,
temperature = temperature
)
if response["code"] == 200:
task_id = response['data']['task_id']
status = "PROCESSING"
while status != "SUCCESS":
# await asyncio.sleep(3) #
resp = self.zhipuai.model_api.query_async_invoke_result(task_id)
status = resp['data']['task_status']
return resp['data']['choices'][-1]['content']
else:
print(f"error occurred, error code:{response['code']},error msg:{response['msg']}")
return
def create_completion(self,
prompt:str="你好",
model:str="chatglm_pro",
top_p:float=0.7,
temperature:float=0.9,
stream:bool=False,
**kwargs):
message = self.prompt_collator(content_user=prompt)
if stream:
return self.create_stream(model=model,
message=message,
top_p=top_p,
temperature=temperature)
else:
return self.create_oneshot(model=model,
message=message,
top_p=top_p,
temperature=temperature)
#? make it a sync function?
async def acreate_completion(self,
prompt:str="你好",
model:str = "chatglm_pro",
top_p:float=0.7,
temperature:float=0.9,
**kwargs):
message = self.prompt_collator(content_user=prompt)
response = self.zhipuai.model_api.async_invoke(
model = model,
prompt = message,
top_p = top_p,
temperature = temperature
)
if response["code"] == 200:
task_id = response['data']['task_id']
status = "PROCESSING"
while status != "SUCCESS":
# await asyncio.sleep(3) #
resp = self.zhipuai.model_api.query_async_invoke_result(task_id)
status = resp['data']['task_status']
return resp['data']['choices'][-1]['content']
else:
print(f"error occurred, error code:{response['code']},error msg:{response['msg']}")
return

View File

@ -3,7 +3,7 @@ from server.model_workers.base import ApiModelWorker
from fastchat import conversation as conv from fastchat import conversation as conv
import sys import sys
import json import json
from typing import List, Literal from typing import List, Literal,Dict
class ChatGLMWorker(ApiModelWorker): class ChatGLMWorker(ApiModelWorker):
@ -23,7 +23,9 @@ class ChatGLMWorker(ApiModelWorker):
kwargs.setdefault("context_len", 32768) kwargs.setdefault("context_len", 32768)
super().__init__(**kwargs) super().__init__(**kwargs)
self.version = version self.version = version
self.zhipuai = zhipuai
from server.utils import get_model_worker_config
self.zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key")
# 这里的是chatglm api的模板其它API的conv_template需要定制 # 这里的是chatglm api的模板其它API的conv_template需要定制
self.conv = conv.Conversation( self.conv = conv.Conversation(
name="chatglm-api", name="chatglm-api",
@ -36,13 +38,16 @@ class ChatGLMWorker(ApiModelWorker):
def generate_stream_gate(self, params): def generate_stream_gate(self, params):
# TODO: 支持stream参数维护request_id传过来的prompt也有问题 # TODO: 支持stream参数维护request_id传过来的prompt也有问题
from server.utils import get_model_worker_config
super().generate_stream_gate(params) super().generate_stream_gate(params)
zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key")
model=self.version, model=self.version
prompt=self.prompt_collator(content_user=params["prompt"]) #[{"role": "user", "content": params["prompt"]}] # if isinstance(params["prompt"], str):
# prompt=self.prompt_collator(content_user=params["prompt"],
# role_user="user") #[{"role": "user", "content": params["prompt"]}]
# else:
# prompt = params["prompt"]
prompt = params["prompt"]
print(prompt)
temperature=params.get("temperature") temperature=params.get("temperature")
top_p=params.get("top_p") top_p=params.get("top_p")
stream = params.get("stream") stream = params.get("stream")
@ -77,6 +82,149 @@ class ChatGLMWorker(ApiModelWorker):
print("embedding") print("embedding")
print(params) print(params)
def create_oneshot(self,
message: List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}],
model:str = "chatglm_pro",
top_p:float=0.7,
temperature:float=0.9,
**kwargs
):
response = self.zhipuai.model_api.invoke(
model = model,
prompt = message,
top_p = top_p,
temperature = temperature
)
if response["code"] == 200:
result = response["data"]["choices"][-1]["content"]
return json.dumps({"error_code": 0, "text": result}, ensure_ascii=False).encode() + b"\0"
else:
#TODO 确认openai的error code
print(f"error occurred, error code:{response['code']},error msg:{response['msg']}")
return json.dumps({"error_code": response['code'],
"text": f"error occurred, error code:{response['code']},error msg:{response['msg']}"
},
ensure_ascii=False).encode() + b"\0"
def create_stream(self,
message: List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}],
model:str = "chatglm_pro",
top_p:float=0.7,
temperature:float=0.9,
**kwargs
):
response = self.zhipuai.model_api.sse_invoke(
model = model,
prompt = message,
top_p = top_p,
temperature = temperature,
incremental = True
)
for event in response.events():
if event.event == "add":
# yield event.data
yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0"
elif event.event == "error" or event.event == "interrupted":
# return event.data
yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0"
elif event.event == "finish":
# yield event.data
yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0"
print(event.meta)
else:
print("Something get wrong with ZhipuAPILoader.create_chat_completion_stream")
print(event.data)
yield json.dumps({"error_code": 1, "text": event.data}, ensure_ascii=False).encode() + b"\0"
def create_chat_completion(self,
model: str = "chatglm_pro",
prompt:List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}],
top_p:float=0.7,
temperature:float=0.9,
stream:bool=False):
if stream:
return self.create_stream(model=model,
message=prompt,
top_p=top_p,
temperature=temperature)
else:
return self.create_oneshot(model=model,
message=prompt,
top_p=top_p,
temperature=temperature)
async def acreate_chat_completion(self,
prompt: List[Dict[str,str]]=[{"role":"system","content":"你是一个人工智能助手"},
{"role":"user","content":"你好。"}],
model:str = "chatglm_pro",
top_p:float=0.7,
temperature:float=0.9,
**kwargs):
response = await self.zhipuai.model_api.async_invoke(
model = model,
prompt = prompt,
top_p = top_p,
temperature = temperature
)
if response["code"] == 200:
task_id = response['data']['task_id']
status = "PROCESSING"
while status != "SUCCESS":
# await asyncio.sleep(3) #
resp = self.zhipuai.model_api.query_async_invoke_result(task_id)
status = resp['data']['task_status']
return resp['data']['choices'][-1]['content']
else:
print(f"error occurred, error code:{response['code']},error msg:{response['msg']}")
return
def create_completion(self,
prompt:str="你好",
model:str="chatglm_pro",
top_p:float=0.7,
temperature:float=0.9,
stream:bool=False,
**kwargs):
message = self.prompt_collator(content_user=prompt)
if stream:
return self.create_stream(model=model,
message=message,
top_p=top_p,
temperature=temperature)
else:
return self.create_oneshot(model=model,
message=message,
top_p=top_p,
temperature=temperature)
#? make it a sync function?
async def acreate_completion(self,
prompt:str="你好",
model:str = "chatglm_pro",
top_p:float=0.7,
temperature:float=0.9,
**kwargs):
message = self.prompt_collator(content_user=prompt)
response = self.zhipuai.model_api.async_invoke(
model = model,
prompt = message,
top_p = top_p,
temperature = temperature
)
if response["code"] == 200:
task_id = response['data']['task_id']
status = "PROCESSING"
while status != "SUCCESS":
# await asyncio.sleep(3) #
resp = self.zhipuai.model_api.query_async_invoke_result(task_id)
status = resp['data']['task_status']
return resp['data']['choices'][-1]['content']
else:
print(f"error occurred, error code:{response['code']},error msg:{response['msg']}")
return
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn