Langchain-Chatchat/models/fastchat_openai_llm.py

254 lines
8.3 KiB
Python

from abc import ABC
from langchain.chains.base import Chain
from typing import (
Any, Dict, List, Optional, Generator, Collection, Set,
Callable,
Tuple,
Union)
from models.loader import LoaderCheckPoint
from langchain.callbacks.manager import CallbackManagerForChainRun
from models.base import (BaseAnswer,
RemoteRpcModel,
AnswerResult,
AnswerResultStream,
AnswerResultQueueSentinelTokenListenerQueue)
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from pydantic import Extra, Field, root_validator
from openai import (
ChatCompletion
)
import openai
import logging
import torch
import transformers
logger = logging.getLogger(__name__)
def _build_message_template() -> Dict[str, str]:
"""
:return: 结构
"""
return {
"role": "",
"content": "",
}
# 将历史对话数组转换为文本格式
def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str, str]]:
build_messages: Collection[Dict[str, str]] = []
system_build_message = _build_message_template()
system_build_message['role'] = 'system'
system_build_message['content'] = "You are a helpful assistant."
build_messages.append(system_build_message)
for i, (old_query, response) in enumerate(history):
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = old_query
system_build_message = _build_message_template()
system_build_message['role'] = 'assistant'
system_build_message['content'] = response
build_messages.append(user_build_message)
build_messages.append(system_build_message)
user_build_message = _build_message_template()
user_build_message['role'] = 'user'
user_build_message['content'] = query
build_messages.append(user_build_message)
return build_messages
class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
client: Any
"""Timeout for requests to OpenAI completion API. Default is 600 seconds."""
max_retries: int = 6
api_base_url: str = "http://localhost:8000/v1"
model_name: str = "chatglm-6b"
max_token: int = 10000
temperature: float = 0.01
top_p = 0.9
checkPoint: LoaderCheckPoint = None
# history = []
history_len: int = 10
api_key: str = ""
streaming_key: str = "streaming" #: :meta private:
history_key: str = "history" #: :meta private:
prompt_key: str = "prompt" #: :meta private:
output_key: str = "answer_result_stream" #: :meta private:
def __init__(self,
checkPoint: LoaderCheckPoint = None,
# api_base_url:str="http://localhost:8000/v1",
# model_name:str="chatglm-6b",
# api_key:str=""
):
super().__init__()
self.checkPoint = checkPoint
@property
def _chain_type(self) -> str:
return "LLamaLLMChain"
@property
def _check_point(self) -> LoaderCheckPoint:
return self.checkPoint
@property
def input_keys(self) -> List[str]:
"""Will be whatever keys the prompt expects.
:meta private:
"""
return [self.prompt_key]
@property
def output_keys(self) -> List[str]:
"""Will always return text key.
:meta private:
"""
return [self.output_key]
@property
def _api_key(self) -> str:
pass
@property
def _api_base_url(self) -> str:
return self.api_base_url
def set_api_key(self, api_key: str):
self.api_key = api_key
def set_api_base_url(self, api_base_url: str):
self.api_base_url = api_base_url
def call_model_name(self, model_name):
self.model_name = model_name
def _create_retry_decorator(self) -> Callable[[Any], Any]:
min_seconds = 1
max_seconds = 60
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(self.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(
retry_if_exception_type(openai.error.Timeout)
| retry_if_exception_type(openai.error.APIError)
| retry_if_exception_type(openai.error.APIConnectionError)
| retry_if_exception_type(openai.error.RateLimitError)
| retry_if_exception_type(openai.error.ServiceUnavailableError)
),
before_sleep=before_sleep_log(logger, logging.WARNING),
)
def completion_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = self._create_retry_decorator()
@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
return self.client.create(**kwargs)
return _completion_with_retry(**kwargs)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Generator]:
generator = self.generatorAnswer(inputs=inputs, run_manager=run_manager)
return {self.output_key: generator}
def _generate_answer(self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
generate_with_callback: AnswerResultStream = None) -> None:
history = inputs[self.history_key]
streaming = inputs[self.streaming_key]
prompt = inputs[self.prompt_key]
stop = inputs['stop']
print(f"__call:{prompt}")
try:
# Not support yet
# openai.api_key = "EMPTY"
openai.api_key = self.api_key
openai.api_base = self.api_base_url
self.client = openai.ChatCompletion
except AttributeError:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`."
)
msg = build_message_list(prompt, history=history)
if streaming:
params = {"stream": streaming,
"model": self.model_name,
"stop": stop}
for stream_resp in self.completion_with_retry(
messages=msg,
**params
):
role = stream_resp["choices"][0]["delta"].get("role", "")
token = stream_resp["choices"][0]["delta"].get("content", "")
history += [[prompt, token]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": token}
generate_with_callback(answer_result)
else:
params = {"stream": streaming,
"model": self.model_name,
"stop": stop}
response = self.completion_with_retry(
messages=msg,
**params
)
role = response["choices"][0]["message"].get("role", "")
content = response["choices"][0]["message"].get("content", "")
history += [[prompt, content]]
answer_result = AnswerResult()
answer_result.history = history
answer_result.llm_output = {"answer": content}
generate_with_callback(answer_result)
if __name__ == "__main__":
chain = FastChatOpenAILLMChain()
chain.set_api_key("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")
chain.set_api_base_url("https://api.openai.com/v1")
chain.call_model_name("gpt-3.5-turbo")
answer_result_stream_result = chain({"streaming": False,
"stop": "",
"prompt": "你好",
"history": []
})
for answer_result in answer_result_stream_result['answer_result_stream']:
resp = answer_result.llm_output["answer"]
print(resp)