110 lines
3.7 KiB
Python
110 lines
3.7 KiB
Python
from uuid import UUID
|
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|
import json
|
|
import asyncio
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from langchain.schema import AgentFinish, AgentAction
|
|
from langchain.schema.output import LLMResult
|
|
|
|
|
|
def dumps(obj: Dict) -> str:
|
|
return json.dumps(obj, ensure_ascii=False)
|
|
|
|
|
|
class Status:
|
|
start: int = 1
|
|
running: int = 2
|
|
complete: int = 3
|
|
agent_action: int = 4
|
|
agent_finish: int = 5
|
|
error: int = 6
|
|
make_tool: int = 7
|
|
|
|
|
|
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.queue = asyncio.Queue()
|
|
self.done = asyncio.Event()
|
|
self.cur_tool = {}
|
|
self.out = True
|
|
|
|
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
|
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
|
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
|
|
self.cur_tool = {
|
|
"tool_name": serialized["name"],
|
|
"input_str": input_str,
|
|
"output_str": "",
|
|
"status": Status.agent_action,
|
|
"run_id": run_id.hex,
|
|
"llm_token": "",
|
|
"final_answer": "",
|
|
"error": "",
|
|
}
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
|
tags: List[str] | None = None, **kwargs: Any) -> None:
|
|
self.out = True
|
|
self.cur_tool.update(
|
|
status=Status.agent_finish,
|
|
output_str=output.replace("Answer:", ""),
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
|
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
|
self.cur_tool.update(
|
|
status=Status.error,
|
|
error=str(error),
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
|
if token:
|
|
if "Action" in token:
|
|
self.out = False
|
|
self.cur_tool.update(
|
|
status=Status.running,
|
|
llm_token="\n\n",
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
if self.out:
|
|
self.cur_tool.update(
|
|
status=Status.running,
|
|
llm_token=token,
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
|
self.cur_tool.update(
|
|
status=Status.start,
|
|
llm_token="",
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
|
self.out = True
|
|
self.cur_tool.update(
|
|
status=Status.complete,
|
|
llm_token="",
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
|
self.out = True
|
|
self.cur_tool.update(
|
|
status=Status.error,
|
|
error=str(error),
|
|
)
|
|
self.queue.put_nowait(dumps(self.cur_tool))
|
|
|
|
async def on_agent_finish(
|
|
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
|
tags: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self.cur_tool = {}
|