update server.model_worker.zhipu.py:更新prompt,遵守zhipu的格式要求
This commit is contained in:
parent
5e4bd5c3d3
commit
76c2c61bb2
|
|
@ -29,11 +29,11 @@ class ChatGLMWorker(ApiModelWorker):
|
||||||
# 这里的是chatglm api的模板,其它API的conv_template需要定制
|
# 这里的是chatglm api的模板,其它API的conv_template需要定制
|
||||||
self.conv = conv.Conversation(
|
self.conv = conv.Conversation(
|
||||||
name="chatglm-api",
|
name="chatglm-api",
|
||||||
system_message="你是一个聪明、对人类有帮助的人工智能,你可以对人类提出的问题给出有用、详细、礼貌的回答。",
|
system_message="",
|
||||||
messages=[],
|
messages=[],
|
||||||
roles=["Human", "Assistant"],
|
roles=["user", "assistant"],
|
||||||
sep="\n### ",
|
sep="",
|
||||||
stop_str="###",
|
stop_str="",
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_stream_gate(self, params):
|
def generate_stream_gate(self, params):
|
||||||
|
|
@ -41,13 +41,13 @@ class ChatGLMWorker(ApiModelWorker):
|
||||||
super().generate_stream_gate(params)
|
super().generate_stream_gate(params)
|
||||||
|
|
||||||
model=self.version
|
model=self.version
|
||||||
# if isinstance(params["prompt"], str):
|
input_prompt = params["prompt"].replace("user:","").replace("assistant:","")
|
||||||
# prompt=self.prompt_collator(content_user=params["prompt"],
|
|
||||||
# role_user="user") #[{"role": "user", "content": params["prompt"]}]
|
if isinstance(input_prompt, str):
|
||||||
# else:
|
prompt=self.prompt_collator(content_user=input_prompt,
|
||||||
# prompt = params["prompt"]
|
role_user="user") #[{"role": "user", "content": params["prompt"]}]
|
||||||
|
else:
|
||||||
prompt = params["prompt"]
|
prompt = params["prompt"]
|
||||||
print(prompt)
|
|
||||||
temperature=params.get("temperature")
|
temperature=params.get("temperature")
|
||||||
top_p=params.get("top_p")
|
top_p=params.get("top_p")
|
||||||
stream = params.get("stream")
|
stream = params.get("stream")
|
||||||
|
|
@ -122,7 +122,7 @@ class ChatGLMWorker(ApiModelWorker):
|
||||||
)
|
)
|
||||||
for event in response.events():
|
for event in response.events():
|
||||||
if event.event == "add":
|
if event.event == "add":
|
||||||
# yield event.data
|
print(f"SSE返回结果 {event.data}")
|
||||||
yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0"
|
yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0"
|
||||||
elif event.event == "error" or event.event == "interrupted":
|
elif event.event == "error" or event.event == "interrupted":
|
||||||
# return event.data
|
# return event.data
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue