Langchain-Chatchat/libs/chatchat-server/chatchat/server/agent/tools_factory/text2sql.py

141 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from langchain.chains import LLMChain
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts.prompt import PromptTemplate
from langchain_experimental.sql import SQLDatabaseChain, SQLDatabaseSequentialChain
from sqlalchemy import event
from sqlalchemy.exc import OperationalError
from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config
from .tools_registry import BaseToolOutput, regist_tool
READ_ONLY_PROMPT_TEMPLATE = """You are a MySQL expert. The database is currently in read-only mode.
Given an input question, determine if the related SQL can be executed in read-only mode.
If the SQL can be executed normally, return Answer:'SQL can be executed normally'.
If the SQL cannot be executed normally, return Answer: 'SQL cannot be executed normally'.
Use the following format:
Answer: Final answer here
Question: {query}
"""
# 定义一个拦截器函数来检查SQL语句以支持read-only,可修改下面的write_operations以匹配你使用的数据库写操作关键字
def intercept_sql(conn, cursor, statement, parameters, context, executemany):
# List of SQL keywords that indicate a write operation
write_operations = (
"insert",
"update",
"delete",
"create",
"drop",
"alter",
"truncate",
"rename",
)
# Check if the statement starts with any of the write operation keywords
if any(statement.strip().lower().startswith(op) for op in write_operations):
raise OperationalError(
"Database is read-only. Write operations are not allowed.",
params=None,
orig=None,
)
def query_database(query: str, config: dict):
model_name= config["model_name"]
top_k = config["top_k"]
return_intermediate_steps = config["return_intermediate_steps"]
sqlalchemy_connect_str = config["sqlalchemy_connect_str"]
read_only = config["read_only"]
db = SQLDatabase.from_uri(sqlalchemy_connect_str)
from chatchat.server.utils import get_ChatOpenAI
llm = get_ChatOpenAI(
model_name=model_name,
temperature=0.1,
streaming=True,
local_wrap=True,
verbose=True,
)
table_names = config["table_names"]
table_comments = config["table_comments"]
result = None
# 如果发现大模型判断用什么表出现问题尝试给langchain提供额外的表说明辅助大模型更好的判断应该使用哪些表尤其是SQLDatabaseSequentialChain模式下,是根据表名做的预测,很容易误判
# 由于langchain固定了输入参数所以只能通过query传递额外的表说明
if table_comments:
TABLE_COMMNET_PROMPT = (
"\n\nI will provide some special notes for a few tables:\n\n"
)
table_comments_str = "\n".join([f"{k}:{v}" for k, v in table_comments.items()])
query = query + TABLE_COMMNET_PROMPT + table_comments_str + "\n\n"
if read_only:
# 在read_only下先让大模型判断只读模式是否能满足需求避免后续执行过程报错返回友好提示。
READ_ONLY_PROMPT = PromptTemplate(
input_variables=["query"],
template=READ_ONLY_PROMPT_TEMPLATE,
)
read_only_chain = LLMChain(
prompt=READ_ONLY_PROMPT,
llm=llm,
)
read_only_result = read_only_chain.invoke(query)
if "SQL cannot be executed normally" in read_only_result["text"]:
return "当前数据库为只读状态,无法满足您的需求!"
# 当然大模型不能保证完全判断准确,为防止大模型判断有误,再从拦截器层面拒绝写操作
event.listen(db._engine, "before_cursor_execute", intercept_sql)
# 如果不指定table_names优先走SQLDatabaseSequentialChain这个链会先预测需要哪些表然后再将相关表输入SQLDatabaseChain
# 这是因为如果不指定table_names直接走SQLDatabaseChainLangchain会将全量表结构传递给大模型可能会因token太长从而引发错误也浪费资源
# 如果指定了table_names直接走SQLDatabaseChain将特定表结构传递给大模型进行判断
if len(table_names) > 0:
db_chain = SQLDatabaseChain.from_llm(
llm,
db,
verbose=True,
top_k=top_k,
return_intermediate_steps=return_intermediate_steps,
)
result = db_chain.invoke({"query": query, "table_names_to_use": table_names})
else:
# 先预测会使用哪些表,然后再将问题和预测的表给大模型
db_chain = SQLDatabaseSequentialChain.from_llm(
llm,
db,
verbose=True,
top_k=top_k,
return_intermediate_steps=return_intermediate_steps,
)
result = db_chain.invoke(query)
context = f"""查询结果:{result['result']}\n\n"""
intermediate_steps = result["intermediate_steps"]
# 如果存在intermediate_steps且这个数组的长度大于2则保留最后两个元素因为前面几个步骤存在示例数据容易引起误解
if intermediate_steps:
if len(intermediate_steps) > 2:
sql_detail = intermediate_steps[-2:-1][0]["input"]
# sql_detail截取从SQLQuery到Answer:之间的内容
sql_detail = sql_detail[
sql_detail.find("SQLQuery:") + 9 : sql_detail.find("Answer:")
]
context = context + "执行的sql:'" + sql_detail + "'\n\n"
return context
# @regist_tool(title="数据库对话")
def text2sql(
query: str = Field(
description="No need for SQL statements,just input the natural language that you want to chat with database"
),
):
"""Use this tool to chat with database,Input natural language, then it will convert it into SQL and execute it in the database, then return the execution result."""
tool_config = get_tool_config("text2sql")
return BaseToolOutput(query_database(query=query, config=tool_config))