From 76c2c61bb24c9b3ce62daa593b7e01665bd6501d Mon Sep 17 00:00:00 2001 From: hzg0601 Date: Wed, 6 Sep 2023 17:33:00 +0800 Subject: [PATCH] =?UTF-8?q?update=20server.model=5Fworker.zhipu.py:?= =?UTF-8?q?=E6=9B=B4=E6=96=B0prompt=EF=BC=8C=E9=81=B5=E5=AE=88zhipu?= =?UTF-8?q?=E7=9A=84=E6=A0=BC=E5=BC=8F=E8=A6=81=E6=B1=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/model_workers/zhipu.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index a15b2b2..56f51ca 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -29,11 +29,11 @@ class ChatGLMWorker(ApiModelWorker): # 这里的是chatglm api的模板,其它API的conv_template需要定制 self.conv = conv.Conversation( name="chatglm-api", - system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", + system_message="", messages=[], - roles=["Human", "Assistant"], - sep="\n### ", - stop_str="###", + roles=["user", "assistant"], + sep="", + stop_str="", ) def generate_stream_gate(self, params): @@ -41,13 +41,13 @@ class ChatGLMWorker(ApiModelWorker): super().generate_stream_gate(params) model=self.version - # if isinstance(params["prompt"], str): - # prompt=self.prompt_collator(content_user=params["prompt"], - # role_user="user") #[{"role": "user", "content": params["prompt"]}] - # else: - # prompt = params["prompt"] - prompt = params["prompt"] - print(prompt) + input_prompt = params["prompt"].replace("user:","").replace("assistant:","") + + if isinstance(input_prompt, str): + prompt=self.prompt_collator(content_user=input_prompt, + role_user="user") #[{"role": "user", "content": params["prompt"]}] + else: + prompt = params["prompt"] temperature=params.get("temperature") top_p=params.get("top_p") stream = params.get("stream") @@ -122,7 +122,7 @@ class ChatGLMWorker(ApiModelWorker): ) for event in response.events(): if event.event == "add": - # yield event.data + print(f"SSE返回结果 {event.data}") yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0" elif event.event == "error" or event.event == "interrupted": # return event.data