From 8b040620de6a046c145c53a85e6960376c0ef342 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Wed, 13 Sep 2023 10:00:54 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AF=B9=E8=AF=9D=E6=8E=A5=E5=8F=A3=E6=94=AF?= =?UTF-8?q?=E6=8C=81temperature=E5=8F=82=E6=95=B0=20(#1455)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/chat/chat.py | 9 ++++++--- server/chat/knowledge_base_chat.py | 5 ++++- server/chat/search_engine_chat.py | 5 ++++- tests/api/test_stream_chat_api.py | 24 ++++++++++-------------- webui_pages/dialogue/dialogue.py | 21 ++++++++++++++++----- webui_pages/utils.py | 9 ++++++++- 6 files changed, 48 insertions(+), 25 deletions(-) diff --git a/server/chat/chat.py b/server/chat/chat.py index b4fd6bb..c025c3c 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,6 +1,6 @@ from fastapi import Body from fastapi.responses import StreamingResponse -from configs.model_config import llm_model_dict, LLM_MODEL +from configs.model_config import llm_model_dict, LLM_MODEL, TEMPERATURE from server.chat.utils import wrap_done from langchain.chat_models import ChatOpenAI from langchain import LLMChain @@ -19,8 +19,10 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, {"role": "assistant", "content": "虎头虎脑"}]] ), - stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0), + # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), ): history = [History.from_data(h) for h in history] @@ -37,6 +39,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 openai_api_key=llm_model_dict[model_name]["api_key"], openai_api_base=llm_model_dict[model_name]["api_base_url"], model_name=model_name, + temperature=temperature, openai_proxy=llm_model_dict[model_name].get("openai_proxy") ) diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 8234396..5d1aac8 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,7 +1,8 @@ from fastapi import Body, Request from fastapi.responses import StreamingResponse from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, - VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD) + VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, + TEMPERATURE) from server.chat.utils import wrap_done from server.utils import BaseResponse from langchain.chat_models import ChatOpenAI @@ -32,6 +33,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", ), stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0), local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), request: Request = None, ): @@ -55,6 +57,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", openai_api_key=llm_model_dict[model_name]["api_key"], openai_api_base=llm_model_dict[model_name]["api_base_url"], model_name=model_name, + temperature=temperature, openai_proxy=llm_model_dict[model_name].get("openai_proxy") ) docs = search_docs(query, knowledge_base_name, top_k, score_threshold) diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index f5e6370..f8e4ebe 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -3,7 +3,8 @@ from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY from fastapi import Body from fastapi.responses import StreamingResponse from fastapi.concurrency import run_in_threadpool -from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE) +from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, + PROMPT_TEMPLATE, TEMPERATURE) from server.chat.utils import wrap_done from server.utils import BaseResponse from langchain.chat_models import ChatOpenAI @@ -72,6 +73,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", ), stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0), ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") @@ -95,6 +97,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", openai_api_key=llm_model_dict[model_name]["api_key"], openai_api_base=llm_model_dict[model_name]["api_base_url"], model_name=model_name, + temperature=temperature, openai_proxy=llm_model_dict[model_name].get("openai_proxy") ) diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py index 75c995e..1431485 100644 --- a/tests/api/test_stream_chat_api.py +++ b/tests/api/test_stream_chat_api.py @@ -43,7 +43,8 @@ data = { "content": "你好,我是人工智能大模型" } ], - "stream": True + "stream": True, + "temperature": 0.7, } @@ -88,14 +89,12 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"): response = requests.post(url, headers=headers, json=data, stream=True) print("\n") print("=" * 30 + api + " output" + "="*30) - first = True for line in response.iter_content(None, decode_unicode=True): data = json.loads(line) - if first: - for doc in data["docs"]: - print(doc) - first = False - print(data["answer"], end="", flush=True) + if "anser" in data: + print(data["answer"], end="", flush=True) + assert "docs" in data and len(data["docs"]) > 0 + pprint(data["docs"]) assert response.status_code == 200 @@ -116,14 +115,11 @@ def test_search_engine_chat(api="/chat/search_engine_chat"): print("\n") print("=" * 30 + api + " by {se} output" + "="*30) - first = True for line in response.iter_content(None, decode_unicode=True): data = json.loads(line) - assert "docs" in data and len(data["docs"]) > 0 - if first: - for doc in data.get("docs", []): - print(doc) - first = False - print(data["answer"], end="", flush=True) + if "answer" in data: + print(data["answer"], end="", flush=True) + assert "docs" in data and len(data["docs"]) > 0 + pprint(data["docs"]) assert response.status_code == 200 diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index ec31093..a0b8bb5 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -5,7 +5,7 @@ from streamlit_chatbox import * from datetime import datetime from server.chat.search_engine_chat import SEARCH_ENGINES import os -from configs.model_config import llm_model_dict, LLM_MODEL +from configs.model_config import LLM_MODEL, TEMPERATURE from server.utils import get_model_worker_config from typing import List, Dict @@ -55,7 +55,7 @@ def dialogue_page(api: ApiRequest): st.toast(text) # sac.alert(text, description="descp", type="success", closable=True, banner=True) - dialogue_mode = st.selectbox("请选择对话模式", + dialogue_mode = st.selectbox("请选择对话模式:", ["LLM 对话", "知识库问答", "搜索引擎问答", @@ -95,6 +95,7 @@ def dialogue_page(api: ApiRequest): r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model) st.session_state["cur_llm_model"] = llm_model + temperature = st.number_input("Temperature:", 0.0, 1.0, TEMPERATURE, 0.05) history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN) def on_kb_change(): @@ -135,7 +136,7 @@ def dialogue_page(api: ApiRequest): if dialogue_mode == "LLM 对话": chat_box.ai_say("正在思考...") text = "" - r = api.chat_chat(prompt, history=history, model=llm_model) + r = api.chat_chat(prompt, history=history, model=llm_model, temperature=temperature) for t in r: if error_msg := check_error_msg(t): # check whether error occured st.error(error_msg) @@ -150,7 +151,13 @@ def dialogue_page(api: ApiRequest): Markdown("...", in_expander=True, title="知识库匹配结果"), ]) text = "" - for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model): + for d in api.knowledge_base_chat(prompt, + knowledge_base_name=selected_kb, + top_k=kb_top_k, + score_threshold=score_threshold, + history=history, + model=llm_model, + temperature=temperature): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) elif chunk := d.get("answer"): @@ -164,7 +171,11 @@ def dialogue_page(api: ApiRequest): Markdown("...", in_expander=True, title="网络搜索结果"), ]) text = "" - for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model): + for d in api.search_engine_chat(prompt, + search_engine_name=search_engine, + top_k=se_top_k, + model=llm_model, + temperature=temperature): if error_msg := check_error_msg(d): # check whether error occured st.error(error_msg) elif chunk := d.get("answer"): diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 3145b6f..3aed735 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -8,6 +8,7 @@ from configs.model_config import ( LLM_MODEL, llm_model_dict, HISTORY_LEN, + TEMPERATURE, SCORE_THRESHOLD, VECTOR_SEARCH_TOP_K, SEARCH_ENGINE_TOP_K, @@ -269,7 +270,7 @@ class ApiRequest: messages: List[Dict], stream: bool = True, model: str = LLM_MODEL, - temperature: float = 0.7, + temperature: float = TEMPERATURE, max_tokens: int = 1024, # todo:根据message内容自动计算max_tokens no_remote_api: bool = None, **kwargs: Any, @@ -310,6 +311,7 @@ class ApiRequest: history: List[Dict] = [], stream: bool = True, model: str = LLM_MODEL, + temperature: float = TEMPERATURE, no_remote_api: bool = None, ): ''' @@ -323,6 +325,7 @@ class ApiRequest: "history": history, "stream": stream, "model_name": model, + "temperature": temperature, } print(f"received input message:") @@ -345,6 +348,7 @@ class ApiRequest: history: List[Dict] = [], stream: bool = True, model: str = LLM_MODEL, + temperature: float = TEMPERATURE, no_remote_api: bool = None, ): ''' @@ -361,6 +365,7 @@ class ApiRequest: "history": history, "stream": stream, "model_name": model, + "temperature": temperature, "local_doc_url": no_remote_api, } @@ -386,6 +391,7 @@ class ApiRequest: top_k: int = SEARCH_ENGINE_TOP_K, stream: bool = True, model: str = LLM_MODEL, + temperature: float = TEMPERATURE, no_remote_api: bool = None, ): ''' @@ -400,6 +406,7 @@ class ApiRequest: "top_k": top_k, "stream": stream, "model_name": model, + "temperature": temperature, } print(f"received input message:")