diff --git a/server/model_workers/base.py b/server/model_workers/base.py index 672b0bc..59c5d79 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -88,145 +88,3 @@ class ApiModelWorker(BaseModelWorker): prompt.append(prompt_dict) return prompt - def create_oneshot(self, - message: List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}], - model:str = "chatglm_pro", - top_p:float=0.7, - temperature:float=0.9, - **kwargs - ): - response = self.zhipuai.model_api.invoke( - model = model, - prompt = message, - top_p = top_p, - temperature = temperature - ) - if response["code"] == 200: - result = response["data"]["choices"][-1]["content"] - return json.dumps({"error_code": 0, "text": result}, ensure_ascii=False).encode() + b"\0" - else: - #TODO 确认openai的error code - print(f"error occurred, error code:{response['code']},error msg:{response['msg']}") - return json.dumps({"error_code": response['code'], - "text": f"error occurred, error code:{response['code']},error msg:{response['msg']}" - }, - ensure_ascii=False).encode() + b"\0" - - def create_stream(self, - message: List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}], - model:str = "chatglm_pro", - top_p:float=0.7, - temperature:float=0.9, - **kwargs - ): - response = self.zhipuai.model_api.sse_invoke( - model = model, - prompt = message, - top_p = top_p, - temperature = temperature, - incremental = True - ) - for event in response.events(): - if event.event == "add": - # yield event.data - yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0" - elif event.event == "error" or event.event == "interrupted": - # return event.data - yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0" - elif event.event == "finish": - # yield event.data - yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0" - print(event.meta) - else: - print("Something get wrong with ZhipuAPILoader.create_chat_completion_stream") - print(event.data) - yield json.dumps({"error_code": 1, "text": event.data}, ensure_ascii=False).encode() + b"\0" - - def create_chat_completion(self, - model: str = "chatglm_pro", - prompt:List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}], - top_p:float=0.7, - temperature:float=0.9, - stream:bool=False): - - if stream: - return self.create_stream(model=model, - message=prompt, - top_p=top_p, - temperature=temperature) - else: - return self.create_oneshot(model=model, - message=prompt, - top_p=top_p, - temperature=temperature) - - async def acreate_chat_completion(self, - prompt: List[Dict[str,str]]=[{"role":"system","content":"你是一个人工智能助手"}, - {"role":"user","content":"你好。"}], - model:str = "chatglm_pro", - top_p:float=0.7, - temperature:float=0.9, - **kwargs): - response = await self.zhipuai.model_api.async_invoke( - model = model, - prompt = prompt, - top_p = top_p, - temperature = temperature - ) - - if response["code"] == 200: - task_id = response['data']['task_id'] - status = "PROCESSING" - while status != "SUCCESS": - # await asyncio.sleep(3) # - resp = self.zhipuai.model_api.query_async_invoke_result(task_id) - status = resp['data']['task_status'] - return resp['data']['choices'][-1]['content'] - else: - print(f"error occurred, error code:{response['code']},error msg:{response['msg']}") - return - - def create_completion(self, - prompt:str="你好", - model:str="chatglm_pro", - top_p:float=0.7, - temperature:float=0.9, - stream:bool=False, - **kwargs): - message = self.prompt_collator(content_user=prompt) - if stream: - return self.create_stream(model=model, - message=message, - top_p=top_p, - temperature=temperature) - else: - return self.create_oneshot(model=model, - message=message, - top_p=top_p, - temperature=temperature) - #? make it a sync function? - async def acreate_completion(self, - prompt:str="你好", - model:str = "chatglm_pro", - top_p:float=0.7, - temperature:float=0.9, - **kwargs): - message = self.prompt_collator(content_user=prompt) - response = self.zhipuai.model_api.async_invoke( - model = model, - prompt = message, - top_p = top_p, - temperature = temperature - ) - - if response["code"] == 200: - task_id = response['data']['task_id'] - status = "PROCESSING" - while status != "SUCCESS": - # await asyncio.sleep(3) # - resp = self.zhipuai.model_api.query_async_invoke_result(task_id) - status = resp['data']['task_status'] - return resp['data']['choices'][-1]['content'] - else: - print(f"error occurred, error code:{response['code']},error msg:{response['msg']}") - return \ No newline at end of file diff --git a/server/model_workers/zhipu.py b/server/model_workers/zhipu.py index 3ae7535..a15b2b2 100644 --- a/server/model_workers/zhipu.py +++ b/server/model_workers/zhipu.py @@ -3,7 +3,7 @@ from server.model_workers.base import ApiModelWorker from fastchat import conversation as conv import sys import json -from typing import List, Literal +from typing import List, Literal,Dict class ChatGLMWorker(ApiModelWorker): @@ -23,7 +23,9 @@ class ChatGLMWorker(ApiModelWorker): kwargs.setdefault("context_len", 32768) super().__init__(**kwargs) self.version = version - + self.zhipuai = zhipuai + from server.utils import get_model_worker_config + self.zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key") # 这里的是chatglm api的模板,其它API的conv_template需要定制 self.conv = conv.Conversation( name="chatglm-api", @@ -36,13 +38,16 @@ class ChatGLMWorker(ApiModelWorker): def generate_stream_gate(self, params): # TODO: 支持stream参数,维护request_id,传过来的prompt也有问题 - from server.utils import get_model_worker_config - super().generate_stream_gate(params) - zhipuai.api_key = get_model_worker_config("chatglm-api").get("api_key") - model=self.version, - prompt=self.prompt_collator(content_user=params["prompt"]) #[{"role": "user", "content": params["prompt"]}] + model=self.version + # if isinstance(params["prompt"], str): + # prompt=self.prompt_collator(content_user=params["prompt"], + # role_user="user") #[{"role": "user", "content": params["prompt"]}] + # else: + # prompt = params["prompt"] + prompt = params["prompt"] + print(prompt) temperature=params.get("temperature") top_p=params.get("top_p") stream = params.get("stream") @@ -77,6 +82,149 @@ class ChatGLMWorker(ApiModelWorker): print("embedding") print(params) + def create_oneshot(self, + message: List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}], + model:str = "chatglm_pro", + top_p:float=0.7, + temperature:float=0.9, + **kwargs + ): + response = self.zhipuai.model_api.invoke( + model = model, + prompt = message, + top_p = top_p, + temperature = temperature + ) + if response["code"] == 200: + result = response["data"]["choices"][-1]["content"] + return json.dumps({"error_code": 0, "text": result}, ensure_ascii=False).encode() + b"\0" + else: + #TODO 确认openai的error code + print(f"error occurred, error code:{response['code']},error msg:{response['msg']}") + return json.dumps({"error_code": response['code'], + "text": f"error occurred, error code:{response['code']},error msg:{response['msg']}" + }, + ensure_ascii=False).encode() + b"\0" + + def create_stream(self, + message: List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}], + model:str = "chatglm_pro", + top_p:float=0.7, + temperature:float=0.9, + **kwargs + ): + response = self.zhipuai.model_api.sse_invoke( + model = model, + prompt = message, + top_p = top_p, + temperature = temperature, + incremental = True + ) + for event in response.events(): + if event.event == "add": + # yield event.data + yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0" + elif event.event == "error" or event.event == "interrupted": + # return event.data + yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0" + elif event.event == "finish": + # yield event.data + yield json.dumps({"error_code": 0, "text": event.data}, ensure_ascii=False).encode() + b"\0" + print(event.meta) + else: + print("Something get wrong with ZhipuAPILoader.create_chat_completion_stream") + print(event.data) + yield json.dumps({"error_code": 1, "text": event.data}, ensure_ascii=False).encode() + b"\0" + + def create_chat_completion(self, + model: str = "chatglm_pro", + prompt:List[Dict[str,str]]=[{"role":"user","content":"你好,你可以做什么"}], + top_p:float=0.7, + temperature:float=0.9, + stream:bool=False): + + if stream: + return self.create_stream(model=model, + message=prompt, + top_p=top_p, + temperature=temperature) + else: + return self.create_oneshot(model=model, + message=prompt, + top_p=top_p, + temperature=temperature) + + async def acreate_chat_completion(self, + prompt: List[Dict[str,str]]=[{"role":"system","content":"你是一个人工智能助手"}, + {"role":"user","content":"你好。"}], + model:str = "chatglm_pro", + top_p:float=0.7, + temperature:float=0.9, + **kwargs): + response = await self.zhipuai.model_api.async_invoke( + model = model, + prompt = prompt, + top_p = top_p, + temperature = temperature + ) + + if response["code"] == 200: + task_id = response['data']['task_id'] + status = "PROCESSING" + while status != "SUCCESS": + # await asyncio.sleep(3) # + resp = self.zhipuai.model_api.query_async_invoke_result(task_id) + status = resp['data']['task_status'] + return resp['data']['choices'][-1]['content'] + else: + print(f"error occurred, error code:{response['code']},error msg:{response['msg']}") + return + + def create_completion(self, + prompt:str="你好", + model:str="chatglm_pro", + top_p:float=0.7, + temperature:float=0.9, + stream:bool=False, + **kwargs): + message = self.prompt_collator(content_user=prompt) + if stream: + return self.create_stream(model=model, + message=message, + top_p=top_p, + temperature=temperature) + else: + return self.create_oneshot(model=model, + message=message, + top_p=top_p, + temperature=temperature) + #? make it a sync function? + async def acreate_completion(self, + prompt:str="你好", + model:str = "chatglm_pro", + top_p:float=0.7, + temperature:float=0.9, + **kwargs): + message = self.prompt_collator(content_user=prompt) + response = self.zhipuai.model_api.async_invoke( + model = model, + prompt = message, + top_p = top_p, + temperature = temperature + ) + + if response["code"] == 200: + task_id = response['data']['task_id'] + status = "PROCESSING" + while status != "SUCCESS": + # await asyncio.sleep(3) # + resp = self.zhipuai.model_api.query_async_invoke_result(task_id) + status = resp['data']['task_status'] + return resp['data']['choices'][-1]['content'] + else: + print(f"error occurred, error code:{response['code']},error msg:{response['msg']}") + return + if __name__ == "__main__": import uvicorn