使用model_config默认配置启动
llama_llm.py 删除流式输出 base.py、shared.py 删除多余代码 fastchat_llm.py 业务实现
This commit is contained in:
parent
78e940f0a7
commit
a1b1b78108
|
|
@ -62,11 +62,33 @@ llm_model_dict = {
|
||||||
"pretrained_model_name": "fnlp/moss-moon-003-sft",
|
"pretrained_model_name": "fnlp/moss-moon-003-sft",
|
||||||
"local_model_path": None,
|
"local_model_path": None,
|
||||||
"provides": "MOSSLLM"
|
"provides": "MOSSLLM"
|
||||||
|
},
|
||||||
|
"vicuna-13b-hf": {
|
||||||
|
"name": "vicuna-13b-hf",
|
||||||
|
"pretrained_model_name": "vicuna-13b-hf",
|
||||||
|
"local_model_path": None,
|
||||||
|
"provides": "LLamaLLM"
|
||||||
|
},
|
||||||
|
"fastChat": {
|
||||||
|
"name": "fastChat",
|
||||||
|
"pretrained_model_name": "fastChat",
|
||||||
|
"local_model_path": None,
|
||||||
|
"provides": "FastChatLLM"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# LLM model name
|
# LLM 名称
|
||||||
LLM_MODEL = "chatglm-6b"
|
LLM_MODEL = "chatglm-6b"
|
||||||
|
# 如果你需要加载本地的model,指定这个参数 ` --no-remote-model`,或者下方参数修改为 `True`
|
||||||
|
NO_REMOTE_MODEL = False
|
||||||
|
# 量化加载8bit 模型
|
||||||
|
LOAD_IN_8BIT = False
|
||||||
|
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
|
||||||
|
BF16 = False
|
||||||
|
# 本地模型存放的位置
|
||||||
|
MODEL_DIR = "model/"
|
||||||
|
# 本地lora存放的位置
|
||||||
|
LORA_DIR = "loras/"
|
||||||
|
|
||||||
# LLM lora path,默认为空,如果有请直接指定文件夹路径
|
# LLM lora path,默认为空,如果有请直接指定文件夹路径
|
||||||
LLM_LORA_PATH = ""
|
LLM_LORA_PATH = ""
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
from .fastchat_api import *
|
|
||||||
|
|
@ -1,261 +0,0 @@
|
||||||
"""
|
|
||||||
Conversation prompt template.
|
|
||||||
|
|
||||||
Now we support
|
|
||||||
- Vicuna
|
|
||||||
- Koala
|
|
||||||
- OpenAssistant/oasst-sft-1-pythia-12b
|
|
||||||
- StabilityAI/stablelm-tuned-alpha-7b
|
|
||||||
- databricks/dolly-v2-12b
|
|
||||||
- THUDM/chatglm-6b
|
|
||||||
- Alpaca/LLaMa
|
|
||||||
"""
|
|
||||||
|
|
||||||
import dataclasses
|
|
||||||
from enum import auto, Enum
|
|
||||||
from typing import List, Tuple, Any
|
|
||||||
|
|
||||||
|
|
||||||
class SeparatorStyle(Enum):
|
|
||||||
"""Different separator style."""
|
|
||||||
|
|
||||||
SINGLE = auto()
|
|
||||||
TWO = auto()
|
|
||||||
DOLLY = auto()
|
|
||||||
OASST_PYTHIA = auto()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class Conversation:
|
|
||||||
"""A class that keeps all conversation history."""
|
|
||||||
|
|
||||||
system: str
|
|
||||||
roles: List[str]
|
|
||||||
messages: List[List[str]]
|
|
||||||
offset: int
|
|
||||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
|
||||||
sep: str = "###"
|
|
||||||
sep2: str = None
|
|
||||||
|
|
||||||
# Used for gradio server
|
|
||||||
skip_next: bool = False
|
|
||||||
conv_id: Any = None
|
|
||||||
|
|
||||||
def get_prompt(self):
|
|
||||||
if self.sep_style == SeparatorStyle.SINGLE:
|
|
||||||
ret = self.system
|
|
||||||
for role, message in self.messages:
|
|
||||||
if message:
|
|
||||||
ret += self.sep + " " + role + ": " + message
|
|
||||||
else:
|
|
||||||
ret += self.sep + " " + role + ":"
|
|
||||||
return ret
|
|
||||||
elif self.sep_style == SeparatorStyle.TWO:
|
|
||||||
seps = [self.sep, self.sep2]
|
|
||||||
ret = self.system + seps[0]
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if message:
|
|
||||||
ret += role + ": " + message + seps[i % 2]
|
|
||||||
else:
|
|
||||||
ret += role + ":"
|
|
||||||
return ret
|
|
||||||
elif self.sep_style == SeparatorStyle.DOLLY:
|
|
||||||
seps = [self.sep, self.sep2]
|
|
||||||
ret = self.system
|
|
||||||
for i, (role, message) in enumerate(self.messages):
|
|
||||||
if message:
|
|
||||||
ret += role + ":\n" + message + seps[i % 2]
|
|
||||||
if i % 2 == 1:
|
|
||||||
ret += "\n\n"
|
|
||||||
else:
|
|
||||||
ret += role + ":\n"
|
|
||||||
return ret
|
|
||||||
elif self.sep_style == SeparatorStyle.OASST_PYTHIA:
|
|
||||||
ret = self.system
|
|
||||||
for role, message in self.messages:
|
|
||||||
if message:
|
|
||||||
ret += role + message + self.sep
|
|
||||||
else:
|
|
||||||
ret += role
|
|
||||||
return ret
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
|
||||||
|
|
||||||
def append_message(self, role, message):
|
|
||||||
self.messages.append([role, message])
|
|
||||||
|
|
||||||
def to_gradio_chatbot(self):
|
|
||||||
ret = []
|
|
||||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
|
||||||
if i % 2 == 0:
|
|
||||||
ret.append([msg, None])
|
|
||||||
else:
|
|
||||||
ret[-1][-1] = msg
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def copy(self):
|
|
||||||
return Conversation(
|
|
||||||
system=self.system,
|
|
||||||
roles=self.roles,
|
|
||||||
messages=[[x, y] for x, y in self.messages],
|
|
||||||
offset=self.offset,
|
|
||||||
sep_style=self.sep_style,
|
|
||||||
sep=self.sep,
|
|
||||||
sep2=self.sep2,
|
|
||||||
conv_id=self.conv_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def dict(self):
|
|
||||||
return {
|
|
||||||
"system": self.system,
|
|
||||||
"roles": self.roles,
|
|
||||||
"messages": self.messages,
|
|
||||||
"offset": self.offset,
|
|
||||||
"sep": self.sep,
|
|
||||||
"sep2": self.sep2,
|
|
||||||
"conv_id": self.conv_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
conv_one_shot = Conversation(
|
|
||||||
system="A chat between a curious human and an artificial intelligence assistant. "
|
|
||||||
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
|
||||||
roles=("Human", "Assistant"),
|
|
||||||
messages=(
|
|
||||||
(
|
|
||||||
"Human",
|
|
||||||
"What are the key differences between renewable and non-renewable energy sources?",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"Assistant",
|
|
||||||
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
|
||||||
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
|
||||||
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
|
||||||
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
|
||||||
"renewable and non-renewable energy sources:\n"
|
|
||||||
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
|
||||||
"energy sources are finite and will eventually run out.\n"
|
|
||||||
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
|
||||||
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
|
||||||
"and other negative effects.\n"
|
|
||||||
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
|
||||||
"have lower operational costs than non-renewable sources.\n"
|
|
||||||
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
|
||||||
"locations than non-renewable sources.\n"
|
|
||||||
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
|
||||||
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
|
||||||
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
|
||||||
"non-renewable sources are not, and their depletion can lead to economic and social instability.",
|
|
||||||
),
|
|
||||||
),
|
|
||||||
offset=2,
|
|
||||||
sep_style=SeparatorStyle.SINGLE,
|
|
||||||
sep="###",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
conv_vicuna_v1_1 = Conversation(
|
|
||||||
system="A chat between a curious user and an artificial intelligence assistant. "
|
|
||||||
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
|
||||||
roles=("USER", "ASSISTANT"),
|
|
||||||
messages=(),
|
|
||||||
offset=0,
|
|
||||||
sep_style=SeparatorStyle.TWO,
|
|
||||||
sep=" ",
|
|
||||||
sep2="</s>",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
conv_koala_v1 = Conversation(
|
|
||||||
system="BEGINNING OF CONVERSATION:",
|
|
||||||
roles=("USER", "GPT"),
|
|
||||||
messages=(),
|
|
||||||
offset=0,
|
|
||||||
sep_style=SeparatorStyle.TWO,
|
|
||||||
sep=" ",
|
|
||||||
sep2="</s>",
|
|
||||||
)
|
|
||||||
|
|
||||||
conv_dolly = Conversation(
|
|
||||||
system="Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n",
|
|
||||||
roles=("### Instruction", "### Response"),
|
|
||||||
messages=(),
|
|
||||||
offset=0,
|
|
||||||
sep_style=SeparatorStyle.DOLLY,
|
|
||||||
sep="\n\n",
|
|
||||||
sep2="### End",
|
|
||||||
)
|
|
||||||
|
|
||||||
conv_oasst = Conversation(
|
|
||||||
system="",
|
|
||||||
roles=("<|prompter|>", "<|assistant|>"),
|
|
||||||
messages=(),
|
|
||||||
offset=0,
|
|
||||||
sep_style=SeparatorStyle.OASST_PYTHIA,
|
|
||||||
sep="<|endoftext|>",
|
|
||||||
)
|
|
||||||
|
|
||||||
conv_stablelm = Conversation(
|
|
||||||
system="""<|SYSTEM|># StableLM Tuned (Alpha version)
|
|
||||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
|
||||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
|
||||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
|
||||||
- StableLM will refuse to participate in anything that could harm a human.
|
|
||||||
""",
|
|
||||||
roles=("<|USER|>", "<|ASSISTANT|>"),
|
|
||||||
messages=(),
|
|
||||||
offset=0,
|
|
||||||
sep_style=SeparatorStyle.OASST_PYTHIA,
|
|
||||||
sep="",
|
|
||||||
)
|
|
||||||
|
|
||||||
conv_templates = {
|
|
||||||
"conv_one_shot": conv_one_shot,
|
|
||||||
"vicuna_v1.1": conv_vicuna_v1_1,
|
|
||||||
"koala_v1": conv_koala_v1,
|
|
||||||
"dolly": conv_dolly,
|
|
||||||
"oasst": conv_oasst,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_conv_template(model_name):
|
|
||||||
model_name = model_name.lower()
|
|
||||||
if "vicuna" in model_name or "output" in model_name:
|
|
||||||
return conv_vicuna_v1_1
|
|
||||||
elif "koala" in model_name:
|
|
||||||
return conv_koala_v1
|
|
||||||
elif "dolly-v2" in model_name:
|
|
||||||
return conv_dolly
|
|
||||||
elif "oasst" in model_name and "pythia" in model_name:
|
|
||||||
return conv_oasst
|
|
||||||
elif "stablelm" in model_name:
|
|
||||||
return conv_stablelm
|
|
||||||
return conv_one_shot
|
|
||||||
|
|
||||||
|
|
||||||
def compute_skip_echo_len(model_name, conv, prompt):
|
|
||||||
model_name = model_name.lower()
|
|
||||||
if "chatglm" in model_name:
|
|
||||||
skip_echo_len = len(conv.messages[-2][1]) + 1
|
|
||||||
elif "dolly-v2" in model_name:
|
|
||||||
special_toks = ["### Instruction:", "### Response:", "### End"]
|
|
||||||
skip_echo_len = len(prompt)
|
|
||||||
for tok in special_toks:
|
|
||||||
skip_echo_len -= prompt.count(tok) * len(tok)
|
|
||||||
elif "oasst" in model_name and "pythia" in model_name:
|
|
||||||
special_toks = ["<|prompter|>", "<|assistant|>", "<|endoftext|>"]
|
|
||||||
skip_echo_len = len(prompt)
|
|
||||||
for tok in special_toks:
|
|
||||||
skip_echo_len -= prompt.count(tok) * len(tok)
|
|
||||||
elif "stablelm" in model_name:
|
|
||||||
special_toks = ["<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"]
|
|
||||||
skip_echo_len = len(prompt)
|
|
||||||
for tok in special_toks:
|
|
||||||
skip_echo_len -= prompt.count(tok) * len(tok)
|
|
||||||
else:
|
|
||||||
skip_echo_len = len(prompt) + 1 - prompt.count("</s>") * 3
|
|
||||||
return skip_echo_len
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print(default_conversation.get_prompt())
|
|
||||||
|
|
@ -1,459 +0,0 @@
|
||||||
"""Wrapper around FastChat APIs."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import warnings
|
|
||||||
from typing import (
|
|
||||||
AbstractSet,
|
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Collection,
|
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Set,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from pydantic import Extra, Field, root_validator
|
|
||||||
from tenacity import (
|
|
||||||
before_sleep_log,
|
|
||||||
retry,
|
|
||||||
retry_if_exception_type,
|
|
||||||
stop_after_attempt,
|
|
||||||
wait_exponential,
|
|
||||||
)
|
|
||||||
|
|
||||||
from langchain.llms.base import BaseLLM
|
|
||||||
from langchain.schema import Generation, LLMResult
|
|
||||||
from langchain.utils import get_from_dict_or_env
|
|
||||||
|
|
||||||
import requests
|
|
||||||
import json
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
FAST_CHAT_API = "http://localhost:21002/worker_generate_stream"
|
|
||||||
|
|
||||||
|
|
||||||
def _streaming_response_template() -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
:return: 响应结构
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"text": "",
|
|
||||||
"error_code": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _update_response(response: Dict[str, Any], stream_response: Dict[str, Any]) -> None:
|
|
||||||
"""Update response from the stream response."""
|
|
||||||
response["text"] += stream_response["text"]
|
|
||||||
response["error_code"] += stream_response["error_code"]
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFastChat(BaseLLM):
|
|
||||||
"""Wrapper around FastChat large language models."""
|
|
||||||
|
|
||||||
model_name: str = "text-davinci-003"
|
|
||||||
"""Model name to use."""
|
|
||||||
temperature: float = 0.7
|
|
||||||
"""What sampling temperature to use."""
|
|
||||||
max_new_tokens: int = 200
|
|
||||||
stop: int = 20
|
|
||||||
batch_size: int = 20
|
|
||||||
"""Maximum number of retries to make when generating."""
|
|
||||||
streaming: bool = False
|
|
||||||
"""Penalizes repeated tokens."""
|
|
||||||
n: int = 1
|
|
||||||
"""Whether to stream the results or not."""
|
|
||||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set()
|
|
||||||
"""Set of special tokens that are allowed。"""
|
|
||||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all"
|
|
||||||
"""Set of special tokens that are not allowed。"""
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for this pydantic object."""
|
|
||||||
|
|
||||||
extra = Extra.ignore
|
|
||||||
|
|
||||||
@root_validator(pre=True)
|
|
||||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Build extra kwargs from additional params that were passed in."""
|
|
||||||
all_required_field_names = {field.alias for field in cls.__fields__.values()}
|
|
||||||
|
|
||||||
extra = values.get("model_kwargs", {})
|
|
||||||
for field_name in list(values):
|
|
||||||
if field_name not in all_required_field_names:
|
|
||||||
if field_name in extra:
|
|
||||||
raise ValueError(f"Found {field_name} supplied twice.")
|
|
||||||
logger.warning(
|
|
||||||
f"""WARNING! {field_name} is not default parameter.
|
|
||||||
{field_name} was transfered to model_kwargs.
|
|
||||||
Please confirm that {field_name} is what you intended."""
|
|
||||||
)
|
|
||||||
extra[field_name] = values.pop(field_name)
|
|
||||||
values["model_kwargs"] = extra
|
|
||||||
return values
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _default_params(self) -> Dict[str, Any]:
|
|
||||||
"""Get the default parameters for calling FastChat API."""
|
|
||||||
normal_params = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"prompt": '',
|
|
||||||
"max_new_tokens": self.max_new_tokens,
|
|
||||||
"temperature": self.temperature,
|
|
||||||
}
|
|
||||||
|
|
||||||
return {**normal_params}
|
|
||||||
|
|
||||||
def _generate(
|
|
||||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
|
||||||
) -> LLMResult:
|
|
||||||
"""Call out to FastChat's endpoint with k unique prompts.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompts: The prompts to pass into the model.
|
|
||||||
stop: Optional list of stop words to use when generating.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The full LLM output.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
response = fastchat.generate(["Tell me a joke."])
|
|
||||||
"""
|
|
||||||
# TODO: write a unit test for this
|
|
||||||
params = self._invocation_params
|
|
||||||
sub_prompts = self.get_sub_prompts(params, prompts)
|
|
||||||
choices = []
|
|
||||||
token_usage: Dict[str, int] = {}
|
|
||||||
headers = {"User-Agent": "fastchat Client"}
|
|
||||||
for _prompts in sub_prompts:
|
|
||||||
|
|
||||||
params["prompt"] = _prompts[0]
|
|
||||||
|
|
||||||
if stop is not None:
|
|
||||||
if "stop" in params:
|
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
|
||||||
params["stop"] = stop
|
|
||||||
|
|
||||||
if self.streaming:
|
|
||||||
if len(_prompts) > 1:
|
|
||||||
raise ValueError("Cannot stream results with multiple prompts.")
|
|
||||||
|
|
||||||
response_template = _streaming_response_template()
|
|
||||||
response = requests.post(
|
|
||||||
FAST_CHAT_API,
|
|
||||||
headers=headers,
|
|
||||||
json=params,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
for stream_resp in response.iter_lines(
|
|
||||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
|
||||||
):
|
|
||||||
if stream_resp:
|
|
||||||
data = json.loads(stream_resp.decode("utf-8"))
|
|
||||||
skip_echo_len = len(_prompts[0])
|
|
||||||
output = data["text"][skip_echo_len:].strip()
|
|
||||||
data["text"] = output
|
|
||||||
self.callback_manager.on_llm_new_token(
|
|
||||||
output,
|
|
||||||
verbose=self.verbose,
|
|
||||||
logprobs=data["error_code"],
|
|
||||||
)
|
|
||||||
_update_response(response_template, data)
|
|
||||||
choices.append(response_template)
|
|
||||||
else:
|
|
||||||
response_template = _streaming_response_template()
|
|
||||||
response = requests.post(
|
|
||||||
FAST_CHAT_API,
|
|
||||||
headers=headers,
|
|
||||||
json=params,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
for stream_resp in response.iter_lines(
|
|
||||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
|
||||||
):
|
|
||||||
if stream_resp:
|
|
||||||
data = json.loads(stream_resp.decode("utf-8"))
|
|
||||||
skip_echo_len = len(_prompts[0])
|
|
||||||
output = data["text"][skip_echo_len:].strip()
|
|
||||||
data["text"] = output
|
|
||||||
_update_response(response_template, data)
|
|
||||||
|
|
||||||
choices.append(response_template)
|
|
||||||
|
|
||||||
return self.create_llm_result(choices, prompts, token_usage)
|
|
||||||
|
|
||||||
async def _agenerate(
|
|
||||||
self, prompts: List[str], stop: Optional[List[str]] = None
|
|
||||||
) -> LLMResult:
|
|
||||||
"""Call out to FastChat's endpoint async with k unique prompts."""
|
|
||||||
params = self._invocation_params
|
|
||||||
sub_prompts = self.get_sub_prompts(params, prompts)
|
|
||||||
choices = []
|
|
||||||
token_usage: Dict[str, int] = {}
|
|
||||||
|
|
||||||
headers = {"User-Agent": "fastchat Client"}
|
|
||||||
for _prompts in sub_prompts:
|
|
||||||
|
|
||||||
params["prompt"] = _prompts[0]
|
|
||||||
if stop is not None:
|
|
||||||
if "stop" in params:
|
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
|
||||||
params["stop"] = stop
|
|
||||||
|
|
||||||
if self.streaming:
|
|
||||||
if len(_prompts) > 1:
|
|
||||||
raise ValueError("Cannot stream results with multiple prompts.")
|
|
||||||
|
|
||||||
response_template = _streaming_response_template()
|
|
||||||
response = requests.post(
|
|
||||||
FAST_CHAT_API,
|
|
||||||
headers=headers,
|
|
||||||
json=params,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
for stream_resp in response.iter_lines(
|
|
||||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
|
||||||
):
|
|
||||||
if stream_resp:
|
|
||||||
data = json.loads(stream_resp.decode("utf-8"))
|
|
||||||
skip_echo_len = len(_prompts[0])
|
|
||||||
output = data["text"][skip_echo_len:].strip()
|
|
||||||
data["text"] = output
|
|
||||||
self.callback_manager.on_llm_new_token(
|
|
||||||
output,
|
|
||||||
verbose=self.verbose,
|
|
||||||
logprobs=data["error_code"],
|
|
||||||
)
|
|
||||||
_update_response(response_template, data)
|
|
||||||
choices.append(response_template)
|
|
||||||
else:
|
|
||||||
response_template = _streaming_response_template()
|
|
||||||
response = requests.post(
|
|
||||||
FAST_CHAT_API,
|
|
||||||
headers=headers,
|
|
||||||
json=params,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
for stream_resp in response.iter_lines(
|
|
||||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
|
||||||
):
|
|
||||||
if stream_resp:
|
|
||||||
data = json.loads(stream_resp.decode("utf-8"))
|
|
||||||
skip_echo_len = len(_prompts[0])
|
|
||||||
output = data["text"][skip_echo_len:].strip()
|
|
||||||
data["text"] = output
|
|
||||||
_update_response(response_template, data)
|
|
||||||
|
|
||||||
choices.append(response_template)
|
|
||||||
|
|
||||||
return self.create_llm_result(choices, prompts, token_usage)
|
|
||||||
|
|
||||||
def get_sub_prompts(
|
|
||||||
self,
|
|
||||||
params: Dict[str, Any],
|
|
||||||
prompts: List[str],
|
|
||||||
) -> List[List[str]]:
|
|
||||||
"""Get the sub prompts for llm call."""
|
|
||||||
if params["max_new_tokens"] == -1:
|
|
||||||
if len(prompts) != 1:
|
|
||||||
raise ValueError(
|
|
||||||
"max_new_tokens set to -1 not supported for multiple inputs."
|
|
||||||
)
|
|
||||||
params["max_new_tokens"] = self.max_new_tokens_for_prompt(prompts[0])
|
|
||||||
# append pload
|
|
||||||
sub_prompts = [
|
|
||||||
prompts[i: i + self.batch_size]
|
|
||||||
for i in range(0, len(prompts), self.batch_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
return sub_prompts
|
|
||||||
|
|
||||||
def create_llm_result(
|
|
||||||
self, choices: Any, prompts: List[str], token_usage: Dict[str, int]
|
|
||||||
) -> LLMResult:
|
|
||||||
"""Create the LLMResult from the choices and prompts."""
|
|
||||||
generations = []
|
|
||||||
for i, _ in enumerate(prompts):
|
|
||||||
sub_choices = choices[i * self.n: (i + 1) * self.n]
|
|
||||||
generations.append(
|
|
||||||
[
|
|
||||||
Generation(
|
|
||||||
text=choice["text"],
|
|
||||||
generation_info=dict(
|
|
||||||
finish_reason='over',
|
|
||||||
logprobs=choice["text"],
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for choice in sub_choices
|
|
||||||
]
|
|
||||||
)
|
|
||||||
llm_output = {"token_usage": token_usage, "model_name": self.model_name}
|
|
||||||
return LLMResult(generations=generations, llm_output=llm_output)
|
|
||||||
|
|
||||||
def stream(self, prompt: str, stop: Optional[List[str]] = None) -> Generator:
|
|
||||||
"""Call FastChat with streaming flag and return the resulting generator.
|
|
||||||
|
|
||||||
BETA: this is a beta feature while we figure out the right abstraction.
|
|
||||||
Once that happens, this interface could change.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: The prompts to pass into the model.
|
|
||||||
stop: Optional list of stop words to use when generating.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A generator representing the stream of tokens from OpenAI.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
generator = fastChat.stream("Tell me a joke.")
|
|
||||||
for token in generator:
|
|
||||||
yield token
|
|
||||||
"""
|
|
||||||
params = self._invocation_params
|
|
||||||
params["prompt"] = prompt
|
|
||||||
if stop is not None:
|
|
||||||
if "stop" in params:
|
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
|
||||||
params["stop"] = stop
|
|
||||||
|
|
||||||
headers = {"User-Agent": "fastchat Client"}
|
|
||||||
response = requests.post(
|
|
||||||
FAST_CHAT_API,
|
|
||||||
headers=headers,
|
|
||||||
json=params,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
for stream_resp in response.iter_lines(
|
|
||||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
|
||||||
):
|
|
||||||
if stream_resp:
|
|
||||||
data = json.loads(stream_resp.decode("utf-8"))
|
|
||||||
skip_echo_len = len(_prompts[0])
|
|
||||||
output = data["text"][skip_echo_len:].strip()
|
|
||||||
data["text"] = output
|
|
||||||
yield data
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _invocation_params(self) -> Dict[str, Any]:
|
|
||||||
"""Get the parameters used to invoke the model."""
|
|
||||||
return self._default_params
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
|
||||||
"""Get the identifying parameters."""
|
|
||||||
return {**{"model_name": self.model_name}, **self._default_params}
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _llm_type(self) -> str:
|
|
||||||
"""Return type of llm."""
|
|
||||||
return "fastChat"
|
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
|
||||||
"""Calculate num tokens with tiktoken package."""
|
|
||||||
# tiktoken NOT supported for Python < 3.8
|
|
||||||
if sys.version_info[1] < 8:
|
|
||||||
return super().get_num_tokens(text)
|
|
||||||
try:
|
|
||||||
import tiktoken
|
|
||||||
except ImportError:
|
|
||||||
raise ValueError(
|
|
||||||
"Could not import tiktoken python package. "
|
|
||||||
"This is needed in order to calculate get_num_tokens. "
|
|
||||||
"Please install it with `pip install tiktoken`."
|
|
||||||
)
|
|
||||||
|
|
||||||
enc = tiktoken.encoding_for_model(self.model_name)
|
|
||||||
|
|
||||||
tokenized_text = enc.encode(
|
|
||||||
text,
|
|
||||||
allowed_special=self.allowed_special,
|
|
||||||
disallowed_special=self.disallowed_special,
|
|
||||||
)
|
|
||||||
|
|
||||||
# calculate the number of tokens in the encoded text
|
|
||||||
return len(tokenized_text)
|
|
||||||
|
|
||||||
def modelname_to_contextsize(self, modelname: str) -> int:
|
|
||||||
"""Calculate the maximum number of tokens possible to generate for a model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
modelname: The modelname we want to know the context size for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The maximum context size
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
max_new_tokens = openai.modelname_to_contextsize("text-davinci-003")
|
|
||||||
"""
|
|
||||||
model_token_mapping = {
|
|
||||||
"vicuna-13b": 2049,
|
|
||||||
"koala": 2049,
|
|
||||||
"dolly-v2": 2049,
|
|
||||||
"oasst": 2049,
|
|
||||||
"stablelm": 2049,
|
|
||||||
}
|
|
||||||
|
|
||||||
context_size = model_token_mapping.get(modelname, None)
|
|
||||||
|
|
||||||
if context_size is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown model: {modelname}. Please provide a valid OpenAI model name."
|
|
||||||
"Known models are: " + ", ".join(model_token_mapping.keys())
|
|
||||||
)
|
|
||||||
|
|
||||||
return context_size
|
|
||||||
|
|
||||||
def max_new_tokens_for_prompt(self, prompt: str) -> int:
|
|
||||||
"""Calculate the maximum number of tokens possible to generate for a prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: The prompt to pass into the model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The maximum number of tokens to generate for a prompt.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
max_new_tokens = openai.max_token_for_prompt("Tell me a joke.")
|
|
||||||
"""
|
|
||||||
num_tokens = self.get_num_tokens(prompt)
|
|
||||||
|
|
||||||
# get max context size for model by name
|
|
||||||
max_size = self.modelname_to_contextsize(self.model_name)
|
|
||||||
return max_size - num_tokens
|
|
||||||
|
|
||||||
|
|
||||||
class FastChat(BaseFastChat):
|
|
||||||
"""Wrapper around OpenAI large language models.
|
|
||||||
|
|
||||||
To use, you should have the ``openai`` python package installed, and the
|
|
||||||
environment variable ``OPENAI_API_KEY`` set with your API key.
|
|
||||||
|
|
||||||
Any parameters that are valid to be passed to the openai.create call can be passed
|
|
||||||
in, even if not explicitly saved on this class.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
from langchain.llms import OpenAI
|
|
||||||
openai = FastChat(model_name="vicuna")
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _invocation_params(self) -> Dict[str, Any]:
|
|
||||||
return {**{"model": self.model_name}, **super()._invocation_params}
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
from .chatglm_llm import ChatGLM
|
from .chatglm_llm import ChatGLM
|
||||||
# from .llama_llm import LLamaLLM
|
from .llama_llm import LLamaLLM
|
||||||
from .moss_llm import MOSSLLM
|
from .moss_llm import MOSSLLM
|
||||||
|
from .fastchat_llm import FastChatLLM
|
||||||
|
|
|
||||||
|
|
@ -175,15 +175,6 @@ class BaseAnswer(ABC):
|
||||||
def generate_with_streaming(**kwargs):
|
def generate_with_streaming(**kwargs):
|
||||||
return Iteratorize(generate_with_callback, kwargs)
|
return Iteratorize(generate_with_callback, kwargs)
|
||||||
|
|
||||||
"""
|
|
||||||
eos_token_id是指定token(例如,"</s>"),
|
|
||||||
用于表示序列的结束。在生成文本任务中,生成器在生成序列时,将不断地生成token,直到生成此特殊的eos_token_id,表示序列生成已经完成。
|
|
||||||
在Hugging Face Transformer模型中,eos_token_id是由tokenizer自动添加到输入中的。
|
|
||||||
在模型生成输出时,如果模型生成了eos_token_id,则生成过程将停止并返回生成的序列。
|
|
||||||
"""
|
|
||||||
eos_token_ids = [
|
|
||||||
self._check_point.tokenizer.eos_token_id] if self._check_point.tokenizer.eos_token_id is not None else []
|
|
||||||
|
|
||||||
with generate_with_streaming(prompt=prompt, history=history, streaming=streaming) as generator:
|
with generate_with_streaming(prompt=prompt, history=history, streaming=streaming) as generator:
|
||||||
for answerResult in generator:
|
for answerResult in generator:
|
||||||
if answerResult.listenerToken:
|
if answerResult.listenerToken:
|
||||||
|
|
|
||||||
|
|
@ -1,223 +0,0 @@
|
||||||
# import gc
|
|
||||||
import traceback
|
|
||||||
from queue import Queue
|
|
||||||
# from threading import Thread
|
|
||||||
# import threading
|
|
||||||
from typing import Optional, List, Dict, Any, TypeVar, Deque
|
|
||||||
from collections import deque
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
from models.extensions.thread_with_exception import ThreadWithException
|
|
||||||
import models.shared as shared
|
|
||||||
|
|
||||||
|
|
||||||
K = TypeVar('K')
|
|
||||||
V = TypeVar('V')
|
|
||||||
|
|
||||||
class LimitedLengthDict(Dict[K, V]):
|
|
||||||
def __init__(self, maxlen=None, *args, **kwargs):
|
|
||||||
self.maxlen = maxlen
|
|
||||||
self._keys: Deque[K] = deque()
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def __setitem__(self, key: K, value: V):
|
|
||||||
if key not in self:
|
|
||||||
if self.maxlen is not None and len(self) >= self.maxlen:
|
|
||||||
oldest_key = self._keys.popleft()
|
|
||||||
if oldest_key in self:
|
|
||||||
del self[oldest_key]
|
|
||||||
self._keys.append(key)
|
|
||||||
super().__setitem__(key, value)
|
|
||||||
|
|
||||||
|
|
||||||
class FixedLengthQueue:
|
|
||||||
# 停止符号列表
|
|
||||||
stop_sequence: Optional[str] = []
|
|
||||||
# 缓冲区
|
|
||||||
max_length: int = 0
|
|
||||||
# 缓冲区容器
|
|
||||||
queue: deque = None
|
|
||||||
# 输入区容器
|
|
||||||
queue_in: LimitedLengthDict[int, str] = {}
|
|
||||||
# 输出区容器
|
|
||||||
queue_out: Dict[int, str] = {}
|
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
|
||||||
# 创建新的实例
|
|
||||||
instance = super().__new__(cls)
|
|
||||||
# 在这里可以对实例进行额外的设置
|
|
||||||
return instance
|
|
||||||
|
|
||||||
def __init__(self, stop_sequence):
|
|
||||||
if stop_sequence is None:
|
|
||||||
self.stop_sequence = []
|
|
||||||
self.max_length = 0
|
|
||||||
elif isinstance(stop_sequence, str):
|
|
||||||
self.stop_sequence = [stop_sequence]
|
|
||||||
self.max_length = 1
|
|
||||||
else:
|
|
||||||
self.stop_sequence = stop_sequence
|
|
||||||
self.max_length = len(''.join(stop_sequence))
|
|
||||||
|
|
||||||
self.queue = deque(maxlen=self.max_length)
|
|
||||||
self.queue.clear()
|
|
||||||
self.queue_in.clear()
|
|
||||||
self.queue_out.clear()
|
|
||||||
|
|
||||||
def add(self, index, item):
|
|
||||||
self.queue_in[index] = item
|
|
||||||
|
|
||||||
def _add_out(self, index, item):
|
|
||||||
self.queue_out[index] = item
|
|
||||||
|
|
||||||
def put_replace_out(self, index):
|
|
||||||
return self.queue_out[index]
|
|
||||||
|
|
||||||
def contains_replace_sequence(self):
|
|
||||||
"""
|
|
||||||
替换字符
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
|
|
||||||
for key, value in self.queue_in.items():
|
|
||||||
|
|
||||||
word_index = value.rfind(":")
|
|
||||||
if word_index != -1:
|
|
||||||
value = value.replace(":", ":")
|
|
||||||
|
|
||||||
word_index = value.rfind("[")
|
|
||||||
if word_index != -1:
|
|
||||||
value = value.replace("[", "")
|
|
||||||
|
|
||||||
word_index = value.rfind("]")
|
|
||||||
if word_index != -1:
|
|
||||||
value = value.replace("]", "")
|
|
||||||
|
|
||||||
self._add_out(key, value)
|
|
||||||
|
|
||||||
def contains_stop_sequence(self):
|
|
||||||
# 截取固定大小的数据判断
|
|
||||||
self.queue.clear()
|
|
||||||
last_three_keys = list(self.queue_out.keys())[-self.max_length:]
|
|
||||||
joined_queue = ''.join([self.queue_out[key] for key in last_three_keys])
|
|
||||||
for char in joined_queue:
|
|
||||||
self.queue.append(char)
|
|
||||||
|
|
||||||
joined_queue = ''.join(self.queue)
|
|
||||||
# Initialize a variable to store the index of the last found stop string
|
|
||||||
last_stop_str_index = -1
|
|
||||||
|
|
||||||
# Iterate through the stop string list
|
|
||||||
for stop_word in self.stop_sequence:
|
|
||||||
# Find the last occurrence of the stop string in the output
|
|
||||||
stop_word_index = joined_queue.rfind(stop_word)
|
|
||||||
|
|
||||||
# If the stop string is found, compare the index with the previously found index
|
|
||||||
if stop_word_index != -1 and stop_word_index > last_stop_str_index:
|
|
||||||
last_stop_str_index = stop_word_index
|
|
||||||
|
|
||||||
# Handle the last found stop string index here
|
|
||||||
return last_stop_str_index
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return str(self.queue)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from https://github.com/PygmalionAI/gradio-ui/
|
|
||||||
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
|
||||||
|
|
||||||
def __init__(self, sentinel_token_ids: list, starting_idx: int):
|
|
||||||
transformers.StoppingCriteria.__init__(self)
|
|
||||||
self.sentinel_token_ids = sentinel_token_ids
|
|
||||||
self.starting_idx = starting_idx
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, _scores: torch.FloatTensor) -> bool:
|
|
||||||
for sample in input_ids:
|
|
||||||
trimmed_sample = sample[self.starting_idx:]
|
|
||||||
|
|
||||||
for i in range(len(self.sentinel_token_ids)):
|
|
||||||
# Can't unfold, output is still too tiny. Skip.
|
|
||||||
if trimmed_sample.shape[-1] < self.sentinel_token_ids[i].shape[-1]:
|
|
||||||
continue
|
|
||||||
for window in trimmed_sample.unfold(0, self.sentinel_token_ids[i].shape[-1], 1):
|
|
||||||
if torch.all(torch.eq(self.sentinel_token_ids[i][0], window)):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class Stream(transformers.StoppingCriteria):
|
|
||||||
def __init__(self, callback_func=None):
|
|
||||||
self.callback_func = callback_func
|
|
||||||
|
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
|
||||||
if shared.stop_everything:
|
|
||||||
raise ValueError
|
|
||||||
if self.callback_func is not None:
|
|
||||||
self.callback_func(input_ids[0])
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
class Iteratorize:
|
|
||||||
"""
|
|
||||||
Transforms a function that takes a callback
|
|
||||||
into a lazy iterator (generator).
|
|
||||||
"""
|
|
||||||
|
|
||||||
thread: ThreadWithException = None
|
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
|
||||||
# 创建新的实例
|
|
||||||
instance = super().__new__(cls)
|
|
||||||
# 在这里可以对实例进行额外的设置
|
|
||||||
return instance
|
|
||||||
|
|
||||||
def __init__(self, func, kwargs={}, callback=None):
|
|
||||||
self.mfunc = func
|
|
||||||
self.c_callback = callback
|
|
||||||
self.q = Queue()
|
|
||||||
self.sentinel = object()
|
|
||||||
self.kwargs = kwargs
|
|
||||||
|
|
||||||
def _callback(val):
|
|
||||||
if shared.stop_everything:
|
|
||||||
raise ValueError
|
|
||||||
self.q.put(val)
|
|
||||||
|
|
||||||
def gen():
|
|
||||||
try:
|
|
||||||
ret = self.mfunc(callback=_callback, **self.kwargs)
|
|
||||||
except ValueError:
|
|
||||||
print("print(ValueError)")
|
|
||||||
except:
|
|
||||||
traceback.print_exc()
|
|
||||||
print("traceback.print_exc()")
|
|
||||||
self.q.put(self.sentinel)
|
|
||||||
|
|
||||||
self.thread = ThreadWithException(target=gen)
|
|
||||||
self.thread.start()
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
shared.stop_everything = False
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
obj = self.q.get(True, None)
|
|
||||||
if obj is self.sentinel:
|
|
||||||
raise StopIteration
|
|
||||||
else:
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
shared.stop_everything = False
|
|
||||||
self.q.empty()
|
|
||||||
shared.loaderCheckPoint.clear_torch_cache()
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
shared.stop_everything = False
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
||||||
shared.stop_everything = True
|
|
||||||
shared.loaderCheckPoint.clear_torch_cache()
|
|
||||||
self.thread.raise_exception()
|
|
||||||
|
|
@ -1,10 +0,0 @@
|
||||||
import gc
|
|
||||||
import traceback
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# This iterator returns the extensions in the order specified in the command-line
|
|
||||||
def iterator():
|
|
||||||
state_extensions = {}
|
|
||||||
for name in sorted(state_extensions, key=lambda x: state_extensions[x][1]):
|
|
||||||
if state_extensions[name][0]:
|
|
||||||
yield getattr(extensions, name).script, name
|
|
||||||
|
|
@ -1,64 +0,0 @@
|
||||||
'''
|
|
||||||
Based on
|
|
||||||
https://github.com/abetlen/llama-cpp-python
|
|
||||||
|
|
||||||
Documentation:
|
|
||||||
https://abetlen.github.io/llama-cpp-python/
|
|
||||||
'''
|
|
||||||
|
|
||||||
from llama_cpp import Llama, LlamaCache
|
|
||||||
|
|
||||||
from modules import shared
|
|
||||||
from modules.callbacks import Iteratorize
|
|
||||||
|
|
||||||
|
|
||||||
class LlamaCppModel:
|
|
||||||
def __init__(self):
|
|
||||||
self.initialized = False
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(self, path):
|
|
||||||
result = self()
|
|
||||||
|
|
||||||
params = {
|
|
||||||
'model_path': str(path),
|
|
||||||
'n_ctx': 2048,
|
|
||||||
'seed': 0,
|
|
||||||
'n_threads': shared.args.threads or None
|
|
||||||
}
|
|
||||||
self.model = Llama(**params)
|
|
||||||
self.model.set_cache(LlamaCache)
|
|
||||||
|
|
||||||
# This is ugly, but the model and the tokenizer are the same object in this library.
|
|
||||||
return result, result
|
|
||||||
|
|
||||||
def encode(self, string):
|
|
||||||
if type(string) is str:
|
|
||||||
string = string.encode()
|
|
||||||
return self.model.tokenize(string)
|
|
||||||
|
|
||||||
def generate(self, context="", token_count=20, temperature=1, top_p=1, top_k=50, repetition_penalty=1, callback=None):
|
|
||||||
if type(context) is str:
|
|
||||||
context = context.encode()
|
|
||||||
tokens = self.model.tokenize(context)
|
|
||||||
|
|
||||||
output = b""
|
|
||||||
count = 0
|
|
||||||
for token in self.model.generate(tokens, top_k=top_k, top_p=top_p, temp=temperature, repeat_penalty=repetition_penalty):
|
|
||||||
text = self.model.detokenize([token])
|
|
||||||
output += text
|
|
||||||
if callback:
|
|
||||||
callback(text.decode())
|
|
||||||
|
|
||||||
count += 1
|
|
||||||
if count >= token_count or (token == self.model.token_eos()):
|
|
||||||
break
|
|
||||||
|
|
||||||
return output.decode()
|
|
||||||
|
|
||||||
def generate_with_streaming(self, **kwargs):
|
|
||||||
with Iteratorize(self.generate, kwargs, callback=None) as generator:
|
|
||||||
reply = ''
|
|
||||||
for token in generator:
|
|
||||||
reply += token
|
|
||||||
yield reply
|
|
||||||
|
|
@ -1,30 +0,0 @@
|
||||||
# Python program raising
|
|
||||||
# exceptions in a python
|
|
||||||
# thread
|
|
||||||
|
|
||||||
import threading
|
|
||||||
import ctypes
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadWithException(threading.Thread):
|
|
||||||
|
|
||||||
def get_id(self):
|
|
||||||
return self.ident
|
|
||||||
|
|
||||||
def raise_exception(self):
|
|
||||||
"""raises the exception, performs cleanup if needed"""
|
|
||||||
try:
|
|
||||||
thread_id = self.get_id()
|
|
||||||
tid = ctypes.c_long(thread_id)
|
|
||||||
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, ctypes.py_object(SystemExit))
|
|
||||||
if res == 0:
|
|
||||||
# pass
|
|
||||||
raise ValueError("invalid thread id")
|
|
||||||
elif res != 1:
|
|
||||||
# """if it returns a number greater than one, you're in trouble,
|
|
||||||
# and you should call it again with exc=NULL to revert the effect"""
|
|
||||||
ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
|
|
||||||
raise SystemError("PyThreadState_SetAsyncExc failed")
|
|
||||||
except Exception as err:
|
|
||||||
print(err)
|
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
from abc import ABC
|
||||||
|
import requests
|
||||||
|
from typing import Optional, List
|
||||||
|
from langchain.llms.base import LLM
|
||||||
|
|
||||||
|
from models.loader import LoaderCheckPoint
|
||||||
|
from models.base import (BaseAnswer,
|
||||||
|
AnswerResult,
|
||||||
|
AnswerResultStream,
|
||||||
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
|
||||||
|
|
||||||
|
class FastChatLLM(BaseAnswer, LLM, ABC):
|
||||||
|
max_token: int = 10000
|
||||||
|
temperature: float = 0.01
|
||||||
|
top_p = 0.9
|
||||||
|
checkPoint: LoaderCheckPoint = None
|
||||||
|
# history = []
|
||||||
|
history_len: int = 10
|
||||||
|
|
||||||
|
def __init__(self, checkPoint: LoaderCheckPoint = None):
|
||||||
|
super().__init__()
|
||||||
|
self.checkPoint = checkPoint
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
return "FastChat"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _check_point(self) -> LoaderCheckPoint:
|
||||||
|
return self.checkPoint
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _history_len(self) -> int:
|
||||||
|
return self.history_len
|
||||||
|
|
||||||
|
def set_history_len(self, history_len: int = 10) -> None:
|
||||||
|
self.history_len = history_len
|
||||||
|
|
||||||
|
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _generate_answer(self, prompt: str,
|
||||||
|
history: List[List[str]] = [],
|
||||||
|
streaming: bool = False,
|
||||||
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
|
||||||
|
response = "fastchat 响应结果"
|
||||||
|
history += [[prompt, response]]
|
||||||
|
answer_result = AnswerResult()
|
||||||
|
answer_result.history = history
|
||||||
|
answer_result.llm_output = {"answer": response}
|
||||||
|
|
||||||
|
generate_with_callback(answer_result)
|
||||||
|
|
@ -8,28 +8,12 @@ from transformers.generation.logits_process import LogitsProcessor
|
||||||
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList
|
||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from models.extensions.callback import (Iteratorize, Stream, FixedLengthQueue)
|
|
||||||
import models.shared as shared
|
|
||||||
from models.base import (BaseAnswer,
|
from models.base import (BaseAnswer,
|
||||||
AnswerResult,
|
AnswerResult,
|
||||||
AnswerResultStream,
|
AnswerResultStream,
|
||||||
AnswerResultQueueSentinelTokenListenerQueue)
|
AnswerResultQueueSentinelTokenListenerQueue)
|
||||||
|
|
||||||
|
|
||||||
def _streaming_response_template() -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
:return: 响应结构
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"text": ""
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _update_response(response: Dict[str, Any], stream_response: str) -> None:
|
|
||||||
"""Update response from the stream response."""
|
|
||||||
response["text"] += stream_response
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||||
|
|
@ -105,16 +89,6 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||||
reply = self.checkPoint.tokenizer.decode(output_ids, skip_special_tokens=True)
|
reply = self.checkPoint.tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
def generate_with_callback(self, callback=None, **kwargs):
|
|
||||||
self.checkPoint.clear_torch_cache()
|
|
||||||
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
|
||||||
with torch.no_grad():
|
|
||||||
self.checkPoint.model.generate(**kwargs)
|
|
||||||
print("方法结束")
|
|
||||||
|
|
||||||
def generate_with_streaming(self, **kwargs):
|
|
||||||
return Iteratorize(self.generate_with_callback, kwargs)
|
|
||||||
|
|
||||||
# 将历史对话数组转换为文本格式
|
# 将历史对话数组转换为文本格式
|
||||||
def history_to_text(self, query):
|
def history_to_text(self, query):
|
||||||
formatted_history = ''
|
formatted_history = ''
|
||||||
|
|
@ -144,45 +118,6 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||||
|
|
||||||
return input_ids, position_ids, attention_mask
|
return input_ids, position_ids, attention_mask
|
||||||
|
|
||||||
def get_position_ids(self, input_ids: torch.LongTensor, mask_positions, device):
|
|
||||||
"""
|
|
||||||
注意力偏移量
|
|
||||||
:param input_ids:
|
|
||||||
:param mask_positions:
|
|
||||||
:param device:
|
|
||||||
:param use_gmasks:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
context_lengths = [seq.tolist().index(self.checkPoint.model_config.bos_token_id) for seq in input_ids]
|
|
||||||
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
|
||||||
for i, context_length in enumerate(context_lengths):
|
|
||||||
position_ids[i, context_length:] = mask_positions[i]
|
|
||||||
block_position_ids = [torch.cat((
|
|
||||||
torch.zeros(context_length, dtype=torch.long, device=device),
|
|
||||||
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
|
||||||
)) for context_length in context_lengths]
|
|
||||||
block_position_ids = torch.stack(block_position_ids, dim=0)
|
|
||||||
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
|
||||||
return position_ids
|
|
||||||
|
|
||||||
def get_masks(self, input_ids, device):
|
|
||||||
"""
|
|
||||||
获取注意力掩码
|
|
||||||
:param input_ids:
|
|
||||||
:param device:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
batch_size, seq_length = input_ids.shape
|
|
||||||
context_lengths = [seq.tolist().index(self.checkPoint.model_config.bos_token_id) for seq in input_ids]
|
|
||||||
attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
|
|
||||||
attention_mask.tril_()
|
|
||||||
for i, context_length in enumerate(context_lengths):
|
|
||||||
attention_mask[i, :, :context_length] = 1
|
|
||||||
attention_mask.unsqueeze_(1)
|
|
||||||
attention_mask = (attention_mask < 0.5).bool()
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
def generate_softprompt_history_tensors(self, query):
|
def generate_softprompt_history_tensors(self, query):
|
||||||
"""
|
"""
|
||||||
历史对话软提示
|
历史对话软提示
|
||||||
|
|
@ -222,11 +157,11 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||||
"eos_token_id": self.eos_token_id,
|
"eos_token_id": self.eos_token_id,
|
||||||
"logits_processor": self.logits_processor}
|
"logits_processor": self.logits_processor}
|
||||||
|
|
||||||
# 向量拼接
|
# 向量转换
|
||||||
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
|
input_ids = self.encode(prompt, add_bos_token=self.state['add_bos_token'], truncation_length=self.max_new_tokens)
|
||||||
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
|
# input_ids, position_ids, attention_mask = self.prepare_inputs_for_generation(input_ids=filler_input_ids)
|
||||||
|
|
||||||
# 对话模型prompt
|
|
||||||
gen_kwargs.update({'inputs': input_ids})
|
gen_kwargs.update({'inputs': input_ids})
|
||||||
# 注意力掩码
|
# 注意力掩码
|
||||||
# gen_kwargs.update({'attention_mask': attention_mask})
|
# gen_kwargs.update({'attention_mask': attention_mask})
|
||||||
|
|
@ -235,45 +170,13 @@ class LLamaLLM(BaseAnswer, LLM, ABC):
|
||||||
self.stopping_criteria = transformers.StoppingCriteriaList()
|
self.stopping_criteria = transformers.StoppingCriteriaList()
|
||||||
# 观测输出
|
# 观测输出
|
||||||
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
|
gen_kwargs.update({'stopping_criteria': self.stopping_criteria})
|
||||||
shared.stop_everything = False
|
|
||||||
stopped = False
|
|
||||||
response_template = _streaming_response_template()
|
|
||||||
|
|
||||||
# TODO 此流输出方法需要重写!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
output_ids = self.checkPoint.model.generate(**gen_kwargs)
|
||||||
# stopping_criteria方法不可控制 迭代器的变量无法共享
|
new_tokens = len(output_ids[0]) - len(input_ids[0])
|
||||||
with self.generate_with_streaming(**gen_kwargs) as generator:
|
reply = self.decode(output_ids[0][-new_tokens:])
|
||||||
last_reply_len = 0
|
print(f"response:{reply}")
|
||||||
reply_index = 0
|
self.history = self.history + [[None, reply]]
|
||||||
# Create a FixedLengthQueue with the desired stop sequence and a maximum length.
|
return reply
|
||||||
queue = FixedLengthQueue(stop)
|
|
||||||
for output in generator:
|
|
||||||
new_tokens = len(output) - len(input_ids[0])
|
|
||||||
reply = self.decode(output[-new_tokens:])
|
|
||||||
|
|
||||||
new_reply = len(reply) - last_reply_len
|
|
||||||
output_reply = reply[-new_reply:]
|
|
||||||
queue.add(reply_index, output_reply)
|
|
||||||
queue.contains_replace_sequence()
|
|
||||||
if stop:
|
|
||||||
pos = queue.contains_stop_sequence()
|
|
||||||
if pos != -1:
|
|
||||||
shared.stop_everything = True
|
|
||||||
stopped = True
|
|
||||||
|
|
||||||
#print(f"{reply_index}:reply {output_reply}")
|
|
||||||
english_reply = queue.put_replace_out(reply_index)
|
|
||||||
#print(f"{reply_index}:english_reply {english_reply}")
|
|
||||||
_update_response(response_template, english_reply)
|
|
||||||
last_reply_len = len(reply)
|
|
||||||
|
|
||||||
reply_index += 1
|
|
||||||
if new_tokens == self.max_new_tokens - 1 or stopped:
|
|
||||||
break
|
|
||||||
|
|
||||||
response = response_template['text']
|
|
||||||
print(f"response:{response}")
|
|
||||||
self.history = self.history + [[None, response]]
|
|
||||||
return response
|
|
||||||
|
|
||||||
def _generate_answer(self, prompt: str,
|
def _generate_answer(self, prompt: str,
|
||||||
history: List[List[str]] = [],
|
history: List[List[str]] = [],
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
from configs.model_config import *
|
||||||
|
|
||||||
|
|
||||||
# Additional argparse types
|
# Additional argparse types
|
||||||
|
|
@ -32,28 +32,25 @@ def dir_path(string):
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(prog='langchina-ChatGLM',
|
parser = argparse.ArgumentParser(prog='langchina-ChatGLM',
|
||||||
description='基于langchain和chatGML的LLM文档阅读器')
|
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain | '
|
||||||
|
'基于本地知识库的 ChatGLM 问答')
|
||||||
|
|
||||||
|
parser.add_argument('--no-remote-model', action='store_true', default=NO_REMOTE_MODEL, help='remote in the model on '
|
||||||
|
'loader checkpoint, '
|
||||||
parser.add_argument('--no-remote-model', action='store_true', default=False, help='remote in the model on loader checkpoint, if your load local model to add the ` --no-remote-model`')
|
'if your load local '
|
||||||
parser.add_argument('--model', type=str, default='chatglm-6b', help='Name of the model to load by default.')
|
'model to add the ` '
|
||||||
|
'--no-remote-model`')
|
||||||
|
parser.add_argument('--model', type=str, default=LLM_MODEL, help='Name of the model to load by default.')
|
||||||
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
||||||
parser.add_argument("--model-dir", type=str, default='model/', help="Path to directory with all the models")
|
parser.add_argument("--model-dir", type=str, default=MODEL_DIR, help="Path to directory with all the models")
|
||||||
parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras")
|
parser.add_argument("--lora-dir", type=str, default=LORA_DIR, help="Path to directory with all the loras")
|
||||||
|
|
||||||
# Accelerate/transformers
|
# Accelerate/transformers
|
||||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text. Warning: Training on CPU is extremely slow.')
|
parser.add_argument('--load-in-8bit', action='store_true', default=LOAD_IN_8BIT,
|
||||||
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
help='Load the model with 8-bit precision.')
|
||||||
parser.add_argument('--gpu-memory', type=str, nargs="+", help='Maxmimum GPU memory in GiB to be allocated per GPU. Example: --gpu-memory 10 for a single GPU, --gpu-memory 10 5 for two GPUs. You can also set values in MiB like --gpu-memory 3500MiB.')
|
parser.add_argument('--bf16', action='store_true', default=BF16,
|
||||||
parser.add_argument('--cpu-memory', type=str, help='Maximum CPU memory in GiB to allocate for offloaded weights. Same as above.')
|
help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
|
||||||
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args([])
|
args = parser.parse_args([])
|
||||||
# Generares dict with a default value for each argument
|
# Generares dict with a default value for each argument
|
||||||
DEFAULT_ARGS = vars(args)
|
DEFAULT_ARGS = vars(args)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,6 @@ from models.loader.args import parser
|
||||||
from models.loader import LoaderCheckPoint
|
from models.loader import LoaderCheckPoint
|
||||||
from configs.model_config import (llm_model_dict, LLM_MODEL)
|
from configs.model_config import (llm_model_dict, LLM_MODEL)
|
||||||
from models.base import BaseAnswer
|
from models.base import BaseAnswer
|
||||||
"""迭代器是否停止状态"""
|
|
||||||
stop_everything = False
|
|
||||||
|
|
||||||
loaderCheckPoint: LoaderCheckPoint = None
|
loaderCheckPoint: LoaderCheckPoint = None
|
||||||
|
|
||||||
|
|
@ -36,6 +34,9 @@ def loaderLLM(llm_model: str = None, no_remote_model: bool = False, use_ptuning_
|
||||||
|
|
||||||
loaderCheckPoint.model_path = llm_model_info["local_model_path"]
|
loaderCheckPoint.model_path = llm_model_info["local_model_path"]
|
||||||
|
|
||||||
|
if 'fastChat' in loaderCheckPoint.model_name:
|
||||||
|
loaderCheckPoint.unload_model()
|
||||||
|
else:
|
||||||
loaderCheckPoint.reload_model()
|
loaderCheckPoint.reload_model()
|
||||||
|
|
||||||
provides_class = getattr(sys.modules['models'], llm_model_info['provides'])
|
provides_class = getattr(sys.modules['models'], llm_model_info['provides'])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue