diff --git a/server/agent/tools/__init__.py b/server/agent/tools/__init__.py index d6cdff5..89527a6 100644 --- a/server/agent/tools/__init__.py +++ b/server/agent/tools/__init__.py @@ -1,11 +1,11 @@ ## 导入所有的工具类 from .search_knowledge_simple import knowledge_search_simple -from .search_all_knowledge_once import knowledge_search_once -from .search_all_knowledge_more import knowledge_search_more -from .calculate import calculate -from .translator import translate -from .weather import weathercheck -from .shell import shell -from .search_internet import search_internet -from .wolfram import wolfram -from .youtube import youtube_search +from .search_all_knowledge_once import knowledge_search_once, KnowledgeSearchInput +from .search_all_knowledge_more import knowledge_search_more, KnowledgeSearchInput +from .calculate import calculate, CalculatorInput +from .translator import translate, TranslateInput +from .weather import weathercheck, WhetherSchema +from .shell import shell, ShellInput +from .search_internet import search_internet, SearchInternetInput +from .wolfram import wolfram, WolframInput +from .youtube import youtube_search, YoutubeInput diff --git a/server/agent/tools/calculate.py b/server/agent/tools/calculate.py index 629331f..bb0cbcc 100644 --- a/server/agent/tools/calculate.py +++ b/server/agent/tools/calculate.py @@ -1,6 +1,7 @@ from langchain.prompts import PromptTemplate from langchain.chains import LLMMathChain from server.agent import model_container +from pydantic import BaseModel, Field _PROMPT_TEMPLATE = """ 将数学问题翻译成可以使用Python的numexpr库执行的表达式。使用运行此代码的输出来回答问题。 @@ -58,6 +59,9 @@ PROMPT = PromptTemplate( ) +class CalculatorInput(BaseModel): + query: str = Field() + def calculate(query: str): model = model_container.MODEL llm_math = LLMMathChain.from_llm(model, verbose=True, prompt=PROMPT) @@ -68,3 +72,5 @@ if __name__ == "__main__": result = calculate("2的三次方") print("答案:",result) + + diff --git a/server/agent/tools/search_all_knowledge_more.py b/server/agent/tools/search_all_knowledge_more.py index a2d76b0..7da9081 100644 --- a/server/agent/tools/search_all_knowledge_more.py +++ b/server/agent/tools/search_all_knowledge_more.py @@ -14,7 +14,7 @@ from server.chat.knowledge_base_chat import knowledge_base_chat from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS import asyncio from server.agent import model_container - +from pydantic import BaseModel, Field async def search_knowledge_base_iter(database: str, query: str) -> str: response = await knowledge_base_chat(query=query, @@ -264,6 +264,8 @@ def knowledge_search_more(query: str): ans = llm_knowledge.run(query) return ans +class KnowledgeSearchInput(BaseModel): + location: str = Field(description="知识库查询的内容") if __name__ == "__main__": result = knowledge_search_more("机器人和大数据在代码教学上有什么区别") diff --git a/server/agent/tools/search_all_knowledge_once.py b/server/agent/tools/search_all_knowledge_once.py index 6e9d965..cf706b8 100644 --- a/server/agent/tools/search_all_knowledge_once.py +++ b/server/agent/tools/search_all_knowledge_once.py @@ -23,7 +23,7 @@ from configs import VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, MAX_TOKENS import asyncio from server.agent import model_container - +from pydantic import BaseModel, Field async def search_knowledge_base_iter(database: str, query: str): response = await knowledge_base_chat(query=query, @@ -225,6 +225,10 @@ def knowledge_search_once(query: str): return ans +class KnowledgeSearchInput(BaseModel): + location: str = Field(description="知识库查询的内容") + + if __name__ == "__main__": result = knowledge_search_once("大数据的男女比例") print(result) diff --git a/server/agent/tools/search_internet.py b/server/agent/tools/search_internet.py index 0d52789..f224849 100644 --- a/server/agent/tools/search_internet.py +++ b/server/agent/tools/search_internet.py @@ -3,6 +3,7 @@ from server.chat.search_engine_chat import search_engine_chat from configs import VECTOR_SEARCH_TOP_K, MAX_TOKENS import asyncio from server.agent import model_container +from pydantic import BaseModel, Field async def search_engine_iter(query: str): response = await search_engine_chat(query=query, @@ -25,9 +26,11 @@ async def search_engine_iter(query: str): return contents def search_internet(query: str): - return asyncio.run(search_engine_iter(query)) +class SearchInternetInput(BaseModel): + location: str = Field(description="需要查询的内容") + if __name__ == "__main__": result = search_internet("今天星期几") diff --git a/server/agent/tools/shell.py b/server/agent/tools/shell.py index 12bb000..db07415 100644 --- a/server/agent/tools/shell.py +++ b/server/agent/tools/shell.py @@ -1,6 +1,9 @@ # LangChain 的 Shell 工具 +from pydantic import BaseModel, Field from langchain.tools import ShellTool def shell(query: str): tool = ShellTool() return tool.run(tool_input=query) +class ShellInput(BaseModel): + query: str = Field(description="一个能在Linux命令行运行的Shell命令") \ No newline at end of file diff --git a/server/agent/tools/translator.py b/server/agent/tools/translator.py index 78422fb..2655d52 100644 --- a/server/agent/tools/translator.py +++ b/server/agent/tools/translator.py @@ -1,6 +1,7 @@ from langchain.prompts import PromptTemplate from langchain.chains import LLMChain from server.agent import model_container +from pydantic import BaseModel, Field _PROMPT_TEMPLATE = ''' # 指令 @@ -29,6 +30,9 @@ def translate(query: str): ans = llm_translate.run(query) return ans +class TranslateInput(BaseModel): + location: str = Field(description="需要被翻译的内容") + if __name__ == "__main__": result = translate("Can Love remember the question and the answer? 这句话如何诗意的翻译成中文") print("答案:",result) \ No newline at end of file diff --git a/server/agent/tools/weather.py b/server/agent/tools/weather.py index 3d99fb6..db88b87 100644 --- a/server/agent/tools/weather.py +++ b/server/agent/tools/weather.py @@ -3,6 +3,7 @@ from __future__ import annotations ## 单独运行的时候需要添加 import sys import os + sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) import re @@ -23,11 +24,11 @@ from typing import List, Any, Optional from datetime import datetime from langchain.prompts import PromptTemplate from server.agent import model_container +from pydantic import BaseModel, Field ## 使用和风天气API查询天气 KEY = "ac880e5a877042809ac7ffdd19d95b0d" -#key长这样,这里提供了示例的key,这个key没法使用,你需要自己去注册和风天气的账号,然后在这里填入你的key - +# key长这样,这里提供了示例的key,这个key没法使用,你需要自己去注册和风天气的账号,然后在这里填入你的key _PROMPT_TEMPLATE = """ @@ -95,7 +96,7 @@ def get_city_info(location, adm, key): return data -def format_weather_data(data,place): +def format_weather_data(data, place): hourly_forecast = data['hourly'] formatted_data = f"\n 这是查询到的关于{place}未来24小时的天气信息: \n" for forecast in hourly_forecast: @@ -141,7 +142,7 @@ def format_weather_data(data,place): return formatted_data -def get_weather(key, location_id,place): +def get_weather(key, location_id, place): url = "https://devapi.qweather.com/v7/weather/24h?" params = { 'location': location_id, @@ -149,18 +150,20 @@ def get_weather(key, location_id,place): } response = requests.get(url, params=params) data = response.json() - return format_weather_data(data,place) + return format_weather_data(data, place) def split_query(query): parts = query.split() adm = parts[0] + if len(parts) == 1: + return adm, adm location = parts[1] if parts[1] != 'None' else adm return location, adm def weather(query): - location, adm= split_query(query) + location, adm = split_query(query) key = KEY if key == "": return "请先在代码中填入和风天气API Key" @@ -169,17 +172,19 @@ def weather(query): location_id = city_info['location'][0]['id'] place = adm + "市" + location + "区" - weather_data = get_weather(key=key, location_id=location_id,place=place) - return weather_data + "以上是查询到的天气信息,请你查收\n" + weather_data = get_weather(key=key, location_id=location_id, place=place) + return weather_data + "以上是查询到的天气信息,请你查收\n" except KeyError: try: city_info = get_city_info(location=adm, adm=adm, key=key) location_id = city_info['location'][0]['id'] place = adm + "市" - weather_data = get_weather(key=key, location_id=location_id,place=place) + weather_data = get_weather(key=key, location_id=location_id, place=place) return weather_data + "重要提醒:用户提供的市和区中,区的信息不存在,或者出现错别字,因此该信息是关于市的天气,请你查收\n" except KeyError: return "输入的地区不存在,无法提供天气预报" + + class LLMWeatherChain(Chain): llm_chain: LLMChain llm: Optional[BaseLanguageModel] = None @@ -319,12 +324,15 @@ class LLMWeatherChain(Chain): return cls(llm_chain=llm_chain, **kwargs) - def weathercheck(query: str): model = model_container.MODEL llm_weather = LLMWeatherChain.from_llm(model, verbose=True, prompt=PROMPT) ans = llm_weather.run(query) return ans + +class WhetherSchema(BaseModel): + location: str = Field(description="应该是一个地区的名称,用空格隔开,例如:上海 浦东,如果没有区的信息,可以只输入上海") + if __name__ == '__main__': - result = weathercheck("苏州姑苏区今晚热不热?") \ No newline at end of file + result = weathercheck("苏州姑苏区今晚热不热?") diff --git a/server/agent/tools/wolfram.py b/server/agent/tools/wolfram.py index 18958e3..c322da1 100644 --- a/server/agent/tools/wolfram.py +++ b/server/agent/tools/wolfram.py @@ -1,7 +1,11 @@ # Langchain 自带的 Wolfram Alpha API 封装 from langchain.utilities.wolfram_alpha import WolframAlphaAPIWrapper +from pydantic import BaseModel, Field wolfram_alpha_appid = "your key" def wolfram(query: str): wolfram = WolframAlphaAPIWrapper(wolfram_alpha_appid=wolfram_alpha_appid) ans = wolfram.run(query) - return ans \ No newline at end of file + return ans + +class WolframInput(BaseModel): + location: str = Field(description="需要运算的具体问题") \ No newline at end of file diff --git a/server/agent/tools/youtube.py b/server/agent/tools/youtube.py index 08c5bda..27eb8bd 100644 --- a/server/agent/tools/youtube.py +++ b/server/agent/tools/youtube.py @@ -1,5 +1,9 @@ # Langchain 自带的 YouTube 搜索工具封装 from langchain.tools import YouTubeSearchTool +from pydantic import BaseModel, Field def youtube_search(query: str): tool = YouTubeSearchTool() - return tool.run(tool_input=query) \ No newline at end of file + return tool.run(tool_input=query) + +class YoutubeInput(BaseModel): + location: str = Field(description="要搜索视频关键字") \ No newline at end of file diff --git a/server/agent/tools_select.py b/server/agent/tools_select.py index be5b30d..e0be597 100644 --- a/server/agent/tools_select.py +++ b/server/agent/tools_select.py @@ -1,60 +1,20 @@ from langchain.tools import Tool from server.agent.tools import * -## 请注意,如果你是为了使用AgentLM,在这里,你应该使用英文版本,下面的内容是英文版本。 -# tools = [ -# Tool.from_function( -# func=calculate, -# name="Simple Calculator Tool", -# description="Perform simple mathematical operations, Just simple, Use Wolfram Math Tool for more complex operations" -# ), -# Tool.from_function( -# func=translate, -# name="Translation Tool", -# description="Use this tool if you can't access the internet and need to translate various languages" -# ), -# Tool.from_function( -# func=weathercheck, -# name="Weather Checking Tool", -# description="Check the weather for various places in China for the next 24 hours without needing internet access" -# ), -# Tool.from_function( -# func=shell, -# name="Shell Tool", -# description="Use command line tool output" -# ), -# Tool.from_function( -# func=knowledge_search_more, -# name="Knowledge Base Tool", -# description="Prioritize accessing the knowledge base to get answers" -# ), -# Tool.from_function( -# func=search_internet, -# name="Internet Tool", -# description="If you can't access the internet, this tool can help you access Bing to answer questions" -# ), -# Tool.from_function( -# func=wolfram, -# name="Wolfram Math Tool", -# description="Use this tool to perform more complex mathematical operations" -# ), -# Tool.from_function( -# func=youtube_search, -# name="Youtube Search Tool", -# description="Use this tool to search for videos on Youtube" -# ) -# ] +## 请注意,如果你是为了使用AgentLM,在这里,你应该使用英文版本。 tools = [ Tool.from_function( func=calculate, name="计算器工具", description="进行简单的数学运算, 只是简单的, 使用Wolfram数学工具进行更复杂的运算", + args_schema=CalculatorInput, ), Tool.from_function( func=translate, name="翻译工具", - description="如果你无法访问互联网,并且需要翻译各种语言,应该使用这个工具" + description="如果你无法访问互联网,并且需要翻译各种语言,应该使用这个工具", + args_schema=TranslateInput, ), Tool.from_function( func=weathercheck, @@ -79,13 +39,15 @@ tools = [ Tool.from_function( func=wolfram, name="Wolfram数学工具", - description="高级的数学运算工具,能够完成非常复杂的数学问题" + description="高级的数学运算工具,能够完成非常复杂的数学问题", + args_schema=WolframInput, ), Tool.from_function( func=youtube_search, name="Youtube搜索工具", - description="使用这个工具在Youtube上搜索视频" - ) + description="使用这个工具在Youtube上搜索视频", + args_schema=YoutubeInput, + ), ] tool_names = [tool.name for tool in tools] diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index 35ef869..7047850 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -1,7 +1,7 @@ from langchain.memory import ConversationBufferWindowMemory from server.agent.tools_select import tools, tool_names from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status -from langchain.agents import AgentExecutor, LLMSingleActionAgent +from langchain.agents import AgentExecutor, LLMSingleActionAgent, initialize_agent, BaseMultiActionAgent from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate from fastapi import Body from fastapi.responses import StreamingResponse diff --git a/server/knowledge_base/kb_cache/faiss_cache.py b/server/knowledge_base/kb_cache/faiss_cache.py index c00ac4f..80e09c1 100644 --- a/server/knowledge_base/kb_cache/faiss_cache.py +++ b/server/knowledge_base/kb_cache/faiss_cache.py @@ -1,8 +1,9 @@ +from configs import CACHED_VS_NUM from server.knowledge_base.kb_cache.base import * from server.knowledge_base.utils import get_vs_path from langchain.vectorstores import FAISS import os - +from langchain.schema import Document class ThreadSafeFaiss(ThreadSafeObject): def __repr__(self) -> str: