diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index a15b2b2..56f51ca 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -29,11 +29,11 @@ class ChatGLMWorker(ApiModelWorker): # 这里的是chatglm api的模板,其它API的conv_template需要定制 self.conv = conv.Conversation( name="chatglm-api", - system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。", + system_message="", messages=[], - roles=["Human", "Assistant"], - sep="\n### ", - stop_str="###", + roles=["user", "assistant"], + sep="", + stop_str="", ) def generate_stream_gate(self, params): @@ -41,13 +41,13 @@ class ChatGLMWorker(ApiModelWorker): super().generate_stream_gate(params) model=self.version - # if isinstance(params["prompt"], str): - # prompt=self.prompt_collator(content_user=params["prompt"], - # role_user="user") #[{"role": "user", "content": params["prompt"]}] - # else: - # prompt = params["prompt"] - prompt = params["prompt"] - print(prompt) + input_prompt = params["prompt"].replace("user:","").replace("assistant:","") + + if isinstance(input_prompt, str): + prompt=self.prompt_collator(content_user=input_prompt, + role_user="user") #[{"role": "user", "content": params["prompt"]}] + else: + prompt = params["prompt"] temperature=params.get("temperature") top_p=params.get("top_p") stream = params.get("stream") @@ -122,7 +122,7 @@ class ChatGLMWorker(ApiModelWorker): ) for event in response.events(): if event.event == "add": - # yield event.data + print(f"SSE返回结果 {event.data}") yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0" elif event.event == "error" or event.event == "interrupted": # return event.data