From d053950aeec6d79ddb436d3924ad6ef7f364d50c Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Thu, 19 Oct 2023 22:09:15 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=8A=9F=E8=83=BD=EF=BC=9A=20(#1801)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 更新langchain/fastchat依赖,添加xformers依赖 - 默认max_tokens=None, 生成tokens自动为模型支持的最大值 修复: - history_len=0 时会带入1条不完整的历史消息,导致LLM错误 - 当对话轮数 达到history_len时,传入的历史消息为空 --- requirements.txt | 9 +++++---- requirements_api.txt | 5 +++-- requirements_webui.txt | 2 +- server/agent/callbacks.py | 11 ++++++++++- server/chat/agent_chat.py | 3 +-- server/chat/chat.py | 3 +-- server/chat/knowledge_base_chat.py | 3 +-- server/chat/openai_chat.py | 2 +- server/chat/search_engine_chat.py | 3 +-- server/llm_api.py | 2 +- server/model_workers/SparkApi.py | 2 +- server/model_workers/base.py | 11 +---------- webui_pages/dialogue/dialogue.py | 2 +- webui_pages/utils.py | 10 +++++----- 14 files changed, 33 insertions(+), 35 deletions(-) diff --git a/requirements.txt b/requirements.txt index 7631247..959003f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ -langchain>=0.0.314 -langchain-experimental>=0.0.30 -fschat[model_worker]==0.2.30 +langchain==0.0.317 +langchain-experimental==0.0.30 +fschat[model_worker]==0.2.31 +xformers==0.0.22.post4 openai sentence_transformers transformers>=4.34 @@ -43,7 +44,7 @@ pandas~=2.0.3 streamlit>=1.26.0 streamlit-option-menu>=0.3.6 streamlit-antd-components>=0.1.11 -streamlit-chatbox>=1.1.9 +streamlit-chatbox==1.1.10 streamlit-aggrid>=0.3.4.post3 httpx~=0.24.1 watchdog diff --git a/requirements_api.txt b/requirements_api.txt index af4e7e0..f1d2f5f 100644 --- a/requirements_api.txt +++ b/requirements_api.txt @@ -1,6 +1,7 @@ -langchain==0.0.313 +langchain==0.0.317 langchain-experimental==0.0.30 -fschat[model_worker]==0.2.30 +fschat[model_worker]==0.2.31 +xformers==0.0.22.post4 openai sentence_transformers>=2.2.2 transformers>=4.34 diff --git a/requirements_webui.txt b/requirements_webui.txt index 9caf085..561b10a 100644 --- a/requirements_webui.txt +++ b/requirements_webui.txt @@ -3,7 +3,7 @@ pandas~=2.0.3 streamlit>=1.26.0 streamlit-option-menu>=0.3.6 streamlit-antd-components>=0.1.11 -streamlit-chatbox>=1.1.9 +streamlit-chatbox==1.1.10 streamlit-aggrid>=0.3.4.post3 httpx~=0.24.1 nltk diff --git a/server/agent/callbacks.py b/server/agent/callbacks.py index 3a82b9c..ddc6ffe 100644 --- a/server/agent/callbacks.py +++ b/server/agent/callbacks.py @@ -98,7 +98,16 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): ) self.queue.put_nowait(dumps(self.cur_tool)) - async def on_chat_model_start(self,serialized: Dict[str, Any], **kwargs: Any, + async def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List], + *, + run_id: UUID, + parent_run_id: Optional[UUID] = None, + tags: Optional[List[str]] = None, + metadata: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> None: self.cur_tool.update( status=Status.start, diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index c78add6..5a71478 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -26,8 +26,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), - # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 + max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), ): diff --git a/server/chat/chat.py b/server/chat/chat.py index 3ec6855..4402185 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -22,8 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), - # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 + max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index c39b147..19ca871 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -31,8 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), - # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 + max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py index 4a46ddd..7efb0a8 100644 --- a/server/chat/openai_chat.py +++ b/server/chat/openai_chat.py @@ -16,7 +16,7 @@ class OpenAiChatMsgIn(BaseModel): messages: List[OpenAiMessage] temperature: float = 0.7 n: int = 1 - max_tokens: int = 1024 + max_tokens: int = None stop: List[str] = [] stream: bool = False presence_penalty: int = 0 diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 06ac856..e1ccaa4 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -114,8 +114,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), - # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 + max_tokens: int = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): if search_engine_name not in SEARCH_ENGINES.keys(): diff --git a/server/llm_api.py b/server/llm_api.py index dc9ddce..0d3ba3f 100644 --- a/server/llm_api.py +++ b/server/llm_api.py @@ -16,7 +16,7 @@ def list_running_models( with get_httpx_client() as client: r = client.post(controller_address + "/list_models") models = r.json()["models"] - data = {m: get_model_worker_config(m) for m in models} + data = {m: get_model_config(m).data for m in models} return BaseResponse(data=data) except Exception as e: logger.error(f'{e.__class__.__name__}: {e}', diff --git a/server/model_workers/SparkApi.py b/server/model_workers/SparkApi.py index e1dce6a..c4e090e 100644 --- a/server/model_workers/SparkApi.py +++ b/server/model_workers/SparkApi.py @@ -65,7 +65,7 @@ def gen_params(appid, domain,question, temperature): "chat": { "domain": domain, "random_threshold": 0.5, - "max_tokens": 2048, + "max_tokens": None, "auditing": "default", "temperature": temperature, } diff --git a/server/model_workers/base.py b/server/model_workers/base.py index ea14104..f280680 100644 --- a/server/model_workers/base.py +++ b/server/model_workers/base.py @@ -1,7 +1,7 @@ from configs.basic_config import LOG_PATH import fastchat.constants fastchat.constants.LOGDIR = LOG_PATH -from fastchat.serve.model_worker import BaseModelWorker +from fastchat.serve.base_model_worker import BaseModelWorker import uuid import json import sys @@ -62,15 +62,6 @@ class ApiModelWorker(BaseModelWorker): print("embedding") print(params) - # workaround to make program exit with Ctrl+c - # it should be deleted after pr is merged by fastchat - def init_heart_beat(self): - self.register_to_controller() - self.heart_beat_thread = threading.Thread( - target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True, - ) - self.heart_beat_thread.start() - # help methods def get_config(self): from server.utils import get_model_worker_config diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 2852250..c7a029d 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -53,7 +53,7 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool): if local_models: return local_models[0], True - return running_models[0], False + return list(running_models)[0], False def dialogue_page(api: ApiRequest): diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 7b9e161..8190dba 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -245,7 +245,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, - max_tokens: int = 1024, + max_tokens: int = None, **kwargs: Any, ): ''' @@ -278,7 +278,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, - max_tokens: int = 1024, + max_tokens: int = None, prompt_name: str = "default", **kwargs, ): @@ -308,7 +308,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, - max_tokens: int = 1024, + max_tokens: int = None, prompt_name: str = "default", ): ''' @@ -340,7 +340,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, - max_tokens: int = 1024, + max_tokens: int = None, prompt_name: str = "default", ): ''' @@ -378,7 +378,7 @@ class ApiRequest: stream: bool = True, model: str = LLM_MODEL, temperature: float = TEMPERATURE, - max_tokens: int = 1024, + max_tokens: int = None, prompt_name: str = "default", ): '''