diff --git a/cudatest.py b/cudatest.py
index 08f7e1a..5d0b8bd 100644
--- a/cudatest.py
+++ b/cudatest.py
@@ -13,6 +13,6 @@ print(f"cuDNN 版本: {cudnn_version}")
# 检查是否可以访问 CUDA
if torch.cuda.is_available():
- print("CUDA is available. GPU name:", torch.cuda.get_device_name(0))
+ print("pip install sentence-transformers -i https://pypi.mirrors.ustc.edu.cn/simpleCUDA is available. GPU name:", torch.cuda.get_device_name(0))
else:
print("CUDA is not available. Please check your installation.")
\ No newline at end of file
diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py
index 71bf7f9..0f46550 100644
--- a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py
+++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py
@@ -104,8 +104,8 @@ def tavily_search(text, config, top_k):
raw_results = tavily_tool.run(text)
search_results = [{k: v for k, v in item.items() if k != 'url'} for item in raw_results]
- print("=== 完整搜索返回值 ===")
- print(search_results)
+ # print("=== 完整搜索返回值 ===")
+ # print(search_results)
return search_results
SEARCH_ENGINES = {
@@ -158,7 +158,7 @@ def search_engine(query: str, top_k:int=0, engine_name: str="", config: dict={})
)
docs = [x for x in search_result2docs(results, engine_name) if x.page_content and x.page_content.strip()]
- print(f"docs: {docs}")
+ print(f"len(docs): {len(docs)}")
return {"docs": docs, "search_engine": engine_name}
@@ -167,7 +167,7 @@ def search_internet(query: str = Field(description="query for Internet search"))
"""用这个工具实现获取世界、历史、实时新闻、或除电力系统之外的信息查询"""
try:
print(f"search_internet: query: {query}")
- return BaseToolOutput(search_engine(query=query), format=format_context)
+ return BaseToolOutput(data= search_engine(query=query), format=format_context)
except Exception as e:
logger.error(f"未知错误: {str(e)}")
return BaseToolOutput(f"搜索过程中发生未知错误,{str(e)}", format=format_context)
diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py
index e0fe1a3..79eb7fb 100644
--- a/libs/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py
+++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py
@@ -176,7 +176,7 @@ def format_context(self: BaseToolOutput) -> str:
doc = DocumentWithVSId.parse_obj(doc)
source_documents.append(doc.page_content)
- print(f"format_context: doc.page_content: {doc.page_content}")
+ # print(f"format_context: doc.page_content: {doc.page_content}")
if len(source_documents) == 0:
context = "没有找到相关文档,请更换关键词重试"
else:
diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py
index 510535a..eb73b63 100644
--- a/libs/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py
+++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py
@@ -12,15 +12,15 @@ from .tools_registry import BaseToolOutput, regist_tool
@regist_tool(title="天气查询")
def weather_check(
- city: str = Field(description="城市名称,包括市和县,例如 '厦门'"),
- date: str = Field(
- default=None,
- description="日期参数,支持以下格式:\n"
- "- '今天':获取当前实时天气\n"
- "- '明天'/'后天':获取未来24/48小时预报\n"
- "- '未来X天':获取最多X天预报(如'未来3天'),X的抽取要符合客户意图\n"
- "- 不支持其他参数,如果是其他参数,则时间参数为None"
- )
+ city: str = Field(description="城市名称,包括市和县,例如 '厦门'"),
+ date: str = Field(
+ default=None,
+ description="日期参数,支持以下格式:\n"
+ "- '今天':获取当前实时天气\n"
+ "- '明天'/'后天':获取未来24/48小时预报\n"
+ "- '未来X天':获取最多X天预报(如'未来3天'),X的抽取要符合客户意图\n"
+ "- 不支持其他参数,如果是其他参数,则时间参数为None\n"
+ )
):
"""用这个工具获取指定地点和指定时间的天气"""
@@ -32,23 +32,22 @@ def weather_check(
missing_params.append("日期参数")
if missing_params:
- return BaseToolOutput(
- error_message=f"缺少必要参数:{', '.join(missing_params)},请补充完整查询信息",
- require_additional_input=True
- )
+ return BaseToolOutput(data={"error_message": f"缺少必要参数:{', '.join(missing_params)},请补充完整查询信息"},
+ require_additional_input=True
+ )
print(f"city:{city}, date:{date}")
try:
weather_type, number = parse_date_parameter(date)
except ValueError as e:
logging.error(f"日期参数解析失败: {str(e)}")
- return BaseToolOutput(str(e))
+ return BaseToolOutput(data={"error_message": str(e)})
# 获取API配置
tool_config = get_tool_config("weather_check")
api_key = tool_config.get("api_key")
if not api_key:
- return BaseToolOutput("API密钥未配置,请联系管理员")
+ return BaseToolOutput(data={"error_message": "API密钥未配置,请联系管理员"})
# 根据天气类型调用API
if weather_type == "daily":
@@ -56,7 +55,8 @@ def weather_check(
elif weather_type == "future":
return _get_future_weather(city, api_key, number)
else:
- return BaseToolOutput("不支持的天气类型")
+ return BaseToolOutput(data={"error_message": "不支持的天气类型"})
+
def _get_current_weather(city: str, api_key: str) -> BaseToolOutput:
"""获取当前实时天气"""
@@ -66,14 +66,15 @@ def _get_current_weather(city: str, api_key: str) -> BaseToolOutput:
if response.status_code != 200:
logging.error(f"天气查询失败: {response.status_code}")
- return BaseToolOutput("天气查询API请求失败")
+ return BaseToolOutput(data={"error_message": "天气查询API请求失败"})
data = response.json()
weather = {
"temperature": data["results"][0]["now"]["temperature"],
"description": data["results"][0]["now"]["text"],
}
- return BaseToolOutput(weather)
+ return BaseToolOutput(data=weather)
+
def _get_future_weather(city: str, api_key: str, days: int) -> BaseToolOutput:
"""获取未来天气预报"""
@@ -115,9 +116,10 @@ def _get_future_weather(city: str, api_key: str, days: int) -> BaseToolOutput:
"后天最高温度": daily_data[2]["high"],
}
else:
- return BaseToolOutput("不支持的天数参数")
+ return BaseToolOutput(data={"error_message": "不支持的天数参数"})
+
+ return BaseToolOutput(data=weather)
- return BaseToolOutput(weather)
def parse_date_parameter(date: str) -> tuple:
"""解析日期参数,返回天气类型和天数"""
@@ -136,5 +138,6 @@ def parse_date_parameter(date: str) -> tuple:
else:
raise ValueError("不支持的日期参数")
+
if __name__ == "__main__":
- weather_check("合肥","明天")
\ No newline at end of file
+ weather_check("合肥", "明天")
diff --git a/libs/chatchat-server/chatchat/server/api_server/chat_routes.py b/libs/chatchat-server/chatchat/server/api_server/chat_routes.py
index 8b8476e..b7e91ce 100644
--- a/libs/chatchat-server/chatchat/server/api_server/chat_routes.py
+++ b/libs/chatchat-server/chatchat/server/api_server/chat_routes.py
@@ -65,7 +65,7 @@ async def chat_completions(
# import rich
# rich.print(body)
# 当调用本接口且 body 中没有传入 "max_tokens" 参数时, 默认使用配置中定义的值
- logger.info(f"body.model_config:{body.model_config},body.tools: {body.tools},body.messages:{body.messages}")
+ # logger.info(f"body.model_config:{body.model_config},body.tools: {body.tools},body.messages:{body.messages}")
if body.max_tokens in [None, 0]:
body.max_tokens = Settings.model_settings.MAX_TOKENS
diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py
index 0d42a2b..2e8d6d0 100644
--- a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py
+++ b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py
@@ -70,6 +70,9 @@ def list_files_from_folder(kb_name: str):
for x in ["temp", "tmp", ".", "~$"]:
if tail.startswith(x):
return True
+ if "_source.txt" in tail.lower() or "_split.txt" in tail.lower():
+ return True
+
return False
def process_entry(entry):
@@ -422,15 +425,15 @@ class KnowledgeFile:
docs = zh_first_title_enhance(docs)
docs = customize_zh_title_enhance(docs)
- # i = 1
- # outputfile = file_name_without_extension + "_split.txt"
- # # 打开文件以写入模式
- # with open(outputfile, 'w') as file:
- # for doc in docs:
- # #print(f"**********切分段{i}:{doc}")
- # file.write(f"\n**********切分段{i}")
- # file.write(doc.page_content)
- # i = i+1
+ i = 1
+ outputfile = file_name_without_extension + "_split.txt"
+ # 打开文件以写入模式
+ with open(outputfile, 'w') as file:
+ for doc in docs:
+ #print(f"**********切分段{i}:{doc}")
+ file.write(f"\n**********切分段{i}")
+ file.write(doc.page_content)
+ i = i+1
self.splited_docs = docs
return self.splited_docs
@@ -537,7 +540,8 @@ def format_reference(kb_name: str, docs: List[Dict], api_base_url: str="") -> Li
f"{api_base_url}/knowledge_base/download_doc?" + parameters
)
page_content = doc.get("page_content")
- ref = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{page_content}\n\n"""
+ ref = f"""出处 [{inum + 1}] {filename}\n\n{page_content}\n\n"""
+ # ref = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{page_content}\n\n"""
source_documents.append(ref)
return source_documents
diff --git a/libs/chatchat-server/chatchat/settings.py b/libs/chatchat-server/chatchat/settings.py
index 2ff1b5d..3d59e4b 100644
--- a/libs/chatchat-server/chatchat/settings.py
+++ b/libs/chatchat-server/chatchat/settings.py
@@ -517,7 +517,7 @@ class ToolSettings(BaseFileSettings):
},
"top_k": 5,
"verbose": "Origin",
- "conclude_prompt": "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 "
+ "conclude_prompt": "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题,不得包含有重复的词汇或句子。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 "
"指令>\n<已知信息>{{ context }}已知信息>\n"
"<问题>\n"
"{{ question }}\n"
@@ -657,7 +657,7 @@ class PromptSettings(BaseFileSettings):
rag: dict = {
"default": (
- "【指令】根据已知信息,简洁和专业的来回答问题。"
+ "【指令】根据已知信息,简洁和专业的来回答问题,不得包含有重复的词汇或句子。"
"如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。\n\n"
"【已知信息】{{context}}\n\n"
"【问题】{{question}}\n"