from __future__ import annotations import asyncio import json from typing import Any, Dict, List, Optional from uuid import UUID from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.schema import AgentAction, AgentFinish from langchain_core.outputs import LLMResult from chatchat.utils import build_logger logger = build_logger() def dumps(obj: Dict) -> str: return json.dumps(obj, ensure_ascii=False) class AgentStatus: llm_start: int = 1 llm_new_token: int = 2 llm_end: int = 3 agent_action: int = 4 agent_finish: int = 5 tool_start: int = 6 tool_end: int = 7 error: int = 8 class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): def __init__(self): super().__init__() self.queue = asyncio.Queue() self.done = asyncio.Event() self.out = True logger.info(f"init....") async def on_llm_start( self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any ) -> None: data = { "status": AgentStatus.llm_start, "text": "", } self.done.clear() self.queue.put_nowait(dumps(data)) logger.info(f"prompts:{prompts}") async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: special_tokens = ["\nAction:", "\nObservation:", "<|observation|>"] for stoken in special_tokens: if stoken in token: before_action = token.split(stoken)[0] data = { "status": AgentStatus.llm_new_token, "text": before_action + "\n", } self.queue.put_nowait(dumps(data)) self.out = False break if token is not None and token != "" and self.out: data = { "status": AgentStatus.llm_new_token, "text": token, } self.queue.put_nowait(dumps(data)) async def on_chat_model_start( self, serialized: Dict[str, Any], messages: List[List], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: data = { "status": AgentStatus.llm_start, "text": "", } self.done.clear() self.queue.put_nowait(dumps(data)) logger.info(f"messages:{messages}") async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: data = { "status": AgentStatus.llm_end, "text": response.generations[0][0].message.content, } self.queue.put_nowait(dumps(data)) logger.info(f"response:{response.json}") async def on_llm_error( self, error: Exception | KeyboardInterrupt, **kwargs: Any ) -> None: data = { "status": AgentStatus.error, "text": str(error), } self.queue.put_nowait(dumps(data)) async def on_tool_start( self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: data = { "run_id": str(run_id), "status": AgentStatus.tool_start, "tool": serialized["name"], "tool_input": input_str, } self.queue.put_nowait(dumps(data)) logger.info(f"input_str:{input_str}") async def on_tool_end( self, output: str, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when tool ends running.""" data = { "run_id": str(run_id), "status": AgentStatus.tool_end, "tool_output": output, } # self.done.clear() self.queue.put_nowait(dumps(data)) logger.info(f"output:{output}") async def on_tool_error( self, error: BaseException, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Run when tool errors.""" data = { "run_id": str(run_id), "status": AgentStatus.tool_end, "tool_output": str(error), "is_error": True, } # self.done.clear() self.queue.put_nowait(dumps(data)) logger.error(f"error:{error.__class__}") async def on_agent_action( self, action: AgentAction, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: data = { "status": AgentStatus.agent_action, "tool_name": action.tool, "tool_input": action.tool_input, "text": action.log, } self.queue.put_nowait(dumps(data)) logger.error(f"tool_name:{action.tool},tool_input:{ action.tool_input}") async def on_agent_finish( self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: if "Thought:" in finish.return_values["output"]: finish.return_values["output"] = finish.return_values["output"].replace( "Thought:", "" ) data = { "status": AgentStatus.agent_finish, "text": finish.return_values["output"], } self.queue.put_nowait(dumps(data)) logger.error(f"data:{data}") async def on_chain_end( self, outputs: Dict[str, Any], *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any, ) -> None: self.done.set() self.out = True logger.info(f"outputs:{outputs}")