Compare commits

..

2 Commits

Author SHA1 Message Date
weiweiw 22c666b4bf 去掉不用的tool 2025-02-14 07:40:25 +08:00
weiweiw b78edb72a1 大模型意图识别提示词 2025-02-13 10:19:08 +08:00
39 changed files with 502 additions and 576 deletions

40
.gitignore vendored
View File

@ -1,40 +0,0 @@
*.csv
*.yaml
*.xlsx
*.pdf
*.txt
*.log
*.pyc
/chatchat_data.bak
/chatchat_data/data/knowledge_base/samples
/chatchat_data
.idea/inspectionProfiles/profiles_settings.xml
.idea/Langchain-Chatchat.iml
.idea/misc.xml
.idea/modules.xml
.idea/prettier.xml
.idea/vcs.xml
.idea/inspectionProfiles/profiles_settings.xml
.idea/Langchain-Chatchat.iml
.idea/modules.xml
.idea/prettier.xml
.idea/vcs.xml
/.idea
/test_tool
chatchat_data/tool_settings.yaml
chatchat_data/prompt_settings.yaml
chatchat_data/model_settings.yaml
chatchat_data/basic_settings.yaml
localconfig/data/knowledge_base/samples/content/分布式训练技术原理.md
localconfig/data/knowledge_base/samples/content/大模型应用技术原理.md
localconfig/data/knowledge_base/samples/content/大模型技术栈-实战与应用.md
localconfig/data/knowledge_base/samples/content/大模型技术栈-算法与原理.md
localconfig/data/knowledge_base/samples/content/大模型指令对齐训练原理.md
localconfig/data/knowledge_base/samples/content/大模型推理优化策略.md
localconfig/data/knowledge_base/samples/vector_store/bge-large-zh-v1.5/index.faiss
localconfig/data/knowledge_base/samples/vector_store/bge-large-zh-v1.5/index.pkl
localconfig/data/knowledge_base/info.db
chatchat_data/basic_settings.yaml
chatchat_data/model_settings.yaml
chatchat_data/prompt_settings.yaml
chatchat_data/tool_settings.yaml

8
.idea/.gitignore vendored
View File

@ -1,8 +0,0 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View File

@ -1,7 +0,0 @@
<component name="ProjectDictionaryState">
<dictionary name="Guan">
<words>
<w>aggrid</w>
</words>
</dictionary>
</component>

View File

@ -2,7 +2,7 @@
# 默认选用的 LLM 名称 # 默认选用的 LLM 名称
DEFAULT_LLM_MODEL: qwen2.5-instruct DEFAULT_LLM_MODEL: qwen2-instruct
# 默认选用的 Embedding 名称 # 默认选用的 Embedding 名称
DEFAULT_EMBEDDING_MODEL: bge-large-zh-v1.5 DEFAULT_EMBEDDING_MODEL: bge-large-zh-v1.5
@ -112,78 +112,78 @@ LLM_MODEL_CONFIG:
MODEL_PLATFORMS: MODEL_PLATFORMS:
- platform_name: xinference - platform_name: xinference
platform_type: xinference platform_type: xinference
api_base_url: http://192.168.0.21:9997/v1 api_base_url: http://127.0.0.1:9997/v1
api_key: EMPTY api_key: EMPTY
api_proxy: '' api_proxy: ''
api_concurrencies: 5 api_concurrencies: 5
auto_detect_model: true auto_detect_model: true
llm_models: [qwen2.5-instruct] llm_models: []
embed_models: [bge-large-zh-v1.5] embed_models: []
text2image_models: []
image2text_models: []
rerank_models: [bge-reranker-large]
speech2text_models: []
text2speech_models: []
- platform_name: ollama
platform_type: ollama
api_base_url: http://127.0.0.1:11434/v1
api_key: EMPTY
api_proxy: ''
api_concurrencies: 5
auto_detect_model: false
llm_models:
- qwen:7b
- qwen2:7b
embed_models:
- quentinz/bge-large-zh-v1.5
text2image_models: []
image2text_models: []
rerank_models: []
speech2text_models: []
text2speech_models: []
- platform_name: oneapi
platform_type: oneapi
api_base_url: http://127.0.0.1:3000/v1
api_key: sk-
api_proxy: ''
api_concurrencies: 5
auto_detect_model: false
llm_models:
- chatglm_pro
- chatglm_turbo
- chatglm_std
- chatglm_lite
- qwen-turbo
- qwen-plus
- qwen-max
- qwen-max-longcontext
- ERNIE-Bot
- ERNIE-Bot-turbo
- ERNIE-Bot-4
- SparkDesk
embed_models:
- text-embedding-v1
- Embedding-V1
text2image_models: []
image2text_models: []
rerank_models: []
speech2text_models: []
text2speech_models: []
- platform_name: openai
platform_type: openai
api_base_url: https://api.openai.com/v1
api_key: sk-proj-
api_proxy: ''
api_concurrencies: 5
auto_detect_model: false
llm_models:
- gpt-4o
- gpt-3.5-turbo
embed_models:
- text-embedding-3-small
- text-embedding-3-large
text2image_models: [] text2image_models: []
image2text_models: [] image2text_models: []
rerank_models: [] rerank_models: []
speech2text_models: [] speech2text_models: []
text2speech_models: [] text2speech_models: []
# - platform_name: ollama
# platform_type: ollama
# api_base_url: http://127.0.0.1:11434/v1
# api_key: EMPTY
# api_proxy: ''
# api_concurrencies: 5
# auto_detect_model: false
# llm_models:
# - qwen:7b
# - qwen2:7b
# embed_models:
# - quentinz/bge-large-zh-v1.5
# text2image_models: []
# image2text_models: []
# rerank_models: []
# speech2text_models: []
# text2speech_models: []
# - platform_name: oneapi
# platform_type: oneapi
# api_base_url: http://127.0.0.1:3000/v1
# api_key: sk-
# api_proxy: ''
# api_concurrencies: 5
# auto_detect_model: false
# llm_models:
# - chatglm_pro
# - chatglm_turbo
# - chatglm_std
# - chatglm_lite
# - qwen-turbo
# - qwen-plus
# - qwen-max
# - qwen-max-longcontext
# - ERNIE-Bot
# - ERNIE-Bot-turbo
# - ERNIE-Bot-4
# - SparkDesk
# embed_models:
# - text-embedding-v1
# - Embedding-V1
# text2image_models: []
# image2text_models: []
# rerank_models: []
# speech2text_models: []
# text2speech_models: []
# - platform_name: openai
# platform_type: openai
# api_base_url: https://api.openai.com/v1
# api_key: sk-proj-
# api_proxy: ''
# api_concurrencies: 5
# auto_detect_model: false
# llm_models:
# - gpt-4o
# - gpt-3.5-turbo
# embed_models:
# - text-embedding-3-small
# - text-embedding-3-large
# text2image_models: []
# image2text_models: []
# rerank_models: []
# speech2text_models: []
# text2speech_models: []

View File

@ -14,7 +14,7 @@ search_local_knowledgebase:
# 搜索引擎工具配置项。推荐自己部署 searx 搜索引擎,国内使用最方便。 # 搜索引擎工具配置项。推荐自己部署 searx 搜索引擎,国内使用最方便。
search_internet: search_internet:
use: false use: false
search_engine_name: zhipu_search search_engine_name: searx
search_engine_config: search_engine_config:
bing: bing:
bing_search_url: https://api.bing.microsoft.com/v7.0/search bing_search_url: https://api.bing.microsoft.com/v7.0/search
@ -30,14 +30,6 @@ search_internet:
engines: [] engines: []
categories: [] categories: []
language: zh-CN language: zh-CN
tavily:
tavily_api_key: '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'
include_answer: true
search_depth: advanced
include_raw_content: True
max_results: 1
zhipu_search:
zhipu_api_key: 'e2bdc39618624fd782ebcd721185645c.pcvcrTPFT69Jda8B'
top_k: 5 top_k: 5
verbose: Origin verbose: Origin
conclude_prompt: "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 conclude_prompt: "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。

View File

@ -13,6 +13,6 @@ print(f"cuDNN 版本: {cudnn_version}")
# 检查是否可以访问 CUDA # 检查是否可以访问 CUDA
if torch.cuda.is_available(): if torch.cuda.is_available():
print("pip install sentence-transformers -i https://pypi.mirrors.ustc.edu.cn/simpleCUDA is available. GPU name:", torch.cuda.get_device_name(0)) print("CUDA is available. GPU name:", torch.cuda.get_device_name(0))
else: else:
print("CUDA is not available. Please check your installation.") print("CUDA is not available. Please check your installation.")

View File

@ -0,0 +1,111 @@
你是一名意图识别专家,任务是根据用户输入提取意图并提取相关的参数信息。意图分为以下六类:
1.日计划数量
2.日计划作业内容
3.周计划数量
4.周计划作业内容
5.页面操作
6.其他
模版结构与提取要求
1. 意图 1 和 2日计划相关
1)提取参数的要求如下:
a.时间 (time):必须字段,缺失时提示用户输入时间。
b.工程名称 (project):去除 "工程" 后的内容。
c.公司名称 (company):去除 "公司" 后的内容。
d.项目(部)名称 (program):去除 "项目" 或 "项目部" 后的内容。
e.项目经理名称 (manager):去除 "项目经理" 后的内容。
f.班组名称 (class):去除 "班组" 后的内容。
g.风险等级 (risk):限定为 "一"、"二"、"三"、"四"、"五"、"六"。
2)返回格式:
{
"intention": "日计划数量",
"time": "时间",
"program": "项目(部)名称",
"company": "公司名称",
"project": "工程名称",
"manager": "项目经理名称",
"class": "班组名称",
"risk": "风险等级"
}
3)未提取到的字段:不包含在结果中。
4)时间缺失时:提示用户输入特定时间。
5)风险等级无效时:提示用户提供有效风险等级("一" 到 "六")。
2. 意图 3 和 4周计划相关
1)提取参数的要求如下:
a.与日计划相同,增加施工状态 (status),限定为:"未开始","进行中"和"已结束"
2)返回格式:
{
"intention": "周计划数量",
"time": "时间",
"program": "项目(部)名称",
"company": "公司名称",
"project": "工程名称",
"manager": "项目经理名称",
"class": "班组名称",
"risk": "风险等级",
"status": "施工状态"
}
3时间缺失时提示用户输入特定时间。
4风险等级无效或施工状态不匹配时提示用户提供有效值。
3. 意图 5页面操作
1提取参数的要求如下
操作类型 (action):存储 "打开" 或 "切换"。若用户输入单一名词,默认为 "切换"。
模块名称 (module):去除 "页面"、"模块"、"菜单" 后的部分内容。
2返回格式
{
"intention": "页面操作",
"action": "操作类型",
"module": "模块名称"
}
4. 意图 6其他
1提取参数不需要有任务要求。
2返回格式
{
"intention": "其他",
"content": "用户输入的原始内容"
}
5.示例
1示例 1
用户输入'今天送变电一公司1号工程B项目5号班组有多少项二级风险作业计划',
返回:
{
'intention': '日计划数量',
'time': '今天',
'company': '变电一',
'project': '1号',
'program': 'B',
'class': '5号,
'risk': '二'
}
2示例 2
本周1号项目部多少项一级风险作业计划正在施工
返回:
{
"intention": "周计划数量",
"time": "本周",
"program": "1号",
"risk": "一",
"status": "进行中"
}
3示例 3
用户输入:
切换到首页
返回:
{
"intention": "页面操作",
"action": "切换",
"module": "首页"
}
4示例 4
用户输入:
你好,请帮我查一下
返回:
{
"intention": "其他",
"content": "你好,请帮我查一下"
}

View File

@ -0,0 +1,144 @@
你是一名意图识别专家任务是根据用户输入提取意图并提取相关的参数信息意图分为以下9类
1.日计划数量 - 用户询问日计划的数量相关。
2.日计划作业内容 - 用户询问日计划的作业内容相关。
3.周计划数量 - 用户询问周计划的数量相关。
4.周计划作业内容 - 用户询问周计划的作业内容相关。
5.页面操作 - 用户希望打开或跳转具体页面。
6.联网查询 - 用户要求获取世界、历史、实时新闻、或除电力系统之外的信息。
7.天气查询 - 用户要求查某地方某时间的天气。
8.知识库查询 - 用户寻找特定的信息或知识,如国家电网各部门规章制度、安徽送变电规章制度等相关的问题,需要通过知识库来回答。
9.其他 - 无法匹配到以上的几个意图,要求用户根据补充问题。
模版结构与提取要求
1. 意图 1 和 2日计划相关
1)提取参数的要求如下:
a.时间 (time):必须字段,缺失时提示用户输入时间。
b.工程名称 (project):去除 '工程' 后的内容。
c.公司名称 (company):去除 '公司' 后的内容。
d.项目(部)名称 (program):去除 '项目' 或 '项目部' 后的内容。
e.项目经理名称 (manager):去除 '项目经理' 后的内容。
f.班组名称 (class):去除 '班组' 后的内容。
g.风险等级 (risk):限定为 '一'、'二'、'三'、'四'、'五'、'六'。
2)返回格式:
{
'intention': '日计划数量',
'time': '时间',
'program': '项目(部)名称',
'company': '公司名称',
'project': '工程名称',
'manager': '项目经理名称',
'class': '班组名称',
'risk': '风险等级'
}
3)未提取到的字段:不包含在结果中。
4)时间缺失时:提示用户输入特定时间。
5)风险等级无效时:提示用户提供有效风险等级('一' 到 '六')。
2. 意图 3 和 4周计划相关
1)提取参数的要求如下:
a.与日计划相同,增加施工状态 (status),限定为:'未开始','进行中'和'已结束'
2)返回格式:
{
'intention': '周计划数量',
'time': '时间',
'program': '项目(部)名称',
'company': '公司名称',
'project': '工程名称',
'manager': '项目经理名称',
'class': '班组名称',
'risk': '风险等级',
'status': '施工状态'
}
3时间缺失时提示用户输入特定时间。
4风险等级无效或施工状态不匹配时提示用户提供有效值。
3. 意图 5页面操作
1提取参数的要求如下
操作类型 (action):存储 '打开' 或 '切换'。若用户输入单一名词,默认为 '切换'。
模块名称 (module):去除 '页面'、'模块'、'菜单' 后的部分内容。
2返回格式
{
'intention': '页面操作',
'action': '操作类型',
'module': '模块名称'
}
4.意图 6联网查询
1不需要提取任何参数。
2返回格式
{
'intention': '联网查询',
}
5. 意图7天气查询
1不需要提取任何参数。
2返回格式
{
'intention': '天气查询'
}
6.意图8知识库查询
1不需要提取任何参数。
2返回格式
{
'intention': '知识库查询',
}
7. 意图 9其他
1不需要提取任何参数。
2返回格式
{
'intention': '其他',
'content': '用户输入的原始内容'
}
5.示例
1示例 1
用户输入'今天送变电一公司1号工程B项目5号班组有多少项二级风险作业计划',
返回:
{
'intention': '日计划数量',
'time': '今天',
'company': '变电一',
'project': '1号',
'program': 'B',
'class': '5号,
'risk': '二'
}
2示例 2
本周1号项目部多少项一级风险作业计划正在施工
返回:
{
'intention': '周计划数量',
'time': '本周',
'program': '1号',
'risk': '一',
'status': '进行中'
}
3示例 3
用户输入:
切换到首页
返回:
{
'intention': '页面操作',
'action': '切换',
'module': '首页'
}
4示例 4
用户输入:
本周合肥会有降雨吗?
返回:
{
'intention': '天气查询'
}
5示例 5
用户输入:
你好,请帮我查一下
返回:
{
'intention': '其他',
'content': '你好,请帮我查一下'
}

View File

@ -0,0 +1,24 @@
你是一名意图识别专家,任务是根据用户输入提取意图。意图分为以下六类:
1.日计划数量
2.日计划作业内容
3.周计划数量
4.周计划作业内容
5.页面操作
6.联网查询
7.天气
8.知识库查询
你是一名意图识别专家,任务是根据下面提供的用户输入,确定其对应的意图类别。意图类别包括:
1. 日计划数量 - 用户询问日计划的数量相关。
2. 日计划作业内容 - 用户询问日计划的作业内容相关。
3. 周计划数量 - 用户询问周计划的数量相关。
4. 周计划作业内容 - 用户询问周计划的作业内容相关。
5. 页面操作 - 用户希望打开或跳转具体页面。
6. 联网查询或天气 - 用户要求获取世界、历史、实时新闻、天气或除电力系统之外的信息。
7. 知识库查询 - 用户寻找特定的信息或知识,如国家电网各部门规章制度、安徽送变电规章制度等相关的问题,需要通过知识库来回答。
规则:
- 对于每个输入请明确指出它属于类别的序号123467一定不要有其他多余描述。
- 尽可能从用户输入中确定以上类别中的一个,如果无法确定,请用户补充更多信息。

View File

@ -0,0 +1,72 @@
你是一名意图识别专家任务是根据用户输入提取意图并提取相关的参数信息意图分为以下9类
1.日计划数量 - 用户询问日计划的数量相关。
2.日计划作业内容 - 用户询问日计划的作业内容相关。
3.周计划数量 - 用户询问周计划的数量相关。
4.周计划作业内容 - 用户询问周计划的作业内容相关。
模版结构与提取要求
1. 意图 1 和 2日计划相关
1)提取参数的要求如下:
a.时间 (time):必须字段,缺失时提示用户输入时间。
b.工程名称 (project):去除 '工程' 后的内容。
c.公司名称 (company):去除 '公司' 后的内容。
d.项目(部)名称 (program):去除 '项目' 或 '项目部' 后的内容。
e.项目经理名称 (manager):去除 '项目经理' 后的内容。
f.班组名称 (class):去除 '班组' 后的内容。
g.风险等级 (risk):限定为 '一'、'二'、'三'、'四'、'五'、'六'。
2)返回格式:
{
'intention': '日计划数量',
'time': '时间',
'program': '项目(部)名称',
'company': '公司名称',
'project': '工程名称',
'manager': '项目经理名称',
'class': '班组名称',
'risk': '风险等级'
}
3)未提取到的字段:不包含在结果中。
4)时间缺失时:提示用户输入特定时间。
5)风险等级无效时:提示用户提供有效风险等级('一' 到 '六')。
2. 意图 3 和 4周计划相关
1)提取参数的要求如下:
a.与日计划相同,增加施工状态 (status),限定为:'未开始','进行中'和'已结束'
2)返回格式:
{
'intention': '周计划数量',
'time': '时间',
'program': '项目(部)名称',
'company': '公司名称',
'project': '工程名称',
'manager': '项目经理名称',
'class': '班组名称',
'risk': '风险等级',
'status': '施工状态'
}
3时间缺失时提示用户输入特定时间。
4风险等级无效或施工状态不匹配时提示用户提供有效值。
5.示例
1示例 1
用户输入'今天送变电一公司1号工程B项目5号班组有多少项二级风险作业计划',
返回:
{
'intention': '日计划数量',
'time': '今天',
'company': '变电一',
'project': '1号',
'program': 'B',
'class': '5号,
'risk': '二'
}
2示例 2
本周1号项目部多少项一级风险作业计划正在施工
返回:
{
'intention': '周计划数量',
'time': '本周',
'program': '1号',
'risk': '一',
'status': '进行中'
}

View File

@ -20,7 +20,7 @@ def amap_poi_search_engine(keywords: str,types: str,config: dict):
#@regist_tool(title="高德地图POI搜索") # @regist_tool(title="高德地图POI搜索")
def amap_poi_search(location: str = Field(description="'实际地名'或者'具体的地址',不能使用简称或者别称"), def amap_poi_search(location: str = Field(description="'实际地名'或者'具体的地址',不能使用简称或者别称"),
types: str = Field(description="POI类型比如商场、学校、医院等等")): types: str = Field(description="POI类型比如商场、学校、医院等等")):
""" A wrapper that uses Amap to search.""" """ A wrapper that uses Amap to search."""

View File

@ -36,7 +36,7 @@ def get_weather(adcode: str, config: dict) -> dict:
else: else:
return {"error": "API request failed"} return {"error": "API request failed"}
#@regist_tool(title="高德地图天气查询") # @regist_tool(title="高德地图天气查询")
def amap_weather(city: str = Field(description="城市名")): def amap_weather(city: str = Field(description="城市名")):
"""A wrapper that uses Amap to get weather information.""" """A wrapper that uses Amap to get weather information."""
tool_config = get_tool_config("amap") tool_config = get_tool_config("amap")

View File

@ -4,7 +4,7 @@ from chatchat.server.pydantic_v1 import Field
from .tools_registry import BaseToolOutput, regist_tool from .tools_registry import BaseToolOutput, regist_tool
#@regist_tool(title="ARXIV论文") # @regist_tool(title="ARXIV论文")
def arxiv(query: str = Field(description="The search query title")): def arxiv(query: str = Field(description="The search query title")):
"""A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields.""" """A wrapper around Arxiv.org for searching and retrieving scientific articles in various fields."""
from langchain.tools.arxiv.tool import ArxivQueryRun from langchain.tools.arxiv.tool import ArxivQueryRun

View File

@ -3,7 +3,7 @@ from chatchat.server.pydantic_v1 import Field
from .tools_registry import BaseToolOutput, regist_tool from .tools_registry import BaseToolOutput, regist_tool
#@regist_tool(title="数学计算器") # @regist_tool(title="数学计算器")
def calculate(text: str = Field(description="a math expression")) -> float: def calculate(text: str = Field(description="a math expression")) -> float:
""" """
Useful to answer questions about simple calculations. Useful to answer questions about simple calculations.

View File

@ -1,8 +1,5 @@
import json
import uuid
from typing import Dict, List from typing import Dict, List
import requests
from langchain.docstore.document import Document from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.utilities.bing_search import BingSearchAPIWrapper from langchain.utilities.bing_search import BingSearchAPIWrapper
@ -10,21 +7,15 @@ from langchain.utilities.duckduckgo_search import DuckDuckGoSearchAPIWrapper
from langchain.utilities.searx_search import SearxSearchWrapper from langchain.utilities.searx_search import SearxSearchWrapper
from markdownify import markdownify from markdownify import markdownify
from strsimpy.normalized_levenshtein import NormalizedLevenshtein from strsimpy.normalized_levenshtein import NormalizedLevenshtein
from langchain_community.tools.tavily_search import TavilySearchResults
import os
from chatchat.settings import Settings from chatchat.settings import Settings
from chatchat.server.pydantic_v1 import Field from chatchat.server.pydantic_v1 import Field
from chatchat.server.utils import get_tool_config from chatchat.server.utils import get_tool_config
from chatchat.utils import build_logger
# from tavily import TavilyClient
from .tools_registry import BaseToolOutput, regist_tool, format_context from .tools_registry import BaseToolOutput, regist_tool, format_context
logger = build_logger()
def searx_search(text ,config, top_k: int):
def searx_search(text, config, top_k: int):
print(f"searx_search: text: {text},config:{config},top_k:{top_k}") print(f"searx_search: text: {text},config:{config},top_k:{top_k}")
search = SearxSearchWrapper( search = SearxSearchWrapper(
searx_host=config["host"], searx_host=config["host"],
@ -35,7 +26,7 @@ def searx_search(text, config, top_k: int):
return search.results(text, top_k) return search.results(text, top_k)
def bing_search(text, config, top_k: int): def bing_search(text, config, top_k:int):
search = BingSearchAPIWrapper( search = BingSearchAPIWrapper(
bing_subscription_key=config["bing_key"], bing_subscription_key=config["bing_key"],
bing_search_url=config["bing_search_url"], bing_search_url=config["bing_search_url"],
@ -43,15 +34,15 @@ def bing_search(text, config, top_k: int):
return search.results(text, top_k) return search.results(text, top_k)
def duckduckgo_search(text, config, top_k: int): def duckduckgo_search(text, config, top_k:int):
search = DuckDuckGoSearchAPIWrapper() search = DuckDuckGoSearchAPIWrapper()
return search.results(text, top_k) return search.results(text, top_k)
def metaphor_search( def metaphor_search(
text: str, text: str,
config: dict, config: dict,
top_k: int top_k:int
) -> List[Dict]: ) -> List[Dict]:
from metaphor_python import Metaphor from metaphor_python import Metaphor
@ -94,77 +85,21 @@ def metaphor_search(
return docs return docs
def tavily_search(text, config, top_k):
# 配置tavily api key
os.environ["TAVILY_API_KEY"] = config["tavily_api_key"]
# 初始化工具(配置参数)
tavily_tool = TavilySearchResults(
include_answer=config["include_answer"], # 关键参数:启用答案生成
search_depth=config["search_depth"], # 必须使用高级搜索模式
include_raw_content=config["include_raw_content"],
max_results=config["max_results"]
)
# 直接执行搜索
raw_results = tavily_tool.run(text)
search_results = [{k: v for k, v in item.items() if k != 'url'} for item in raw_results]
# print("=== 完整搜索返回值 ===")
# print(search_results)
return search_results
def zhipu_search(text, config, top_k):
api_key = config["zhipu_api_key"]
endpoint = "https://open.bigmodel.cn/api/paas/v4/web_search"
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
payload = {
"search_engine": "Search-Pro", # 指定Web搜索专用模型
"search_query": text
}
response = requests.post(endpoint, headers=headers, json=payload)
result = response.json()
print(f"================!! result: {result}")
return result
SEARCH_ENGINES = { SEARCH_ENGINES = {
"bing": bing_search, "bing": bing_search,
"duckduckgo": duckduckgo_search, "duckduckgo": duckduckgo_search,
"metaphor": metaphor_search, "metaphor": metaphor_search,
"searx": searx_search, "searx": searx_search,
"tavily": tavily_search,
"zhipu_search": zhipu_search
} }
def search_result2docs(search_results, engine_name, top_k) -> List[Document]: def search_result2docs(search_results) -> List[Document]:
docs = [] docs = []
if engine_name == "zhipu_search":
try:
# search_results_json = json.loads(search_results)
results = search_results["search_result"]
except (KeyError, IndexError) as e:
print(f"结构异常: {e}")
results = []
# 遍历并处理每个结果
for item in results[:top_k]:
doc = Document(
page_content=item['content'],
metadata={"link": item['link'], "title": item['title']}
)
docs.append(doc)
return docs
page_contents_key = "snippet" if engine_name != "tavily" else "content"
metadata_key = "link" if engine_name != "tavily" else "url"
for result in search_results: for result in search_results:
doc = Document( doc = Document(
page_content=result[page_contents_key] if page_contents_key in result.keys() else "", page_content=result["snippet"] if "snippet" in result.keys() else "",
metadata={ metadata={
"source": result[metadata_key] if metadata_key in result.keys() else "", "source": result["link"] if "link" in result.keys() else "",
"filename": result["title"] if "title" in result.keys() else "", "filename": result["title"] if "title" in result.keys() else "",
}, },
) )
@ -172,7 +107,7 @@ def search_result2docs(search_results, engine_name, top_k) -> List[Document]:
return docs return docs
def search_engine(query: str, top_k: int = 0, engine_name: str = "", config: dict = {}): def search_engine(query: str, top_k:int=0, engine_name: str="", config: dict={}):
config = config or get_tool_config("search_internet") config = config or get_tool_config("search_internet")
if top_k <= 0: if top_k <= 0:
top_k = config.get("top_k", Settings.kb_settings.SEARCH_ENGINE_TOP_K) top_k = config.get("top_k", Settings.kb_settings.SEARCH_ENGINE_TOP_K)
@ -182,20 +117,12 @@ def search_engine(query: str, top_k: int = 0, engine_name: str = "", config: dic
results = search_engine_use( results = search_engine_use(
text=query, config=config["search_engine_config"][engine_name], top_k=top_k text=query, config=config["search_engine_config"][engine_name], top_k=top_k
) )
docs = [x for x in search_result2docs(results) if x.page_content and x.page_content.strip()]
docs = [x for x in search_result2docs(results, engine_name, top_k) if x.page_content and x.page_content.strip()]
print(f"len(docs): {len(docs)}")
# print(f"docs:{docs}")
# # print(f"docs: {docs[:150]}")
return {"docs": docs, "search_engine": engine_name} return {"docs": docs, "search_engine": engine_name}
@regist_tool(title="互联网搜索") @regist_tool(title="互联网搜索")
def search_internet(query: str = Field(description="query for Internet search")): def search_internet(query: str = Field(description="query for Internet search")):
"""用这个工具实现获取世界、历史、实时新闻、或除电力系统之外的信息查询""" """用这个工具实现获取世界、历史、实时新闻、或除电力系统之外的信息查询"""
try: print(f"search_internet: query: {query}")
print(f"search_internet: query: {query}") return BaseToolOutput(search_engine(query=query), format=format_context)
return BaseToolOutput(data=search_engine(query=query), format=format_context)
except Exception as e:
logger.error(f"未知错误: {str(e)}")
return BaseToolOutput(f"搜索过程中发生未知错误,{str(e)}", format=format_context)

View File

@ -3,7 +3,7 @@ from chatchat.server.pydantic_v1 import Field
from .tools_registry import BaseToolOutput, regist_tool from .tools_registry import BaseToolOutput, regist_tool
#@regist_tool(title="油管视频") # @regist_tool(title="油管视频")
def search_youtube(query: str = Field(description="Query for Videos search")): def search_youtube(query: str = Field(description="Query for Videos search")):
"""use this tools_factory to search youtube videos""" """use this tools_factory to search youtube videos"""
from langchain_community.tools import YouTubeSearchTool from langchain_community.tools import YouTubeSearchTool

View File

@ -6,7 +6,7 @@ from chatchat.server.pydantic_v1 import Field
from .tools_registry import BaseToolOutput, regist_tool from .tools_registry import BaseToolOutput, regist_tool
#@regist_tool(title="系统命令") # @regist_tool(title="系统命令")
def shell(query: str = Field(description="The command to execute")): def shell(query: str = Field(description="The command to execute")):
"""Use Shell to execute system shell commands""" """Use Shell to execute system shell commands"""
tool = ShellTool() tool = ShellTool()

View File

@ -14,7 +14,7 @@ from chatchat.server.utils import MsgType, get_tool_config, get_model_info
from .tools_registry import BaseToolOutput, regist_tool from .tools_registry import BaseToolOutput, regist_tool
#@regist_tool(title="文生图", return_direct=True) # @regist_tool(title="文生图", return_direct=True)
def text2images( def text2images(
prompt: str, prompt: str,
n: int = Field(1, description="需生成图片的数量"), n: int = Field(1, description="需生成图片的数量"),

View File

@ -108,7 +108,7 @@ def query_prometheus(query: str, config: dict) -> str:
return content return content
#@regist_tool(title="Prometheus对话") # @regist_tool(title="Prometheus对话")
def text2promql( def text2promql(
query: str = Field( query: str = Field(
description="Tool for querying a Prometheus server, No need for PromQL statements, " description="Tool for querying a Prometheus server, No need for PromQL statements, "

View File

@ -129,7 +129,7 @@ def query_database(query: str, config: dict):
return context return context
#@regist_tool(title="数据库对话") # @regist_tool(title="数据库对话")
def text2sql( def text2sql(
query: str = Field( query: str = Field(
description="No need for SQL statements,just input the natural language that you want to chat with database" description="No need for SQL statements,just input the natural language that you want to chat with database"

View File

@ -176,7 +176,7 @@ def format_context(self: BaseToolOutput) -> str:
doc = DocumentWithVSId.parse_obj(doc) doc = DocumentWithVSId.parse_obj(doc)
source_documents.append(doc.page_content) source_documents.append(doc.page_content)
# print(f"format_context: doc.page_content: {doc.page_content}") print(f"format_context: doc.page_content: {doc.page_content}")
if len(source_documents) == 0: if len(source_documents) == 0:
context = "没有找到相关文档,请更换关键词重试" context = "没有找到相关文档,请更换关键词重试"
else: else:

View File

@ -13,7 +13,7 @@ from chatchat.server.agent.tools_factory.tools_registry import format_context
from .tools_registry import BaseToolOutput, regist_tool from .tools_registry import BaseToolOutput, regist_tool
#@regist_tool(title="URL内容阅读") # @regist_tool(title="URL内容阅读")
def url_reader( def url_reader(
url: str = Field( url: str = Field(
description="The URL to be processed, so that its web content can be made more clear to read. Then provide a detailed description of the content in about 500 words. As structured as possible. ONLY THE LINK SHOULD BE PASSED IN."), description="The URL to be processed, so that its web content can be made more clear to read. Then provide a detailed description of the content in about 500 words. As structured as possible. ONLY THE LINK SHOULD BE PASSED IN."),

View File

@ -12,132 +12,26 @@ from .tools_registry import BaseToolOutput, regist_tool
@regist_tool(title="天气查询") @regist_tool(title="天气查询")
def weather_check( def weather_check(
city: str = Field(description="城市名称,包括市和县,例如 '厦门'"), city: str = Field(description="City name,include city and county,like '厦门'"),
date: str = Field(
default=None,
description="日期参数,支持以下格式:\n"
"- '今天':获取当前实时天气\n"
"- '明天'/'后天'获取未来24/48小时预报\n"
"- '未来X天'获取最多X天预报'未来3天',X的抽取要符合客户意图\n"
"- 不支持其他参数如果是其他参数则时间参数为None\n"
)
): ):
"""用这个工具获取指定地点和指定时间的天气""" """用这个工具获取指定地点和指定时间的天气"""
# """Use this tool to check the weather at a specific city"""
# 参数校验 print(f"weather_check tool内部调用city{city}")
missing_params = []
if not city:
missing_params.append("城市名称")
if not date:
missing_params.append("日期参数")
if missing_params:
return BaseToolOutput(data={"error_message": f"缺少必要参数:{', '.join(missing_params)},请补充完整查询信息"},
require_additional_input=True
)
print(f"city:{city}, date:{date}")
try:
weather_type, number = parse_date_parameter(date)
except ValueError as e:
logging.error(f"日期参数解析失败: {str(e)}")
return BaseToolOutput(data={"error_message": str(e)})
# 获取API配置
tool_config = get_tool_config("weather_check") tool_config = get_tool_config("weather_check")
api_key = tool_config.get("api_key") api_key = tool_config.get("api_key")
if not api_key:
return BaseToolOutput(data={"error_message": "API密钥未配置请联系管理员"})
# 根据天气类型调用API
if weather_type == "daily":
return _get_current_weather(city, api_key)
elif weather_type == "future":
return _get_future_weather(city, api_key, number)
else:
return BaseToolOutput(data={"error_message": "不支持的天气类型"})
def _get_current_weather(city: str, api_key: str) -> BaseToolOutput:
"""获取当前实时天气"""
url = f"http://api.seniverse.com/v3/weather/now.json?key={api_key}&location={city}&language=zh-Hans&unit=c" url = f"http://api.seniverse.com/v3/weather/now.json?key={api_key}&location={city}&language=zh-Hans&unit=c"
logging.info(f"请求URL: {url}") logging.info(f"url:{url}")
response = requests.get(url) response = requests.get(url)
if response.status_code == 200:
if response.status_code != 200: data = response.json()
logging.error(f"天气查询失败: {response.status_code}") logging.info(f"response.json():{data}")
return BaseToolOutput(data={"error_message": "天气查询API请求失败"})
data = response.json()
weather = {
"temperature": data["results"][0]["now"]["temperature"],
"description": data["results"][0]["now"]["text"],
}
return BaseToolOutput(data=weather)
def _get_future_weather(city: str, api_key: str, days: int) -> BaseToolOutput:
"""获取未来天气预报"""
url = f"http://api.seniverse.com/v3/weather/daily.json?key={api_key}&location={city}&language=zh-Hans&unit=c"
logging.info(f"请求URL: {url}")
response = requests.get(url)
if response.status_code != 200:
logging.error(f"天气查询失败: {response.status_code}")
return BaseToolOutput("天气查询API请求失败")
data = response.json()
daily_data = data["results"][0]["daily"]
if days == 1:
weather = { weather = {
"date": "明天", "temperature": data["results"][0]["now"]["temperature"],
"low_temperature": daily_data[1]["low"], "description": data["results"][0]["now"]["text"],
"high_temperature": daily_data[1]["high"],
"description": daily_data[1]["text_day"],
}
elif days == 2:
weather = {
"date": "后天",
"low_temperature": daily_data[2]["low"],
"high_temperature": daily_data[2]["high"],
"description": daily_data[2]["text_day"],
}
elif days == 3:
weather = {
"今天天气": daily_data[0]["text_day"],
"今天最低温度": daily_data[0]["low"],
"今天最高温度": daily_data[0]["high"],
"明天天气": daily_data[1]["text_day"],
"明天最低温度": daily_data[1]["low"],
"明天最高温度": daily_data[1]["high"],
"后天天气": daily_data[2]["text_day"],
"后天最低温度": daily_data[2]["low"],
"后天最高温度": daily_data[2]["high"],
} }
return BaseToolOutput(weather)
else: else:
return BaseToolOutput(data={"error_message": "不支持的天数参数"}) logging.error(f"Failed to retrieve weather: {response.status_code}")
raise Exception(f"Failed to retrieve weather: {response.status_code}")
return BaseToolOutput(data=weather)
def parse_date_parameter(date: str) -> tuple:
"""解析日期参数,返回天气类型和天数"""
if date == "今天":
return "daily", 1
elif date == "明天":
return "future", 1
elif date == "后天":
return "future", 2
elif date.startswith("未来") and date.endswith(""):
days = int(date[2:-1])
if 1 <= days <= 3:
return "future", days
else:
raise ValueError("未来预报仅支持1-3天")
else:
raise ValueError("不支持的日期参数")
if __name__ == "__main__":
weather_check("合肥", "明天")

View File

@ -8,7 +8,7 @@ from chatchat.server.pydantic_v1 import Field
from .tools_registry import BaseToolOutput, regist_tool from .tools_registry import BaseToolOutput, regist_tool
#@regist_tool(title="维基百科搜索") # @regist_tool(title="维基百科搜索")
def wikipedia_search(query: str = Field(description="The search query")): def wikipedia_search(query: str = Field(description="The search query")):
""" A wrapper that uses Wikipedia to search.""" """ A wrapper that uses Wikipedia to search."""

View File

@ -6,7 +6,7 @@ from chatchat.server.utils import get_tool_config
from .tools_registry import BaseToolOutput, regist_tool from .tools_registry import BaseToolOutput, regist_tool
#@regist_tool # @regist_tool
def wolfram(query: str = Field(description="The formula to be calculated")): def wolfram(query: str = Field(description="The formula to be calculated")):
"""Useful for when you need to calculate difficult formulas""" """Useful for when you need to calculate difficult formulas"""

View File

@ -65,7 +65,7 @@ async def chat_completions(
# import rich # import rich
# rich.print(body) # rich.print(body)
# 当调用本接口且 body 中没有传入 "max_tokens" 参数时, 默认使用配置中定义的值 # 当调用本接口且 body 中没有传入 "max_tokens" 参数时, 默认使用配置中定义的值
# logger.info(f"body.model_config:{body.model_config},body.tools: {body.tools},body.messages:{body.messages}") logger.info(f"body.model_config:{body.model_config},body.tools: {body.tools},body.messages:{body.messages}")
if body.max_tokens in [None, 0]: if body.max_tokens in [None, 0]:
body.max_tokens = Settings.model_settings.MAX_TOKENS body.max_tokens = Settings.model_settings.MAX_TOKENS

View File

@ -22,7 +22,7 @@ from chatchat.server.utils import (wrap_done, get_ChatOpenAI, get_default_llm,
BaseResponse, get_prompt_template, build_logger, BaseResponse, get_prompt_template, build_logger,
check_embed_model, api_address check_embed_model, api_address
) )
import time
logger = build_logger() logger = build_logger()
@ -60,8 +60,6 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
return_direct: bool = Body(False, description="直接返回检索结果,不送入 LLM"), return_direct: bool = Body(False, description="直接返回检索结果,不送入 LLM"),
request: Request = None, request: Request = None,
): ):
logger.info(f"kb_chat:,mode {mode}")
start_time = time.time()
if mode == "local_kb": if mode == "local_kb":
kb = KBServiceFactory.get_service_by_name(kb_name) kb = KBServiceFactory.get_service_by_name(kb_name)
if kb is None: if kb is None:
@ -69,8 +67,6 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
async def knowledge_base_chat_iterator() -> AsyncIterable[str]: async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
try: try:
logger.info(f"***********************************knowledge_base_chat_iterator:,mode {mode}")
start_time1 = time.time()
nonlocal history, prompt_name, max_tokens nonlocal history, prompt_name, max_tokens
history = [History.from_data(h) for h in history] history = [History.from_data(h) for h in history]
@ -78,10 +74,8 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
if mode == "local_kb": if mode == "local_kb":
kb = KBServiceFactory.get_service_by_name(kb_name) kb = KBServiceFactory.get_service_by_name(kb_name)
ok, msg = kb.check_embed_model() ok, msg = kb.check_embed_model()
logger.info(f"***********************************knowledge_base_chat_iterator:,mode {mode}kb_name{kb_name}")
if not ok: if not ok:
raise ValueError(msg) raise ValueError(msg)
# docs = search_docs( query = query,knowledge_base_name = kb_name,top_k = top_k, score_threshold = score_threshold,)
docs = await run_in_threadpool(search_docs, docs = await run_in_threadpool(search_docs,
query=query, query=query,
knowledge_base_name=kb_name, knowledge_base_name=kb_name,
@ -89,13 +83,7 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
score_threshold=score_threshold, score_threshold=score_threshold,
file_name="", file_name="",
metadata={}) metadata={})
source_documents = format_reference(kb_name, docs, api_address(is_public=True)) source_documents = format_reference(kb_name, docs, api_address(is_public=True))
# logger.info(
# f"***********************************knowledge_base_chat_iterator:,after format_reference:{docs}")
end_time1 = time.time()
execution_time1 = end_time1 - start_time1
logger.info(f"kb_chat Execution time检索完成: {execution_time1:.6f} seconds")
elif mode == "temp_kb": elif mode == "temp_kb":
ok, msg = check_embed_model() ok, msg = check_embed_model()
if not ok: if not ok:
@ -151,7 +139,6 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
if max_tokens in [None, 0]: if max_tokens in [None, 0]:
max_tokens = Settings.model_settings.MAX_TOKENS max_tokens = Settings.model_settings.MAX_TOKENS
start_time1 = time.time()
llm = get_ChatOpenAI( llm = get_ChatOpenAI(
model_name=model, model_name=model,
temperature=temperature, temperature=temperature,
@ -236,12 +223,6 @@ async def kb_chat(query: str = Body(..., description="用户输入", examples=["
return return
if stream: if stream:
eventSource = EventSourceResponse(knowledge_base_chat_iterator()) return EventSourceResponse(knowledge_base_chat_iterator())
# 记录结束时间
end_time = time.time()
# 计算执行时间
execution_time = end_time - start_time
logger.info(f"final kb_chat Execution time: {execution_time:.6f} seconds")
return eventSource
else: else:
return await knowledge_base_chat_iterator().__anext__() return await knowledge_base_chat_iterator().__anext__()

View File

@ -32,7 +32,6 @@ from chatchat.server.utils import (
get_default_embedding, get_default_embedding,
) )
from chatchat.utils import build_logger from chatchat.utils import build_logger
from typing import List, Dict,Tuple
logger = build_logger() logger = build_logger()
@ -72,15 +71,8 @@ def search_docs(
if kb is not None: if kb is not None:
if query: if query:
docs = kb.search_docs(query, top_k, score_threshold) docs = kb.search_docs(query, top_k, score_threshold)
if docs is not None: logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold}")
logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold},len(docs):{len(docs)}") # data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
docs_key = kb.search_content_internal(query,2)
if docs_key is not None:
logger.info(f"before merge_and_deduplicate ,len(docs_key):{len(docs_key)}")
docs = merge_and_deduplicate(docs, docs_key)
if docs is not None:
logger.info(f"after merge_and_deduplicate len(docs):{len(docs)}")
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs] data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
elif file_name or metadata: elif file_name or metadata:
data = kb.list_docs(file_name=file_name, metadata=metadata) data = kb.list_docs(file_name=file_name, metadata=metadata)
@ -89,20 +81,6 @@ def search_docs(
del d.metadata["vector"] del d.metadata["vector"]
return [x.dict() for x in data] return [x.dict() for x in data]
def merge_and_deduplicate(list1: List[Document], list2: List[Document]) -> List[Document]:
# 使用字典存储唯一的 Document
merged_dict = {}
if list1 is not None:
merged_dict = {doc.page_content: doc for doc in list1}
# 遍历 list2将新的 Document 添加到字典
if list2 is not None:
for doc in list2:
if doc.page_content not in merged_dict:
merged_dict[doc.page_content] = doc
# 返回去重后的列表
return list(merged_dict.values())
def list_files(knowledge_base_name: str) -> ListResponse: def list_files(knowledge_base_name: str) -> ListResponse:
if not validate_kb_name(knowledge_base_name): if not validate_kb_name(knowledge_base_name):

View File

@ -210,12 +210,6 @@ class KBService(ABC):
docs = self.do_search(query, top_k, score_threshold) docs = self.do_search(query, top_k, score_threshold)
return docs return docs
def search_content_internal(self,
query: str,
top_k: int,
)->List[Document]:
docs = self.searchbyContentInternal(query,top_k)
return docs
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
return [] return []
@ -325,16 +319,6 @@ class KBService(ABC):
""" """
pass pass
@abstractmethod
def searchbyContentInternal(self,
query: str,
top_k: int,
)->List[Tuple[Document, float]]:
"""
搜索知识库子类实自己逻辑
"""
pass
@abstractmethod @abstractmethod
def do_add_doc( def do_add_doc(
self, self,

View File

@ -16,7 +16,7 @@ from chatchat.server.knowledge_base.kb_service.base import KBService, SupportedV
from chatchat.server.knowledge_base.utils import KnowledgeFile from chatchat.server.knowledge_base.utils import KnowledgeFile
from chatchat.server.utils import get_Embeddings from chatchat.server.utils import get_Embeddings
from chatchat.utils import build_logger from chatchat.utils import build_logger
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
logger = build_logger() logger = build_logger()
@ -37,12 +37,9 @@ class ESKBService(KBService):
self.client_cert = kb_config.get("client_cert", None) self.client_cert = kb_config.get("client_cert", None)
self.dims_length = kb_config.get("dims_length", None) self.dims_length = kb_config.get("dims_length", None)
self.embeddings_model = get_Embeddings(self.embed_model) self.embeddings_model = get_Embeddings(self.embed_model)
logger.info(f"self.kb_path:{self.kb_path },self.index_name:{self.index_name}, self.scheme:{self.scheme},self.IP:{self.IP},"
f"self.PORT:{self.PORT},self.user:{self.user},self.password:{self.password},self.verify_certs:{self.verify_certs},"
f"self.client_cert:{self.client_cert},self.client_key:{self.client_key},self.dims_length:{self.dims_length}")
try: try:
connection_info = dict( connection_info = dict(
hosts=f"{self.scheme}://{self.IP}:{self.PORT}" host=f"{self.scheme}://{self.IP}:{self.PORT}"
) )
if self.user != "" and self.password != "": if self.user != "" and self.password != "":
connection_info.update(basic_auth=(self.user, self.password)) connection_info.update(basic_auth=(self.user, self.password))
@ -56,9 +53,7 @@ class ESKBService(KBService):
connection_info.update(client_key=self.client_key) connection_info.update(client_key=self.client_key)
connection_info.update(client_cert=self.client_cert) connection_info.update(client_cert=self.client_cert)
# ES python客户端连接仅连接 # ES python客户端连接仅连接
logger.info(f"connection_info:{connection_info}")
self.es_client_python = Elasticsearch(**connection_info) self.es_client_python = Elasticsearch(**connection_info)
logger.info(f"after Elasticsearch connection_info:{connection_info}")
except ConnectionError: except ConnectionError:
logger.error("连接到 Elasticsearch 失败!") logger.error("连接到 Elasticsearch 失败!")
raise ConnectionError raise ConnectionError
@ -89,10 +84,9 @@ class ESKBService(KBService):
es_url=f"{self.scheme}://{self.IP}:{self.PORT}", es_url=f"{self.scheme}://{self.IP}:{self.PORT}",
index_name=self.index_name, index_name=self.index_name,
query_field="context", query_field="context",
distance_strategy="COSINE",
vector_query_field="dense_vector", vector_query_field="dense_vector",
embedding=self.embeddings_model, embedding=self.embeddings_model,
# strategy=ApproxRetrievalStrategy(), strategy=ApproxRetrievalStrategy(),
es_params={ es_params={
"timeout": 60, "timeout": 60,
}, },
@ -107,7 +101,6 @@ class ESKBService(KBService):
params["es_params"].update(client_key=self.client_key) params["es_params"].update(client_key=self.client_key)
params["es_params"].update(client_cert=self.client_cert) params["es_params"].update(client_cert=self.client_cert)
self.db = ElasticsearchStore(**params) self.db = ElasticsearchStore(**params)
logger.info(f"after ElasticsearchStore create params:{params}")
except ConnectionError: except ConnectionError:
logger.error("### 初始化 Elasticsearch 失败!") logger.error("### 初始化 Elasticsearch 失败!")
raise ConnectionError raise ConnectionError
@ -140,72 +133,16 @@ class ESKBService(KBService):
def vs_type(self) -> str: def vs_type(self) -> str:
return SupportedVSType.ES return SupportedVSType.ES
def do_search(self, query: str, top_k: int, score_threshold: float)->List[Document]: def do_search(self, query: str, top_k: int, score_threshold: float):
# 确保 ElasticsearchStore 正确初始化
if not hasattr(self, "db") or self.db is None:
raise ValueError("ElasticsearchStore (db) not initialized.")
# 文本相似性检索 # 文本相似性检索
retriever = get_Retriever("vectorstore").from_vectorstore( retriever = get_Retriever("vectorstore").from_vectorstore(
self.db, self.db,
top_k=top_k, top_k=top_k,
score_threshold=score_threshold, score_threshold=score_threshold,
) )
docs = retriever.get_relevant_documents(query) docs = retriever.get_relevant_documents(query)
return docs return docs
def searchbyContent(self, query:str, top_k: int = 2):
if self.es_client_python.indices.exists(index=self.index_name):
logger.info(f"******ESKBService searchByContent {self.index_name},query:{query}")
tem_query = {
"query": {"match": {
"context": "*" + query + "*"
}},
"highlight":{"fields":{
"context":{}
}}
}
search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k)
hits = [hit for hit in search_results["hits"]["hits"]]
docs_and_scores = []
for hit in hits:
highlighted_contexts = ""
if 'highlight' in hit:
highlighted_contexts = " ".join(hit['highlight']['context'])
#print(f"******searchByContent highlighted_contexts:{highlighted_contexts}")
docs_and_scores.append(DocumentWithVSId(
page_content=highlighted_contexts,
metadata=hit["_source"]["metadata"],
id = hit["_id"],
))
return docs_and_scores
def searchbyContentInternal(self, query:str, top_k: int = 2):
if self.es_client_python.indices.exists(index=self.index_name):
logger.info(f"******ESKBService searchbyContentInternal {self.index_name},query:{query}")
tem_query = {
"query": {"match": {
"context": "*" + query + "*"
}}
}
search_results = self.es_client_python.search(index=self.index_name, body=tem_query, size=top_k)
hits = [hit for hit in search_results["hits"]["hits"]]
docs_and_scores = [
# (
Document(
page_content=hit["_source"]["context"],
metadata=hit["_source"]["metadata"],
)
# ,
# 1.3,
# )
for hit in hits
]
# logger.info(f"docs_and_scores:{docs_and_scores}")
return docs_and_scores
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: def get_doc_by_ids(self, ids: List[str]) -> List[Document]:
results = [] results = []
for doc_id in ids: for doc_id in ids:
@ -242,13 +179,10 @@ class ESKBService(KBService):
}, },
"track_total_hits": True, "track_total_hits": True,
} }
print(f"***do_delete_doc: kb_file.filepath:{kb_file.filepath}, kb.filename:{kb_file.filename}") # 注意设置size默认返回10个es检索设置track_total_hits为True返回数据库中真实的size。
print(f"***do_delete_doc: kb.filename:{kb_file.filename}") size = self.es_client_python.search(body=query)["hits"]["total"]["value"]
# 注意设置size默认返回10个。 search_results = self.es_client_python.search(body=query, size=size)
search_results = self.es_client_python.search(index=self.index_name, body=query,size=200) delete_list = [hit["_id"] for hit in search_results["hits"]["hits"]]
delete_list = [hit["_id"] for hit in search_results['hits']['hits']]
size = len(delete_list)
#print(f"***do_delete_doc: 删除的size:{size}, {delete_list}")
if len(delete_list) == 0: if len(delete_list) == 0:
return None return None
else: else:
@ -278,34 +212,20 @@ class ESKBService(KBService):
if self.es_client_python.indices.exists(index=self.index_name): if self.es_client_python.indices.exists(index=self.index_name):
file_path = docs[0].metadata.get("source") file_path = docs[0].metadata.get("source")
print(f"****************do_add_doc, file_path:{file_path}") query = {
# enhanced by weiweiwang 2025/2/24 to specific index name
# query = {
# "query": {
# "term": {"metadata.source.keyword": file_path},
# # "term": {"_index": self.index_name},
# }
# }
query = {
"query": { "query": {
"bool": { "term": {"metadata.source.keyword": file_path},
"must": [ "term": {"_index": self.index_name},
{ "term": { "metadata.source.keyword": file_path } },
{ "term": { "_index": self.index_name } }
]
}
} }
} }
# 注意设置size默认返回10个。 # 注意设置size默认返回10个。
search_results = self.es_client_python.search(body=query, size=200) search_results = self.es_client_python.search(body=query, size=50)
if len(search_results["hits"]["hits"]) == 0: if len(search_results["hits"]["hits"]) == 0:
raise ValueError("召回元素个数为0") raise ValueError("召回元素个数为0")
info_docs = [ info_docs = [
{"id": hit["_id"], "metadata": hit["_source"]["metadata"]} {"id": hit["_id"], "metadata": hit["_source"]["metadata"]}
for hit in search_results["hits"]["hits"] for hit in search_results["hits"]["hits"]
] ]
# size = len(info_docs)
# print(f"do_add_doc 召回元素个数:{size}")
return info_docs return info_docs
def do_clear_vs(self): def do_clear_vs(self):

View File

@ -78,12 +78,6 @@ class FaissKBService(KBService):
docs = retriever.get_relevant_documents(query) docs = retriever.get_relevant_documents(query)
return docs return docs
def searchbyContent(self, query:str, top_k: int = 2):
pass
def searchbyContentInternal(self, query:str, top_k: int = 2):
return None
def do_add_doc( def do_add_doc(
self, self,
docs: List[Document], docs: List[Document],

View File

@ -88,12 +88,6 @@ class MilvusKBService(KBService):
docs = retriever.get_relevant_documents(query) docs = retriever.get_relevant_documents(query)
return docs return docs
def searchbyContent(self, query:str, top_k: int = 2):
pass
def searchbyContentInternal(self, query:str, top_k: int = 2):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs: for doc in docs:
for k, v in doc.metadata.items(): for k, v in doc.metadata.items():

View File

@ -84,12 +84,6 @@ class PGKBService(KBService):
docs = retriever.get_relevant_documents(query) docs = retriever.get_relevant_documents(query)
return docs return docs
def searchbyContent(self, query:str, top_k: int = 2):
pass
def searchbyContentInternal(self, query:str, top_k: int = 2):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
ids = self.pg_vector.add_documents(docs) ids = self.pg_vector.add_documents(docs)
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]

View File

@ -94,13 +94,6 @@ class RelytKBService(KBService):
docs = self.relyt.similarity_search_with_score(query, top_k) docs = self.relyt.similarity_search_with_score(query, top_k)
return score_threshold_process(score_threshold, top_k, docs) return score_threshold_process(score_threshold, top_k, docs)
def searchbyContent(self, query:str, top_k: int = 2):
pass
def searchbyContentInternal(self, query:str, top_k: int = 2):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
print(docs) print(docs)
ids = self.relyt.add_documents(docs) ids = self.relyt.add_documents(docs)

View File

@ -79,12 +79,6 @@ class ZillizKBService(KBService):
docs = retriever.get_relevant_documents(query) docs = retriever.get_relevant_documents(query)
return docs return docs
def searchbyContent(self, query:str, top_k: int = 2):
pass
def searchbyContentInternal(self, query:str, top_k: int = 2):
pass
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]:
for doc in docs: for doc in docs:
for k, v in doc.metadata.items(): for k, v in doc.metadata.items():

View File

@ -70,9 +70,6 @@ def list_files_from_folder(kb_name: str):
for x in ["temp", "tmp", ".", "~$"]: for x in ["temp", "tmp", ".", "~$"]:
if tail.startswith(x): if tail.startswith(x):
return True return True
if "_source.txt" in tail.lower() or "_split.txt" in tail.lower():
return True
return False return False
def process_entry(entry): def process_entry(entry):
@ -425,15 +422,15 @@ class KnowledgeFile:
docs = zh_first_title_enhance(docs) docs = zh_first_title_enhance(docs)
docs = customize_zh_title_enhance(docs) docs = customize_zh_title_enhance(docs)
i = 1 # i = 1
outputfile = file_name_without_extension + "_split.txt" # outputfile = file_name_without_extension + "_split.txt"
# 打开文件以写入模式 # # 打开文件以写入模式
with open(outputfile, 'w') as file: # with open(outputfile, 'w') as file:
for doc in docs: # for doc in docs:
#print(f"**********切分段{i}{doc}") # #print(f"**********切分段{i}{doc}")
file.write(f"\n**********切分段{i}") # file.write(f"\n**********切分段{i}")
file.write(doc.page_content) # file.write(doc.page_content)
i = i+1 # i = i+1
self.splited_docs = docs self.splited_docs = docs
return self.splited_docs return self.splited_docs

View File

@ -488,7 +488,7 @@ class ToolSettings(BaseFileSettings):
search_internet: dict = { search_internet: dict = {
"use": False, "use": False,
"search_engine_name": "zhipu_search", "search_engine_name": "duckduckgo",
"search_engine_config": { "search_engine_config": {
"bing": { "bing": {
"bing_search_url": "https://api.bing.microsoft.com/v7.0/search", "bing_search_url": "https://api.bing.microsoft.com/v7.0/search",
@ -506,21 +506,11 @@ class ToolSettings(BaseFileSettings):
"engines": [], "engines": [],
"categories": [], "categories": [],
"language": "zh-CN", "language": "zh-CN",
},
"tavily":{
"tavily_api_key": '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',
"include_answer": True,
"search_depth": "advanced",
"include_raw_content": True,
"max_results": 1
},
"zhipu_search":{
"zhipu_api_key": ""
} }
}, },
"top_k": 1, "top_k": 5,
"verbose": "Origin", "verbose": "Origin",
"conclude_prompt": "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题,不得包含有重复的词汇或句子。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 " "conclude_prompt": "<指令>这是搜索到的互联网信息,请你根据这些信息进行提取并有调理,简洁的回答问题。如果无法从中得到答案,请说 “无法搜索到能回答问题的内容”。 "
"</指令>\n<已知信息>{{ context }}</已知信息>\n" "</指令>\n<已知信息>{{ context }}</已知信息>\n"
"<问题>\n" "<问题>\n"
"{{ question }}\n" "{{ question }}\n"
@ -660,7 +650,7 @@ class PromptSettings(BaseFileSettings):
rag: dict = { rag: dict = {
"default": ( "default": (
"【指令】根据已知信息,简洁和专业的来回答问题,不得包含有重复的词汇或句子" "【指令】根据已知信息,简洁和专业的来回答问题"
"如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。\n\n" "如果无法从中得到答案,请说 “根据已知信息无法回答该问题”,不允许在答案中添加编造成分,答案请使用中文。\n\n"
"【已知信息】{{context}}\n\n" "【已知信息】{{context}}\n\n"
"【问题】{{question}}\n" "【问题】{{question}}\n"
@ -751,8 +741,6 @@ class PromptSettings(BaseFileSettings):
"Begin!\n\n" "Begin!\n\n"
"Question: {input}\n\n" "Question: {input}\n\n"
"{agent_scratchpad}\n\n" "{agent_scratchpad}\n\n"
"Important: After the last Observation, you must always add a Final Answer "
"summarizing the result. Do not skip this step."
), ),
"structured-chat-agent": ( "structured-chat-agent": (
"Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n" "Respond to the human as helpfully and accurately as possible. You have access to the following tools:\n\n"

View File

@ -46,7 +46,7 @@ if __name__ == "__main__":
with st.sidebar: with st.sidebar:
st.image( st.image(
get_img_base64("logo-long-chatchat-trans-v2.png"), use_column_width=True get_img_base64("logo-long-chatchat-trans-v2.png"), use_container_width=True
) )
st.caption( st.caption(
f"""<p align="right">当前版本:{__version__}</p>""", f"""<p align="right">当前版本:{__version__}</p>""",

View File

@ -238,7 +238,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
doc_details, doc_details,
{ {
("No", "序号"): {}, ("No", "序号"): {},
("file_name", "文档名称"): {"filter": "agTextColumnFilter"}, ("file_name", "文档名称"): {},
# ("file_ext", "文档类型"): {}, # ("file_ext", "文档类型"): {},
# ("file_version", "文档版本"): {}, # ("file_version", "文档版本"): {},
("document_loader", "文档加载器"): {}, ("document_loader", "文档加载器"): {},
@ -398,7 +398,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
cellEditor="agLargeTextCellEditor", cellEditor="agLargeTextCellEditor",
cellEditorPopup=True, cellEditorPopup=True,
autoWidth=True, autoWidth=True,
cellEditorParams= { "maxLength": 1500} cellEditorParams= { "maxLength": 1000}
) )
gb.configure_column( gb.configure_column(
"to_del", "to_del",
@ -406,8 +406,8 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
editable=True, editable=True,
width=50, width=50,
wrapHeaderText=True, wrapHeaderText=True,
cellEditor="agTextCellEditor", cellEditor="agCheckboxCellEditor",
cellRender="agTextCellRenderer", cellRender="agCheckboxCellRenderer",
) )
# 启用分页 # 启用分页
gb.configure_pagination( gb.configure_pagination(
@ -428,15 +428,15 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
changed_docs = [] changed_docs = []
for index, row in edit_docs.data.iterrows(): for index, row in edit_docs.data.iterrows():
origin_doc = origin_docs[row["id"]] origin_doc = origin_docs[row["id"]]
# if row["page_content"] != origin_doc["page_content"]: if row["page_content"] != origin_doc["page_content"]:
if row["to_del"] not in ["Y", "y", 1]: if row["to_del"] not in ["Y", "y", 1]:
changed_docs.append( changed_docs.append(
{ {
"page_content": row["page_content"], "page_content": row["page_content"],
"type": row["type"], "type": row["type"],
"metadata": json.loads(row["metadata"]), "metadata": json.loads(row["metadata"]),
} }
) )
if changed_docs: if changed_docs:
if api.update_kb_docs( if api.update_kb_docs(