rewrite streamlit ui with streamlit-chatbox;fix bugs (#850)
This commit is contained in:
parent
df3744f501
commit
5cbb86a823
482
webui_st.py
482
webui_st.py
|
|
@ -1,5 +1,5 @@
|
|||
import streamlit as st
|
||||
# from st_btn_select import st_btn_select
|
||||
from streamlit_chatbox import st_chatbox
|
||||
import tempfile
|
||||
###### 从webui借用的代码 #####
|
||||
###### 做了少量修改 #####
|
||||
|
|
@ -31,7 +31,6 @@ def get_vs_list():
|
|||
|
||||
embedding_model_dict_list = list(embedding_model_dict.keys())
|
||||
llm_model_dict_list = list(llm_model_dict.keys())
|
||||
# flag_csv_logger = gr.CSVLogger()
|
||||
|
||||
|
||||
def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD,
|
||||
|
|
@ -50,6 +49,9 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
|||
history[-1][-1] += source
|
||||
yield history, ""
|
||||
elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path):
|
||||
local_doc_qa.top_k = vector_search_top_k
|
||||
local_doc_qa.chunk_conent = chunk_conent
|
||||
local_doc_qa.chunk_size = chunk_size
|
||||
for resp, history in local_doc_qa.get_knowledge_based_answer(
|
||||
query=query, vs_path=vs_path, chat_history=history, streaming=streaming):
|
||||
source = "\n\n"
|
||||
|
|
@ -95,17 +97,101 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
|||
"\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "")
|
||||
yield history, ""
|
||||
logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}")
|
||||
# flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
|
||||
|
||||
|
||||
def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'):
|
||||
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||
filelist = []
|
||||
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
|
||||
os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
|
||||
qa = st.session_state.local_doc_qa
|
||||
if qa.llm_model_chain and qa.embeddings:
|
||||
if isinstance(files, list):
|
||||
for file in files:
|
||||
filename = os.path.split(file.name)[-1]
|
||||
shutil.move(file.name, os.path.join(
|
||||
KB_ROOT_PATH, vs_id, "content", filename))
|
||||
filelist.append(os.path.join(
|
||||
KB_ROOT_PATH, vs_id, "content", filename))
|
||||
vs_path, loaded_files = qa.init_knowledge_vector_store(
|
||||
filelist, vs_path, sentence_size)
|
||||
else:
|
||||
vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
|
||||
sentence_size)
|
||||
if len(loaded_files):
|
||||
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
|
||||
else:
|
||||
file_status = "文件未成功加载,请重新上传文件"
|
||||
else:
|
||||
file_status = "模型未完成加载,请先在加载模型后再导入文件"
|
||||
vs_path = None
|
||||
logger.info(file_status)
|
||||
return vs_path, None, history + [[None, file_status]]
|
||||
|
||||
|
||||
knowledge_base_test_mode_info = ("【注意】\n\n"
|
||||
"1. 您已进入知识库测试模式,您输入的任何对话内容都将用于进行知识库查询,"
|
||||
"并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n"
|
||||
"2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。"
|
||||
"""3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n"""
|
||||
"4. 单条内容长度建议设置在100-150左右。")
|
||||
|
||||
|
||||
webui_title = """
|
||||
# 🎉langchain-ChatGLM WebUI🎉
|
||||
👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)
|
||||
"""
|
||||
###### #####
|
||||
|
||||
|
||||
###### todo #####
|
||||
# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。
|
||||
# 目前已经实现了local_doc_qa和shared.loaderCheckPoint的全局化。
|
||||
# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。
|
||||
# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。
|
||||
# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。
|
||||
###### #####
|
||||
|
||||
|
||||
###### 配置项 #####
|
||||
class ST_CONFIG:
|
||||
default_mode = "知识库问答"
|
||||
default_kb = ""
|
||||
###### #####
|
||||
|
||||
|
||||
class TempFile:
|
||||
'''
|
||||
为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式
|
||||
'''
|
||||
|
||||
def __init__(self, path):
|
||||
self.name = path
|
||||
|
||||
|
||||
@st.cache_resource(show_spinner=False, max_entries=1)
|
||||
def load_model(
|
||||
llm_model: str = LLM_MODEL,
|
||||
embedding_model: str = EMBEDDING_MODEL,
|
||||
use_ptuning_v2: bool = USE_PTUNING_V2,
|
||||
):
|
||||
'''
|
||||
对应init_model,利用streamlit cache避免模型重复加载
|
||||
'''
|
||||
local_doc_qa = LocalDocQA()
|
||||
# 初始化消息
|
||||
args = parser.parse_args()
|
||||
args_dict = vars(args)
|
||||
args_dict.update(model=llm_model)
|
||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||
llm_model_ins = shared.loaderLLM()
|
||||
if shared.loaderCheckPoint is None: # avoid checkpoint reloading when reinit model
|
||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||
# shared.loaderCheckPoint.model_name is different by no_remote_model.
|
||||
# if it is not set properly error occurs when reinit llm model(issue#473).
|
||||
# as no_remote_model is removed from model_config, need workaround to set it automaticlly.
|
||||
local_model_path = llm_model_dict.get(llm_model, {}).get('local_model_path') or ''
|
||||
no_remote_model = os.path.isdir(local_model_path)
|
||||
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
|
||||
llm_model_ins.history_len = LLM_HISTORY_LEN
|
||||
|
||||
try:
|
||||
local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
||||
|
|
@ -128,235 +214,9 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'
|
|||
return local_doc_qa
|
||||
|
||||
|
||||
# 暂未使用到,先保留
|
||||
# def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history):
|
||||
# try:
|
||||
# llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
|
||||
# llm_model_ins.history_len = llm_history_len
|
||||
# local_doc_qa.init_cfg(llm_model=llm_model_ins,
|
||||
# embedding_model=embedding_model,
|
||||
# top_k=top_k)
|
||||
# model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话"""
|
||||
# logger.info(model_status)
|
||||
# except Exception as e:
|
||||
# logger.error(e)
|
||||
# model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮"""
|
||||
# logger.info(model_status)
|
||||
# return history + [[None, model_status]]
|
||||
|
||||
|
||||
def get_vector_store(local_doc_qa, vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
||||
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
|
||||
filelist = []
|
||||
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
|
||||
os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
|
||||
if local_doc_qa.llm and local_doc_qa.embeddings:
|
||||
if isinstance(files, list):
|
||||
for file in files:
|
||||
filename = os.path.split(file.name)[-1]
|
||||
shutil.move(file.name, os.path.join(
|
||||
KB_ROOT_PATH, vs_id, "content", filename))
|
||||
filelist.append(os.path.join(
|
||||
KB_ROOT_PATH, vs_id, "content", filename))
|
||||
vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store(
|
||||
filelist, vs_path, sentence_size)
|
||||
else:
|
||||
vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
|
||||
sentence_size)
|
||||
if len(loaded_files):
|
||||
file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
|
||||
else:
|
||||
file_status = "文件未成功加载,请重新上传文件"
|
||||
else:
|
||||
file_status = "模型未完成加载,请先在加载模型后再导入文件"
|
||||
vs_path = None
|
||||
logger.info(file_status)
|
||||
return vs_path, None, history + [[None, file_status]]
|
||||
|
||||
|
||||
knowledge_base_test_mode_info = ("【注意】\n\n"
|
||||
"1. 您已进入知识库测试模式,您输入的任何对话内容都将用于进行知识库查询,"
|
||||
"并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n"
|
||||
"2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。"
|
||||
"""3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n"""
|
||||
"4. 单条内容长度建议设置在100-150左右。\n\n"
|
||||
"5. 本界面用于知识入库及知识匹配相关参数设定,但当前版本中,"
|
||||
"本界面中修改的参数并不会直接修改对话界面中参数,仍需前往`configs/model_config.py`修改后生效。"
|
||||
"相关参数将在后续版本中支持本界面直接修改。")
|
||||
|
||||
|
||||
webui_title = """
|
||||
# 🎉langchain-ChatGLM WebUI🎉
|
||||
👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)
|
||||
"""
|
||||
###### #####
|
||||
|
||||
|
||||
###### todo #####
|
||||
# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。
|
||||
# 目前已经实现了local_doc_qa的全局化,后面要考虑shared。
|
||||
# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。
|
||||
# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。
|
||||
# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。
|
||||
###### #####
|
||||
|
||||
|
||||
###### 配置项 #####
|
||||
class ST_CONFIG:
|
||||
user_bg_color = '#77ff77'
|
||||
user_icon = 'https://tse2-mm.cn.bing.net/th/id/OIP-C.LTTKrxNWDr_k74wz6jKqBgHaHa?w=203&h=203&c=7&r=0&o=5&pid=1.7'
|
||||
robot_bg_color = '#ccccee'
|
||||
robot_icon = 'https://ts1.cn.mm.bing.net/th/id/R-C.5302e2cc6f5c7c4933ebb3394e0c41bc?rik=z4u%2b7efba5Mgxw&riu=http%3a%2f%2fcomic-cons.xyz%2fwp-content%2fuploads%2fStar-Wars-avatar-icon-C3PO.png&ehk=kBBvCvpJMHPVpdfpw1GaH%2brbOaIoHjY5Ua9PKcIs%2bAc%3d&risl=&pid=ImgRaw&r=0'
|
||||
default_mode = '知识库问答'
|
||||
defalut_kb = ''
|
||||
###### #####
|
||||
|
||||
|
||||
class MsgType:
|
||||
'''
|
||||
目前仅支持文本类型的输入输出,为以后多模态模型预留图像、视频、音频支持。
|
||||
'''
|
||||
TEXT = 1
|
||||
IMAGE = 2
|
||||
VIDEO = 3
|
||||
AUDIO = 4
|
||||
|
||||
|
||||
class TempFile:
|
||||
'''
|
||||
为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式
|
||||
'''
|
||||
|
||||
def __init__(self, path):
|
||||
self.name = path
|
||||
|
||||
|
||||
def init_session():
|
||||
st.session_state.setdefault('history', [])
|
||||
|
||||
|
||||
# def get_query_params():
|
||||
# '''
|
||||
# 可以用url参数传递配置参数:llm_model, embedding_model, kb, mode。
|
||||
# 该参数将覆盖model_config中的配置。处于安全考虑,目前只支持kb和mode
|
||||
# 方便将固定的配置分享给特定的人。
|
||||
# '''
|
||||
# params = st.experimental_get_query_params()
|
||||
# return {k: v[0] for k, v in params.items() if v}
|
||||
|
||||
|
||||
def robot_say(msg, kb=''):
|
||||
st.session_state['history'].append(
|
||||
{'is_user': False, 'type': MsgType.TEXT, 'content': msg, 'kb': kb})
|
||||
|
||||
|
||||
def user_say(msg):
|
||||
st.session_state['history'].append(
|
||||
{'is_user': True, 'type': MsgType.TEXT, 'content': msg})
|
||||
|
||||
|
||||
def format_md(msg, is_user=False, bg_color='', margin='10%'):
|
||||
'''
|
||||
将文本消息格式化为markdown文本
|
||||
'''
|
||||
if is_user:
|
||||
bg_color = bg_color or ST_CONFIG.user_bg_color
|
||||
text = f'''
|
||||
<div style="background:{bg_color};
|
||||
margin-left:{margin};
|
||||
word-break:break-all;
|
||||
float:right;
|
||||
padding:2%;
|
||||
border-radius:2%;">
|
||||
{msg}
|
||||
</div>
|
||||
'''
|
||||
else:
|
||||
bg_color = bg_color or ST_CONFIG.robot_bg_color
|
||||
text = f'''
|
||||
<div style="background:{bg_color};
|
||||
margin-right:{margin};
|
||||
word-break:break-all;
|
||||
padding:2%;
|
||||
border-radius:2%;">
|
||||
{msg}
|
||||
</div>
|
||||
'''
|
||||
return text
|
||||
|
||||
|
||||
def message(msg,
|
||||
is_user=False,
|
||||
msg_type=MsgType.TEXT,
|
||||
icon='',
|
||||
bg_color='',
|
||||
margin='10%',
|
||||
kb='',
|
||||
):
|
||||
'''
|
||||
渲染单条消息。目前仅支持文本
|
||||
'''
|
||||
cols = st.columns([1, 10, 1])
|
||||
empty = cols[1].empty()
|
||||
if is_user:
|
||||
icon = icon or ST_CONFIG.user_icon
|
||||
bg_color = bg_color or ST_CONFIG.user_bg_color
|
||||
cols[2].image(icon, width=40)
|
||||
if msg_type == MsgType.TEXT:
|
||||
text = format_md(msg, is_user, bg_color, margin)
|
||||
empty.markdown(text, unsafe_allow_html=True)
|
||||
else:
|
||||
raise RuntimeError('only support text message now.')
|
||||
else:
|
||||
icon = icon or ST_CONFIG.robot_icon
|
||||
bg_color = bg_color or ST_CONFIG.robot_bg_color
|
||||
cols[0].image(icon, width=40)
|
||||
if kb:
|
||||
cols[0].write(f'({kb})')
|
||||
if msg_type == MsgType.TEXT:
|
||||
text = format_md(msg, is_user, bg_color, margin)
|
||||
empty.markdown(text, unsafe_allow_html=True)
|
||||
else:
|
||||
raise RuntimeError('only support text message now.')
|
||||
return empty
|
||||
|
||||
|
||||
def output_messages(
|
||||
user_bg_color='',
|
||||
robot_bg_color='',
|
||||
user_icon='',
|
||||
robot_icon='',
|
||||
):
|
||||
with chat_box.container():
|
||||
last_response = None
|
||||
for msg in st.session_state['history']:
|
||||
bg_color = user_bg_color if msg['is_user'] else robot_bg_color
|
||||
icon = user_icon if msg['is_user'] else robot_icon
|
||||
empty = message(msg['content'],
|
||||
is_user=msg['is_user'],
|
||||
icon=icon,
|
||||
msg_type=msg['type'],
|
||||
bg_color=bg_color,
|
||||
kb=msg.get('kb', '')
|
||||
)
|
||||
if not msg['is_user']:
|
||||
last_response = empty
|
||||
return last_response
|
||||
|
||||
|
||||
@st.cache_resource(show_spinner=False, max_entries=1)
|
||||
def load_model(llm_model: str, embedding_model: str):
|
||||
'''
|
||||
对应init_model,利用streamlit cache避免模型重复加载
|
||||
'''
|
||||
local_doc_qa = init_model(llm_model, embedding_model)
|
||||
robot_say('模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。\n请尽量不要刷新页面,以免模型出错或重复加载。')
|
||||
return local_doc_qa
|
||||
|
||||
|
||||
# @st.cache_data
|
||||
def answer(query, vs_path='', history=[], mode='', score_threshold=0,
|
||||
vector_search_top_k=5, chunk_conent=True, chunk_size=100, qa=None
|
||||
vector_search_top_k=5, chunk_conent=True, chunk_size=100
|
||||
):
|
||||
'''
|
||||
对应get_answer,--利用streamlit cache缓存相同问题的答案--
|
||||
|
|
@ -365,48 +225,24 @@ def answer(query, vs_path='', history=[], mode='', score_threshold=0,
|
|||
vector_search_top_k, chunk_conent, chunk_size)
|
||||
|
||||
|
||||
def load_vector_store(
|
||||
vs_id,
|
||||
files,
|
||||
sentence_size=100,
|
||||
history=[],
|
||||
one_conent=None,
|
||||
one_content_segmentation=None,
|
||||
):
|
||||
return get_vector_store(
|
||||
local_doc_qa,
|
||||
vs_id,
|
||||
files,
|
||||
sentence_size,
|
||||
history,
|
||||
one_conent,
|
||||
one_content_segmentation,
|
||||
)
|
||||
def use_kb_mode(m):
|
||||
return m in ["知识库问答", "知识库测试"]
|
||||
|
||||
|
||||
# main ui
|
||||
st.set_page_config(webui_title, layout='wide')
|
||||
init_session()
|
||||
# params = get_query_params()
|
||||
# llm_model = params.get('llm_model', LLM_MODEL)
|
||||
# embedding_model = params.get('embedding_model', EMBEDDING_MODEL)
|
||||
|
||||
with st.spinner(f'正在加载模型({LLM_MODEL} + {EMBEDDING_MODEL}),请耐心等候...'):
|
||||
local_doc_qa = load_model(LLM_MODEL, EMBEDDING_MODEL)
|
||||
|
||||
|
||||
def use_kb_mode(m):
|
||||
return m in ['知识库问答', '知识库测试']
|
||||
|
||||
chat_box = st_chatbox(greetings=["模型已成功加载,可以开始对话,或从左侧选择模式后开始对话。"])
|
||||
# 使用 help(st_chatbox) 查看自定义参数
|
||||
|
||||
# sidebar
|
||||
modes = ['LLM 对话', '知识库问答', 'Bing搜索问答', '知识库测试']
|
||||
with st.sidebar:
|
||||
def on_mode_change():
|
||||
m = st.session_state.mode
|
||||
robot_say(f'已切换到"{m}"模式')
|
||||
chat_box.robot_say(f'已切换到"{m}"模式')
|
||||
if m == '知识库测试':
|
||||
robot_say(knowledge_base_test_mode_info)
|
||||
chat_box.robot_say(knowledge_base_test_mode_info)
|
||||
|
||||
index = 0
|
||||
try:
|
||||
|
|
@ -416,7 +252,7 @@ with st.sidebar:
|
|||
mode = st.selectbox('对话模式', modes, index,
|
||||
on_change=on_mode_change, key='mode')
|
||||
|
||||
with st.expander('模型配置', '知识' not in mode):
|
||||
with st.expander('模型配置', not use_kb_mode(mode)):
|
||||
with st.form('model_config'):
|
||||
index = 0
|
||||
try:
|
||||
|
|
@ -425,9 +261,8 @@ with st.sidebar:
|
|||
pass
|
||||
llm_model = st.selectbox('LLM模型', llm_model_dict_list, index)
|
||||
|
||||
no_remote_model = st.checkbox('加载本地模型', False)
|
||||
use_ptuning_v2 = st.checkbox('使用p-tuning-v2微调过的模型', False)
|
||||
use_lora = st.checkbox('使用lora微调的权重', False)
|
||||
|
||||
try:
|
||||
index = embedding_model_dict_list.index(EMBEDDING_MODEL)
|
||||
except:
|
||||
|
|
@ -437,9 +272,12 @@ with st.sidebar:
|
|||
|
||||
btn_load_model = st.form_submit_button('重新加载模型')
|
||||
if btn_load_model:
|
||||
local_doc_qa = load_model(llm_model, embedding_model)
|
||||
local_doc_qa = load_model(llm_model, embedding_model, use_ptuning_v2)
|
||||
|
||||
if mode in ['知识库问答', '知识库测试']:
|
||||
history_len = st.slider(
|
||||
"LLM对话轮数", 1, 50, LLM_HISTORY_LEN)
|
||||
|
||||
if use_kb_mode(mode):
|
||||
vs_list = get_vs_list()
|
||||
vs_list.remove('新建知识库')
|
||||
|
||||
|
|
@ -452,7 +290,7 @@ with st.sidebar:
|
|||
st.session_state.vs_path = name
|
||||
|
||||
def on_vs_change():
|
||||
robot_say(f'已加载知识库: {st.session_state.vs_path}')
|
||||
chat_box.robot_say(f'已加载知识库: {st.session_state.vs_path}')
|
||||
with st.expander('知识库配置', True):
|
||||
cols = st.columns([12, 10])
|
||||
kb_name = cols[0].text_input(
|
||||
|
|
@ -460,21 +298,22 @@ with st.sidebar:
|
|||
if 'kb_name' not in st.session_state:
|
||||
st.session_state.kb_name = kb_name
|
||||
cols[1].button('新建知识库', on_click=on_new_kb)
|
||||
index = 0
|
||||
try:
|
||||
index = vs_list.index(ST_CONFIG.default_kb)
|
||||
except:
|
||||
pass
|
||||
vs_path = st.selectbox(
|
||||
'选择知识库', vs_list, on_change=on_vs_change, key='vs_path')
|
||||
'选择知识库', vs_list, index, on_change=on_vs_change, key='vs_path')
|
||||
|
||||
st.text('')
|
||||
|
||||
score_threshold = st.slider(
|
||||
'知识相关度阈值', 0, 1000, VECTOR_SEARCH_SCORE_THRESHOLD)
|
||||
top_k = st.slider('向量匹配数量', 1, 20, VECTOR_SEARCH_TOP_K)
|
||||
history_len = st.slider(
|
||||
'LLM对话轮数', 1, 50, LLM_HISTORY_LEN) # 也许要跟知识库分开设置
|
||||
# local_doc_qa.llm.set_history_len(history_len)
|
||||
chunk_conent = st.checkbox('启用上下文关联', False)
|
||||
st.text('')
|
||||
# chunk_conent = st.checkbox('分割文本', True) # 知识库文本分割入库
|
||||
chunk_size = st.slider('上下文关联长度', 1, 1000, CHUNK_SIZE)
|
||||
st.text('')
|
||||
sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE)
|
||||
files = st.file_uploader('上传知识文件',
|
||||
['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'],
|
||||
|
|
@ -487,56 +326,61 @@ with st.sidebar:
|
|||
with open(file, 'wb') as fp:
|
||||
fp.write(f.getvalue())
|
||||
file_list.append(TempFile(file))
|
||||
_, _, history = load_vector_store(
|
||||
_, _, history = get_vector_store(
|
||||
vs_path, file_list, sentence_size, [], None, None)
|
||||
st.session_state.files = []
|
||||
|
||||
|
||||
# main body
|
||||
chat_box = st.empty()
|
||||
# load model after params rendered
|
||||
with st.spinner(f"正在加载模型({llm_model} + {embedding_model}),请耐心等候..."):
|
||||
local_doc_qa = load_model(
|
||||
llm_model,
|
||||
embedding_model,
|
||||
use_ptuning_v2,
|
||||
)
|
||||
local_doc_qa.llm_model_chain.history_len = history_len
|
||||
if use_kb_mode(mode):
|
||||
local_doc_qa.chunk_conent = chunk_conent
|
||||
local_doc_qa.chunk_size = chunk_size
|
||||
# local_doc_qa.llm_model_chain.temperature = temperature # 这样设置temperature似乎不起作用
|
||||
st.session_state.local_doc_qa = local_doc_qa
|
||||
|
||||
with st.form('my_form', clear_on_submit=True):
|
||||
# input form
|
||||
with st.form("my_form", clear_on_submit=True):
|
||||
cols = st.columns([8, 1])
|
||||
question = cols[0].text_input(
|
||||
question = cols[0].text_area(
|
||||
'temp', key='input_question', label_visibility='collapsed')
|
||||
|
||||
def on_send():
|
||||
q = st.session_state.input_question
|
||||
if q:
|
||||
user_say(q)
|
||||
if cols[1].form_submit_button("发送"):
|
||||
chat_box.user_say(question)
|
||||
history = []
|
||||
if mode == "LLM 对话":
|
||||
chat_box.robot_say("正在思考...")
|
||||
chat_box.output_messages()
|
||||
for history, _ in answer(question,
|
||||
history=[],
|
||||
mode=mode):
|
||||
chat_box.update_last_box_text(history[-1][-1])
|
||||
elif use_kb_mode(mode):
|
||||
chat_box.robot_say(f"正在查询 [{vs_path}] ...")
|
||||
chat_box.output_messages()
|
||||
for history, _ in answer(question,
|
||||
vs_path=os.path.join(
|
||||
KB_ROOT_PATH, vs_path, 'vector_store'),
|
||||
history=[],
|
||||
mode=mode,
|
||||
score_threshold=score_threshold,
|
||||
vector_search_top_k=top_k,
|
||||
chunk_conent=chunk_conent,
|
||||
chunk_size=chunk_size):
|
||||
chat_box.update_last_box_text(history[-1][-1])
|
||||
else:
|
||||
chat_box.robot_say(f"正在执行Bing搜索...")
|
||||
chat_box.output_messages()
|
||||
for history, _ in answer(question,
|
||||
history=[],
|
||||
mode=mode):
|
||||
chat_box.update_last_box_text(history[-1][-1])
|
||||
|
||||
if mode == 'LLM 对话':
|
||||
robot_say('正在思考...')
|
||||
last_response = output_messages()
|
||||
for history, _ in answer(q,
|
||||
history=[],
|
||||
mode=mode):
|
||||
last_response.markdown(
|
||||
format_md(history[-1][-1], False),
|
||||
unsafe_allow_html=True
|
||||
)
|
||||
elif use_kb_mode(mode):
|
||||
robot_say('正在思考...', vs_path)
|
||||
last_response = output_messages()
|
||||
for history, _ in answer(q,
|
||||
vs_path=os.path.join(
|
||||
KB_ROOT_PATH, vs_path, "vector_store"),
|
||||
history=[],
|
||||
mode=mode,
|
||||
score_threshold=score_threshold,
|
||||
vector_search_top_k=top_k,
|
||||
chunk_conent=chunk_conent,
|
||||
chunk_size=chunk_size):
|
||||
last_response.markdown(
|
||||
format_md(history[-1][-1], False, 'ligreen'),
|
||||
unsafe_allow_html=True
|
||||
)
|
||||
else:
|
||||
robot_say('正在思考...')
|
||||
last_response = output_messages()
|
||||
st.session_state['history'][-1]['content'] = history[-1][-1]
|
||||
submit = cols[1].form_submit_button('发送', on_click=on_send)
|
||||
|
||||
output_messages()
|
||||
|
||||
# st.write(st.session_state['history'])
|
||||
# st.write(chat_box.history)
|
||||
chat_box.output_messages()
|
||||
|
|
|
|||
Loading…
Reference in New Issue