Dev (#1822)
* 支持了agentlm * 支持了agentlm和相关提示词 * 修改了Agent的一些功能,加入了Embed方面的一个优化 --------- Co-authored-by: zR <zRzRzRzRzRzRzR>
This commit is contained in:
parent
1c5382d96b
commit
1b50547e60
|
|
@ -116,3 +116,6 @@ text_splitter_dict = {
|
||||||
|
|
||||||
# TEXT_SPLITTER 名称
|
# TEXT_SPLITTER 名称
|
||||||
TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter"
|
TEXT_SPLITTER_NAME = "ChineseRecursiveTextSplitter"
|
||||||
|
|
||||||
|
## Embedding模型定制词语的词表文件
|
||||||
|
EMBEDDING_KEYWORD_FILE = "embedding_keywords.txt"
|
||||||
|
|
@ -0,0 +1,55 @@
|
||||||
|
'''
|
||||||
|
该功能是为了将关键词加入到embedding模型中,以便于在embedding模型中进行关键词的embedding
|
||||||
|
该功能的实现是通过修改embedding模型的tokenizer来实现的
|
||||||
|
该功能仅仅对EMBEDDING_MODEL参数对应的的模型有效,输出后的模型保存在原本模型
|
||||||
|
该功能的Idea由社区贡献,感谢@CharlesJu1
|
||||||
|
|
||||||
|
保存的模型的位置位于原本嵌入模型的目录下,模型的名称为原模型名称+Merge_Keywords_时间戳
|
||||||
|
'''
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append("..")
|
||||||
|
import os
|
||||||
|
from safetensors.torch import save_model
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from datetime import datetime
|
||||||
|
from configs import (
|
||||||
|
MODEL_PATH,
|
||||||
|
EMBEDDING_MODEL,
|
||||||
|
EMBEDDING_KEYWORD_FILE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def add_keyword_to_model(model_name: str = EMBEDDING_MODEL, keyword_file: str = "", output_model_path: str = None):
|
||||||
|
key_words = []
|
||||||
|
with open(keyword_file, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
key_words.append(line.strip())
|
||||||
|
|
||||||
|
model = SentenceTransformer(model_name)
|
||||||
|
word_embedding_model = model._first_module()
|
||||||
|
tokenizer = word_embedding_model.tokenizer
|
||||||
|
tokenizer.add_tokens(key_words)
|
||||||
|
word_embedding_model.auto_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
|
||||||
|
|
||||||
|
if output_model_path:
|
||||||
|
os.makedirs(output_model_path, exist_ok=True)
|
||||||
|
tokenizer.save_pretrained(output_model_path)
|
||||||
|
model.save(output_model_path)
|
||||||
|
safetensors_file = os.path.join(output_model_path, "model.safetensors")
|
||||||
|
metadata = {'format': 'pt'}
|
||||||
|
save_model(model, safetensors_file, metadata)
|
||||||
|
|
||||||
|
def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE):
|
||||||
|
keyword_file = os.path.join(path)
|
||||||
|
model_name = MODEL_PATH["embed_model"][EMBEDDING_MODEL]
|
||||||
|
model_parent_directory = os.path.dirname(model_name)
|
||||||
|
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
|
||||||
|
output_model_path = os.path.join(model_parent_directory, output_model_name)
|
||||||
|
add_keyword_to_model(model_name, keyword_file, output_model_path)
|
||||||
|
print("save model to {}".format(output_model_path))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
add_keyword_to_embedding_model(EMBEDDING_KEYWORD_FILE)
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
Langchain-Chatchat
|
||||||
|
数据科学与大数据技术
|
||||||
|
人工智能与先进计算
|
||||||
|
|
@ -37,7 +37,7 @@ class CustomOutputParser(AgentOutputParser):
|
||||||
|
|
||||||
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
def parse(self, llm_output: str) -> AgentFinish | tuple[dict[str, str], str] | AgentAction:
|
||||||
# Check if agent should finish
|
# Check if agent should finish
|
||||||
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
|
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型
|
||||||
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
|
if not any(agent in model_container.MODEL for agent in support_agent) and self.begin:
|
||||||
self.begin = False
|
self.begin = False
|
||||||
stop_words = ["Observation:"]
|
stop_words = ["Observation:"]
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,3 @@
|
||||||
## 单独运行的时候需要添加
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
|
||||||
|
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.chains import LLMMathChain
|
from langchain.chains import LLMMathChain
|
||||||
from server.agent import model_container
|
from server.agent import model_container
|
||||||
|
|
@ -19,7 +14,7 @@ ${{运行代码的输出}}
|
||||||
```
|
```
|
||||||
答案: ${{答案}}
|
答案: ${{答案}}
|
||||||
|
|
||||||
这是两个例子:
|
这是两个例子:
|
||||||
|
|
||||||
问题: 37593 * 67是多少?
|
问题: 37593 * 67是多少?
|
||||||
```text
|
```text
|
||||||
|
|
@ -56,6 +51,7 @@ ${{运行代码的输出}}
|
||||||
现在,这是我的问题:
|
现在,这是我的问题:
|
||||||
问题: {question}
|
问题: {question}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
PROMPT = PromptTemplate(
|
PROMPT = PromptTemplate(
|
||||||
input_variables=["question"],
|
input_variables=["question"],
|
||||||
template=_PROMPT_TEMPLATE,
|
template=_PROMPT_TEMPLATE,
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,3 @@
|
||||||
## 单独运行的时候需要添加
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,3 @@
|
||||||
## 单独运行的时候需要添加
|
|
||||||
# import sys
|
|
||||||
# import os
|
|
||||||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,3 @@
|
||||||
## 单独运行的时候需要添加
|
|
||||||
# import sys
|
|
||||||
# import os
|
|
||||||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from server.chat import search_engine_chat
|
from server.chat import search_engine_chat
|
||||||
from configs import VECTOR_SEARCH_TOP_K
|
from configs import VECTOR_SEARCH_TOP_K
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,3 @@
|
||||||
## 最简单的版本,只支持固定的知识库
|
|
||||||
|
|
||||||
# ## 单独运行的时候需要添加
|
|
||||||
# import sys
|
|
||||||
# import os
|
|
||||||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
|
||||||
|
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||||
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD
|
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD
|
||||||
import json
|
import json
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,3 @@
|
||||||
## 单独运行的时候需要添加
|
|
||||||
# import sys
|
|
||||||
# import os
|
|
||||||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
|
||||||
|
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts import PromptTemplate
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from server.agent import model_container
|
from server.agent import model_container
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,4 @@
|
||||||
## 使用和风天气API查询天气,这个模型仅仅对免费的API进行了适配
|
## 使用和风天气API查询天气,这个模型仅仅对免费的API进行了适配,建议使用GPT4等高级模型进行适配
|
||||||
## 这个模型的提示词非常复杂,我们推荐使用GPT4模型进行运行
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
## 单独运行的时候需要添加
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
@ -30,8 +22,6 @@ from server.agent import model_container
|
||||||
KEY = "ac880e5a877042809ac7ffdd19d95b0d"
|
KEY = "ac880e5a877042809ac7ffdd19d95b0d"
|
||||||
#key长这样,这里提供了示例的key,这个key没法使用,你需要自己去注册和风天气的账号,然后在这里填入你的key
|
#key长这样,这里提供了示例的key,这个key没法使用,你需要自己去注册和风天气的账号,然后在这里填入你的key
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
_PROMPT_TEMPLATE = """
|
_PROMPT_TEMPLATE = """
|
||||||
用户会提出一个关于天气的问题,你的目标是拆分出用户问题中的区,市 并按照我提供的工具回答。
|
用户会提出一个关于天气的问题,你的目标是拆分出用户问题中的区,市 并按照我提供的工具回答。
|
||||||
例如 用户提出的问题是: 上海浦东未来1小时天气情况?
|
例如 用户提出的问题是: 上海浦东未来1小时天气情况?
|
||||||
|
|
|
||||||
|
|
@ -58,12 +58,12 @@ tools = [
|
||||||
),
|
),
|
||||||
Tool.from_function(
|
Tool.from_function(
|
||||||
func=knowledge_search_more,
|
func=knowledge_search_more,
|
||||||
name="Knowledge Base Query Tool",
|
name="Knowledge Base Tool",
|
||||||
description="Prioritize accessing the knowledge base to get answers"
|
description="Prioritize accessing the knowledge base to get answers"
|
||||||
),
|
),
|
||||||
Tool.from_function(
|
Tool.from_function(
|
||||||
func=search_internet,
|
func=search_internet,
|
||||||
name="Internet Query Tool",
|
name="Internet Tool",
|
||||||
description="If you can't access the internet, this tool can help you access Bing to answer questions"
|
description="If you can't access the internet, this tool can help you access Bing to answer questions"
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from langchain.agents import AgentExecutor, LLMSingleActionAgent
|
||||||
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN
|
from configs import LLM_MODEL, TEMPERATURE, HISTORY_LEN,Agent_MODEL
|
||||||
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
from typing import AsyncIterable, Optional, Dict
|
from typing import AsyncIterable, Optional, Dict
|
||||||
|
|
@ -49,7 +49,19 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
||||||
## 传入全局变量来实现agent调用
|
## 传入全局变量来实现agent调用
|
||||||
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
||||||
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
|
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
|
||||||
model_container.MODEL = model
|
|
||||||
|
|
||||||
|
if Agent_MODEL:
|
||||||
|
## 如果有指定使用Agent模型来完成任务
|
||||||
|
model_agent = get_ChatOpenAI(
|
||||||
|
model_name=Agent_MODEL,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
callbacks=[callback],
|
||||||
|
)
|
||||||
|
model_container.MODEL = model_agent
|
||||||
|
else:
|
||||||
|
model_container.MODEL = model
|
||||||
|
|
||||||
prompt_template = get_prompt_template("agent_chat", prompt_name)
|
prompt_template = get_prompt_template("agent_chat", prompt_name)
|
||||||
prompt_template_agent = CustomPromptTemplate(
|
prompt_template_agent = CustomPromptTemplate(
|
||||||
|
|
|
||||||
|
|
@ -224,7 +224,7 @@ def dialogue_page(api: ApiRequest):
|
||||||
])
|
])
|
||||||
text = ""
|
text = ""
|
||||||
ans = ""
|
ans = ""
|
||||||
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
|
support_agent = ["Azure-OpenAI", "OpenAI", "Anthropic", "Qwen", "qwen-api", "baichuan-api","agentlm"] # 目前支持agent的模型
|
||||||
if not any(agent in llm_model for agent in support_agent):
|
if not any(agent in llm_model for agent in support_agent):
|
||||||
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!</span>\n\n\n"
|
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!</span>\n\n\n"
|
||||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
chat_box.update_msg(ans, element_index=0, streaming=False)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue