diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 25d411c..bad5e55 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -8,6 +8,7 @@ import sentence_transformers import os from configs.model_config import * import datetime +from typing import List # return top-k text chunk from vector store VECTOR_SEARCH_TOP_K = 10 @@ -42,25 +43,35 @@ class LocalDocQA: self.top_k = top_k def init_knowledge_vector_store(self, - filepath: str): - if not os.path.exists(filepath): - print("路径不存在") - return None - elif os.path.isfile(filepath): - file = os.path.split(filepath)[-1] - try: - loader = UnstructuredFileLoader(filepath, mode="elements") - docs = loader.load() - print(f"{file} 已成功加载") - except: - print(f"{file} 未能成功加载") + filepath: str or List[str]): + if isinstance(filepath, str): + if not os.path.exists(filepath): + print("路径不存在") return None - elif os.path.isdir(filepath): - docs = [] - for file in os.listdir(filepath): - fullfilepath = os.path.join(filepath, file) + elif os.path.isfile(filepath): + file = os.path.split(filepath)[-1] try: - loader = UnstructuredFileLoader(fullfilepath, mode="elements") + loader = UnstructuredFileLoader(filepath, mode="elements") + docs = loader.load() + print(f"{file} 已成功加载") + except: + print(f"{file} 未能成功加载") + return None + elif os.path.isdir(filepath): + docs = [] + for file in os.listdir(filepath): + fullfilepath = os.path.join(filepath, file) + try: + loader = UnstructuredFileLoader(fullfilepath, mode="elements") + docs += loader.load() + print(f"{file} 已成功加载") + except: + print(f"{file} 未能成功加载") + else: + docs = [] + for file in filepath: + try: + loader = UnstructuredFileLoader(file, mode="elements") docs += loader.load() print(f"{file} 已成功加载") except: @@ -74,7 +85,7 @@ class LocalDocQA: def get_knowledge_based_answer(self, query, vs_path, - chat_history=[],): + chat_history=[], ): prompt_template = """基于以下已知信息,简洁和专业的来回答用户的问题。 如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。 diff --git a/webui.py b/webui.py index 0e19e1a..cf82075 100644 --- a/webui.py +++ b/webui.py @@ -1,7 +1,8 @@ import gradio as gr import os import shutil -import cli_demo as kb +from chains.local_doc_qa import LocalDocQA +from configs.model_config import * def get_file_list(): @@ -12,9 +13,11 @@ def get_file_list(): file_list = get_file_list() -embedding_model_dict_list = list(kb.embedding_model_dict.keys()) +embedding_model_dict_list = list(embedding_model_dict.keys()) -llm_model_dict_list = list(kb.llm_model_dict.keys()) +llm_model_dict_list = list(llm_model_dict.keys()) + +local_doc_qa = LocalDocQA() def upload_file(file): @@ -27,9 +30,9 @@ def upload_file(file): return gr.Dropdown.update(choices=file_list, value=filename) -def get_answer(query, vector_store, history): - resp, history = kb.get_knowledge_based_answer( - query=query, vector_store=vector_store, chat_history=history) +def get_answer(query, vs_path, history): + resp, history = local_doc_qa.get_knowledge_based_answer( + query=query, vs_path=vs_path, chat_history=history) return history, history @@ -41,6 +44,25 @@ def get_file_status(history): return history + [[None, "文档已完成加载,请开始提问"]] +def init_model(): + try: + local_doc_qa.init_cfg() + return """模型已成功加载,请选择文件后点击"加载文件"按钮""" + except: + return """模型未成功加载,请重新选择后点击"加载模型"按钮""" + + +def reinit_model(llm_model, embedding_model, llm_history_len, top_k): + local_doc_qa.init_cfg(llm_model=llm_model, + embedding_model=embedding_model, + llm_history_len=llm_history_len, + top_k=top_k), + + +model_status = gr.State() +history = gr.State([]) +vs_path = gr.State() +model_status = init_model() with gr.Blocks(css=""" .importantButton { background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important; @@ -63,89 +85,78 @@ with gr.Blocks(css=""" with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot([[None, """欢迎使用 langchain-ChatGLM Web UI,开始提问前,请依次如下 3 个步骤: -1. 选择语言模型、Embedding 模型及相关参数后点击"step.1: setting",并等待加载完成提示 -2. 上传或选择已有文件作为本地知识文档输入后点击"step.2 loading",并等待加载完成提示 -3. 输入要提交的问题后点击"step.3 asking" """]], +1. 选择语言模型、Embedding 模型及相关参数后点击"重新加载模型",并等待加载完成提示 +2. 上传或选择已有文件作为本地知识文档输入后点击"重新加载文档",并等待加载完成提示 +3. 输入要提交的问题后,点击回车提交 """], [None, str(model_status)]], elem_id="chat-box", show_label=False).style(height=600) - with gr.Column(scale=1): - with gr.Column(): - llm_model = gr.Radio(llm_model_dict_list, - label="llm model", - value="chatglm-6b", - interactive=True) - LLM_HISTORY_LEN = gr.Slider(0, - 10, - value=3, - step=1, - label="LLM history len", - interactive=True) - embedding_model = gr.Radio(embedding_model_dict_list, - label="embedding model", - value="text2vec", - interactive=True) - VECTOR_SEARCH_TOP_K = gr.Slider(1, - 20, - value=6, - step=1, - label="vector search top k", - interactive=True) - load_model_button = gr.Button("step.1:setting") - load_model_button.click(lambda *args: - kb.init_cfg(args[0], args[1], args[2], args[3]), - show_progress=True, - api_name="init_cfg", - inputs=[llm_model, embedding_model, LLM_HISTORY_LEN,VECTOR_SEARCH_TOP_K] - ).then( - get_model_status, chatbot, chatbot - ) - - with gr.Column(): - with gr.Tab("select"): - selectFile = gr.Dropdown(file_list, - label="content file", - interactive=True, - value=file_list[0] if len(file_list) > 0 else None) - with gr.Tab("upload"): - file = gr.File(label="content file", - file_types=['.txt', '.md', '.docx', '.pdf'] - ).style(height=100) - # 将上传的文件保存到content文件夹下,并更新下拉框 - file.upload(upload_file, - inputs=file, - outputs=selectFile) - history = gr.State([]) - vector_store = gr.State() - load_button = gr.Button("step.2:loading") - load_button.click(lambda fileName: - kb.init_knowledge_vector_store( - "content/" + fileName), - show_progress=True, - api_name="init_knowledge_vector_store", - inputs=selectFile, - outputs=vector_store - ).then( - get_file_status, - chatbot, - chatbot, - show_progress=True, - ) - - with gr.Row(): - with gr.Column(scale=2): query = gr.Textbox(show_label=False, - placeholder="Prompts", + placeholder="请提问", lines=1, value="用200字总结一下" ).style(container=False) + with gr.Column(scale=1): - generate_button = gr.Button("step.3:asking", - elem_classes="importantButton") - generate_button.click(get_answer, - [query, vector_store, chatbot], - [chatbot, history], - api_name="get_knowledge_based_answer" - ) + llm_model = gr.Radio(llm_model_dict_list, + label="LLM 模型", + value="chatglm-6b", + interactive=True) + llm_history_len = gr.Slider(0, + 10, + value=3, + step=1, + label="LLM history len", + interactive=True) + embedding_model = gr.Radio(embedding_model_dict_list, + label="Embedding 模型", + value="text2vec", + interactive=True) + top_k = gr.Slider(1, + 20, + value=6, + step=1, + label="向量匹配 top k", + interactive=True) + load_model_button = gr.Button("重新加载模型") + + # with gr.Column(): + with gr.Tab("select"): + selectFile = gr.Dropdown(file_list, + label="content file", + interactive=True, + value=file_list[0] if len(file_list) > 0 else None) + with gr.Tab("upload"): + file = gr.File(label="content file", + file_types=['.txt', '.md', '.docx', '.pdf'] + ) # .style(height=100) + load_button = gr.Button("重新加载文件") + load_model_button.click(reinit_model, + show_progress=True, + api_name="init_cfg", + inputs=[llm_model, embedding_model, llm_history_len, top_k] + ).then( + get_model_status, chatbot, chatbot + ) + # 将上传的文件保存到content文件夹下,并更新下拉框 + file.upload(upload_file, + inputs=file, + outputs=selectFile) + # load_button.click(local_doc_qa.init_knowledge_vector_store, + # show_progress=True, + # api_name="init_knowledge_vector_store", + # inputs=selectFile, + # outputs=vs_path + # ).then( + # get_file_status, + # chatbot, + # chatbot, + # show_progress=True, + # ) + # query.submit(get_answer, + # [query, vs_path, chatbot], + # [chatbot, history], + # api_name="get_knowledge_based_answer" + # ) demo.queue(concurrency_count=3).launch( server_name='0.0.0.0', share=False, inbrowser=False)