diff --git a/server/model_workers/base.py b/server/model_workers/base.py index b72f683..672b0bc 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -69,3 +69,164 @@ class ApiModelWorker(BaseModelWorker): target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, ) self.heart_beat_thread.start() + + def prompt_collator(self, + content_user: str = None, + role_user:str = "user", + content_assistant: str = None, + role_assistant: str = "assistant", + meta_prompt:List[Dict[str,str]] = [{"role":"system","content":"你是一个AI工具"}], + use_meta_prompt:bool=False): + prompt = [] + if use_meta_prompt: + prompt += meta_prompt + if content_user: + prompt_dict = {"role": role_user, "content":content_user} + prompt.append(prompt_dict) + if content_assistant: + prompt_dict = {"role": role_assistant, "content":content_assistant} + prompt.append(prompt_dict) + return prompt + + def create_oneshot(self, + message: List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}], + model:str = "chatglm_pro", + top_p:float=0.7, + temperature:float=0.9, + **kwargs + ): + response = self.zhipuai.model_api.invoke( + model = model, + prompt = message, + top_p = top_p, + temperature = temperature + ) + if response["code"] == 200: + result = response["data"]["choices"][-1]["content"] + return json.dumps({"error_code": 0, "text": result}, ensure_ascii=False).encode() + b"\0" + else: + #TODO 确认openai的error code + print(f"error occurred, error code:{response['code']},error msg:{response['msg']}") + return json.dumps({"error_code": response['code'], + "text": f"error occurred, error code:{response['code']},error msg:{response['msg']}" + }, + ensure_ascii=False).encode() + b"\0" + + def create_stream(self, + message: List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}], + model:str = "chatglm_pro", + top_p:float=0.7, + temperature:float=0.9, + **kwargs + ): + response = self.zhipuai.model_api.sse_invoke( + model = model, + prompt = message, + top_p = top_p, + temperature = temperature, + incremental = True + ) + for event in response.events(): + if event.event == "add": + # yield event.data + yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0" + elif event.event == "error" or event.event == "interrupted": + # return event.data + yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0" + elif event.event == "finish": + # yield event.data + yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0" + print(event.meta) + else: + print("Something get wrong with ZhipuAPILoader.create_chat_completion_stream") + print(event.data) + yield json.dumps({"error_code": 1, "text": event.data}, ensure_ascii=False).encode() + b"\0" + + def create_chat_completion(self, + model: str = "chatglm_pro", + prompt:List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}], + top_p:float=0.7, + temperature:float=0.9, + stream:bool=False): + + if stream: + return self.create_stream(model=model, + message=prompt, + top_p=top_p, + temperature=temperature) + else: + return self.create_oneshot(model=model, + message=prompt, + top_p=top_p, + temperature=temperature) + + async def acreate_chat_completion(self, + prompt: List[Dict[str,str]]=[{"role":"system","content":"你是一个人工智能助手"}, + {"role":"user","content":"你好。"}], + model:str = "chatglm_pro", + top_p:float=0.7, + temperature:float=0.9, + **kwargs): + response = await self.zhipuai.model_api.async_invoke( + model = model, + prompt = prompt, + top_p = top_p, + temperature = temperature + ) + + if response["code"] == 200: + task_id = response['data']['task_id'] + status = "PROCESSING" + while status != "SUCCESS": + # await asyncio.sleep(3) # + resp = self.zhipuai.model_api.query_async_invoke_result(task_id) + status = resp['data']['task_status'] + return resp['data']['choices'][-1]['content'] + else: + print(f"error occurred, error code:{response['code']},error msg:{response['msg']}") + return + + def create_completion(self, + prompt:str="你好", + model:str="chatglm_pro", + top_p:float=0.7, + temperature:float=0.9, + stream:bool=False, + **kwargs): + message = self.prompt_collator(content_user=prompt) + if stream: + return self.create_stream(model=model, + message=message, + top_p=top_p, + temperature=temperature) + else: + return self.create_oneshot(model=model, + message=message, + top_p=top_p, + temperature=temperature) + #? make it a sync function? + async def acreate_completion(self, + prompt:str="你好", + model:str = "chatglm_pro", + top_p:float=0.7, + temperature:float=0.9, + **kwargs): + message = self.prompt_collator(content_user=prompt) + response = self.zhipuai.model_api.async_invoke( + model = model, + prompt = message, + top_p = top_p, + temperature = temperature + ) + + if response["code"] == 200: + task_id = response['data']['task_id'] + status = "PROCESSING" + while status != "SUCCESS": + # await asyncio.sleep(3) # + resp = self.zhipuai.model_api.query_async_invoke_result(task_id) + status = resp['data']['task_status'] + return resp['data']['choices'][-1]['content'] + else: + print(f"error occurred, error code:{response['code']},error msg:{response['msg']}") + return \ No newline at end of file diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 82eab5d..3ae7535 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -1,3 +1,4 @@ +import zhipuai from server.model_workers.base import ApiModelWorker from fastchat import conversation as conv import sys @@ -36,24 +37,40 @@ class ChatGLMWorker(ApiModelWorker): def generate_stream_gate(self, params): # TODO: 支持stream参数,维护request_id,传过来的prompt也有问题 from server.utils import get_model_worker_config - import zhipuai super().generate_stream_gate(params) zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key") - response = zhipuai.model_api.sse_invoke( - model=self.version, - prompt=[{"role": "user", "content": params["prompt"]}], - temperature=params.get("temperature"), - top_p=params.get("top_p"), - incremental=False, - ) - for e in response.events(): - if e.event == "add": - yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0" - # TODO: 更健壮的消息处理 - # elif e.event == "finish": - # ... + model=self.version, + prompt=self.prompt_collator(content_user=params["prompt"]) #[{"role": "user", "content": params["prompt"]}] + temperature=params.get("temperature") + top_p=params.get("top_p") + stream = params.get("stream") + + if stream: + return self.create_stream(model=model, + message=prompt, + top_p=top_p, + temperature=temperature) + else: + return self.create_oneshot(model=model, + message=prompt, + top_p=top_p, + temperature=temperature) + + # response = zhipuai.model_api.sse_invoke( + # model=self.version, + # prompt=[{"role": "user", "content": params["prompt"]}], + # temperature=params.get("temperature"), + # top_p=params.get("top_p"), + # incremental=False, + # ) + # for e in response.events(): + # if e.event == "add": + # yield json.dumps({"error_code": 0, "text": e.data}, ensure_ascii=False).encode() + b"\0" + # # TODO: 更健壮的消息处理 + # # elif e.event == "finish": + # # ... def get_embeddings(self, params): # TODO: 支持embeddings