move wrap_done & get_ChatOpenAI from server.chat.utils to server.utils (#1506)

This commit is contained in:
liunux4odoo 2023-09-17 16:19:50 +08:00 committed by GitHub
parent 7d31e84cc7
commit ec85cd1954
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 50 additions and 50 deletions

View File

@ -1,4 +1,4 @@
from server.chat.utils import get_ChatOpenAI from server.utils import get_ChatOpenAI
from configs.model_config import LLM_MODEL, TEMPERATURE from configs.model_config import LLM_MODEL, TEMPERATURE
from langchain import LLMChain from langchain import LLMChain
from langchain.prompts.chat import ( from langchain.prompts.chat import (

View File

@ -1,6 +1,6 @@
from langchain import PromptTemplate from langchain import PromptTemplate
from langchain.chains import LLMMathChain from langchain.chains import LLMMathChain
from server.chat.utils import wrap_done, get_ChatOpenAI from server.utils import wrap_done, get_ChatOpenAI
from configs.model_config import LLM_MODEL, TEMPERATURE from configs.model_config import LLM_MODEL, TEMPERATURE
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.callbacks.manager import CallbackManagerForToolRun from langchain.callbacks.manager import CallbackManagerForToolRun

View File

@ -2,7 +2,7 @@ from langchain import PromptTemplate, LLMChain
import sys import sys
import os import os
from server.chat.utils import get_ChatOpenAI from server.utils import get_ChatOpenAI
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from langchain.chains.llm_math.prompt import PROMPT from langchain.chains.llm_math.prompt import PROMPT

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import sys import sys
import os import os
from server.chat.utils import get_ChatOpenAI from server.utils import get_ChatOpenAI
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

View File

@ -5,7 +5,7 @@ from server.agent.custom_template import CustomOutputParser, prompt
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN
from server.chat.utils import wrap_done, get_ChatOpenAI from server.utils import wrap_done, get_ChatOpenAI
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.callbacks.streaming_aiter_final_only import AsyncFinalIteratorCallbackHandler from langchain.callbacks.streaming_aiter_final_only import AsyncFinalIteratorCallbackHandler

View File

@ -1,7 +1,7 @@
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from configs import LLM_MODEL, TEMPERATURE from configs import LLM_MODEL, TEMPERATURE
from server.chat.utils import wrap_done, get_ChatOpenAI from server.utils import wrap_done, get_ChatOpenAI
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable from typing import AsyncIterable

View File

@ -1,7 +1,7 @@
from fastapi import Body, Request from fastapi import Body, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from configs import (LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE) from configs import (LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
from server.chat.utils import wrap_done, get_ChatOpenAI from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template from server.utils import BaseResponse, get_prompt_template
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler

View File

@ -4,7 +4,7 @@ from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY,
from fastapi import Body from fastapi import Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
from server.chat.utils import wrap_done, get_ChatOpenAI from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template from server.utils import BaseResponse, get_prompt_template
from langchain import LLMChain from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler

View File

@ -1,46 +1,7 @@
import asyncio
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from langchain.prompts.chat import ChatMessagePromptTemplate from langchain.prompts.chat import ChatMessagePromptTemplate
from configs import logger, log_verbose from configs import logger, log_verbose
from server.utils import get_model_worker_config, fschat_openai_api_address from typing import List, Tuple, Dict, Union
from langchain.chat_models import ChatOpenAI
from typing import Awaitable, List, Tuple, Dict, Union, Callable, Any
def get_ChatOpenAI(
model_name: str,
temperature: float,
streaming: bool = True,
callbacks: List[Callable] = [],
**kwargs: Any,
) -> ChatOpenAI:
config = get_model_worker_config(model_name)
model = ChatOpenAI(
streaming=streaming,
verbose=True,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name,
temperature=temperature,
openai_proxy=config.get("openai_proxy"),
**kwargs
)
return model
async def wrap_done(fn: Awaitable, event: asyncio.Event):
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
try:
await fn
except Exception as e:
# TODO: handle exception
msg = f"Caught exception: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
finally:
# Signal the aiter to stop.
event.set()
class History(BaseModel): class History(BaseModel):

View File

@ -10,12 +10,50 @@ from configs import (LLM_MODEL, LLM_DEVICE, EMBEDDING_DEVICE,
FSCHAT_MODEL_WORKERS) FSCHAT_MODEL_WORKERS)
import os import os
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Literal, Optional, Callable, Generator, Dict, Any, Tuple from langchain.chat_models import ChatOpenAI
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable
thread_pool = ThreadPoolExecutor(os.cpu_count()) thread_pool = ThreadPoolExecutor(os.cpu_count())
async def wrap_done(fn: Awaitable, event: asyncio.Event):
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
try:
await fn
except Exception as e:
# TODO: handle exception
msg = f"Caught exception: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
finally:
# Signal the aiter to stop.
event.set()
def get_ChatOpenAI(
model_name: str,
temperature: float,
streaming: bool = True,
callbacks: List[Callable] = [],
verbose: bool = True,
**kwargs: Any,
) -> ChatOpenAI:
config = get_model_worker_config(model_name)
model = ChatOpenAI(
streaming=streaming,
verbose=verbose,
callbacks=callbacks,
openai_api_key=config.get("api_key", "EMPTY"),
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name,
temperature=temperature,
openai_proxy=config.get("openai_proxy"),
**kwargs
)
return model
class BaseResponse(BaseModel): class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code") code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message") msg: str = pydantic.Field("success", description="API status message")

View File

@ -2,7 +2,7 @@ import sys
import os import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from configs import LLM_MODEL, TEMPERATURE from configs import LLM_MODEL, TEMPERATURE
from server.chat.utils import get_ChatOpenAI from server.utils import get_ChatOpenAI
from langchain import LLMChain from langchain import LLMChain
from langchain.agents import LLMSingleActionAgent, AgentExecutor from langchain.agents import LLMSingleActionAgent, AgentExecutor
from server.agent.tools import tools, tool_names from server.agent.tools import tools, tool_names

View File

@ -377,6 +377,7 @@ class ApiRequest:
else: else:
response = self.post("/chat/agent_chat", json=data, stream=True) response = self.post("/chat/agent_chat", json=data, stream=True)
return self._httpx_stream2generator(response) return self._httpx_stream2generator(response)
def knowledge_base_chat( def knowledge_base_chat(
self, self,
query: str, query: str,