添加对话评分与历史消息保存功能 (#1940)

* 新功能:
- WEBUI 添加对话评分功能
- 增加 /chat/feedback 接口,用于接收对话评分
- /chat/chat 接口返回值由 str 改为 {"text":str, "chat_history_id": str}
- init_database.py 添加 --create-tables --clear-tables 参数

依赖:
- streamlit-chatbox==1.1.11

开发者:
- ChatHistoryModel 的 id 字段支持自动生成
- SAVE_CHAT_HISTORY 改到 basic_config.py

* 修复:点击反馈后页面未刷新

---------

Co-authored-by: liqiankun.1111 <liqiankun.1111@bytedance.com>
Co-authored-by: liunux4odoo <liunux@qq.com>
Co-authored-by: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com>
This commit is contained in:
qiankunli 2023-11-03 11:31:45 +08:00 committed by GitHub
parent 554122f60e
commit fa906b33a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 246 additions and 56 deletions

View File

@ -6,6 +6,8 @@ import langchain
log_verbose = False
langchain.verbose = False
# 是否保存聊天记录
SAVE_CHAT_HISTORY = False
# 通常情况下不需要更改以下内容

View File

@ -24,10 +24,15 @@ if __name__ == "__main__":
'''
)
)
parser.add_argument(
"--create-tables",
action="store_true",
help=("create empty tables if not existed")
)
parser.add_argument(
"--clear-tables",
action="store_true",
help=("drop the database tables before recreate vector stores")
help=("create empty tables, or drop the database tables before recreate vector stores")
)
parser.add_argument(
"--import-db",
@ -87,12 +92,10 @@ if __name__ == "__main__":
help=("specify embeddings model.")
)
if len(sys.argv) <= 1:
parser.print_help()
else:
args = parser.parse_args()
start_time = datetime.now()
if args.create_tables:
create_tables() # confirm tables exist
if args.clear_tables:

View File

@ -48,7 +48,7 @@ pandas~=2.0.3
streamlit>=1.26.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox==1.1.10
streamlit-chatbox==1.1.11
streamlit-aggrid>=0.3.4.post3
httpx[brotli,http2,socks]>=0.25.0
watchdog

View File

@ -44,7 +44,7 @@ pandas~=2.0.3
streamlit>=1.26.0
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.1.11
streamlit-chatbox>=1.1.9
streamlit-chatbox==1.1.11
streamlit-aggrid>=0.3.4.post3
httpx~=0.24.1
watchdog

View File

@ -3,7 +3,7 @@ pandas~=2.0.3
streamlit>=1.27.2
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.2.3
streamlit-chatbox==1.1.10
streamlit-chatbox==1.1.11
streamlit-aggrid>=0.3.4.post3
httpx>=0.25.0
nltk>=3.8.1

View File

@ -16,6 +16,7 @@ from server.chat.chat import chat
from server.chat.openai_chat import openai_chat
from server.chat.search_engine_chat import search_engine_chat
from server.chat.completion import completion
from server.chat.feedback import chat_feedback
from server.embeddings_api import embed_texts_endpoint
from server.llm_api import (list_running_models, list_config_models,
change_llm_model, stop_llm_model,
@ -73,6 +74,11 @@ def mount_app_routes(app: FastAPI, run_mode: str = None):
summary="与搜索引擎对话",
)(search_engine_chat)
app.post("/chat/feedback",
tags=["Chat"],
summary="返回llm模型对话评分",
)(chat_feedback)
# 知识库相关接口
mount_knowledge_routes(app)

View File

@ -1,15 +1,17 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from configs import LLM_MODEL, TEMPERATURE
from configs import LLM_MODEL, TEMPERATURE, SAVE_CHAT_HISTORY
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
import json
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional
from server.chat.utils import History
from server.utils import get_prompt_template
from server.db.repository.chat_history_repository import add_chat_history_to_db, update_chat_history
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
@ -53,16 +55,26 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callback.done),
)
answer = ""
chat_history_id = add_chat_history_to_db(chat_type="llm_chat", query=query)
if stream:
async for token in callback.aiter():
answer += token
# Use server-sent-events to stream the response
yield token
yield json.dumps(
{"text": token, "chat_history_id": chat_history_id},
ensure_ascii=False)
else:
answer = ""
async for token in callback.aiter():
answer += token
yield answer
yield json.dumps(
{"text": answer, "chat_history_id": chat_history_id},
ensure_ascii=False)
if SAVE_CHAT_HISTORY and len(chat_history_id) > 0:
# 后续可以加入一些其他信息比如真实的prompt等
update_chat_history(chat_history_id, response=answer)
await task
return StreamingResponse(chat_iterator(query=query,

19
server/chat/feedback.py Normal file
View File

@ -0,0 +1,19 @@
from fastapi import Body
from configs import logger, log_verbose
from server.utils import BaseResponse
from server.db.repository.chat_history_repository import feedback_chat_history_to_db
def chat_feedback(chat_history_id: str = Body("", max_length=32, description="聊天记录id"),
score: int = Body(0, max=100, description="用户评分满分100越大表示评价越高"),
reason: str = Body("", description="用户评分理由,比如不符合事实等")
):
try:
feedback_chat_history_to_db(chat_history_id, score, reason)
except Exception as e:
msg = f"反馈聊天记录出错: {e}"
logger.error(f'{e.__class__.__name__}: {msg}',
exc_info=e if log_verbose else None)
return BaseResponse(code=500, msg=msg)
return BaseResponse(code=200, msg=f"已反馈聊天记录 {chat_history_id}")

View File

@ -0,0 +1,25 @@
from sqlalchemy import Column, Integer, String, DateTime, JSON, func
from server.db.base import Base
class ChatHistoryModel(Base):
"""
聊天记录模型
"""
__tablename__ = 'chat_history'
# 由前端生成的uuid如果是自增的话则需要将id 传给前端,这在流式返回里有点麻烦
id = Column(String(32), primary_key=True, comment='聊天记录ID')
# chat/agent_chat等
chat_type = Column(String(50), comment='聊天类型')
query = Column(String(4096), comment='用户问题')
response = Column(String(4096), comment='模型回答')
# 记录知识库id等以便后续扩展
meta_data = Column(JSON, default={})
# 满分100 越高表示评价越好
feedback_score = Column(Integer, default=-1, comment='用户评分')
feedback_reason = Column(String(255), default="", comment='用户评分理由')
create_time = Column(DateTime, default=func.now(), comment='创建时间')
def __repr__(self):
return f"<ChatHistory(id='{self.id}', chat_type='{self.chat_type}', query='{self.query}', response='{self.response}',meta_data='{self.meta_data}',feedback_score='{self.feedback_score}',feedback_reason='{self.feedback_reason}', create_time='{self.create_time}')>"

View File

@ -0,0 +1,75 @@
from server.db.session import with_session
from server.db.models.chat_history_model import ChatHistoryModel
import re
import uuid
from typing import Dict, List
def _convert_query(query: str) -> str:
p = re.sub(r"\s+", "%", query)
return f"%{p}%"
@with_session
def add_chat_history_to_db(session, chat_type, query, response="", chat_history_id=None, metadata: Dict = {}):
"""
新增聊天记录
"""
if not chat_history_id:
chat_history_id = uuid.uuid4().hex
ch = ChatHistoryModel(id=chat_history_id, chat_type=chat_type, query=query, response=response,
metadata=metadata)
session.add(ch)
session.commit()
return ch.id
@with_session
def update_chat_history(session, chat_history_id, response: str = None, metadata: Dict = None):
"""
更新已有的聊天记录
"""
ch = get_chat_history_by_id(chat_history_id)
if ch is not None:
if response is not None:
ch.response = response
if isinstance(metadata, dict):
ch.meta_data = metadata
session.add(ch)
return ch.id
@with_session
def feedback_chat_history_to_db(session, chat_history_id, feedback_score, feedback_reason):
"""
反馈聊天记录
"""
ch = session.query(ChatHistoryModel).filter_by(id=chat_history_id).first()
if ch:
ch.feedback_score = feedback_score
ch.feedback_reason = feedback_reason
return ch.id
@with_session
def get_chat_history_by_id(session, chat_history_id) -> ChatHistoryModel:
"""
查询聊天记录
"""
ch = session.query(ChatHistoryModel).filter_by(id=chat_history_id).first()
return ch
@with_session
def filter_chat_history(session, query=None, response=None, score=None, reason=None) -> List[ChatHistoryModel]:
ch =session.query(ChatHistoryModel)
if query is not None:
ch = ch.filter(ChatHistoryModel.query.ilike(_convert_query(query)))
if response is not None:
ch = ch.filter(ChatHistoryModel.response.ilike(_convert_query(response)))
if score is not None:
ch = ch.filter_by(feedback_score=score)
if reason is not None:
ch = ch.filter(ChatHistoryModel.feedback_reason.ilike(_convert_query(reason)))
return ch

View File

@ -5,6 +5,7 @@ from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
list_files_from_folder,files2docs_in_thread,
KnowledgeFile,)
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.db.models.chat_history_model import ChatHistoryModel
from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported
from server.db.base import Base, engine
from server.db.session import session_scope

View File

@ -7,6 +7,7 @@ from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
from typing import List, Dict
chat_box = ChatBox(
assistant_avatar=os.path.join(
"img",
@ -43,6 +44,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
f"当前运行的模型`{default_model}`, 您可以开始提问了."
)
chat_box.init_session()
with st.sidebar:
# TODO: 对话模型与会话绑定
def on_mode_change():
@ -173,17 +175,34 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K)
# Display chat messages from history on app rerun
chat_box.output_messages()
chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter "
def on_feedback(
feedback,
chat_history_id: str = "",
history_index: int = -1,
):
reason = feedback["text"]
score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
api.chat_feedback(chat_history_id=chat_history_id,
score=score_int,
reason=reason)
st.session_state["need_rerun"] = True
feedback_kwargs = {
"feedback_type": "thumbs",
"optional_text_label": "欢迎反馈您打分的理由",
}
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
history = get_messages_history(history_len)
chat_box.user_say(prompt)
if dialogue_mode == "LLM 对话":
chat_box.ai_say("正在思考...")
text = ""
chat_history_id = ""
r = api.chat_chat(prompt,
history=history,
model=llm_model,
@ -193,11 +212,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
if error_msg := check_error_msg(t): # check whether error occured
st.error(error_msg)
break
text += t
text += t.get("text", "")
chat_box.update_msg(text)
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
chat_history_id = t.get("chat_history_id", "")
metadata = {
"chat_history_id": chat_history_id,
}
chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
chat_box.show_feedback(**feedback_kwargs,
key=chat_history_id,
on_submit=on_feedback,
kwargs={"chat_history_id": chat_history_id, "history_index": len(chat_box.history) - 1})
elif dialogue_mode == "自定义Agent问答":
if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
@ -280,6 +306,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
chat_box.update_msg(text, element_index=0, streaming=False)
chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
if st.session_state.get("need_rerun"):
st.session_state["need_rerun"] = False
st.rerun()
now = datetime.now()
with st.sidebar:
@ -290,7 +320,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
use_container_width=True,
):
chat_box.reset_history()
st.experimental_rerun()
st.rerun()
export_btn.download_button(
"导出记录",

View File

@ -135,7 +135,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
st.toast(ret.get("msg", " "))
st.session_state["selected_kb_name"] = kb_name
st.session_state["selected_kb_info"] = kb_info
st.experimental_rerun()
st.rerun()
elif selected_kb:
kb = selected_kb
@ -256,7 +256,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
zh_title_enhance=zh_title_enhance)
st.experimental_rerun()
st.rerun()
# 将文件从向量库中删除,但不删除文件本身。
if cols[2].button(
@ -266,7 +266,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
):
file_names = [row["file_name"] for row in selected_rows]
api.delete_kb_docs(kb, file_names=file_names)
st.experimental_rerun()
st.rerun()
if cols[3].button(
"从知识库中删除",
@ -275,7 +275,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
):
file_names = [row["file_name"] for row in selected_rows]
api.delete_kb_docs(kb, file_names=file_names, delete_content=True)
st.experimental_rerun()
st.rerun()
st.divider()
@ -298,7 +298,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
st.toast(msg)
else:
empty.progress(d["finished"] / d["total"], d["msg"])
st.experimental_rerun()
st.rerun()
if cols[2].button(
"删除知识库",
@ -307,4 +307,4 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None):
ret = api.delete_knowledge_base(kb)
st.toast(ret.get("msg", " "))
time.sleep(1)
st.experimental_rerun()
st.rerun()

View File

@ -314,7 +314,7 @@ class ApiRequest:
pprint(data)
response = self.post("/chat/chat", json=data, stream=True, **kwargs)
return self._httpx_stream2generator(response)
return self._httpx_stream2generator(response, as_json=True)
def agent_chat(
self,
@ -873,6 +873,23 @@ class ApiRequest:
)
return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data"))
def chat_feedback(
self,
chat_history_id: str,
score: int,
reason: str = "",
) -> int:
'''
反馈对话评价
'''
data = {
"chat_history_id": chat_history_id,
"score": score,
"reason": reason,
}
resp = self.post("/chat/feedback", json=data)
return self._get_response_value(resp)
class AsyncApiRequest(ApiRequest):
def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):