update server.model_worker.zhipu.py:更新prompt,遵守zhipu的格式要求

This commit is contained in:
hzg0601 2023-09-06 17:33:00 +08:00
parent 5e4bd5c3d3
commit 76c2c61bb2
1 changed files with 12 additions and 12 deletions

View File

@ -29,11 +29,11 @@ class ChatGLMWorker(ApiModelWorker):
# 这里的是chatglm api的模板其它API的conv_template需要定制
self.conv = conv.Conversation(
name="chatglm-api",
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
system_message="",
messages=[],
roles=["Human", "Assistant"],
sep="\n### ",
stop_str="###",
roles=["user", "assistant"],
sep="",
stop_str="",
)
def generate_stream_gate(self, params):
@ -41,13 +41,13 @@ class ChatGLMWorker(ApiModelWorker):
super().generate_stream_gate(params)
model=self.version
# if isinstance(params["prompt"], str):
# prompt=self.prompt_collator(content_user=params["prompt"],
# role_user="user") #[{"role": "user", "content": params["prompt"]}]
# else:
# prompt = params["prompt"]
prompt = params["prompt"]
print(prompt)
input_prompt = params["prompt"].replace("user:","").replace("assistant:","")
if isinstance(input_prompt, str):
prompt=self.prompt_collator(content_user=input_prompt,
role_user="user") #[{"role": "user", "content": params["prompt"]}]
else:
prompt = params["prompt"]
temperature=params.get("temperature")
top_p=params.get("top_p")
stream = params.get("stream")
@ -122,7 +122,7 @@ class ChatGLMWorker(ApiModelWorker):
)
for event in response.events():
if event.event == "add":
# yield event.data
print(f"SSE返回结果 {event.data}")
yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0"
elif event.event == "error" or event.event == "interrupted":
# return event.data