diff --git a/webui_st.py b/webui_st.py index 4300662..541851f 100644 --- a/webui_st.py +++ b/webui_st.py @@ -1,5 +1,5 @@ import streamlit as st -# from st_btn_select import st_btn_select +from streamlit_chatbox import st_chatbox import tempfile ###### 从webui借用的代码 ##### ###### 做了少量修改 ##### @@ -31,7 +31,6 @@ def get_vs_list(): embedding_model_dict_list = list(embedding_model_dict.keys()) llm_model_dict_list = list(llm_model_dict.keys()) -# flag_csv_logger = gr.CSVLogger() def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, @@ -50,6 +49,9 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR history[-1][-1] += source yield history, "" elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path): + local_doc_qa.top_k = vector_search_top_k + local_doc_qa.chunk_conent = chunk_conent + local_doc_qa.chunk_size = chunk_size for resp, history in local_doc_qa.get_knowledge_based_answer( query=query, vs_path=vs_path, chat_history=history, streaming=streaming): source = "\n\n" @@ -95,17 +97,101 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR "\n\n当前知识库为空,如需基于知识库进行问答,请先加载知识库后,再进行提问。" if mode == "知识库问答" else "") yield history, "" logger.info(f"flagging: username={FLAG_USER_NAME},query={query},vs_path={vs_path},mode={mode},history={history}") - # flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME) -def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec'): +def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): + vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") + filelist = [] + if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")): + os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content")) + qa = st.session_state.local_doc_qa + if qa.llm_model_chain and qa.embeddings: + if isinstance(files, list): + for file in files: + filename = os.path.split(file.name)[-1] + shutil.move(file.name, os.path.join( + KB_ROOT_PATH, vs_id, "content", filename)) + filelist.append(os.path.join( + KB_ROOT_PATH, vs_id, "content", filename)) + vs_path, loaded_files = qa.init_knowledge_vector_store( + filelist, vs_path, sentence_size) + else: + vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation, + sentence_size) + if len(loaded_files): + file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问" + else: + file_status = "文件未成功加载,请重新上传文件" + else: + file_status = "模型未完成加载,请先在加载模型后再导入文件" + vs_path = None + logger.info(file_status) + return vs_path, None, history + [[None, file_status]] + + +knowledge_base_test_mode_info = ("【注意】\n\n" + "1. 您已进入知识库测试模式,您输入的任何对话内容都将用于进行知识库查询," + "并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n" + "2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。" + """3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n""" + "4. 单条内容长度建议设置在100-150左右。") + + +webui_title = """ +# 🎉langchain-ChatGLM WebUI🎉 +👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM) +""" +###### ##### + + +###### todo ##### +# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。 +# 目前已经实现了local_doc_qa和shared.loaderCheckPoint的全局化。 +# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。 +# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。 +# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。 +###### ##### + + +###### 配置项 ##### +class ST_CONFIG: + default_mode = "知识库问答" + default_kb = "" +###### ##### + + +class TempFile: + ''' + 为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式 + ''' + + def __init__(self, path): + self.name = path + + +@st.cache_resource(show_spinner=False, max_entries=1) +def load_model( + llm_model: str = LLM_MODEL, + embedding_model: str = EMBEDDING_MODEL, + use_ptuning_v2: bool = USE_PTUNING_V2, +): + ''' + 对应init_model,利用streamlit cache避免模型重复加载 + ''' local_doc_qa = LocalDocQA() # 初始化消息 args = parser.parse_args() args_dict = vars(args) args_dict.update(model=llm_model) - shared.loaderCheckPoint = LoaderCheckPoint(args_dict) - llm_model_ins = shared.loaderLLM() + if shared.loaderCheckPoint is None: # avoid checkpoint reloading when reinit model + shared.loaderCheckPoint = LoaderCheckPoint(args_dict) + # shared.loaderCheckPoint.model_name is different by no_remote_model. + # if it is not set properly error occurs when reinit llm model(issue#473). + # as no_remote_model is removed from model_config, need workaround to set it automaticlly. + local_model_path = llm_model_dict.get(llm_model, {}).get('local_model_path') or '' + no_remote_model = os.path.isdir(local_model_path) + llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2) + llm_model_ins.history_len = LLM_HISTORY_LEN try: local_doc_qa.init_cfg(llm_model=llm_model_ins, @@ -128,235 +214,9 @@ def init_model(llm_model: str = 'chat-glm-6b', embedding_model: str = 'text2vec' return local_doc_qa -# 暂未使用到,先保留 -# def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history): -# try: -# llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2) -# llm_model_ins.history_len = llm_history_len -# local_doc_qa.init_cfg(llm_model=llm_model_ins, -# embedding_model=embedding_model, -# top_k=top_k) -# model_status = """模型已成功重新加载,可以开始对话,或从右侧选择模式后开始对话""" -# logger.info(model_status) -# except Exception as e: -# logger.error(e) -# model_status = """模型未成功重新加载,请到页面左上角"模型配置"选项卡中重新选择后点击"加载模型"按钮""" -# logger.info(model_status) -# return history + [[None, model_status]] - - -def get_vector_store(local_doc_qa, vs_id, files, sentence_size, history, one_conent, one_content_segmentation): - vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") - filelist = [] - if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")): - os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content")) - if local_doc_qa.llm and local_doc_qa.embeddings: - if isinstance(files, list): - for file in files: - filename = os.path.split(file.name)[-1] - shutil.move(file.name, os.path.join( - KB_ROOT_PATH, vs_id, "content", filename)) - filelist.append(os.path.join( - KB_ROOT_PATH, vs_id, "content", filename)) - vs_path, loaded_files = local_doc_qa.init_knowledge_vector_store( - filelist, vs_path, sentence_size) - else: - vs_path, loaded_files = local_doc_qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation, - sentence_size) - if len(loaded_files): - file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问" - else: - file_status = "文件未成功加载,请重新上传文件" - else: - file_status = "模型未完成加载,请先在加载模型后再导入文件" - vs_path = None - logger.info(file_status) - return vs_path, None, history + [[None, file_status]] - - -knowledge_base_test_mode_info = ("【注意】\n\n" - "1. 您已进入知识库测试模式,您输入的任何对话内容都将用于进行知识库查询," - "并仅输出知识库匹配出的内容及相似度分值和及输入的文本源路径,查询的内容并不会进入模型查询。\n\n" - "2. 知识相关度 Score 经测试,建议设置为 500 或更低,具体设置情况请结合实际使用调整。" - """3. 使用"添加单条数据"添加文本至知识库时,内容如未分段,则内容越多越会稀释各查询内容与之关联的score阈值。\n\n""" - "4. 单条内容长度建议设置在100-150左右。\n\n" - "5. 本界面用于知识入库及知识匹配相关参数设定,但当前版本中," - "本界面中修改的参数并不会直接修改对话界面中参数,仍需前往`configs/model_config.py`修改后生效。" - "相关参数将在后续版本中支持本界面直接修改。") - - -webui_title = """ -# 🎉langchain-ChatGLM WebUI🎉 -👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM) -""" -###### ##### - - -###### todo ##### -# 1. streamlit运行方式与一般web服务器不同,使用模块是无法实现单例模式的,所以shared和local_doc_qa都需要进行全局化处理。 -# 目前已经实现了local_doc_qa的全局化,后面要考虑shared。 -# 2. 当前local_doc_qa是一个全局变量,一方面:任何一个session对其做出修改,都会影响所有session的对话;另一方面,如何处理所有session的请求竞争也是问题。 -# 这个暂时无法避免,在配置普通的机器上暂时也无需考虑。 -# 3. 目前只包含了get_answer对应的参数,以后可以添加其他参数,如temperature。 -###### ##### - - -###### 配置项 ##### -class ST_CONFIG: - user_bg_color = '#77ff77' - user_icon = 'https://tse2-mm.cn.bing.net/th/id/OIP-C.LTTKrxNWDr_k74wz6jKqBgHaHa?w=203&h=203&c=7&r=0&o=5&pid=1.7' - robot_bg_color = '#ccccee' - robot_icon = 'https://ts1.cn.mm.bing.net/th/id/R-C.5302e2cc6f5c7c4933ebb3394e0c41bc?rik=z4u%2b7efba5Mgxw&riu=http%3a%2f%2fcomic-cons.xyz%2fwp-content%2fuploads%2fStar-Wars-avatar-icon-C3PO.png&ehk=kBBvCvpJMHPVpdfpw1GaH%2brbOaIoHjY5Ua9PKcIs%2bAc%3d&risl=&pid=ImgRaw&r=0' - default_mode = '知识库问答' - defalut_kb = '' -###### ##### - - -class MsgType: - ''' - 目前仅支持文本类型的输入输出,为以后多模态模型预留图像、视频、音频支持。 - ''' - TEXT = 1 - IMAGE = 2 - VIDEO = 3 - AUDIO = 4 - - -class TempFile: - ''' - 为保持与get_vector_store的兼容性,需要将streamlit上传文件转化为其可以接受的方式 - ''' - - def __init__(self, path): - self.name = path - - -def init_session(): - st.session_state.setdefault('history', []) - - -# def get_query_params(): -# ''' -# 可以用url参数传递配置参数:llm_model, embedding_model, kb, mode。 -# 该参数将覆盖model_config中的配置。处于安全考虑,目前只支持kb和mode -# 方便将固定的配置分享给特定的人。 -# ''' -# params = st.experimental_get_query_params() -# return {k: v[0] for k, v in params.items() if v} - - -def robot_say(msg, kb=''): - st.session_state['history'].append( - {'is_user': False, 'type': MsgType.TEXT, 'content': msg, 'kb': kb}) - - -def user_say(msg): - st.session_state['history'].append( - {'is_user': True, 'type': MsgType.TEXT, 'content': msg}) - - -def format_md(msg, is_user=False, bg_color='', margin='10%'): - ''' - 将文本消息格式化为markdown文本 - ''' - if is_user: - bg_color = bg_color or ST_CONFIG.user_bg_color - text = f''' -