修改Agent的内容
This commit is contained in:
zR 2023-10-27 22:53:43 +08:00 committed by GitHub
parent bb72d9ac26
commit aa7c580974
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 75 additions and 74 deletions

View File

@ -1,11 +1,11 @@
## 导入所有的工具类 ## 导入所有的工具类
from .search_knowledge_simple import knowledge_search_simple from .search_knowledge_simple import knowledge_search_simple
from .search_all_knowledge_once import knowledge_search_once from .search_all_knowledge_once import knowledge_search_once, KnowledgeSearchInput
from .search_all_knowledge_more import knowledge_search_more from .search_all_knowledge_more import knowledge_search_more, KnowledgeSearchInput
from .calculate import calculate from .calculate import calculate, CalculatorInput
from .translator import translate from .translator import translate, TranslateInput
from .weather import weathercheck from .weather import weathercheck, WhetherSchema
from .shell import shell from .shell import shell, ShellInput
from .search_internet import search_internet from .search_internet import search_internet, SearchInternetInput
from .wolfram import wolfram from .wolfram import wolfram, WolframInput
from .youtube import youtube_search from .youtube import youtube_search, YoutubeInput

View File

@ -1,6 +1,7 @@
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.chains import LLMMathChain from langchain.chains import LLMMathChain
from server.agent import model_container from server.agent import model_container
from pydantic import BaseModel, Field
_PROMPT_TEMPLATE = """ _PROMPT_TEMPLATE = """
将数学问题翻译成可以使用Python的numexpr库执行的表达式使用运行此代码的输出来回答问题 将数学问题翻译成可以使用Python的numexpr库执行的表达式使用运行此代码的输出来回答问题
@ -58,6 +59,9 @@ PROMPT = PromptTemplate(
) )
class CalculatorInput(BaseModel):
query: str = Field()
def calculate(query: str): def calculate(query: str):
model = model_container.MODEL model = model_container.MODEL
llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT) llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT)
@ -68,3 +72,5 @@ if __name__ == "__main__":
result = calculate("2的三次方") result = calculate("2的三次方")
print("答案:",result) print("答案:",result)

View File

@ -14,7 +14,7 @@ from server.chat.knowledge_base_chat import knowledge_base_chat
from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
import asyncio import asyncio
from server.agent import model_container from server.agent import model_container
from pydantic import BaseModel, Field
async def search_knowledge_base_iter(database: str, query: str) -> str: async def search_knowledge_base_iter(database: str, query: str) -> str:
response = await knowledge_base_chat(query=query, response = await knowledge_base_chat(query=query,
@ -264,6 +264,8 @@ def knowledge_search_more(query: str):
ans = llm_knowledge.run(query) ans = llm_knowledge.run(query)
return ans return ans
class KnowledgeSearchInput(BaseModel):
location: str = Field(description="知识库查询的内容")
if __name__ == "__main__": if __name__ == "__main__":
result = knowledge_search_more("机器人和大数据在代码教学上有什么区别") result = knowledge_search_more("机器人和大数据在代码教学上有什么区别")

View File

@ -23,7 +23,7 @@ from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS
import asyncio import asyncio
from server.agent import model_container from server.agent import model_container
from pydantic import BaseModel, Field
async def search_knowledge_base_iter(database: str, query: str): async def search_knowledge_base_iter(database: str, query: str):
response = await knowledge_base_chat(query=query, response = await knowledge_base_chat(query=query,
@ -225,6 +225,10 @@ def knowledge_search_once(query: str):
return ans return ans
class KnowledgeSearchInput(BaseModel):
location: str = Field(description="知识库查询的内容")
if __name__ == "__main__": if __name__ == "__main__":
result = knowledge_search_once("大数据的男女比例") result = knowledge_search_once("大数据的男女比例")
print(result) print(result)

View File

@ -3,6 +3,7 @@ from server.chat.search_engine_chat import search_engine_chat
from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS
import asyncio import asyncio
from server.agent import model_container from server.agent import model_container
from pydantic import BaseModel, Field
async def search_engine_iter(query: str): async def search_engine_iter(query: str):
response = await search_engine_chat(query=query, response = await search_engine_chat(query=query,
@ -25,9 +26,11 @@ async def search_engine_iter(query: str):
return contents return contents
def search_internet(query: str): def search_internet(query: str):
return asyncio.run(search_engine_iter(query)) return asyncio.run(search_engine_iter(query))
class SearchInternetInput(BaseModel):
location: str = Field(description="需要查询的内容")
if __name__ == "__main__": if __name__ == "__main__":
result = search_internet("今天星期几") result = search_internet("今天星期几")

View File

@ -1,6 +1,9 @@
# LangChain 的 Shell 工具 # LangChain 的 Shell 工具
from pydantic import BaseModel, Field
from langchain.tools import ShellTool from langchain.tools import ShellTool
def shell(query: str): def shell(query: str):
tool = ShellTool() tool = ShellTool()
return tool.run(tool_input=query) return tool.run(tool_input=query)
class ShellInput(BaseModel):
query: str = Field(description="一个能在Linux命令行运行的Shell命令")

View File

@ -1,6 +1,7 @@
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain from langchain.chains import LLMChain
from server.agent import model_container from server.agent import model_container
from pydantic import BaseModel, Field
_PROMPT_TEMPLATE = ''' _PROMPT_TEMPLATE = '''
# 指令 # 指令
@ -29,6 +30,9 @@ def translate(query: str):
ans = llm_translate.run(query) ans = llm_translate.run(query)
return ans return ans
class TranslateInput(BaseModel):
location: str = Field(description="需要被翻译的内容")
if __name__ == "__main__": if __name__ == "__main__":
result = translate("Can Love remember the question and the answer? 这句话如何诗意的翻译成中文") result = translate("Can Love remember the question and the answer? 这句话如何诗意的翻译成中文")
print("答案:",result) print("答案:",result)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
## 单独运行的时候需要添加 ## 单独运行的时候需要添加
import sys import sys
import os import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
import re import re
@ -23,11 +24,11 @@ from typing import List, Any, Optional
from datetime import datetime from datetime import datetime
from langchain.prompts import PromptTemplate from langchain.prompts import PromptTemplate
from server.agent import model_container from server.agent import model_container
from pydantic import BaseModel, Field
## 使用和风天气API查询天气 ## 使用和风天气API查询天气
KEY = "ac880e5a877042809ac7ffdd19d95b0d" KEY = "ac880e5a877042809ac7ffdd19d95b0d"
#key长这样这里提供了示例的key这个key没法使用你需要自己去注册和风天气的账号然后在这里填入你的key # key长这样这里提供了示例的key这个key没法使用你需要自己去注册和风天气的账号然后在这里填入你的key
_PROMPT_TEMPLATE = """ _PROMPT_TEMPLATE = """
@ -95,7 +96,7 @@ def get_city_info(location, adm, key):
return data return data
def format_weather_data(data,place): def format_weather_data(data, place):
hourly_forecast = data['hourly'] hourly_forecast = data['hourly']
formatted_data = f"\n 这是查询到的关于{place}未来24小时的天气信息: \n" formatted_data = f"\n 这是查询到的关于{place}未来24小时的天气信息: \n"
for forecast in hourly_forecast: for forecast in hourly_forecast:
@ -141,7 +142,7 @@ def format_weather_data(data,place):
return formatted_data return formatted_data
def get_weather(key, location_id,place): def get_weather(key, location_id, place):
url = "https://devapi.qweather.com/v7/weather/24h?" url = "https://devapi.qweather.com/v7/weather/24h?"
params = { params = {
'location': location_id, 'location': location_id,
@ -149,18 +150,20 @@ def get_weather(key, location_id,place):
} }
response = requests.get(url, params=params) response = requests.get(url, params=params)
data = response.json() data = response.json()
return format_weather_data(data,place) return format_weather_data(data, place)
def split_query(query): def split_query(query):
parts = query.split() parts = query.split()
adm = parts[0] adm = parts[0]
if len(parts) == 1:
return adm, adm
location = parts[1] if parts[1] != 'None' else adm location = parts[1] if parts[1] != 'None' else adm
return location, adm return location, adm
def weather(query): def weather(query):
location, adm= split_query(query) location, adm = split_query(query)
key = KEY key = KEY
if key == "": if key == "":
return "请先在代码中填入和风天气API Key" return "请先在代码中填入和风天气API Key"
@ -169,17 +172,19 @@ def weather(query):
location_id = city_info['location'][0]['id'] location_id = city_info['location'][0]['id']
place = adm + "" + location + "" place = adm + "" + location + ""
weather_data = get_weather(key=key, location_id=location_id,place=place) weather_data = get_weather(key=key, location_id=location_id, place=place)
return weather_data + "以上是查询到的天气信息,请你查收\n" return weather_data + "以上是查询到的天气信息,请你查收\n"
except KeyError: except KeyError:
try: try:
city_info = get_city_info(location=adm, adm=adm, key=key) city_info = get_city_info(location=adm, adm=adm, key=key)
location_id = city_info['location'][0]['id'] location_id = city_info['location'][0]['id']
place = adm + "" place = adm + ""
weather_data = get_weather(key=key, location_id=location_id,place=place) weather_data = get_weather(key=key, location_id=location_id, place=place)
return weather_data + "重要提醒:用户提供的市和区中,区的信息不存在,或者出现错别字,因此该信息是关于市的天气,请你查收\n" return weather_data + "重要提醒:用户提供的市和区中,区的信息不存在,或者出现错别字,因此该信息是关于市的天气,请你查收\n"
except KeyError: except KeyError:
return "输入的地区不存在,无法提供天气预报" return "输入的地区不存在,无法提供天气预报"
class LLMWeatherChain(Chain): class LLMWeatherChain(Chain):
llm_chain: LLMChain llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None llm: Optional[BaseLanguageModel] = None
@ -319,12 +324,15 @@ class LLMWeatherChain(Chain):
return cls(llm_chain=llm_chain, **kwargs) return cls(llm_chain=llm_chain, **kwargs)
def weathercheck(query: str): def weathercheck(query: str):
model = model_container.MODEL model = model_container.MODEL
llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT) llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT)
ans = llm_weather.run(query) ans = llm_weather.run(query)
return ans return ans
class WhetherSchema(BaseModel):
location: str = Field(description="应该是一个地区的名称,用空格隔开,例如:上海 浦东,如果没有区的信息,可以只输入上海")
if __name__ == '__main__': if __name__ == '__main__':
result = weathercheck("苏州姑苏区今晚热不热?") result = weathercheck("苏州姑苏区今晚热不热?")

View File

@ -1,7 +1,11 @@
# Langchain 自带的 Wolfram Alpha API 封装 # Langchain 自带的 Wolfram Alpha API 封装
from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper
from pydantic import BaseModel, Field
wolfram_alpha_appid = "your key" wolfram_alpha_appid = "your key"
def wolfram(query: str): def wolfram(query: str):
wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid) wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid)
ans = wolfram.run(query) ans = wolfram.run(query)
return ans return ans
class WolframInput(BaseModel):
location: str = Field(description="需要运算的具体问题")

View File

@ -1,5 +1,9 @@
# Langchain 自带的 YouTube 搜索工具封装 # Langchain 自带的 YouTube 搜索工具封装
from langchain.tools import YouTubeSearchTool from langchain.tools import YouTubeSearchTool
from pydantic import BaseModel, Field
def youtube_search(query: str): def youtube_search(query: str):
tool = YouTubeSearchTool() tool = YouTubeSearchTool()
return tool.run(tool_input=query) return tool.run(tool_input=query)
class YoutubeInput(BaseModel):
location: str = Field(description="要搜索视频关键字")

View File

@ -1,60 +1,20 @@
from langchain.tools import Tool from langchain.tools import Tool
from server.agent.tools import * from server.agent.tools import *
## 请注意如果你是为了使用AgentLM在这里你应该使用英文版本下面的内容是英文版本。 ## 请注意如果你是为了使用AgentLM在这里你应该使用英文版本。
# tools = [
# Tool.from_function(
# func=calculate,
# name="Simple Calculator Tool",
# description="Perform simple mathematical operations, Just simple, Use Wolfram Math Tool for more complex operations"
# ),
# Tool.from_function(
# func=translate,
# name="Translation Tool",
# description="Use this tool if you can't access the internet and need to translate various languages"
# ),
# Tool.from_function(
# func=weathercheck,
# name="Weather Checking Tool",
# description="Check the weather for various places in China for the next 24 hours without needing internet access"
# ),
# Tool.from_function(
# func=shell,
# name="Shell Tool",
# description="Use command line tool output"
# ),
# Tool.from_function(
# func=knowledge_search_more,
# name="Knowledge Base Tool",
# description="Prioritize accessing the knowledge base to get answers"
# ),
# Tool.from_function(
# func=search_internet,
# name="Internet Tool",
# description="If you can't access the internet, this tool can help you access Bing to answer questions"
# ),
# Tool.from_function(
# func=wolfram,
# name="Wolfram Math Tool",
# description="Use this tool to perform more complex mathematical operations"
# ),
# Tool.from_function(
# func=youtube_search,
# name="Youtube Search Tool",
# description="Use this tool to search for videos on Youtube"
# )
# ]
tools = [ tools = [
Tool.from_function( Tool.from_function(
func=calculate, func=calculate,
name="计算器工具", name="计算器工具",
description="进行简单的数学运算, 只是简单的, 使用Wolfram数学工具进行更复杂的运算", description="进行简单的数学运算, 只是简单的, 使用Wolfram数学工具进行更复杂的运算",
args_schema=CalculatorInput,
), ),
Tool.from_function( Tool.from_function(
func=translate, func=translate,
name="翻译工具", name="翻译工具",
description="如果你无法访问互联网,并且需要翻译各种语言,应该使用这个工具" description="如果你无法访问互联网,并且需要翻译各种语言,应该使用这个工具",
args_schema=TranslateInput,
), ),
Tool.from_function( Tool.from_function(
func=weathercheck, func=weathercheck,
@ -79,13 +39,15 @@ tools = [
Tool.from_function( Tool.from_function(
func=wolfram, func=wolfram,
name="Wolfram数学工具", name="Wolfram数学工具",
description="高级的数学运算工具,能够完成非常复杂的数学问题" description="高级的数学运算工具,能够完成非常复杂的数学问题",
args_schema=WolframInput,
), ),
Tool.from_function( Tool.from_function(
func=youtube_search, func=youtube_search,
name="Youtube搜索工具", name="Youtube搜索工具",
description="使用这个工具在Youtube上搜索视频" description="使用这个工具在Youtube上搜索视频",
) args_schema=YoutubeInput,
),
] ]
tool_names = [tool.name for tool in tools] tool_names = [tool.name for tool in tools]

View File

@ -1,7 +1,7 @@
from langchain.memory import ConversationBufferWindowMemory from langchain.memory import ConversationBufferWindowMemory
from server.agent.tools_select import tools, tool_names from server.agent.tools_select import tools, tool_names
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from langchain.agents import AgentExecutor, LLMSingleActionAgent from langchain.agents import AgentExecutor, LLMSingleActionAgent, initialize_agent, BaseMultiActionAgent
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse

View File

@ -1,8 +1,9 @@
from configs import CACHED_VS_NUM
from server.knowledge_base.kb_cache.base import * from server.knowledge_base.kb_cache.base import *
from server.knowledge_base.utils import get_vs_path from server.knowledge_base.utils import get_vs_path
from langchain.vectorstores import FAISS from langchain.vectorstores import FAISS
import os import os
from langchain.schema import Document
class ThreadSafeFaiss(ThreadSafeObject): class ThreadSafeFaiss(ThreadSafeObject):
def __repr__(self) -> str: def __repr__(self) -> str: