对话接口支持temperature参数 (#1455)
This commit is contained in:
parent
a03b8d330d
commit
8b040620de
|
|
@ -1,6 +1,6 @@
|
|||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL, TEMPERATURE
|
||||
from server.chat.utils import wrap_done
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
|
|
@ -19,8 +19,10 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
):
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
|
|
@ -37,6 +39,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
from fastapi import Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
|
||||
VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD,
|
||||
TEMPERATURE)
|
||||
from server.chat.utils import wrap_done
|
||||
from server.utils import BaseResponse
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
|
@ -32,6 +33,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
):
|
||||
|
|
@ -55,6 +57,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
||||
)
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,8 @@ from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
|
|||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K,
|
||||
PROMPT_TEMPLATE, TEMPERATURE)
|
||||
from server.chat.utils import wrap_done
|
||||
from server.utils import BaseResponse
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
|
|
@ -72,6 +73,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
||||
):
|
||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
|
|
@ -95,6 +97,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||
openai_api_key=llm_model_dict[model_name]["api_key"],
|
||||
openai_api_base=llm_model_dict[model_name]["api_base_url"],
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -43,7 +43,8 @@ data = {
|
|||
"content": "你好,我是人工智能大模型"
|
||||
}
|
||||
],
|
||||
"stream": True
|
||||
"stream": True,
|
||||
"temperature": 0.7,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -88,14 +89,12 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"):
|
|||
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||
print("\n")
|
||||
print("=" * 30 + api + " output" + "="*30)
|
||||
first = True
|
||||
for line in response.iter_content(None, decode_unicode=True):
|
||||
data = json.loads(line)
|
||||
if first:
|
||||
for doc in data["docs"]:
|
||||
print(doc)
|
||||
first = False
|
||||
print(data["answer"], end="", flush=True)
|
||||
if "anser" in data:
|
||||
print(data["answer"], end="", flush=True)
|
||||
assert "docs" in data and len(data["docs"]) > 0
|
||||
pprint(data["docs"])
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
|
|
@ -116,14 +115,11 @@ def test_search_engine_chat(api="/chat/search_engine_chat"):
|
|||
|
||||
print("\n")
|
||||
print("=" * 30 + api + " by {se} output" + "="*30)
|
||||
first = True
|
||||
for line in response.iter_content(None, decode_unicode=True):
|
||||
data = json.loads(line)
|
||||
assert "docs" in data and len(data["docs"]) > 0
|
||||
if first:
|
||||
for doc in data.get("docs", []):
|
||||
print(doc)
|
||||
first = False
|
||||
print(data["answer"], end="", flush=True)
|
||||
if "answer" in data:
|
||||
print(data["answer"], end="", flush=True)
|
||||
assert "docs" in data and len(data["docs"]) > 0
|
||||
pprint(data["docs"])
|
||||
assert response.status_code == 200
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from streamlit_chatbox import *
|
|||
from datetime import datetime
|
||||
from server.chat.search_engine_chat import SEARCH_ENGINES
|
||||
import os
|
||||
from configs.model_config import llm_model_dict, LLM_MODEL
|
||||
from configs.model_config import LLM_MODEL, TEMPERATURE
|
||||
from server.utils import get_model_worker_config
|
||||
from typing import List, Dict
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ def dialogue_page(api: ApiRequest):
|
|||
st.toast(text)
|
||||
# sac.alert(text, description="descp", type="success", closable=True, banner=True)
|
||||
|
||||
dialogue_mode = st.selectbox("请选择对话模式",
|
||||
dialogue_mode = st.selectbox("请选择对话模式:",
|
||||
["LLM 对话",
|
||||
"知识库问答",
|
||||
"搜索引擎问答",
|
||||
|
|
@ -95,6 +95,7 @@ def dialogue_page(api: ApiRequest):
|
|||
r = api.change_llm_model(st.session_state.get("prev_llm_model"), llm_model)
|
||||
st.session_state["cur_llm_model"] = llm_model
|
||||
|
||||
temperature = st.number_input("Temperature:", 0.0, 1.0, TEMPERATURE, 0.05)
|
||||
history_len = st.number_input("历史对话轮数:", 0, 10, HISTORY_LEN)
|
||||
|
||||
def on_kb_change():
|
||||
|
|
@ -135,7 +136,7 @@ def dialogue_page(api: ApiRequest):
|
|||
if dialogue_mode == "LLM 对话":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
r = api.chat_chat(prompt, history=history, model=llm_model)
|
||||
r = api.chat_chat(prompt, history=history, model=llm_model, temperature=temperature)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
|
|
@ -150,7 +151,13 @@ def dialogue_page(api: ApiRequest):
|
|||
Markdown("...", in_expander=True, title="知识库匹配结果"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, score_threshold, history, model=llm_model):
|
||||
for d in api.knowledge_base_chat(prompt,
|
||||
knowledge_base_name=selected_kb,
|
||||
top_k=kb_top_k,
|
||||
score_threshold=score_threshold,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
temperature=temperature):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
|
|
@ -164,7 +171,11 @@ def dialogue_page(api: ApiRequest):
|
|||
Markdown("...", in_expander=True, title="网络搜索结果"),
|
||||
])
|
||||
text = ""
|
||||
for d in api.search_engine_chat(prompt, search_engine, se_top_k, model=llm_model):
|
||||
for d in api.search_engine_chat(prompt,
|
||||
search_engine_name=search_engine,
|
||||
top_k=se_top_k,
|
||||
model=llm_model,
|
||||
temperature=temperature):
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
elif chunk := d.get("answer"):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from configs.model_config import (
|
|||
LLM_MODEL,
|
||||
llm_model_dict,
|
||||
HISTORY_LEN,
|
||||
TEMPERATURE,
|
||||
SCORE_THRESHOLD,
|
||||
VECTOR_SEARCH_TOP_K,
|
||||
SEARCH_ENGINE_TOP_K,
|
||||
|
|
@ -269,7 +270,7 @@ class ApiRequest:
|
|||
messages: List[Dict],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
temperature: float = 0.7,
|
||||
temperature: float = TEMPERATURE,
|
||||
max_tokens: int = 1024, # todo:根据message内容自动计算max_tokens
|
||||
no_remote_api: bool = None,
|
||||
**kwargs: Any,
|
||||
|
|
@ -310,6 +311,7 @@ class ApiRequest:
|
|||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
temperature: float = TEMPERATURE,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -323,6 +325,7 @@ class ApiRequest:
|
|||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
|
|
@ -345,6 +348,7 @@ class ApiRequest:
|
|||
history: List[Dict] = [],
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
temperature: float = TEMPERATURE,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -361,6 +365,7 @@ class ApiRequest:
|
|||
"history": history,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
"temperature": temperature,
|
||||
"local_doc_url": no_remote_api,
|
||||
}
|
||||
|
||||
|
|
@ -386,6 +391,7 @@ class ApiRequest:
|
|||
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||
stream: bool = True,
|
||||
model: str = LLM_MODEL,
|
||||
temperature: float = TEMPERATURE,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -400,6 +406,7 @@ class ApiRequest:
|
|||
"top_k": top_k,
|
||||
"stream": stream,
|
||||
"model_name": model,
|
||||
"temperature": temperature,
|
||||
}
|
||||
|
||||
print(f"received input message:")
|
||||
|
|
|
|||
Loading…
Reference in New Issue