增加fastchat打字机输出
This commit is contained in:
parent
5cbb86a823
commit
c389f1a33a
|
|
@ -6,6 +6,7 @@ from queue import Queue
|
|||
from threading import Thread
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from models.loader import LoaderCheckPoint
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
|
|
@ -23,13 +24,12 @@ class ListenerToken:
|
|||
self._scores = _scores
|
||||
|
||||
|
||||
class AnswerResult:
|
||||
class AnswerResult(BaseModel):
|
||||
"""
|
||||
消息实体
|
||||
"""
|
||||
history: List[List[str]] = []
|
||||
llm_output: Optional[dict] = None
|
||||
listenerToken: ListenerToken = None
|
||||
|
||||
|
||||
class AnswerResultStream:
|
||||
|
|
@ -167,8 +167,6 @@ class BaseAnswer(ABC):
|
|||
|
||||
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
|
||||
for answerResult in generator:
|
||||
if answerResult.listenerToken:
|
||||
output = answerResult.listenerToken.input_ids
|
||||
yield answerResult
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -94,8 +94,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
|
|||
answer_result = AnswerResult()
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": stream_resp}
|
||||
if listenerQueue.listenerQueue.__len__() > 0:
|
||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||
generate_with_callback(answer_result)
|
||||
self.checkPoint.clear_torch_cache()
|
||||
else:
|
||||
|
|
@ -114,8 +112,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
|
|||
answer_result = AnswerResult()
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": response}
|
||||
if listenerQueue.listenerQueue.__len__() > 0:
|
||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||
|
||||
generate_with_callback(answer_result)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,11 @@
|
|||
from abc import ABC
|
||||
from langchain.chains.base import Chain
|
||||
from typing import Any, Dict, List, Optional, Generator, Collection
|
||||
from typing import (
|
||||
Any, Dict, List, Optional, Generator, Collection, Set,
|
||||
Callable,
|
||||
Tuple,
|
||||
Union)
|
||||
|
||||
from models.loader import LoaderCheckPoint
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from models.base import (BaseAnswer,
|
||||
|
|
@ -8,9 +13,26 @@ from models.base import (BaseAnswer,
|
|||
AnswerResult,
|
||||
AnswerResultStream,
|
||||
AnswerResultQueueSentinelTokenListenerQueue)
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
from openai import (
|
||||
ChatCompletion
|
||||
)
|
||||
|
||||
import openai
|
||||
import logging
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_message_template() -> Dict[str, str]:
|
||||
"""
|
||||
|
|
@ -25,12 +47,18 @@ def _build_message_template() -> Dict[str, str]:
|
|||
# 将历史对话数组转换为文本格式
|
||||
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
|
||||
build_messages: Collection[Dict[str, str]] = []
|
||||
|
||||
system_build_message = _build_message_template()
|
||||
system_build_message['role'] = 'system'
|
||||
system_build_message['content'] = "You are a helpful assistant."
|
||||
build_messages.append(system_build_message)
|
||||
|
||||
for i, (old_query, response) in enumerate(history):
|
||||
user_build_message = _build_message_template()
|
||||
user_build_message['role'] = 'user'
|
||||
user_build_message['content'] = old_query
|
||||
system_build_message = _build_message_template()
|
||||
system_build_message['role'] = 'system'
|
||||
system_build_message['role'] = 'assistant'
|
||||
system_build_message['content'] = response
|
||||
build_messages.append(user_build_message)
|
||||
build_messages.append(system_build_message)
|
||||
|
|
@ -43,6 +71,9 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str,
|
|||
|
||||
|
||||
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
||||
client: Any
|
||||
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||
max_retries: int = 6
|
||||
api_base_url: str = "http://localhost:8000/v1"
|
||||
model_name: str = "chatglm-6b"
|
||||
max_token: int = 10000
|
||||
|
|
@ -108,6 +139,35 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
|||
def call_model_name(self, model_name):
|
||||
self.model_name = model_name
|
||||
|
||||
def _create_retry_decorator(self) -> Callable[[Any], Any]:
|
||||
min_seconds = 1
|
||||
max_seconds = 60
|
||||
# Wait 2^x * 1 second between each retry starting with
|
||||
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||
return retry(
|
||||
reraise=True,
|
||||
stop=stop_after_attempt(self.max_retries),
|
||||
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||
retry=(
|
||||
retry_if_exception_type(openai.error.Timeout)
|
||||
| retry_if_exception_type(openai.error.APIError)
|
||||
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||
| retry_if_exception_type(openai.error.RateLimitError)
|
||||
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||
),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
def completion_with_retry(self, **kwargs: Any) -> Any:
|
||||
"""Use tenacity to retry the completion call."""
|
||||
retry_decorator = self._create_retry_decorator()
|
||||
|
||||
@retry_decorator
|
||||
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||
return self.client.create(**kwargs)
|
||||
|
||||
return _completion_with_retry(**kwargs)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
|
|
@ -124,29 +184,70 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
|||
history = inputs[self.history_key]
|
||||
streaming = inputs[self.streaming_key]
|
||||
prompt = inputs[self.prompt_key]
|
||||
stop = inputs['stop']
|
||||
print(f"__call:{prompt}")
|
||||
try:
|
||||
|
||||
import openai
|
||||
# Not support yet
|
||||
# openai.api_key = "EMPTY"
|
||||
openai.api_key = self.api_key
|
||||
openai.api_base = self.api_base_url
|
||||
except ImportError:
|
||||
self.client = openai.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"Could not import openai python package. "
|
||||
"Please install it with `pip install openai`."
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
# create a chat completion
|
||||
completion = openai.ChatCompletion.create(
|
||||
model=self.model_name,
|
||||
messages=build_message_list(prompt)
|
||||
)
|
||||
print(f"response:{completion.choices[0].message.content}")
|
||||
print(f"+++++++++++++++++++++++++++++++++++")
|
||||
msg = build_message_list(prompt, history=history)
|
||||
|
||||
history += [[prompt, completion.choices[0].message.content]]
|
||||
if streaming:
|
||||
params = {"stream": streaming,
|
||||
"model": self.model_name,
|
||||
"stop": stop}
|
||||
for stream_resp in self.completion_with_retry(
|
||||
messages=msg,
|
||||
**params
|
||||
):
|
||||
role = stream_resp["choices"][0]["delta"].get("role", "")
|
||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||
history += [[prompt, token]]
|
||||
answer_result = AnswerResult()
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": completion.choices[0].message.content}
|
||||
answer_result.llm_output = {"answer": token}
|
||||
generate_with_callback(answer_result)
|
||||
else:
|
||||
|
||||
params = {"stream": streaming,
|
||||
"model": self.model_name,
|
||||
"stop": stop}
|
||||
response = self.completion_with_retry(
|
||||
messages=msg,
|
||||
**params
|
||||
)
|
||||
role = response["choices"][0]["message"].get("role", "")
|
||||
content = response["choices"][0]["message"].get("content", "")
|
||||
history += [[prompt, content]]
|
||||
answer_result = AnswerResult()
|
||||
answer_result.history = history
|
||||
answer_result.llm_output = {"answer": content}
|
||||
generate_with_callback(answer_result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
chain = FastChatOpenAILLMChain()
|
||||
|
||||
chain.set_api_key("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")
|
||||
chain.set_api_base_url("https://api.openai.com/v1")
|
||||
chain.call_model_name("gpt-3.5-turbo")
|
||||
|
||||
answer_result_stream_result = chain({"streaming": False,
|
||||
"stop": "",
|
||||
"prompt": "你好",
|
||||
"history": []
|
||||
})
|
||||
|
||||
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||
resp = answer_result.llm_output["answer"]
|
||||
print(resp)
|
||||
|
|
|
|||
|
|
@ -186,7 +186,5 @@ class LLamaLLMChain(BaseAnswer, Chain, ABC):
|
|||
answer_result = AnswerResult()
|
||||
history += [[prompt, reply]]
|
||||
answer_result.history = history
|
||||
if listenerQueue.listenerQueue.__len__() > 0:
|
||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
||||
answer_result.llm_output = {"answer": reply}
|
||||
generate_with_callback(answer_result)
|
||||
|
|
|
|||
Loading…
Reference in New Issue