增加fastchat打字机输出
This commit is contained in:
parent
5cbb86a823
commit
c389f1a33a
|
|
@ -6,6 +6,7 @@ from queue import Queue
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
|
from pydantic import BaseModel
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
|
@ -23,13 +24,12 @@ class ListenerToken:
|
||||||
self._scores = _scores
|
self._scores = _scores
|
||||||
|
|
||||||
|
|
||||||
class AnswerResult:
|
class AnswerResult(BaseModel):
|
||||||
"""
|
"""
|
||||||
消息实体
|
消息实体
|
||||||
"""
|
"""
|
||||||
history: List[List[str]] = []
|
history: List[List[str]] = []
|
||||||
llm_output: Optional[dict] = None
|
llm_output: Optional[dict] = None
|
||||||
listenerToken: ListenerToken = None
|
|
||||||
|
|
||||||
|
|
||||||
class AnswerResultStream:
|
class AnswerResultStream:
|
||||||
|
|
@ -167,8 +167,6 @@ class BaseAnswer(ABC):
|
||||||
|
|
||||||
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
|
with generate_with_streaming(inputs=inputs, run_manager=run_manager) as generator:
|
||||||
for answerResult in generator:
|
for answerResult in generator:
|
||||||
if answerResult.listenerToken:
|
|
||||||
output = answerResult.listenerToken.input_ids
|
|
||||||
yield answerResult
|
yield answerResult
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
|
|
||||||
|
|
@ -94,8 +94,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
answer_result.llm_output = {"answer": stream_resp}
|
answer_result.llm_output = {"answer": stream_resp}
|
||||||
if listenerQueue.listenerQueue.__len__() > 0:
|
|
||||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
|
||||||
generate_with_callback(answer_result)
|
generate_with_callback(answer_result)
|
||||||
self.checkPoint.clear_torch_cache()
|
self.checkPoint.clear_torch_cache()
|
||||||
else:
|
else:
|
||||||
|
|
@ -114,8 +112,6 @@ class ChatGLMLLMChain(BaseAnswer, Chain, ABC):
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
answer_result.llm_output = {"answer": response}
|
answer_result.llm_output = {"answer": response}
|
||||||
if listenerQueue.listenerQueue.__len__() > 0:
|
|
||||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
|
||||||
|
|
||||||
generate_with_callback(answer_result)
|
generate_with_callback(answer_result)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,11 @@
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from typing import Any, Dict, List, Optional, Generator, Collection
|
from typing import (
|
||||||
|
Any, Dict, List, Optional, Generator, Collection, Set,
|
||||||
|
Callable,
|
||||||
|
Tuple,
|
||||||
|
Union)
|
||||||
|
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from models.base import (BaseAnswer,
|
from models.base import (BaseAnswer,
|
||||||
|
|
@ -8,9 +13,26 @@ from models.base import (BaseAnswer,
|
||||||
AnswerResult,
|
AnswerResult,
|
||||||
AnswerResultStream,
|
AnswerResultStream,
|
||||||
AnswerResultQueueSentinelTokenListenerQueue)
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
from pydantic import Extra, Field, root_validator
|
||||||
|
|
||||||
|
from openai import (
|
||||||
|
ChatCompletion
|
||||||
|
)
|
||||||
|
|
||||||
|
import openai
|
||||||
|
import logging
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _build_message_template() -> Dict[str, str]:
|
def _build_message_template() -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -25,12 +47,18 @@ def _build_message_template() -> Dict[str, str]:
|
||||||
# 将历史对话数组转换为文本格式
|
# 将历史对话数组转换为文本格式
|
||||||
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
|
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
|
||||||
build_messages: Collection[Dict[str, str]] = []
|
build_messages: Collection[Dict[str, str]] = []
|
||||||
|
|
||||||
|
system_build_message = _build_message_template()
|
||||||
|
system_build_message['role'] = 'system'
|
||||||
|
system_build_message['content'] = "You are a helpful assistant."
|
||||||
|
build_messages.append(system_build_message)
|
||||||
|
|
||||||
for i, (old_query, response) in enumerate(history):
|
for i, (old_query, response) in enumerate(history):
|
||||||
user_build_message = _build_message_template()
|
user_build_message = _build_message_template()
|
||||||
user_build_message['role'] = 'user'
|
user_build_message['role'] = 'user'
|
||||||
user_build_message['content'] = old_query
|
user_build_message['content'] = old_query
|
||||||
system_build_message = _build_message_template()
|
system_build_message = _build_message_template()
|
||||||
system_build_message['role'] = 'system'
|
system_build_message['role'] = 'assistant'
|
||||||
system_build_message['content'] = response
|
system_build_message['content'] = response
|
||||||
build_messages.append(user_build_message)
|
build_messages.append(user_build_message)
|
||||||
build_messages.append(system_build_message)
|
build_messages.append(system_build_message)
|
||||||
|
|
@ -43,6 +71,9 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str,
|
||||||
|
|
||||||
|
|
||||||
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
||||||
|
client: Any
|
||||||
|
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
|
||||||
|
max_retries: int = 6
|
||||||
api_base_url: str = "http://localhost:8000/v1"
|
api_base_url: str = "http://localhost:8000/v1"
|
||||||
model_name: str = "chatglm-6b"
|
model_name: str = "chatglm-6b"
|
||||||
max_token: int = 10000
|
max_token: int = 10000
|
||||||
|
|
@ -108,6 +139,35 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
||||||
def call_model_name(self, model_name):
|
def call_model_name(self, model_name):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
|
||||||
|
def _create_retry_decorator(self) -> Callable[[Any], Any]:
|
||||||
|
min_seconds = 1
|
||||||
|
max_seconds = 60
|
||||||
|
# Wait 2^x * 1 second between each retry starting with
|
||||||
|
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
|
||||||
|
return retry(
|
||||||
|
reraise=True,
|
||||||
|
stop=stop_after_attempt(self.max_retries),
|
||||||
|
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
|
||||||
|
retry=(
|
||||||
|
retry_if_exception_type(openai.error.Timeout)
|
||||||
|
| retry_if_exception_type(openai.error.APIError)
|
||||||
|
| retry_if_exception_type(openai.error.APIConnectionError)
|
||||||
|
| retry_if_exception_type(openai.error.RateLimitError)
|
||||||
|
| retry_if_exception_type(openai.error.ServiceUnavailableError)
|
||||||
|
),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
def completion_with_retry(self, **kwargs: Any) -> Any:
|
||||||
|
"""Use tenacity to retry the completion call."""
|
||||||
|
retry_decorator = self._create_retry_decorator()
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _completion_with_retry(**kwargs: Any) -> Any:
|
||||||
|
return self.client.create(**kwargs)
|
||||||
|
|
||||||
|
return _completion_with_retry(**kwargs)
|
||||||
|
|
||||||
def _call(
|
def _call(
|
||||||
self,
|
self,
|
||||||
inputs: Dict[str, Any],
|
inputs: Dict[str, Any],
|
||||||
|
|
@ -124,29 +184,70 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
||||||
history = inputs[self.history_key]
|
history = inputs[self.history_key]
|
||||||
streaming = inputs[self.streaming_key]
|
streaming = inputs[self.streaming_key]
|
||||||
prompt = inputs[self.prompt_key]
|
prompt = inputs[self.prompt_key]
|
||||||
|
stop = inputs['stop']
|
||||||
print(f"__call:{prompt}")
|
print(f"__call:{prompt}")
|
||||||
try:
|
try:
|
||||||
|
|
||||||
import openai
|
|
||||||
# Not support yet
|
# Not support yet
|
||||||
# openai.api_key = "EMPTY"
|
# openai.api_key = "EMPTY"
|
||||||
openai.api_key = self.api_key
|
openai.api_key = self.api_key
|
||||||
openai.api_base = self.api_base_url
|
openai.api_base = self.api_base_url
|
||||||
except ImportError:
|
self.client = openai.ChatCompletion
|
||||||
|
except AttributeError:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Could not import openai python package. "
|
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||||
"Please install it with `pip install openai`."
|
"due to an old version of the openai package. Try upgrading it "
|
||||||
|
"with `pip install --upgrade openai`."
|
||||||
)
|
)
|
||||||
# create a chat completion
|
msg = build_message_list(prompt, history=history)
|
||||||
completion = openai.ChatCompletion.create(
|
|
||||||
model=self.model_name,
|
|
||||||
messages=build_message_list(prompt)
|
|
||||||
)
|
|
||||||
print(f"response:{completion.choices[0].message.content}")
|
|
||||||
print(f"+++++++++++++++++++++++++++++++++++")
|
|
||||||
|
|
||||||
history += [[prompt, completion.choices[0].message.content]]
|
if streaming:
|
||||||
|
params = {"stream": streaming,
|
||||||
|
"model": self.model_name,
|
||||||
|
"stop": stop}
|
||||||
|
for stream_resp in self.completion_with_retry(
|
||||||
|
messages=msg,
|
||||||
|
**params
|
||||||
|
):
|
||||||
|
role = stream_resp["choices"][0]["delta"].get("role", "")
|
||||||
|
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||||
|
history += [[prompt, token]]
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
answer_result.llm_output = {"answer": completion.choices[0].message.content}
|
answer_result.llm_output = {"answer": token}
|
||||||
generate_with_callback(answer_result)
|
generate_with_callback(answer_result)
|
||||||
|
else:
|
||||||
|
|
||||||
|
params = {"stream": streaming,
|
||||||
|
"model": self.model_name,
|
||||||
|
"stop": stop}
|
||||||
|
response = self.completion_with_retry(
|
||||||
|
messages=msg,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
role = response["choices"][0]["message"].get("role", "")
|
||||||
|
content = response["choices"][0]["message"].get("content", "")
|
||||||
|
history += [[prompt, content]]
|
||||||
|
answer_result = AnswerResult()
|
||||||
|
answer_result.history = history
|
||||||
|
answer_result.llm_output = {"answer": content}
|
||||||
|
generate_with_callback(answer_result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
chain = FastChatOpenAILLMChain()
|
||||||
|
|
||||||
|
chain.set_api_key("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")
|
||||||
|
chain.set_api_base_url("https://api.openai.com/v1")
|
||||||
|
chain.call_model_name("gpt-3.5-turbo")
|
||||||
|
|
||||||
|
answer_result_stream_result = chain({"streaming": False,
|
||||||
|
"stop": "",
|
||||||
|
"prompt": "你好",
|
||||||
|
"history": []
|
||||||
|
})
|
||||||
|
|
||||||
|
for answer_result in answer_result_stream_result['answer_result_stream']:
|
||||||
|
resp = answer_result.llm_output["answer"]
|
||||||
|
print(resp)
|
||||||
|
|
|
||||||
|
|
@ -186,7 +186,5 @@ class LLamaLLMChain(BaseAnswer, Chain, ABC):
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
history += [[prompt, reply]]
|
history += [[prompt, reply]]
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
if listenerQueue.listenerQueue.__len__() > 0:
|
|
||||||
answer_result.listenerToken = listenerQueue.listenerQueue.pop()
|
|
||||||
answer_result.llm_output = {"answer": reply}
|
answer_result.llm_output = {"answer": reply}
|
||||||
generate_with_callback(answer_result)
|
generate_with_callback(answer_result)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue