对话接口支持temperature参数 (#1455)

This commit is contained in:
liunux4odoo 2023-09-13 10:00:54 +08:00 committed by GitHub
parent a03b8d330d
commit 8b040620de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 48 additions and 25 deletions

View File

@ -1,6 +1,6 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs.model_config import llm_model_dict, LLM_MODEL
from configs.model_config import llm_model_dict, LLM_MODEL, TEMPERATURE
from server.chat.utils import wrap_done
from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
@ -19,8 +19,10 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]]
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
):
history = [History.from_data(h) for h in history]
@ -37,6 +39,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name,
temperature=temperature,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)

View File

@ -1,7 +1,8 @@
from fastapi import Body, Request
from fastapi.responses import StreamingResponse
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
TEMPERATURE)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
@ -32,6 +33,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None,
):
@ -55,6 +57,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name,
temperature=temperature,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)

View File

@ -3,7 +3,8 @@ from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
from fastapi import Body
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K,
PROMPT_TEMPLATE, TEMPERATURE)
from server.chat.utils import wrap_done
from server.utils import BaseResponse
from langchain.chat_models import ChatOpenAI
@ -72,6 +73,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
),
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -95,6 +97,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
openai_api_key=llm_model_dict[model_name]["api_key"],
openai_api_base=llm_model_dict[model_name]["api_base_url"],
model_name=model_name,
temperature=temperature,
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
)

View File

@ -43,7 +43,8 @@ data = {
"content": "你好,我是人工智能大模型"
}
],
"stream": True
"stream": True,
"temperature": 0.7,
}
@ -88,14 +89,12 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"):
response = requests.post(url, headers=headers, json=data, stream=True)
print("\n")
print("=" * 30 + api + " output" + "="*30)
first = True
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
if first:
for doc in data["docs"]:
print(doc)
first = False
print(data["answer"], end="", flush=True)
if "anser" in data:
print(data["answer"], end="", flush=True)
assert "docs" in data and len(data["docs"]) > 0
pprint(data["docs"])
assert response.status_code == 200
@ -116,14 +115,11 @@ def test_search_engine_chat(api="/chat/search_engine_chat"):
print("\n")
print("=" * 30 + api + " by {se} output" + "="*30)
first = True
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
assert "docs" in data and len(data["docs"]) > 0
if first:
for doc in data.get("docs", []):
print(doc)
first = False
print(data["answer"], end="", flush=True)
if "answer" in data:
print(data["answer"], end="", flush=True)
assert "docs" in data and len(data["docs"]) > 0
pprint(data["docs"])
assert response.status_code == 200

View File

@ -5,7 +5,7 @@ from streamlit_chatbox import *
from datetime import datetime
from server.chat.search_engine_chat import SEARCH_ENGINES
import os
from configs.model_config import llm_model_dict, LLM_MODEL
from configs.model_config import LLM_MODEL, TEMPERATURE
from server.utils import get_model_worker_config
from typing import List, Dict
@ -55,7 +55,7 @@ def dialogue_page(api: ApiRequest):
st.toast(text)
# sac.alert(text, description="descp", type="success", closable=True, banner=True)
dialogue_mode = st.selectbox("请选择对话模式",
dialogue_mode = st.selectbox("请选择对话模式",
["LLM 对话",
"知识库问答",
"搜索引擎问答",
@ -95,6 +95,7 @@ def dialogue_page(api: ApiRequest):
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
st.session_state["cur_llm_model"] = llm_model
temperature = st.number_input("Temperature", 0.0, 1.0, TEMPERATURE, 0.05)
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
def on_kb_change():
@ -135,7 +136,7 @@ def dialogue_page(api: ApiRequest):
if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...")
text = ""
r = api.chat_chat(prompt, history=history, model=llm_model)
r = api.chat_chat(prompt, history=history, model=llm_model, temperature=temperature)
for t in r:
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
@ -150,7 +151,13 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="知识库匹配结果"),
])
text = ""
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model):
for d in api.knowledge_base_chat(prompt,
knowledge_base_name=selected_kb,
top_k=kb_top_k,
score_threshold=score_threshold,
history=history,
model=llm_model,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):
@ -164,7 +171,11 @@ def dialogue_page(api: ApiRequest):
Markdown("...", in_expander=True, title="网络搜索结果"),
])
text = ""
for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model):
for d in api.search_engine_chat(prompt,
search_engine_name=search_engine,
top_k=se_top_k,
model=llm_model,
temperature=temperature):
if error_msg := check_error_msg(d): # check whether error occured
st.error(error_msg)
elif chunk := d.get("answer"):

View File

@ -8,6 +8,7 @@ from configs.model_config import (
LLM_MODEL,
llm_model_dict,
HISTORY_LEN,
TEMPERATURE,
SCORE_THRESHOLD,
VECTOR_SEARCH_TOP_K,
SEARCH_ENGINE_TOP_K,
@ -269,7 +270,7 @@ class ApiRequest:
messages: List[Dict],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = 0.7,
temperature: float = TEMPERATURE,
max_tokens: int = 1024, # todo:根据message内容自动计算max_tokens
no_remote_api: bool = None,
**kwargs: Any,
@ -310,6 +311,7 @@ class ApiRequest:
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
no_remote_api: bool = None,
):
'''
@ -323,6 +325,7 @@ class ApiRequest:
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
}
print(f"received input message:")
@ -345,6 +348,7 @@ class ApiRequest:
history: List[Dict] = [],
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
no_remote_api: bool = None,
):
'''
@ -361,6 +365,7 @@ class ApiRequest:
"history": history,
"stream": stream,
"model_name": model,
"temperature": temperature,
"local_doc_url": no_remote_api,
}
@ -386,6 +391,7 @@ class ApiRequest:
top_k: int = SEARCH_ENGINE_TOP_K,
stream: bool = True,
model: str = LLM_MODEL,
temperature: float = TEMPERATURE,
no_remote_api: bool = None,
):
'''
@ -400,6 +406,7 @@ class ApiRequest:
"top_k": top_k,
"stream": stream,
"model_name": model,
"temperature": temperature,
}
print(f"received input message:")