添加对话评分与历史消息保存功能 (#1940)
* 新功能:
- WEBUI 添加对话评分功能
- 增加 /chat/feedback 接口,用于接收对话评分
- /chat/chat 接口返回值由 str 改为 {"text":str, "chat_history_id": str}
- init_database.py 添加 --create-tables --clear-tables 参数
依赖:
- streamlit-chatbox==1.1.11
开发者:
- ChatHistoryModel 的 id 字段支持自动生成
- SAVE_CHAT_HISTORY 改到 basic_config.py
* 修复:点击反馈后页面未刷新
---------
Co-authored-by: liqiankun.1111 <liqiankun.1111@bytedance.com>
Co-authored-by: liunux4odoo <liunux@qq.com>
Co-authored-by: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com>
This commit is contained in:
parent
554122f60e
commit
fa906b33a8
|
|
@ -6,6 +6,8 @@ import langchain
|
|||
log_verbose = False
|
||||
langchain.verbose = False
|
||||
|
||||
# 是否保存聊天记录
|
||||
SAVE_CHAT_HISTORY = False
|
||||
|
||||
# 通常情况下不需要更改以下内容
|
||||
|
||||
|
|
|
|||
|
|
@ -24,10 +24,15 @@ if __name__ == "__main__":
|
|||
'''
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--create-tables",
|
||||
action="store_true",
|
||||
help=("create empty tables if not existed")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clear-tables",
|
||||
action="store_true",
|
||||
help=("drop the database tables before recreate vector stores")
|
||||
help=("create empty tables, or drop the database tables before recreate vector stores")
|
||||
)
|
||||
parser.add_argument(
|
||||
"--import-db",
|
||||
|
|
@ -87,12 +92,10 @@ if __name__ == "__main__":
|
|||
help=("specify embeddings model.")
|
||||
)
|
||||
|
||||
if len(sys.argv) <= 1:
|
||||
parser.print_help()
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
start_time = datetime.now()
|
||||
|
||||
if args.create_tables:
|
||||
create_tables() # confirm tables exist
|
||||
|
||||
if args.clear_tables:
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ pandas~=2.0.3
|
|||
streamlit>=1.26.0
|
||||
streamlit-option-menu>=0.3.6
|
||||
streamlit-antd-components>=0.1.11
|
||||
streamlit-chatbox==1.1.10
|
||||
streamlit-chatbox==1.1.11
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx[brotli,http2,socks]>=0.25.0
|
||||
watchdog
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ pandas~=2.0.3
|
|||
streamlit>=1.26.0
|
||||
streamlit-option-menu>=0.3.6
|
||||
streamlit-antd-components>=0.1.11
|
||||
streamlit-chatbox>=1.1.9
|
||||
streamlit-chatbox==1.1.11
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx~=0.24.1
|
||||
watchdog
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ pandas~=2.0.3
|
|||
streamlit>=1.27.2
|
||||
streamlit-option-menu>=0.3.6
|
||||
streamlit-antd-components>=0.2.3
|
||||
streamlit-chatbox==1.1.10
|
||||
streamlit-chatbox==1.1.11
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx>=0.25.0
|
||||
nltk>=3.8.1
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from server.chat.chat import chat
|
|||
from server.chat.openai_chat import openai_chat
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
from server.chat.completion import completion
|
||||
from server.chat.feedback import chat_feedback
|
||||
from server.embeddings_api import embed_texts_endpoint
|
||||
from server.llm_api import (list_running_models, list_config_models,
|
||||
change_llm_model, stop_llm_model,
|
||||
|
|
@ -73,6 +74,11 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
|
|||
summary="与搜索引擎对话",
|
||||
)(search_engine_chat)
|
||||
|
||||
app.post("/chat/feedback",
|
||||
tags=["Chat"],
|
||||
summary="返回llm模型对话评分",
|
||||
)(chat_feedback)
|
||||
|
||||
# 知识库相关接口
|
||||
mount_knowledge_routes(app)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,17 @@
|
|||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs import LLM_MODEL, TEMPERATURE
|
||||
from configs import LLM_MODEL, TEMPERATURE, SAVE_CHAT_HISTORY
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from langchain.chains import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
import json
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional
|
||||
from server.chat.utils import History
|
||||
from server.utils import get_prompt_template
|
||||
from server.db.repository.chat_history_repository import add_chat_history_to_db, update_chat_history
|
||||
|
||||
|
||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
|
|
@ -53,16 +55,26 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
callback.done),
|
||||
)
|
||||
|
||||
answer = ""
|
||||
chat_history_id = add_chat_history_to_db(chat_type="llm_chat", query=query)
|
||||
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
# Use server-sent-events to stream the response
|
||||
yield token
|
||||
yield json.dumps(
|
||||
{"text": token, "chat_history_id": chat_history_id},
|
||||
ensure_ascii=False)
|
||||
else:
|
||||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
yield answer
|
||||
yield json.dumps(
|
||||
{"text": answer, "chat_history_id": chat_history_id},
|
||||
ensure_ascii=False)
|
||||
|
||||
if SAVE_CHAT_HISTORY and len(chat_history_id) > 0:
|
||||
# 后续可以加入一些其他信息,比如真实的prompt等
|
||||
update_chat_history(chat_history_id, response=answer)
|
||||
await task
|
||||
|
||||
return StreamingResponse(chat_iterator(query=query,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,19 @@
|
|||
from fastapi import Body
|
||||
from configs import logger, log_verbose
|
||||
from server.utils import BaseResponse
|
||||
from server.db.repository.chat_history_repository import feedback_chat_history_to_db
|
||||
|
||||
|
||||
def chat_feedback(chat_history_id: str = Body("", max_length=32, description="聊天记录id"),
|
||||
score: int = Body(0, max=100, description="用户评分,满分100,越大表示评价越高"),
|
||||
reason: str = Body("", description="用户评分理由,比如不符合事实等")
|
||||
):
|
||||
try:
|
||||
feedback_chat_history_to_db(chat_history_id, score, reason)
|
||||
except Exception as e:
|
||||
msg = f"反馈聊天记录出错: {e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
return BaseResponse(code=500, msg=msg)
|
||||
|
||||
return BaseResponse(code=200, msg=f"已反馈聊天记录 {chat_history_id}")
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
from sqlalchemy import Column, Integer, String, DateTime, JSON, func
|
||||
|
||||
from server.db.base import Base
|
||||
|
||||
|
||||
class ChatHistoryModel(Base):
|
||||
"""
|
||||
聊天记录模型
|
||||
"""
|
||||
__tablename__ = 'chat_history'
|
||||
# 由前端生成的uuid,如果是自增的话,则需要将id 传给前端,这在流式返回里有点麻烦
|
||||
id = Column(String(32), primary_key=True, comment='聊天记录ID')
|
||||
# chat/agent_chat等
|
||||
chat_type = Column(String(50), comment='聊天类型')
|
||||
query = Column(String(4096), comment='用户问题')
|
||||
response = Column(String(4096), comment='模型回答')
|
||||
# 记录知识库id等,以便后续扩展
|
||||
meta_data = Column(JSON, default={})
|
||||
# 满分100 越高表示评价越好
|
||||
feedback_score = Column(Integer, default=-1, comment='用户评分')
|
||||
feedback_reason = Column(String(255), default="", comment='用户评分理由')
|
||||
create_time = Column(DateTime, default=func.now(), comment='创建时间')
|
||||
|
||||
def __repr__(self):
|
||||
return f"<ChatHistory(id='{self.id}', chat_type='{self.chat_type}', query='{self.query}', response='{self.response}',meta_data='{self.meta_data}',feedback_score='{self.feedback_score}',feedback_reason='{self.feedback_reason}', create_time='{self.create_time}')>"
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
from server.db.session import with_session
|
||||
from server.db.models.chat_history_model import ChatHistoryModel
|
||||
import re
|
||||
import uuid
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
def _convert_query(query: str) -> str:
|
||||
p = re.sub(r"\s+", "%", query)
|
||||
return f"%{p}%"
|
||||
|
||||
|
||||
@with_session
|
||||
def add_chat_history_to_db(session, chat_type, query, response="", chat_history_id=None, metadata: Dict = {}):
|
||||
"""
|
||||
新增聊天记录
|
||||
"""
|
||||
if not chat_history_id:
|
||||
chat_history_id = uuid.uuid4().hex
|
||||
ch = ChatHistoryModel(id=chat_history_id, chat_type=chat_type, query=query, response=response,
|
||||
metadata=metadata)
|
||||
session.add(ch)
|
||||
session.commit()
|
||||
return ch.id
|
||||
|
||||
|
||||
@with_session
|
||||
def update_chat_history(session, chat_history_id, response: str = None, metadata: Dict = None):
|
||||
"""
|
||||
更新已有的聊天记录
|
||||
"""
|
||||
ch = get_chat_history_by_id(chat_history_id)
|
||||
if ch is not None:
|
||||
if response is not None:
|
||||
ch.response = response
|
||||
if isinstance(metadata, dict):
|
||||
ch.meta_data = metadata
|
||||
session.add(ch)
|
||||
return ch.id
|
||||
|
||||
|
||||
@with_session
|
||||
def feedback_chat_history_to_db(session, chat_history_id, feedback_score, feedback_reason):
|
||||
"""
|
||||
反馈聊天记录
|
||||
"""
|
||||
ch = session.query(ChatHistoryModel).filter_by(id=chat_history_id).first()
|
||||
if ch:
|
||||
ch.feedback_score = feedback_score
|
||||
ch.feedback_reason = feedback_reason
|
||||
return ch.id
|
||||
|
||||
|
||||
@with_session
|
||||
def get_chat_history_by_id(session, chat_history_id) -> ChatHistoryModel:
|
||||
"""
|
||||
查询聊天记录
|
||||
"""
|
||||
ch = session.query(ChatHistoryModel).filter_by(id=chat_history_id).first()
|
||||
return ch
|
||||
|
||||
|
||||
@with_session
|
||||
def filter_chat_history(session, query=None, response=None, score=None, reason=None) -> List[ChatHistoryModel]:
|
||||
ch =session.query(ChatHistoryModel)
|
||||
if query is not None:
|
||||
ch = ch.filter(ChatHistoryModel.query.ilike(_convert_query(query)))
|
||||
if response is not None:
|
||||
ch = ch.filter(ChatHistoryModel.response.ilike(_convert_query(response)))
|
||||
if score is not None:
|
||||
ch = ch.filter_by(feedback_score=score)
|
||||
if reason is not None:
|
||||
ch = ch.filter(ChatHistoryModel.feedback_reason.ilike(_convert_query(reason)))
|
||||
|
||||
return ch
|
||||
|
|
@ -5,6 +5,7 @@ from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
|
|||
list_files_from_folder,files2docs_in_thread,
|
||||
KnowledgeFile,)
|
||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||
from server.db.models.chat_history_model import ChatHistoryModel
|
||||
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
|
||||
from server.db.base import Base, engine
|
||||
from server.db.session import session_scope
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
|
|||
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar=os.path.join(
|
||||
"img",
|
||||
|
|
@ -43,6 +44,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||
f"当前运行的模型`{default_model}`, 您可以开始提问了."
|
||||
)
|
||||
chat_box.init_session()
|
||||
|
||||
with st.sidebar:
|
||||
# TODO: 对话模型与会话绑定
|
||||
def on_mode_change():
|
||||
|
|
@ -173,17 +175,34 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
|
||||
chat_box.output_messages()
|
||||
|
||||
chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter "
|
||||
|
||||
def on_feedback(
|
||||
feedback,
|
||||
chat_history_id: str = "",
|
||||
history_index: int = -1,
|
||||
):
|
||||
reason = feedback["text"]
|
||||
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
|
||||
api.chat_feedback(chat_history_id=chat_history_id,
|
||||
score=score_int,
|
||||
reason=reason)
|
||||
st.session_state["need_rerun"] = True
|
||||
|
||||
feedback_kwargs = {
|
||||
"feedback_type": "thumbs",
|
||||
"optional_text_label": "欢迎反馈您打分的理由",
|
||||
}
|
||||
|
||||
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
|
||||
history = get_messages_history(history_len)
|
||||
chat_box.user_say(prompt)
|
||||
if dialogue_mode == "LLM 对话":
|
||||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
chat_history_id = ""
|
||||
r = api.chat_chat(prompt,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
|
|
@ -193,11 +212,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t
|
||||
text += t.get("text", "")
|
||||
chat_box.update_msg(text)
|
||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
||||
|
||||
chat_history_id = t.get("chat_history_id", "")
|
||||
|
||||
metadata = {
|
||||
"chat_history_id": chat_history_id,
|
||||
}
|
||||
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
|
||||
chat_box.show_feedback(**feedback_kwargs,
|
||||
key=chat_history_id,
|
||||
on_submit=on_feedback,
|
||||
kwargs={"chat_history_id": chat_history_id, "history_index": len(chat_box.history) - 1})
|
||||
|
||||
elif dialogue_mode == "自定义Agent问答":
|
||||
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
|
||||
|
|
@ -280,6 +306,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
|
||||
|
||||
if st.session_state.get("need_rerun"):
|
||||
st.session_state["need_rerun"] = False
|
||||
st.rerun()
|
||||
|
||||
now = datetime.now()
|
||||
with st.sidebar:
|
||||
|
||||
|
|
@ -290,7 +320,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
|
|||
use_container_width=True,
|
||||
):
|
||||
chat_box.reset_history()
|
||||
st.experimental_rerun()
|
||||
st.rerun()
|
||||
|
||||
export_btn.download_button(
|
||||
"导出记录",
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||
st.toast(ret.get("msg", " "))
|
||||
st.session_state["selected_kb_name"] = kb_name
|
||||
st.session_state["selected_kb_info"] = kb_info
|
||||
st.experimental_rerun()
|
||||
st.rerun()
|
||||
|
||||
elif selected_kb:
|
||||
kb = selected_kb
|
||||
|
|
@ -256,7 +256,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
zh_title_enhance=zh_title_enhance)
|
||||
st.experimental_rerun()
|
||||
st.rerun()
|
||||
|
||||
# 将文件从向量库中删除,但不删除文件本身。
|
||||
if cols[2].button(
|
||||
|
|
@ -266,7 +266,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||
):
|
||||
file_names = [row["file_name"] for row in selected_rows]
|
||||
api.delete_kb_docs(kb, file_names=file_names)
|
||||
st.experimental_rerun()
|
||||
st.rerun()
|
||||
|
||||
if cols[3].button(
|
||||
"从知识库中删除",
|
||||
|
|
@ -275,7 +275,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||
):
|
||||
file_names = [row["file_name"] for row in selected_rows]
|
||||
api.delete_kb_docs(kb, file_names=file_names, delete_content=True)
|
||||
st.experimental_rerun()
|
||||
st.rerun()
|
||||
|
||||
st.divider()
|
||||
|
||||
|
|
@ -298,7 +298,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||
st.toast(msg)
|
||||
else:
|
||||
empty.progress(d["finished"] / d["total"], d["msg"])
|
||||
st.experimental_rerun()
|
||||
st.rerun()
|
||||
|
||||
if cols[2].button(
|
||||
"删除知识库",
|
||||
|
|
@ -307,4 +307,4 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
|
|||
ret = api.delete_knowledge_base(kb)
|
||||
st.toast(ret.get("msg", " "))
|
||||
time.sleep(1)
|
||||
st.experimental_rerun()
|
||||
st.rerun()
|
||||
|
|
|
|||
|
|
@ -314,7 +314,7 @@ class ApiRequest:
|
|||
pprint(data)
|
||||
|
||||
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
|
||||
return self._httpx_stream2generator(response)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
||||
def agent_chat(
|
||||
self,
|
||||
|
|
@ -873,6 +873,23 @@ class ApiRequest:
|
|||
)
|
||||
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
|
||||
|
||||
def chat_feedback(
|
||||
self,
|
||||
chat_history_id: str,
|
||||
score: int,
|
||||
reason: str = "",
|
||||
) -> int:
|
||||
'''
|
||||
反馈对话评价
|
||||
'''
|
||||
data = {
|
||||
"chat_history_id": chat_history_id,
|
||||
"score": score,
|
||||
"reason": reason,
|
||||
}
|
||||
resp = self.post("/chat/feedback", json=data)
|
||||
return self._get_response_value(resp)
|
||||
|
||||
|
||||
class AsyncApiRequest(ApiRequest):
|
||||
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):
|
||||
|
|
|
|||
Loading…
Reference in New Issue