move wrap_done & get_ChatOpenAI from server.chat.utils to server.utils (#1506)
This commit is contained in:
parent
7d31e84cc7
commit
ec85cd1954
|
|
@ -1,4 +1,4 @@
|
|||
from server.chat.utils import get_ChatOpenAI
|
||||
from server.utils import get_ChatOpenAI
|
||||
from configs.model_config import LLM_MODEL, TEMPERATURE
|
||||
from langchain import LLMChain
|
||||
from langchain.prompts.chat import (
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from langchain import PromptTemplate
|
||||
from langchain.chains import LLMMathChain
|
||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from configs.model_config import LLM_MODEL, TEMPERATURE
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.callbacks.manager import CallbackManagerForToolRun
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from langchain import PromptTemplate, LLMChain
|
|||
import sys
|
||||
import os
|
||||
|
||||
from server.chat.utils import get_ChatOpenAI
|
||||
from server.utils import get_ChatOpenAI
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
from langchain.chains.llm_math.prompt import PROMPT
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
import sys
|
||||
import os
|
||||
|
||||
from server.chat.utils import get_ChatOpenAI
|
||||
from server.utils import get_ChatOpenAI
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from server.agent.custom_template import CustomOutputParser, prompt
|
|||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN
|
||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.callbacks.streaming_aiter_final_only import AsyncFinalIteratorCallbackHandler
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs import LLM_MODEL, TEMPERATURE
|
||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from fastapi import Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs import (LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
|
||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
|
|||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from server.chat.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from server.utils import BaseResponse, get_prompt_template
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
|
|
|
|||
|
|
@ -1,46 +1,7 @@
|
|||
import asyncio
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.prompts.chat import ChatMessagePromptTemplate
|
||||
from configs import logger, log_verbose
|
||||
from server.utils import get_model_worker_config, fschat_openai_api_address
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from typing import Awaitable, List, Tuple, Dict, Union, Callable, Any
|
||||
|
||||
|
||||
def get_ChatOpenAI(
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
streaming: bool = True,
|
||||
callbacks: List[Callable] = [],
|
||||
**kwargs: Any,
|
||||
) -> ChatOpenAI:
|
||||
config = get_model_worker_config(model_name)
|
||||
model = ChatOpenAI(
|
||||
streaming=streaming,
|
||||
verbose=True,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=config.get("api_key", "EMPTY"),
|
||||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
|
||||
try:
|
||||
await fn
|
||||
except Exception as e:
|
||||
# TODO: handle exception
|
||||
msg = f"Caught exception: {e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
finally:
|
||||
# Signal the aiter to stop.
|
||||
event.set()
|
||||
from typing import List, Tuple, Dict, Union
|
||||
|
||||
|
||||
class History(BaseModel):
|
||||
|
|
|
|||
|
|
@ -10,12 +10,50 @@ from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
|
|||
FSCHAT_MODEL_WORKERS)
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Tuple
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable
|
||||
|
||||
|
||||
thread_pool = ThreadPoolExecutor(os.cpu_count())
|
||||
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
|
||||
try:
|
||||
await fn
|
||||
except Exception as e:
|
||||
# TODO: handle exception
|
||||
msg = f"Caught exception: {e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
finally:
|
||||
# Signal the aiter to stop.
|
||||
event.set()
|
||||
|
||||
|
||||
def get_ChatOpenAI(
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
streaming: bool = True,
|
||||
callbacks: List[Callable] = [],
|
||||
verbose: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> ChatOpenAI:
|
||||
config = get_model_worker_config(model_name)
|
||||
model = ChatOpenAI(
|
||||
streaming=streaming,
|
||||
verbose=verbose,
|
||||
callbacks=callbacks,
|
||||
openai_api_key=config.get("api_key", "EMPTY"),
|
||||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
**kwargs
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
code: int = pydantic.Field(200, description="API status code")
|
||||
msg: str = pydantic.Field("success", description="API status message")
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import sys
|
|||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
from configs import LLM_MODEL, TEMPERATURE
|
||||
from server.chat.utils import get_ChatOpenAI
|
||||
from server.utils import get_ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.agents import LLMSingleActionAgent, AgentExecutor
|
||||
from server.agent.tools import tools, tool_names
|
||||
|
|
|
|||
|
|
@ -377,6 +377,7 @@ class ApiRequest:
|
|||
else:
|
||||
response = self.post("/chat/agent_chat", json=data, stream=True)
|
||||
return self._httpx_stream2generator(response)
|
||||
|
||||
def knowledge_base_chat(
|
||||
self,
|
||||
query: str,
|
||||
|
|
|
|||
Loading…
Reference in New Issue