diff --git a/webui.py b/webui.py
index cfe7b58..a93de0c 100644
--- a/webui.py
+++ b/webui.py
@@ -25,8 +25,6 @@ def get_vs_list():
return lst_default + lst
-vs_list = get_vs_list()
-
embedding_model_dict_list = list(embedding_model_dict.keys())
llm_model_dict_list = list(llm_model_dict.keys())
@@ -44,11 +42,12 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
query=query, chat_history=history, streaming=streaming):
source = "\n\n"
source += "".join(
- [f""" 出处 [{i + 1}] {doc.metadata["source"]}
\n"""
- f"""{doc.page_content}\n"""
- f""" """
- for i, doc in
- enumerate(resp["source_documents"])])
+ [
+ f""" 出处 [{i + 1}] {doc.metadata["source"]}
\n"""
+ f"""{doc.page_content}\n"""
+ f""" """
+ for i, doc in
+ enumerate(resp["source_documents"])])
history[-1][-1] += source
yield history, ""
elif mode == "知识库问答" and vs_path is not None and os.path.exists(vs_path):
@@ -89,7 +88,6 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
else:
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
streaming=streaming):
-
resp = answer_result.llm_output["answer"]
history = answer_result.history
history[-1][-1] = resp + (
@@ -99,9 +97,15 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
-def init_model(llm_model: BaseAnswer = None):
+def init_model():
+ args = parser.parse_args()
+
+ args_dict = vars(args)
+ shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
+ llm_model_ins = shared.loaderLLM()
+ llm_model_ins.set_history_len(LLM_HISTORY_LEN)
try:
- local_doc_qa.init_cfg(llm_model=llm_model)
+ local_doc_qa.init_cfg(llm_model=llm_model_ins)
generator = local_doc_qa.llm.generatorAnswer("你好")
for answer_result in generator:
print(answer_result.llm_output)
@@ -119,7 +123,9 @@ def init_model(llm_model: BaseAnswer = None):
return reply
-def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history):
+def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k,
+ history):
+
try:
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
llm_model_ins.history_len = llm_history_len
@@ -138,8 +144,6 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
vs_path = os.path.join(VS_ROOT_PATH, vs_id)
filelist = []
- if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)):
- os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id))
if local_doc_qa.llm and local_doc_qa.embeddings:
if isinstance(files, list):
for file in files:
@@ -166,9 +170,8 @@ def change_vs_name_input(vs_id, history):
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history
else:
file_status = f"已加载知识库{vs_id},请开始提问"
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH,
- vs_id), history + [
- [None, file_status]]
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), \
+ os.path.join(VS_ROOT_PATH, vs_id), history + [[None, file_status]]
knowledge_base_test_mode_info = ("【注意】\n\n"
@@ -206,19 +209,29 @@ def change_chunk_conent(mode, label_conent, history):
return gr.update(visible=False), history + [[None, f"【已关闭{conent}】"]]
-def add_vs_name(vs_name, vs_list, chatbot):
- if vs_name in vs_list:
+def add_vs_name(vs_name, chatbot):
+ if vs_name in get_vs_list():
vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
chatbot = chatbot + [[None, vs_status]]
- return gr.update(visible=True), vs_list, gr.update(visible=True), gr.update(visible=True), gr.update(
+ return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
visible=False), chatbot
else:
+ # 新建上传文件存储路径
+ if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_name)):
+ os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_name))
+ # 新建向量库存储路径
+ if not os.path.exists(os.path.join(VS_ROOT_PATH, vs_name)):
+ os.makedirs(os.path.join(VS_ROOT_PATH, vs_name))
vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """
chatbot = chatbot + [[None, vs_status]]
- return gr.update(visible=True, choices=[vs_name] + vs_list, value=vs_name), [vs_name] + vs_list, gr.update(
+ return gr.update(visible=True, choices=get_vs_list(), value=vs_name), gr.update(
visible=False), gr.update(visible=False), gr.update(visible=True), chatbot
+def refresh_vs_list():
+ return gr.update(choices=get_vs_list())
+
+
block_css = """.importantButton {
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
border: none !important;
@@ -232,7 +245,7 @@ webui_title = """
# 🎉langchain-ChatGLM WebUI🎉
👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)
"""
-default_vs = vs_list[0] if len(vs_list) > 1 else "为空"
+default_vs = get_vs_list()[0] if len(get_vs_list()) > 1 else "为空"
init_message = f"""欢迎使用 langchain-ChatGLM Web UI!
请在右侧切换模式,目前支持直接与 LLM 模型对话或基于本地知识库问答。
@@ -243,16 +256,7 @@ init_message = f"""欢迎使用 langchain-ChatGLM Web UI!
"""
# 初始化消息
-args = None
-args = parser.parse_args()
-
-args_dict = vars(args)
-shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
-llm_model_ins = shared.loaderLLM()
-llm_model_ins.set_history_len(LLM_HISTORY_LEN)
-
-model_status = init_model(llm_model=llm_model_ins)
-
+model_status = init_model()
default_theme_args = dict(
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
@@ -260,10 +264,9 @@ default_theme_args = dict(
)
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
- vs_path, file_status, model_status, vs_list = gr.State(
- os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else ""), gr.State(""), gr.State(
- model_status), gr.State(vs_list)
-
+ vs_path, file_status, model_status = gr.State(
+ os.path.join(VS_ROOT_PATH, get_vs_list()[0]) if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State(
+ model_status)
gr.Markdown(webui_title)
with gr.Tab("对话"):
with gr.Row():
@@ -283,10 +286,11 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
inputs=[mode, chatbot],
outputs=[vs_setting, knowledge_set, chatbot])
with vs_setting:
- select_vs = gr.Dropdown(vs_list.value,
+ vs_refresh = gr.Button("更新已有知识库选项")
+ select_vs = gr.Dropdown(get_vs_list(),
label="请选择要加载的知识库",
interactive=True,
- value=vs_list.value[0] if len(vs_list.value) > 0 else None
+ value=get_vs_list()[0] if len(get_vs_list()) > 0 else None
)
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
lines=1,
@@ -302,19 +306,21 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
interactive=True, visible=True)
with gr.Tab("上传文件"):
files = gr.File(label="添加文件",
- file_types=['.txt', '.md', '.docx', '.pdf'],
+ file_types=['.txt', '.md', '.docx', '.pdf', '.png', '.jpg'],
file_count="multiple",
show_label=False)
load_file_button = gr.Button("上传文件并加载知识库")
with gr.Tab("上传文件夹"):
folder_files = gr.File(label="添加文件",
- # file_types=['.txt', '.md', '.docx', '.pdf'],
file_count="directory",
show_label=False)
load_folder_button = gr.Button("上传文件夹并加载知识库")
+ vs_refresh.click(fn=refresh_vs_list,
+ inputs=[],
+ outputs=select_vs)
vs_add.click(fn=add_vs_name,
- inputs=[vs_name, vs_list, chatbot],
- outputs=[select_vs, vs_list, vs_name, vs_add, file2vs, chatbot])
+ inputs=[vs_name, chatbot],
+ outputs=[select_vs, vs_name, vs_add, file2vs, chatbot])
select_vs.change(fn=change_vs_name_input,
inputs=[select_vs, chatbot],
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
@@ -366,10 +372,11 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
inputs=[chunk_conent, gr.Textbox(value="chunk_conent", visible=False), chatbot],
outputs=[chunk_sizes, chatbot])
with vs_setting:
- select_vs = gr.Dropdown(vs_list.value,
+ vs_refresh = gr.Button("更新已有知识库选项")
+ select_vs = gr.Dropdown(get_vs_list(),
label="请选择要加载的知识库",
interactive=True,
- value=vs_list.value[0] if len(vs_list.value) > 0 else None)
+ value=get_vs_list()[0] if len(get_vs_list()) > 0 else None)
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
lines=1,
interactive=True,
@@ -402,9 +409,12 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
interactive=True)
load_conent_button = gr.Button("添加内容并加载知识库")
# 将上传的文件保存到content文件夹下,并更新下拉框
+ vs_refresh.click(fn=refresh_vs_list,
+ inputs=[],
+ outputs=select_vs)
vs_add.click(fn=add_vs_name,
- inputs=[vs_name, vs_list, chatbot],
- outputs=[select_vs, vs_list, vs_name, vs_add, file2vs, chatbot])
+ inputs=[vs_name, chatbot],
+ outputs=[select_vs, vs_name, vs_add, file2vs, chatbot])
select_vs.change(fn=change_vs_name_input,
inputs=[select_vs, chatbot],
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
@@ -455,8 +465,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
label="向量匹配 top k", interactive=True)
load_model_button = gr.Button("重新加载模型")
load_model_button.click(reinit_model, show_progress=True,
- inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora,
- top_k, chatbot], outputs=chatbot)
+ inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2,
+ use_lora, top_k, chatbot], outputs=chatbot)
(demo
.queue(concurrency_count=3)
@@ -464,4 +474,4 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
server_port=7860,
show_api=False,
share=False,
- inbrowser=False))
\ No newline at end of file
+ inbrowser=False))