必要参数校验
This commit is contained in:
parent
1e2124ff54
commit
22d08f5ec5
|
|
@ -16,7 +16,7 @@ embedding_model_dict = {
|
||||||
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
|
"ernie-tiny": "nghuyong/ernie-3.0-nano-zh",
|
||||||
"ernie-base": "nghuyong/ernie-3.0-base-zh",
|
"ernie-base": "nghuyong/ernie-3.0-base-zh",
|
||||||
"text2vec-base": "shibing624/text2vec-base-chinese",
|
"text2vec-base": "shibing624/text2vec-base-chinese",
|
||||||
"text2vec": "GanymedeNil/text2vec-large-chinese",
|
"text2vec": "/media/checkpoint/text2vec-large-chinese/",
|
||||||
"m3e-small": "moka-ai/m3e-small",
|
"m3e-small": "moka-ai/m3e-small",
|
||||||
"m3e-base": "moka-ai/m3e-base",
|
"m3e-base": "moka-ai/m3e-base",
|
||||||
}
|
}
|
||||||
|
|
@ -186,7 +186,7 @@ llm_model_dict = {
|
||||||
}
|
}
|
||||||
|
|
||||||
# LLM 名称
|
# LLM 名称
|
||||||
LLM_MODEL = "chatglm-6b"
|
LLM_MODEL = "fastchat-chatglm-6b"
|
||||||
# 量化加载8bit 模型
|
# 量化加载8bit 模型
|
||||||
LOAD_IN_8BIT = False
|
LOAD_IN_8BIT = False
|
||||||
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
|
# Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.
|
||||||
|
|
|
||||||
|
|
@ -52,15 +52,20 @@ def build_message_list(query, history: List[List[str]]) -> Collection[Dict[str,
|
||||||
system_build_message['role'] = 'system'
|
system_build_message['role'] = 'system'
|
||||||
system_build_message['content'] = "You are a helpful assistant."
|
system_build_message['content'] = "You are a helpful assistant."
|
||||||
build_messages.append(system_build_message)
|
build_messages.append(system_build_message)
|
||||||
|
if history:
|
||||||
|
for i, (user, assistant) in enumerate(history):
|
||||||
|
if user:
|
||||||
|
|
||||||
for i, (old_query, response) in enumerate(history):
|
|
||||||
user_build_message = _build_message_template()
|
user_build_message = _build_message_template()
|
||||||
user_build_message['role'] = 'user'
|
user_build_message['role'] = 'user'
|
||||||
user_build_message['content'] = old_query
|
user_build_message['content'] = user
|
||||||
|
build_messages.append(user_build_message)
|
||||||
|
|
||||||
|
if not assistant:
|
||||||
|
raise RuntimeError("历史数据结构不正确")
|
||||||
system_build_message = _build_message_template()
|
system_build_message = _build_message_template()
|
||||||
system_build_message['role'] = 'assistant'
|
system_build_message['role'] = 'assistant'
|
||||||
system_build_message['content'] = response
|
system_build_message['content'] = assistant
|
||||||
build_messages.append(user_build_message)
|
|
||||||
build_messages.append(system_build_message)
|
build_messages.append(system_build_message)
|
||||||
|
|
||||||
user_build_message = _build_message_template()
|
user_build_message = _build_message_template()
|
||||||
|
|
@ -181,10 +186,10 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||||
generate_with_callback: AnswerResultStream = None) -> None:
|
generate_with_callback: AnswerResultStream = None) -> None:
|
||||||
|
|
||||||
history = inputs[self.history_key]
|
history = inputs.get(self.history_key, [])
|
||||||
streaming = inputs[self.streaming_key]
|
streaming = inputs.get(self.streaming_key, False)
|
||||||
prompt = inputs[self.prompt_key]
|
prompt = inputs[self.prompt_key]
|
||||||
stop = inputs.get("stop", None)
|
stop = inputs.get("stop", "stop")
|
||||||
print(f"__call:{prompt}")
|
print(f"__call:{prompt}")
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
|
@ -205,16 +210,18 @@ class FastChatOpenAILLMChain(RemoteRpcModel, Chain, ABC):
|
||||||
params = {"stream": streaming,
|
params = {"stream": streaming,
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"stop": stop}
|
"stop": stop}
|
||||||
|
out_str = ""
|
||||||
for stream_resp in self.completion_with_retry(
|
for stream_resp in self.completion_with_retry(
|
||||||
messages=msg,
|
messages=msg,
|
||||||
**params
|
**params
|
||||||
):
|
):
|
||||||
role = stream_resp["choices"][0]["delta"].get("role", "")
|
role = stream_resp["choices"][0]["delta"].get("role", "")
|
||||||
token = stream_resp["choices"][0]["delta"].get("content", "")
|
token = stream_resp["choices"][0]["delta"].get("content", "")
|
||||||
history += [[prompt, token]]
|
out_str += token
|
||||||
|
history[-1] = [prompt, out_str]
|
||||||
answer_result = AnswerResult()
|
answer_result = AnswerResult()
|
||||||
answer_result.history = history
|
answer_result.history = history
|
||||||
answer_result.llm_output = {"answer": token}
|
answer_result.llm_output = {"answer": out_str}
|
||||||
generate_with_callback(answer_result)
|
generate_with_callback(answer_result)
|
||||||
else:
|
else:
|
||||||
|
|
||||||
|
|
@ -239,10 +246,10 @@ if __name__ == "__main__":
|
||||||
chain = FastChatOpenAILLMChain()
|
chain = FastChatOpenAILLMChain()
|
||||||
|
|
||||||
chain.set_api_key("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")
|
chain.set_api_key("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")
|
||||||
chain.set_api_base_url("https://api.openai.com/v1")
|
# chain.set_api_base_url("https://api.openai.com/v1")
|
||||||
chain.call_model_name("gpt-3.5-turbo")
|
# chain.call_model_name("gpt-3.5-turbo")
|
||||||
|
|
||||||
answer_result_stream_result = chain({"streaming": False,
|
answer_result_stream_result = chain({"streaming": True,
|
||||||
"prompt": "你好",
|
"prompt": "你好",
|
||||||
"history": []
|
"history": []
|
||||||
})
|
})
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue