diff --git a/configs/basic_config.py.example b/configs/basic_config.py.example index 6bd8c8d..3540872 100644 --- a/configs/basic_config.py.example +++ b/configs/basic_config.py.example @@ -6,6 +6,8 @@ import langchain log_verbose = False langchain.verbose = False +# 是否保存聊天记录 +SAVE_CHAT_HISTORY = False # 通常情况下不需要更改以下内容 diff --git a/init_database.py b/init_database.py index 9af7747..3a4405d 100644 --- a/init_database.py +++ b/init_database.py @@ -24,10 +24,15 @@ if __name__ == "__main__": ''' ) ) + parser.add_argument( + "--create-tables", + action="store_true", + help=("create empty tables if not existed") + ) parser.add_argument( "--clear-tables", action="store_true", - help=("drop the database tables before recreate vector stores") + help=("create empty tables, or drop the database tables before recreate vector stores") ) parser.add_argument( "--import-db", @@ -87,31 +92,29 @@ if __name__ == "__main__": help=("specify embeddings model.") ) - if len(sys.argv) <= 1: - parser.print_help() - else: - args = parser.parse_args() - start_time = datetime.now() + args = parser.parse_args() + start_time = datetime.now() + if args.create_tables: create_tables() # confirm tables exist - if args.clear_tables: - reset_tables() - print("database talbes reseted") + if args.clear_tables: + reset_tables() + print("database talbes reseted") - if args.recreate_vs: - print("recreating all vector stores") - folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model) - elif args.import_db: - import_from_db(args.import_db) - elif args.update_in_db: - folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model) - elif args.increament: - folder2db(kb_names=args.kb_name, mode="increament", embed_model=args.embed_model) - elif args.prune_db: - prune_db_docs(args.kb_name) - elif args.prune_folder: - prune_folder_files(args.kb_name) + if args.recreate_vs: + print("recreating all vector stores") + folder2db(kb_names=args.kb_name, mode="recreate_vs", embed_model=args.embed_model) + elif args.import_db: + import_from_db(args.import_db) + elif args.update_in_db: + folder2db(kb_names=args.kb_name, mode="update_in_db", embed_model=args.embed_model) + elif args.increament: + folder2db(kb_names=args.kb_name, mode="increament", embed_model=args.embed_model) + elif args.prune_db: + prune_db_docs(args.kb_name) + elif args.prune_folder: + prune_folder_files(args.kb_name) - end_time = datetime.now() - print(f"总计用时: {end_time-start_time}") + end_time = datetime.now() + print(f"总计用时: {end_time-start_time}") diff --git a/requirements.txt b/requirements.txt index 419c650..0af4bd1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -48,7 +48,7 @@ pandas~=2.0.3 streamlit>=1.26.0 streamlit-option-menu>=0.3.6 streamlit-antd-components>=0.1.11 -streamlit-chatbox==1.1.10 +streamlit-chatbox==1.1.11 streamlit-aggrid>=0.3.4.post3 httpx[brotli,http2,socks]>=0.25.0 watchdog diff --git a/requirements_lite.txt b/requirements_lite.txt index a20cb8b..8b0a32d 100644 --- a/requirements_lite.txt +++ b/requirements_lite.txt @@ -44,7 +44,7 @@ pandas~=2.0.3 streamlit>=1.26.0 streamlit-option-menu>=0.3.6 streamlit-antd-components>=0.1.11 -streamlit-chatbox>=1.1.9 +streamlit-chatbox==1.1.11 streamlit-aggrid>=0.3.4.post3 httpx~=0.24.1 watchdog diff --git a/requirements_webui.txt b/requirements_webui.txt index a36fec4..e1ec9ff 100644 --- a/requirements_webui.txt +++ b/requirements_webui.txt @@ -3,7 +3,7 @@ pandas~=2.0.3 streamlit>=1.27.2 streamlit-option-menu>=0.3.6 streamlit-antd-components>=0.2.3 -streamlit-chatbox==1.1.10 +streamlit-chatbox==1.1.11 streamlit-aggrid>=0.3.4.post3 httpx>=0.25.0 nltk>=3.8.1 diff --git a/server/api.py b/server/api.py index 093f81b..941b238 100644 --- a/server/api.py +++ b/server/api.py @@ -16,6 +16,7 @@ from server.chat.chat import chat from server.chat.openai_chat import openai_chat from server.chat.search_engine_chat import search_engine_chat from server.chat.completion import completion +from server.chat.feedback import chat_feedback from server.embeddings_api import embed_texts_endpoint from server.llm_api import (list_running_models, list_config_models, change_llm_model, stop_llm_model, @@ -73,6 +74,11 @@ def mount_app_routes(app: FastAPI, run_mode: str = None): summary="与搜索引擎对话", )(search_engine_chat) + app.post("/chat/feedback", + tags=["Chat"], + summary="返回llm模型对话评分", + )(chat_feedback) + # 知识库相关接口 mount_knowledge_routes(app) diff --git a/server/chat/chat.py b/server/chat/chat.py index 00d2fbb..7225576 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,31 +1,33 @@ from fastapi import Body from fastapi.responses import StreamingResponse -from configs import LLM_MODEL, TEMPERATURE +from configs import LLM_MODEL, TEMPERATURE, SAVE_CHAT_HISTORY from server.utils import wrap_done, get_ChatOpenAI from langchain.chains import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio +import json from langchain.prompts.chat import ChatPromptTemplate from typing import List, Optional from server.chat.utils import History from server.utils import get_prompt_template +from server.db.repository.chat_history_repository import add_chat_history_to_db, update_chat_history async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), - history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", "content": "虎头虎脑"}]] - ), - stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), - temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), - max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), - # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), - prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), - ): + history: List[History] = Body([], + description="历史对话", + examples=[[ + {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", "content": "虎头虎脑"}]] + ), + stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"), + # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), + prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), + ): history = [History.from_data(h) for h in history] async def chat_iterator(query: str, @@ -53,16 +55,26 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 callback.done), ) + answer = "" + chat_history_id = add_chat_history_to_db(chat_type="llm_chat", query=query) + if stream: async for token in callback.aiter(): + answer += token # Use server-sent-events to stream the response - yield token + yield json.dumps( + {"text": token, "chat_history_id": chat_history_id}, + ensure_ascii=False) else: - answer = "" async for token in callback.aiter(): answer += token - yield answer + yield json.dumps( + {"text": answer, "chat_history_id": chat_history_id}, + ensure_ascii=False) + if SAVE_CHAT_HISTORY and len(chat_history_id) > 0: + # 后续可以加入一些其他信息,比如真实的prompt等 + update_chat_history(chat_history_id, response=answer) await task return StreamingResponse(chat_iterator(query=query, diff --git a/server/chat/feedback.py b/server/chat/feedback.py new file mode 100644 index 0000000..6c8dd44 --- /dev/null +++ b/server/chat/feedback.py @@ -0,0 +1,19 @@ +from fastapi import Body +from configs import logger, log_verbose +from server.utils import BaseResponse +from server.db.repository.chat_history_repository import feedback_chat_history_to_db + + +def chat_feedback(chat_history_id: str = Body("", max_length=32, description="聊天记录id"), + score: int = Body(0, max=100, description="用户评分,满分100,越大表示评价越高"), + reason: str = Body("", description="用户评分理由,比如不符合事实等") + ): + try: + feedback_chat_history_to_db(chat_history_id, score, reason) + except Exception as e: + msg = f"反馈聊天记录出错: {e}" + logger.error(f'{e.__class__.__name__}: {msg}', + exc_info=e if log_verbose else None) + return BaseResponse(code=500, msg=msg) + + return BaseResponse(code=200, msg=f"已反馈聊天记录 {chat_history_id}") diff --git a/server/db/models/chat_history_model.py b/server/db/models/chat_history_model.py new file mode 100644 index 0000000..fec27ef --- /dev/null +++ b/server/db/models/chat_history_model.py @@ -0,0 +1,25 @@ +from sqlalchemy import Column, Integer, String, DateTime, JSON, func + +from server.db.base import Base + + +class ChatHistoryModel(Base): + """ + 聊天记录模型 + """ + __tablename__ = 'chat_history' + # 由前端生成的uuid,如果是自增的话,则需要将id 传给前端,这在流式返回里有点麻烦 + id = Column(String(32), primary_key=True, comment='聊天记录ID') + # chat/agent_chat等 + chat_type = Column(String(50), comment='聊天类型') + query = Column(String(4096), comment='用户问题') + response = Column(String(4096), comment='模型回答') + # 记录知识库id等,以便后续扩展 + meta_data = Column(JSON, default={}) + # 满分100 越高表示评价越好 + feedback_score = Column(Integer, default=-1, comment='用户评分') + feedback_reason = Column(String(255), default="", comment='用户评分理由') + create_time = Column(DateTime, default=func.now(), comment='创建时间') + + def __repr__(self): + return f"" diff --git a/server/db/repository/chat_history_repository.py b/server/db/repository/chat_history_repository.py new file mode 100644 index 0000000..031be1b --- /dev/null +++ b/server/db/repository/chat_history_repository.py @@ -0,0 +1,75 @@ +from server.db.session import with_session +from server.db.models.chat_history_model import ChatHistoryModel +import re +import uuid +from typing import Dict, List + + +def _convert_query(query: str) -> str: + p = re.sub(r"\s+", "%", query) + return f"%{p}%" + + +@with_session +def add_chat_history_to_db(session, chat_type, query, response="", chat_history_id=None, metadata: Dict = {}): + """ + 新增聊天记录 + """ + if not chat_history_id: + chat_history_id = uuid.uuid4().hex + ch = ChatHistoryModel(id=chat_history_id, chat_type=chat_type, query=query, response=response, + metadata=metadata) + session.add(ch) + session.commit() + return ch.id + + +@with_session +def update_chat_history(session, chat_history_id, response: str = None, metadata: Dict = None): + """ + 更新已有的聊天记录 + """ + ch = get_chat_history_by_id(chat_history_id) + if ch is not None: + if response is not None: + ch.response = response + if isinstance(metadata, dict): + ch.meta_data = metadata + session.add(ch) + return ch.id + + +@with_session +def feedback_chat_history_to_db(session, chat_history_id, feedback_score, feedback_reason): + """ + 反馈聊天记录 + """ + ch = session.query(ChatHistoryModel).filter_by(id=chat_history_id).first() + if ch: + ch.feedback_score = feedback_score + ch.feedback_reason = feedback_reason + return ch.id + + +@with_session +def get_chat_history_by_id(session, chat_history_id) -> ChatHistoryModel: + """ + 查询聊天记录 + """ + ch = session.query(ChatHistoryModel).filter_by(id=chat_history_id).first() + return ch + + +@with_session +def filter_chat_history(session, query=None, response=None, score=None, reason=None) -> List[ChatHistoryModel]: + ch =session.query(ChatHistoryModel) + if query is not None: + ch = ch.filter(ChatHistoryModel.query.ilike(_convert_query(query))) + if response is not None: + ch = ch.filter(ChatHistoryModel.response.ilike(_convert_query(response))) + if score is not None: + ch = ch.filter_by(feedback_score=score) + if reason is not None: + ch = ch.filter(ChatHistoryModel.feedback_reason.ilike(_convert_query(reason))) + + return ch diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index e8ebc06..58a2ffd 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -5,6 +5,7 @@ from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder, list_files_from_folder,files2docs_in_thread, KnowledgeFile,) from server.knowledge_base.kb_service.base import KBServiceFactory +from server.db.models.chat_history_model import ChatHistoryModel from server.db.repository.knowledge_file_repository import add_file_to_db # ensure Models are imported from server.db.base import Base, engine from server.db.session import session_scope diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 594b14e..823b8fb 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -7,6 +7,7 @@ from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES, DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL) from typing import List, Dict + chat_box = ChatBox( assistant_avatar=os.path.join( "img", @@ -43,6 +44,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): f"当前运行的模型`{default_model}`, 您可以开始提问了." ) chat_box.init_session() + with st.sidebar: # TODO: 对话模型与会话绑定 def on_mode_change(): @@ -173,17 +175,34 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, SEARCH_ENGINE_TOP_K) # Display chat messages from history on app rerun - chat_box.output_messages() chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter " + def on_feedback( + feedback, + chat_history_id: str = "", + history_index: int = -1, + ): + reason = feedback["text"] + score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index) + api.chat_feedback(chat_history_id=chat_history_id, + score=score_int, + reason=reason) + st.session_state["need_rerun"] = True + + feedback_kwargs = { + "feedback_type": "thumbs", + "optional_text_label": "欢迎反馈您打分的理由", + } + if prompt := st.chat_input(chat_input_placeholder, key="prompt"): history = get_messages_history(history_len) chat_box.user_say(prompt) if dialogue_mode == "LLM 对话": chat_box.ai_say("正在思考...") text = "" + chat_history_id = "" r = api.chat_chat(prompt, history=history, model=llm_model, @@ -193,11 +212,18 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): if error_msg := check_error_msg(t): # check whether error occured st.error(error_msg) break - text += t + text += t.get("text", "") chat_box.update_msg(text) - chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 - + chat_history_id = t.get("chat_history_id", "") + metadata = { + "chat_history_id": chat_history_id, + } + chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标 + chat_box.show_feedback(**feedback_kwargs, + key=chat_history_id, + on_submit=on_feedback, + kwargs={"chat_history_id": chat_history_id, "history_index": len(chat_box.history) - 1}) elif dialogue_mode == "自定义Agent问答": if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL): @@ -280,6 +306,10 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): chat_box.update_msg(text, element_index=0, streaming=False) chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False) + if st.session_state.get("need_rerun"): + st.session_state["need_rerun"] = False + st.rerun() + now = datetime.now() with st.sidebar: @@ -290,7 +320,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False): use_container_width=True, ): chat_box.reset_history() - st.experimental_rerun() + st.rerun() export_btn.download_button( "导出记录", diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index ca8425c..57d6202 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -135,7 +135,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): st.toast(ret.get("msg", " ")) st.session_state["selected_kb_name"] = kb_name st.session_state["selected_kb_info"] = kb_info - st.experimental_rerun() + st.rerun() elif selected_kb: kb = selected_kb @@ -256,7 +256,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): chunk_size=chunk_size, chunk_overlap=chunk_overlap, zh_title_enhance=zh_title_enhance) - st.experimental_rerun() + st.rerun() # 将文件从向量库中删除,但不删除文件本身。 if cols[2].button( @@ -266,7 +266,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): ): file_names = [row["file_name"] for row in selected_rows] api.delete_kb_docs(kb, file_names=file_names) - st.experimental_rerun() + st.rerun() if cols[3].button( "从知识库中删除", @@ -275,7 +275,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): ): file_names = [row["file_name"] for row in selected_rows] api.delete_kb_docs(kb, file_names=file_names, delete_content=True) - st.experimental_rerun() + st.rerun() st.divider() @@ -298,7 +298,7 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): st.toast(msg) else: empty.progress(d["finished"] / d["total"], d["msg"]) - st.experimental_rerun() + st.rerun() if cols[2].button( "删除知识库", @@ -307,4 +307,4 @@ def knowledge_base_page(api: ApiRequest, is_lite: bool = None): ret = api.delete_knowledge_base(kb) st.toast(ret.get("msg", " ")) time.sleep(1) - st.experimental_rerun() + st.rerun() diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 4a3963a..80b855b 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -314,7 +314,7 @@ class ApiRequest: pprint(data) response = self.post("/chat/chat", json=data, stream=True, **kwargs) - return self._httpx_stream2generator(response) + return self._httpx_stream2generator(response, as_json=True) def agent_chat( self, @@ -873,6 +873,23 @@ class ApiRequest: ) return self._get_response_value(resp, as_json=True, value_func=lambda r: r.get("data")) + def chat_feedback( + self, + chat_history_id: str, + score: int, + reason: str = "", + ) -> int: + ''' + 反馈对话评价 + ''' + data = { + "chat_history_id": chat_history_id, + "score": score, + "reason": reason, + } + resp = self.post("/chat/feedback", json=data) + return self._get_response_value(resp) + class AsyncApiRequest(ApiRequest): def __init__(self, base_url: str = api_address(), timeout: float = HTTPX_DEFAULT_TIMEOUT):