diff --git a/cudatest.py b/cudatest.py index 08f7e1a..5d0b8bd 100644 --- a/cudatest.py +++ b/cudatest.py @@ -13,6 +13,6 @@ print(f"cuDNN 版本: {cudnn_version}") # 检查是否可以访问 CUDA if torch.cuda.is_available(): - print("CUDA is available. GPU name:", torch.cuda.get_device_name(0)) + print("pip install sentence-transformers -i https://pypi.mirrors.ustc.edu.cn/simpleCUDA is available. GPU name:", torch.cuda.get_device_name(0)) else: print("CUDA is not available. Please check your installation.") \ No newline at end of file diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py index 71bf7f9..0f46550 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/search_internet.py @@ -104,8 +104,8 @@ def tavily_search(text, config, top_k): raw_results = tavily_tool.run(text) search_results = [{k: v for k, v in item.items() if k != 'url'} for item in raw_results] - print("=== 完整搜索返回值 ===") - print(search_results) + # print("=== 完整搜索返回值 ===") + # print(search_results) return search_results SEARCH_ENGINES = { @@ -158,7 +158,7 @@ def search_engine(query: str, top_k:int=0, engine_name: str="", config: dict={}) ) docs = [x for x in search_result2docs(results, engine_name) if x.page_content and x.page_content.strip()] - print(f"docs: {docs}") + print(f"len(docs): {len(docs)}") return {"docs": docs, "search_engine": engine_name} @@ -167,7 +167,7 @@ def search_internet(query: str = Field(description="query for Internet search")) """用这个工具实现获取世界、历史、实时新闻、或除电力系统之外的信息查询""" try: print(f"search_internet: query: {query}") - return BaseToolOutput(search_engine(query=query), format=format_context) + return BaseToolOutput(data= search_engine(query=query), format=format_context) except Exception as e: logger.error(f"未知错误: {str(e)}") return BaseToolOutput(f"搜索过程中发生未知错误,{str(e)}", format=format_context) diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py index e0fe1a3..79eb7fb 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/tools_registry.py @@ -176,7 +176,7 @@ def format_context(self: BaseToolOutput) -> str: doc = DocumentWithVSId.parse_obj(doc) source_documents.append(doc.page_content) - print(f"format_context: doc.page_content: {doc.page_content}") + # print(f"format_context: doc.page_content: {doc.page_content}") if len(source_documents) == 0: context = "没有找到相关文档,请更换关键词重试" else: diff --git a/libs/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py b/libs/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py index 510535a..eb73b63 100644 --- a/libs/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py +++ b/libs/chatchat-server/chatchat/server/agent/tools_factory/weather_check.py @@ -12,15 +12,15 @@ from .tools_registry import BaseToolOutput, regist_tool @regist_tool(title="天气查询") def weather_check( - city: str = Field(description="城市名称,包括市和县,例如 '厦门'"), - date: str = Field( - default=None, - description="日期参数,支持以下格式:\n" - "- '今天':获取当前实时天气\n" - "- '明天'/'后天':获取未来24/48小时预报\n" - "- '未来X天':获取最多X天预报(如'未来3天'),X的抽取要符合客户意图\n" - "- 不支持其他参数,如果是其他参数,则时间参数为None" - ) + city: str = Field(description="城市名称,包括市和县,例如 '厦门'"), + date: str = Field( + default=None, + description="日期参数,支持以下格式:\n" + "- '今天':获取当前实时天气\n" + "- '明天'/'后天':获取未来24/48小时预报\n" + "- '未来X天':获取最多X天预报(如'未来3天'),X的抽取要符合客户意图\n" + "- 不支持其他参数,如果是其他参数,则时间参数为None\n" + ) ): """用这个工具获取指定地点和指定时间的天气""" @@ -32,23 +32,22 @@ def weather_check( missing_params.append("日期参数") if missing_params: - return BaseToolOutput( - error_message=f"缺少必要参数:{', '.join(missing_params)},请补充完整查询信息", - require_additional_input=True - ) + return BaseToolOutput(data={"error_message": f"缺少必要参数:{', '.join(missing_params)},请补充完整查询信息"}, + require_additional_input=True + ) print(f"city:{city}, date:{date}") try: weather_type, number = parse_date_parameter(date) except ValueError as e: logging.error(f"日期参数解析失败: {str(e)}") - return BaseToolOutput(str(e)) + return BaseToolOutput(data={"error_message": str(e)}) # 获取API配置 tool_config = get_tool_config("weather_check") api_key = tool_config.get("api_key") if not api_key: - return BaseToolOutput("API密钥未配置,请联系管理员") + return BaseToolOutput(data={"error_message": "API密钥未配置,请联系管理员"}) # 根据天气类型调用API if weather_type == "daily": @@ -56,7 +55,8 @@ def weather_check( elif weather_type == "future": return _get_future_weather(city, api_key, number) else: - return BaseToolOutput("不支持的天气类型") + return BaseToolOutput(data={"error_message": "不支持的天气类型"}) + def _get_current_weather(city: str, api_key: str) -> BaseToolOutput: """获取当前实时天气""" @@ -66,14 +66,15 @@ def _get_current_weather(city: str, api_key: str) -> BaseToolOutput: if response.status_code != 200: logging.error(f"天气查询失败: {response.status_code}") - return BaseToolOutput("天气查询API请求失败") + return BaseToolOutput(data={"error_message": "天气查询API请求失败"}) data = response.json() weather = { "temperature": data["results"][0]["now"]["temperature"], "description": data["results"][0]["now"]["text"], } - return BaseToolOutput(weather) + return BaseToolOutput(data=weather) + def _get_future_weather(city: str, api_key: str, days: int) -> BaseToolOutput: """获取未来天气预报""" @@ -115,9 +116,10 @@ def _get_future_weather(city: str, api_key: str, days: int) -> BaseToolOutput: "后天最高温度": daily_data[2]["high"], } else: - return BaseToolOutput("不支持的天数参数") + return BaseToolOutput(data={"error_message": "不支持的天数参数"}) + + return BaseToolOutput(data=weather) - return BaseToolOutput(weather) def parse_date_parameter(date: str) -> tuple: """解析日期参数,返回天气类型和天数""" @@ -136,5 +138,6 @@ def parse_date_parameter(date: str) -> tuple: else: raise ValueError("不支持的日期参数") + if __name__ == "__main__": - weather_check("合肥","明天") \ No newline at end of file + weather_check("合肥", "明天") diff --git a/libs/chatchat-server/chatchat/server/api_server/chat_routes.py b/libs/chatchat-server/chatchat/server/api_server/chat_routes.py index 8b8476e..b7e91ce 100644 --- a/libs/chatchat-server/chatchat/server/api_server/chat_routes.py +++ b/libs/chatchat-server/chatchat/server/api_server/chat_routes.py @@ -65,7 +65,7 @@ async def chat_completions( # import rich # rich.print(body) # 当调用本接口且 body 中没有传入 "max_tokens" 参数时, 默认使用配置中定义的值 - logger.info(f"body.model_config:{body.model_config},body.tools: {body.tools},body.messages:{body.messages}") + # logger.info(f"body.model_config:{body.model_config},body.tools: {body.tools},body.messages:{body.messages}") if body.max_tokens in [None, 0]: body.max_tokens = Settings.model_settings.MAX_TOKENS diff --git a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py index 0d42a2b..2e8d6d0 100644 --- a/libs/chatchat-server/chatchat/server/knowledge_base/utils.py +++ b/libs/chatchat-server/chatchat/server/knowledge_base/utils.py @@ -70,6 +70,9 @@ def list_files_from_folder(kb_name: str): for x in ["temp", "tmp", ".", "~$"]: if tail.startswith(x): return True + if "_source.txt" in tail.lower() or "_split.txt" in tail.lower(): + return True + return False def process_entry(entry): @@ -422,15 +425,15 @@ class KnowledgeFile: docs = zh_first_title_enhance(docs) docs = customize_zh_title_enhance(docs) - # i = 1 - # outputfile = file_name_without_extension + "_split.txt" - # # 打开文件以写入模式 - # with open(outputfile, 'w') as file: - # for doc in docs: - # #print(f"**********切分段{i}:{doc}") - # file.write(f"\n**********切分段{i}") - # file.write(doc.page_content) - # i = i+1 + i = 1 + outputfile = file_name_without_extension + "_split.txt" + # 打开文件以写入模式 + with open(outputfile, 'w') as file: + for doc in docs: + #print(f"**********切分段{i}:{doc}") + file.write(f"\n**********切分段{i}") + file.write(doc.page_content) + i = i+1 self.splited_docs = docs return self.splited_docs @@ -537,7 +540,8 @@ def format_reference(kb_name: str, docs: List[Dict], api_base_url: str="") -> Li f"{api_base_url}/knowledge_base/download_doc?" + parameters ) page_content = doc.get("page_content") - ref = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{page_content}\n\n""" + ref = f"""出处 [{inum + 1}] {filename}\n\n{page_content}\n\n""" + # ref = f"""出处 [{inum + 1}] [{filename}]({url}) \n\n{page_content}\n\n""" source_documents.append(ref) return source_documents diff --git a/libs/chatchat-server/chatchat/settings.py b/libs/chatchat-server/chatchat/settings.py index 2ff1b5d..3d59e4b 100644 --- a/libs/chatchat-server/chatchat/settings.py +++ b/libs/chatchat-server/chatchat/settings.py @@ -517,7 +517,7 @@ class ToolSettings(BaseFileSettings): }, "top_k": 5, "verbose": "Origin", - "conclude_prompt": "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 " + "conclude_prompt": "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题,不得包含有重复的词汇或句子。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 " "\n<已知信息>{{ context }}\n" "<问题>\n" "{{ question }}\n" @@ -657,7 +657,7 @@ class PromptSettings(BaseFileSettings): rag: dict = { "default": ( - "【指令】根据已知信息,简洁和专业的来回答问题。" + "【指令】根据已知信息,简洁和专业的来回答问题,不得包含有重复的词汇或句子。" "如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。\n\n" "【已知信息】{{context}}\n\n" "【问题】{{question}}\n"