From ec85cd1954ae6d1050ef6e13e83bab72079cb3d6 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Sun, 17 Sep 2023 16:19:50 +0800 Subject: [PATCH] move wrap_done & get_ChatOpenAI from server.chat.utils to server.utils (#1506) --- chains/llmchain_with_history.py | 2 +- server/agent/math.py | 2 +- server/agent/translator.py | 2 +- server/agent/weather.py | 2 +- server/chat/agent_chat.py | 2 +- server/chat/chat.py | 2 +- server/chat/knowledge_base_chat.py | 2 +- server/chat/search_engine_chat.py | 2 +- server/chat/utils.py | 41 +----------------------------- server/utils.py | 40 ++++++++++++++++++++++++++++- tests/agent/test_agent_function.py | 2 +- webui_pages/utils.py | 1 + 12 files changed, 50 insertions(+), 50 deletions(-) diff --git a/chains/llmchain_with_history.py b/chains/llmchain_with_history.py index 044f470..2dbdff0 100644 --- a/chains/llmchain_with_history.py +++ b/chains/llmchain_with_history.py @@ -1,4 +1,4 @@ -from server.chat.utils import get_ChatOpenAI +from server.utils import get_ChatOpenAI from configs.model_config import LLM_MODEL, TEMPERATURE from langchain import LLMChain from langchain.prompts.chat import ( diff --git a/server/agent/math.py b/server/agent/math.py index c865125..aa6c759 100644 --- a/server/agent/math.py +++ b/server/agent/math.py @@ -1,6 +1,6 @@ from langchain import PromptTemplate from langchain.chains import LLMMathChain -from server.chat.utils import wrap_done, get_ChatOpenAI +from server.utils import wrap_done, get_ChatOpenAI from configs.model_config import LLM_MODEL, TEMPERATURE from langchain.chat_models import ChatOpenAI from langchain.callbacks.manager import CallbackManagerForToolRun diff --git a/server/agent/translator.py b/server/agent/translator.py index 5740034..4477c17 100644 --- a/server/agent/translator.py +++ b/server/agent/translator.py @@ -2,7 +2,7 @@ from langchain import PromptTemplate, LLMChain import sys import os -from server.chat.utils import get_ChatOpenAI +from server.utils import get_ChatOpenAI sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from langchain.chains.llm_math.prompt import PROMPT diff --git a/server/agent/weather.py b/server/agent/weather.py index 9a4464b..5822eba 100644 --- a/server/agent/weather.py +++ b/server/agent/weather.py @@ -4,7 +4,7 @@ from __future__ import annotations import sys import os -from server.chat.utils import get_ChatOpenAI +from server.utils import get_ChatOpenAI sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index 8d7cf3b..4d2496f 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -5,7 +5,7 @@ from server.agent.custom_template import CustomOutputParser, prompt from fastapi import Body from fastapi.responses import StreamingResponse from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN -from server.chat.utils import wrap_done, get_ChatOpenAI +from server.utils import wrap_done, get_ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks.streaming_aiter_final_only import AsyncFinalIteratorCallbackHandler diff --git a/server/chat/chat.py b/server/chat/chat.py index 8f4d343..02856ea 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,7 +1,7 @@ from fastapi import Body from fastapi.responses import StreamingResponse from configs import LLM_MODEL, TEMPERATURE -from server.chat.utils import wrap_done, get_ChatOpenAI +from server.utils import wrap_done, get_ChatOpenAI from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 4c461aa..0fb7ab4 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,7 +1,7 @@ from fastapi import Body, Request from fastapi.responses import StreamingResponse from configs import (LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE) -from server.chat.utils import wrap_done, get_ChatOpenAI +from server.utils import wrap_done, get_ChatOpenAI from server.utils import BaseResponse, get_prompt_template from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index cfc1fa9..78246ae 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -4,7 +4,7 @@ from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, from fastapi import Body from fastapi.responses import StreamingResponse from fastapi.concurrency import run_in_threadpool -from server.chat.utils import wrap_done, get_ChatOpenAI +from server.utils import wrap_done, get_ChatOpenAI from server.utils import BaseResponse, get_prompt_template from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler diff --git a/server/chat/utils.py b/server/chat/utils.py index 60ea2d1..dd3c333 100644 --- a/server/chat/utils.py +++ b/server/chat/utils.py @@ -1,46 +1,7 @@ -import asyncio from pydantic import BaseModel, Field from langchain.prompts.chat import ChatMessagePromptTemplate from configs import logger, log_verbose -from server.utils import get_model_worker_config, fschat_openai_api_address -from langchain.chat_models import ChatOpenAI -from typing import Awaitable, List, Tuple, Dict, Union, Callable, Any - - -def get_ChatOpenAI( - model_name: str, - temperature: float, - streaming: bool = True, - callbacks: List[Callable] = [], - **kwargs: Any, -) -> ChatOpenAI: - config = get_model_worker_config(model_name) - model = ChatOpenAI( - streaming=streaming, - verbose=True, - callbacks=callbacks, - openai_api_key=config.get("api_key", "EMPTY"), - openai_api_base=config.get("api_base_url", fschat_openai_api_address()), - model_name=model_name, - temperature=temperature, - openai_proxy=config.get("openai_proxy"), - **kwargs - ) - return model - - -async def wrap_done(fn: Awaitable, event: asyncio.Event): - """Wrap an awaitable with a event to signal when it's done or an exception is raised.""" - try: - await fn - except Exception as e: - # TODO: handle exception - msg = f"Caught exception: {e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) - finally: - # Signal the aiter to stop. - event.set() +from typing import List, Tuple, Dict, Union class History(BaseModel): diff --git a/server/utils.py b/server/utils.py index e0a595d..0efe122 100644 --- a/server/utils.py +++ b/server/utils.py @@ -10,12 +10,50 @@ from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE, FSCHAT_MODEL_WORKERS) import os from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Literal, Optional, Callable, Generator, Dict, Any, Tuple +from langchain.chat_models import ChatOpenAI +from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable thread_pool = ThreadPoolExecutor(os.cpu_count()) +async def wrap_done(fn: Awaitable, event: asyncio.Event): + """Wrap an awaitable with a event to signal when it's done or an exception is raised.""" + try: + await fn + except Exception as e: + # TODO: handle exception + msg = f"Caught exception: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + finally: + # Signal the aiter to stop. + event.set() + + +def get_ChatOpenAI( + model_name: str, + temperature: float, + streaming: bool = True, + callbacks: List[Callable] = [], + verbose: bool = True, + **kwargs: Any, +) -> ChatOpenAI: + config = get_model_worker_config(model_name) + model = ChatOpenAI( + streaming=streaming, + verbose=verbose, + callbacks=callbacks, + openai_api_key=config.get("api_key", "EMPTY"), + openai_api_base=config.get("api_base_url", fschat_openai_api_address()), + model_name=model_name, + temperature=temperature, + openai_proxy=config.get("openai_proxy"), + **kwargs + ) + return model + + class BaseResponse(BaseModel): code: int = pydantic.Field(200, description="API status code") msg: str = pydantic.Field("success", description="API status message") diff --git a/tests/agent/test_agent_function.py b/tests/agent/test_agent_function.py index 91c3a32..0ee9863 100644 --- a/tests/agent/test_agent_function.py +++ b/tests/agent/test_agent_function.py @@ -2,7 +2,7 @@ import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) from configs import LLM_MODEL, TEMPERATURE -from server.chat.utils import get_ChatOpenAI +from server.utils import get_ChatOpenAI from langchain import LLMChain from langchain.agents import LLMSingleActionAgent, AgentExecutor from server.agent.tools import tools, tool_names diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 176dbb4..cca30fa 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -377,6 +377,7 @@ class ApiRequest: else: response = self.post("/chat/agent_chat", json=data, stream=True) return self._httpx_stream2generator(response) + def knowledge_base_chat( self, query: str,