新功能: (#1801)

- 更新langchain/fastchat依赖,添加xformers依赖
- 默认max_tokens=None, 生成tokens自动为模型支持的最大值

修复:
- history_len=0 时会带入1条不完整的历史消息,导致LLM错误
- 当对话轮数 达到history_len时,传入的历史消息为空
This commit is contained in:
liunux4odoo 2023-10-19 22:09:15 +08:00 committed by GitHub
parent 7e28291e9f
commit d053950aee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 33 additions and 35 deletions

View File

@ -1,6 +1,7 @@
langchain>=0.0.314 langchain==0.0.317
langchain-experimental>=0.0.30 langchain-experimental==0.0.30
fschat[model_worker]==0.2.30 fschat[model_worker]==0.2.31
xformers==0.0.22.post4
openai openai
sentence_transformers sentence_transformers
transformers>=4.34 transformers>=4.34
@ -43,7 +44,7 @@ pandas~=2.0.3
streamlit>=1.26.0 streamlit>=1.26.0
streamlit-option-menu>=0.3.6 streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11 streamlit-antd-components>=0.1.11
streamlit-chatbox>=1.1.9 streamlit-chatbox==1.1.10
streamlit-aggrid>=0.3.4.post3 streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1 httpx~=0.24.1
watchdog watchdog

View File

@ -1,6 +1,7 @@
langchain==0.0.313 langchain==0.0.317
langchain-experimental==0.0.30 langchain-experimental==0.0.30
fschat[model_worker]==0.2.30 fschat[model_worker]==0.2.31
xformers==0.0.22.post4
openai openai
sentence_transformers>=2.2.2 sentence_transformers>=2.2.2
transformers>=4.34 transformers>=4.34

View File

@ -3,7 +3,7 @@ pandas~=2.0.3
streamlit>=1.26.0 streamlit>=1.26.0
streamlit-option-menu>=0.3.6 streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11 streamlit-antd-components>=0.1.11
streamlit-chatbox>=1.1.9 streamlit-chatbox==1.1.10
streamlit-aggrid>=0.3.4.post3 streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1 httpx~=0.24.1
nltk nltk

View File

@ -98,7 +98,16 @@ class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
) )
self.queue.put_nowait(dumps(self.cur_tool)) self.queue.put_nowait(dumps(self.cur_tool))
async def on_chat_model_start(self,serialized: Dict[str, Any], **kwargs: Any, async def on_chat_model_start(
self,
serialized: Dict[str, Any],
messages: List[List],
*,
run_id: UUID,
parent_run_id: Optional[UUID] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> None: ) -> None:
self.cur_tool.update( self.cur_tool.update(
status=Status.start, status=Status.start,

View File

@ -26,8 +26,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
): ):

View File

@ -22,8 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
): ):

View File

@ -31,8 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
): ):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name) kb = KBServiceFactory.get_service_by_name(knowledge_base_name)

View File

@ -16,7 +16,7 @@ class OpenAiChatMsgIn(BaseModel):
messages: List[OpenAiMessage] messages: List[OpenAiMessage]
temperature: float = 0.7 temperature: float = 0.7
n: int = 1 n: int = 1
max_tokens: int = 1024 max_tokens: int = None
stop: List[str] = [] stop: List[str] = []
stream: bool = False stream: bool = False
presence_penalty: int = 0 presence_penalty: int = 0

View File

@ -114,8 +114,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), max_tokens: int = Body(None, description="限制LLM生成Token数量默认None代表模型最大值"),
# TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default",description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
): ):
if search_engine_name not in SEARCH_ENGINES.keys(): if search_engine_name not in SEARCH_ENGINES.keys():

View File

@ -16,7 +16,7 @@ def list_running_models(
with get_httpx_client() as client: with get_httpx_client() as client:
r = client.post(controller_address + "/list_models") r = client.post(controller_address + "/list_models")
models = r.json()["models"] models = r.json()["models"]
data = {m: get_model_worker_config(m) for m in models} data = {m: get_model_config(m).data for m in models}
return BaseResponse(data=data) return BaseResponse(data=data)
except Exception as e: except Exception as e:
logger.error(f'{e.__class__.__name__}: {e}', logger.error(f'{e.__class__.__name__}: {e}',

View File

@ -65,7 +65,7 @@ def gen_params(appid, domain,question, temperature):
"chat": { "chat": {
"domain": domain, "domain": domain,
"random_threshold": 0.5, "random_threshold": 0.5,
"max_tokens": 2048, "max_tokens": None,
"auditing": "default", "auditing": "default",
"temperature": temperature, "temperature": temperature,
} }

View File

@ -1,7 +1,7 @@
from configs.basic_config import LOG_PATH from configs.basic_config import LOG_PATH
import fastchat.constants import fastchat.constants
fastchat.constants.LOGDIR = LOG_PATH fastchat.constants.LOGDIR = LOG_PATH
from fastchat.serve.model_worker import BaseModelWorker from fastchat.serve.base_model_worker import BaseModelWorker
import uuid import uuid
import json import json
import sys import sys
@ -62,15 +62,6 @@ class ApiModelWorker(BaseModelWorker):
print("embedding") print("embedding")
print(params) print(params)
# workaround to make program exit with Ctrl+c
# it should be deleted after pr is merged by fastchat
def init_heart_beat(self):
self.register_to_controller()
self.heart_beat_thread = threading.Thread(
target=fastchat.serve.model_worker.heart_beat_worker, args=(self,), daemon=True,
)
self.heart_beat_thread.start()
# help methods # help methods
def get_config(self): def get_config(self):
from server.utils import get_model_worker_config from server.utils import get_model_worker_config

View File

@ -53,7 +53,7 @@ def get_default_llm_model(api: ApiRequest) -> (str, bool):
if local_models: if local_models:
return local_models[0], True return local_models[0], True
return running_models[0], False return list(running_models)[0], False
def dialogue_page(api: ApiRequest): def dialogue_page(api: ApiRequest):

View File

@ -245,7 +245,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = None,
**kwargs: Any, **kwargs: Any,
): ):
''' '''
@ -278,7 +278,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
**kwargs, **kwargs,
): ):
@ -308,7 +308,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
): ):
''' '''
@ -340,7 +340,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
): ):
''' '''
@ -378,7 +378,7 @@ class ApiRequest:
stream: bool = True, stream: bool = True,
model: str = LLM_MODEL, model: str = LLM_MODEL,
temperature: float = TEMPERATURE, temperature: float = TEMPERATURE,
max_tokens: int = 1024, max_tokens: int = None,
prompt_name: str = "default", prompt_name: str = "default",
): ):
''' '''