from __future__ import annotations from uuid import UUID from langchain.callbacks import AsyncIteratorCallbackHandler import json import asyncio from typing import Any, Dict, List, Optional from langchain.schema import AgentFinish, AgentAction from langchain.schema.output import LLMResult def dumps(obj: Dict) -> str: return json.dumps(obj, ensure_ascii=False) class Status: start: int = 1 running: int = 2 complete: int = 3 agent_action: int = 4 agent_finish: int = 5 error: int = 6 tool_finish: int = 7 class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): def __init__(self): super().__init__() self.queue = asyncio.Queue() self.done = asyncio.Event() self.cur_tool = {} async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None: # 对于截断不能自理的大模型,我来帮他截断 stop_words = ["Observation:", "Thought","\"","(", "\n","\t"] for stop_word in stop_words: index = input_str.find(stop_word) if index != -1: input_str = input_str[:index] break self.cur_tool = { "tool_name": serialized["name"], "input_str": input_str, "output_str": "", "status": Status.agent_action, "run_id": run_id.hex, "llm_token": "", "final_answer": "", "error": "", } # print("\nInput Str:",self.cur_tool["input_str"]) self.queue.put_nowait(dumps(self.cur_tool)) async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None: self.cur_tool.update( status=Status.tool_finish, output_str=output.replace("Answer:", ""), ) self.queue.put_nowait(dumps(self.cur_tool)) async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None: self.cur_tool.update( status=Status.error, error=str(error), ) self.queue.put_nowait(dumps(self.cur_tool)) async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: if token: self.cur_tool.update( status=Status.running, llm_token=token, ) self.queue.put_nowait(dumps(self.cur_tool)) async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: self.cur_tool.update( status=Status.start, llm_token="", ) self.queue.put_nowait(dumps(self.cur_tool)) async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.cur_tool.update( status=Status.complete, llm_token="\n", ) self.queue.put_nowait(dumps(self.cur_tool)) async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None: self.cur_tool.update( status=Status.error, error=str(error), ) self.queue.put_nowait(dumps(self.cur_tool)) async def on_agent_finish( self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: # 返回最终答案 self.cur_tool.update( status=Status.agent_finish, final_answer=finish.return_values["output"], ) self.queue.put_nowait(dumps(self.cur_tool)) self.cur_tool = {}