diff --git a/models/base/base.py b/models/base/base.py index 1b65b21..c6674c9 100644 --- a/models/base/base.py +++ b/models/base/base.py @@ -6,6 +6,7 @@ from queue import Queue from threading import Thread from langchain.callbacks.manager import CallbackManagerForChainRun from models.loader import LoaderCheckPoint +from pydantic import BaseModel import torch import transformers @@ -23,13 +24,12 @@ class ListenerToken: self._scores = _scores -class AnswerResult: +class AnswerResult(BaseModel): """ 消息实体 """ history: List[List[str]] = [] llm_output: Optional[dict] = None - listenerToken: ListenerToken = None class AnswerResultStream: @@ -167,8 +167,6 @@ class BaseAnswer(ABC): with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator: for answerResult in generator: - if answerResult.listenerToken: - output = answerResult.listenerToken.input_ids yield answerResult @abstractmethod diff --git a/models/chatglm_llm.py b/models/chatglm_llm.py index c45cf3b..81878ce 100644 --- a/models/chatglm_llm.py +++ b/models/chatglm_llm.py @@ -94,8 +94,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC): answer_result = AnswerResult() answer_result.history = history answer_result.llm_output = {"answer": stream_resp} - if listenerQueue.listenerQueue.__len__() > 0: - answer_result.listenerToken = listenerQueue.listenerQueue.pop() generate_with_callback(answer_result) self.checkPoint.clear_torch_cache() else: @@ -114,8 +112,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC): answer_result = AnswerResult() answer_result.history = history answer_result.llm_output = {"answer": response} - if listenerQueue.listenerQueue.__len__() > 0: - answer_result.listenerToken = listenerQueue.listenerQueue.pop() generate_with_callback(answer_result) diff --git a/models/fastchat_openai_llm.py b/models/fastchat_openai_llm.py index 398364b..1787281 100644 --- a/models/fastchat_openai_llm.py +++ b/models/fastchat_openai_llm.py @@ -1,6 +1,11 @@ from abc import ABC from langchain.chains.base import Chain -from typing import Any, Dict, List, Optional, Generator, Collection +from typing import ( + Any, Dict, List, Optional, Generator, Collection, Set, + Callable, + Tuple, + Union) + from models.loader import LoaderCheckPoint from langchain.callbacks.manager import CallbackManagerForChainRun from models.base import (BaseAnswer, @@ -8,9 +13,26 @@ from models.base import (BaseAnswer, AnswerResult, AnswerResultStream, AnswerResultQueueSentinelTokenListenerQueue) +from tenacity import ( + before_sleep_log, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) +from pydantic import Extra, Field, root_validator + +from openai import ( + ChatCompletion +) + +import openai +import logging import torch import transformers +logger = logging.getLogger(__name__) + def _build_message_template() -> Dict[str, str]: """ @@ -25,12 +47,18 @@ def _build_message_template() -> Dict[str, str]: # 将历史对话数组转换为文本格式 def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]: build_messages: Collection[Dict[str, str]] = [] + + system_build_message = _build_message_template() + system_build_message['role'] = 'system' + system_build_message['content'] = "You are a helpful assistant." + build_messages.append(system_build_message) + for i, (old_query, response) in enumerate(history): user_build_message = _build_message_template() user_build_message['role'] = 'user' user_build_message['content'] = old_query system_build_message = _build_message_template() - system_build_message['role'] = 'system' + system_build_message['role'] = 'assistant' system_build_message['content'] = response build_messages.append(user_build_message) build_messages.append(system_build_message) @@ -43,6 +71,9 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): + client: Any + """Timeout for requests to OpenAI completion API. Default is 600 seconds.""" + max_retries: int = 6 api_base_url: str = "http://localhost:8000/v1" model_name: str = "chatglm-6b" max_token: int = 10000 @@ -108,6 +139,35 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): def call_model_name(self, model_name): self.model_name = model_name + def _create_retry_decorator(self) -> Callable[[Any], Any]: + min_seconds = 1 + max_seconds = 60 + # Wait 2^x * 1 second between each retry starting with + # 4 seconds, then up to 10 seconds, then 10 seconds afterwards + return retry( + reraise=True, + stop=stop_after_attempt(self.max_retries), + wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), + retry=( + retry_if_exception_type(openai.error.Timeout) + | retry_if_exception_type(openai.error.APIError) + | retry_if_exception_type(openai.error.APIConnectionError) + | retry_if_exception_type(openai.error.RateLimitError) + | retry_if_exception_type(openai.error.ServiceUnavailableError) + ), + before_sleep=before_sleep_log(logger, logging.WARNING), + ) + + def completion_with_retry(self, **kwargs: Any) -> Any: + """Use tenacity to retry the completion call.""" + retry_decorator = self._create_retry_decorator() + + @retry_decorator + def _completion_with_retry(**kwargs: Any) -> Any: + return self.client.create(**kwargs) + + return _completion_with_retry(**kwargs) + def _call( self, inputs: Dict[str, Any], @@ -124,29 +184,70 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC): history = inputs[self.history_key] streaming = inputs[self.streaming_key] prompt = inputs[self.prompt_key] + stop = inputs['stop'] print(f"__call:{prompt}") try: - import openai # Not support yet # openai.api_key = "EMPTY" openai.api_key = self.api_key openai.api_base = self.api_base_url - except ImportError: + self.client = openai.ChatCompletion + except AttributeError: raise ValueError( - "Could not import openai python package. " - "Please install it with `pip install openai`." + "`openai` has no `ChatCompletion` attribute, this is likely " + "due to an old version of the openai package. Try upgrading it " + "with `pip install --upgrade openai`." ) - # create a chat completion - completion = openai.ChatCompletion.create( - model=self.model_name, - messages=build_message_list(prompt) - ) - print(f"response:{completion.choices[0].message.content}") - print(f"+++++++++++++++++++++++++++++++++++") + msg = build_message_list(prompt, history=history) - history += [[prompt, completion.choices[0].message.content]] - answer_result = AnswerResult() - answer_result.history = history - answer_result.llm_output = {"answer": completion.choices[0].message.content} - generate_with_callback(answer_result) + if streaming: + params = {"stream": streaming, + "model": self.model_name, + "stop": stop} + for stream_resp in self.completion_with_retry( + messages=msg, + **params + ): + role = stream_resp["choices"][0]["delta"].get("role", "") + token = stream_resp["choices"][0]["delta"].get("content", "") + history += [[prompt, token]] + answer_result = AnswerResult() + answer_result.history = history + answer_result.llm_output = {"answer": token} + generate_with_callback(answer_result) + else: + + params = {"stream": streaming, + "model": self.model_name, + "stop": stop} + response = self.completion_with_retry( + messages=msg, + **params + ) + role = response["choices"][0]["message"].get("role", "") + content = response["choices"][0]["message"].get("content", "") + history += [[prompt, content]] + answer_result = AnswerResult() + answer_result.history = history + answer_result.llm_output = {"answer": content} + generate_with_callback(answer_result) + + +if __name__ == "__main__": + + chain = FastChatOpenAILLMChain() + + chain.set_api_key("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") + chain.set_api_base_url("https://api.openai.com/v1") + chain.call_model_name("gpt-3.5-turbo") + + answer_result_stream_result = chain({"streaming": False, + "stop": "", + "prompt": "你好", + "history": [] + }) + + for answer_result in answer_result_stream_result['answer_result_stream']: + resp = answer_result.llm_output["answer"] + print(resp) diff --git a/models/llama_llm.py b/models/llama_llm.py index 89d21ac..014fd81 100644 --- a/models/llama_llm.py +++ b/models/llama_llm.py @@ -186,7 +186,5 @@ class LLamaLLMChain(BaseAnswer, Chain, ABC): answer_result = AnswerResult() history += [[prompt, reply]] answer_result.history = history - if listenerQueue.listenerQueue.__len__() > 0: - answer_result.listenerToken = listenerQueue.listenerQueue.pop() answer_result.llm_output = {"answer": reply} generate_with_callback(answer_result)