update webui.py
This commit is contained in:
parent
89fe20b59f
commit
c7106317a0
96
webui.py
96
webui.py
|
|
@ -25,8 +25,6 @@ def get_vs_list():
|
||||||
return lst_default + lst
|
return lst_default + lst
|
||||||
|
|
||||||
|
|
||||||
vs_list = get_vs_list()
|
|
||||||
|
|
||||||
embedding_model_dict_list = list(embedding_model_dict.keys())
|
embedding_model_dict_list = list(embedding_model_dict.keys())
|
||||||
|
|
||||||
llm_model_dict_list = list(llm_model_dict.keys())
|
llm_model_dict_list = list(llm_model_dict.keys())
|
||||||
|
|
@ -44,7 +42,8 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
||||||
query=query, chat_history=history, streaming=streaming):
|
query=query, chat_history=history, streaming=streaming):
|
||||||
source = "\n\n"
|
source = "\n\n"
|
||||||
source += "".join(
|
source += "".join(
|
||||||
[f"""<details> <summary>出处 [{i + 1}] <a href="{doc.metadata["source"]}" target="_blank">{doc.metadata["source"]}</a> </summary>\n"""
|
[
|
||||||
|
f"""<details> <summary>出处 [{i + 1}] <a href="{doc.metadata["source"]}" target="_blank">{doc.metadata["source"]}</a> </summary>\n"""
|
||||||
f"""{doc.page_content}\n"""
|
f"""{doc.page_content}\n"""
|
||||||
f"""</details>"""
|
f"""</details>"""
|
||||||
for i, doc in
|
for i, doc in
|
||||||
|
|
@ -89,7 +88,6 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
||||||
else:
|
else:
|
||||||
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
|
for answer_result in local_doc_qa.llm.generatorAnswer(prompt=query, history=history,
|
||||||
streaming=streaming):
|
streaming=streaming):
|
||||||
|
|
||||||
resp = answer_result.llm_output["answer"]
|
resp = answer_result.llm_output["answer"]
|
||||||
history = answer_result.history
|
history = answer_result.history
|
||||||
history[-1][-1] = resp + (
|
history[-1][-1] = resp + (
|
||||||
|
|
@ -99,9 +97,15 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
|
||||||
flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
|
flag_csv_logger.flag([query, vs_path, history, mode], username=FLAG_USER_NAME)
|
||||||
|
|
||||||
|
|
||||||
def init_model(llm_model: BaseAnswer = None):
|
def init_model():
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
args_dict = vars(args)
|
||||||
|
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
||||||
|
llm_model_ins = shared.loaderLLM()
|
||||||
|
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
||||||
try:
|
try:
|
||||||
local_doc_qa.init_cfg(llm_model=llm_model)
|
local_doc_qa.init_cfg(llm_model=llm_model_ins)
|
||||||
generator = local_doc_qa.llm.generatorAnswer("你好")
|
generator = local_doc_qa.llm.generatorAnswer("你好")
|
||||||
for answer_result in generator:
|
for answer_result in generator:
|
||||||
print(answer_result.llm_output)
|
print(answer_result.llm_output)
|
||||||
|
|
@ -119,7 +123,9 @@ def init_model(llm_model: BaseAnswer = None):
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
|
|
||||||
def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k, history):
|
def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora, top_k,
|
||||||
|
history):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
|
llm_model_ins = shared.loaderLLM(llm_model, no_remote_model, use_ptuning_v2)
|
||||||
llm_model_ins.history_len = llm_history_len
|
llm_model_ins.history_len = llm_history_len
|
||||||
|
|
@ -138,8 +144,6 @@ def reinit_model(llm_model, embedding_model, llm_history_len, no_remote_model, u
|
||||||
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
|
||||||
vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
vs_path = os.path.join(VS_ROOT_PATH, vs_id)
|
||||||
filelist = []
|
filelist = []
|
||||||
if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_id)):
|
|
||||||
os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_id))
|
|
||||||
if local_doc_qa.llm and local_doc_qa.embeddings:
|
if local_doc_qa.llm and local_doc_qa.embeddings:
|
||||||
if isinstance(files, list):
|
if isinstance(files, list):
|
||||||
for file in files:
|
for file in files:
|
||||||
|
|
@ -166,9 +170,8 @@ def change_vs_name_input(vs_id, history):
|
||||||
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history
|
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=False), None, history
|
||||||
else:
|
else:
|
||||||
file_status = f"已加载知识库{vs_id},请开始提问"
|
file_status = f"已加载知识库{vs_id},请开始提问"
|
||||||
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), os.path.join(VS_ROOT_PATH,
|
return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), \
|
||||||
vs_id), history + [
|
os.path.join(VS_ROOT_PATH, vs_id), history + [[None, file_status]]
|
||||||
[None, file_status]]
|
|
||||||
|
|
||||||
|
|
||||||
knowledge_base_test_mode_info = ("【注意】\n\n"
|
knowledge_base_test_mode_info = ("【注意】\n\n"
|
||||||
|
|
@ -206,19 +209,29 @@ def change_chunk_conent(mode, label_conent, history):
|
||||||
return gr.update(visible=False), history + [[None, f"【已关闭{conent}】"]]
|
return gr.update(visible=False), history + [[None, f"【已关闭{conent}】"]]
|
||||||
|
|
||||||
|
|
||||||
def add_vs_name(vs_name, vs_list, chatbot):
|
def add_vs_name(vs_name, chatbot):
|
||||||
if vs_name in vs_list:
|
if vs_name in get_vs_list():
|
||||||
vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
|
vs_status = "与已有知识库名称冲突,请重新选择其他名称后提交"
|
||||||
chatbot = chatbot + [[None, vs_status]]
|
chatbot = chatbot + [[None, vs_status]]
|
||||||
return gr.update(visible=True), vs_list, gr.update(visible=True), gr.update(visible=True), gr.update(
|
return gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(
|
||||||
visible=False), chatbot
|
visible=False), chatbot
|
||||||
else:
|
else:
|
||||||
|
# 新建上传文件存储路径
|
||||||
|
if not os.path.exists(os.path.join(UPLOAD_ROOT_PATH, vs_name)):
|
||||||
|
os.makedirs(os.path.join(UPLOAD_ROOT_PATH, vs_name))
|
||||||
|
# 新建向量库存储路径
|
||||||
|
if not os.path.exists(os.path.join(VS_ROOT_PATH, vs_name)):
|
||||||
|
os.makedirs(os.path.join(VS_ROOT_PATH, vs_name))
|
||||||
vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """
|
vs_status = f"""已新增知识库"{vs_name}",将在上传文件并载入成功后进行存储。请在开始对话前,先完成文件上传。 """
|
||||||
chatbot = chatbot + [[None, vs_status]]
|
chatbot = chatbot + [[None, vs_status]]
|
||||||
return gr.update(visible=True, choices=[vs_name] + vs_list, value=vs_name), [vs_name] + vs_list, gr.update(
|
return gr.update(visible=True, choices=get_vs_list(), value=vs_name), gr.update(
|
||||||
visible=False), gr.update(visible=False), gr.update(visible=True), chatbot
|
visible=False), gr.update(visible=False), gr.update(visible=True), chatbot
|
||||||
|
|
||||||
|
|
||||||
|
def refresh_vs_list():
|
||||||
|
return gr.update(choices=get_vs_list())
|
||||||
|
|
||||||
|
|
||||||
block_css = """.importantButton {
|
block_css = """.importantButton {
|
||||||
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
|
background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
|
||||||
border: none !important;
|
border: none !important;
|
||||||
|
|
@ -232,7 +245,7 @@ webui_title = """
|
||||||
# 🎉langchain-ChatGLM WebUI🎉
|
# 🎉langchain-ChatGLM WebUI🎉
|
||||||
👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)
|
👍 [https://github.com/imClumsyPanda/langchain-ChatGLM](https://github.com/imClumsyPanda/langchain-ChatGLM)
|
||||||
"""
|
"""
|
||||||
default_vs = vs_list[0] if len(vs_list) > 1 else "为空"
|
default_vs = get_vs_list()[0] if len(get_vs_list()) > 1 else "为空"
|
||||||
init_message = f"""欢迎使用 langchain-ChatGLM Web UI!
|
init_message = f"""欢迎使用 langchain-ChatGLM Web UI!
|
||||||
|
|
||||||
请在右侧切换模式,目前支持直接与 LLM 模型对话或基于本地知识库问答。
|
请在右侧切换模式,目前支持直接与 LLM 模型对话或基于本地知识库问答。
|
||||||
|
|
@ -243,16 +256,7 @@ init_message = f"""欢迎使用 langchain-ChatGLM Web UI!
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 初始化消息
|
# 初始化消息
|
||||||
args = None
|
model_status = init_model()
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
args_dict = vars(args)
|
|
||||||
shared.loaderCheckPoint = LoaderCheckPoint(args_dict)
|
|
||||||
llm_model_ins = shared.loaderLLM()
|
|
||||||
llm_model_ins.set_history_len(LLM_HISTORY_LEN)
|
|
||||||
|
|
||||||
model_status = init_model(llm_model=llm_model_ins)
|
|
||||||
|
|
||||||
|
|
||||||
default_theme_args = dict(
|
default_theme_args = dict(
|
||||||
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
font=["Source Sans Pro", 'ui-sans-serif', 'system-ui', 'sans-serif'],
|
||||||
|
|
@ -260,10 +264,9 @@ default_theme_args = dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
|
with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as demo:
|
||||||
vs_path, file_status, model_status, vs_list = gr.State(
|
vs_path, file_status, model_status = gr.State(
|
||||||
os.path.join(VS_ROOT_PATH, vs_list[0]) if len(vs_list) > 1 else ""), gr.State(""), gr.State(
|
os.path.join(VS_ROOT_PATH, get_vs_list()[0]) if len(get_vs_list()) > 1 else ""), gr.State(""), gr.State(
|
||||||
model_status), gr.State(vs_list)
|
model_status)
|
||||||
|
|
||||||
gr.Markdown(webui_title)
|
gr.Markdown(webui_title)
|
||||||
with gr.Tab("对话"):
|
with gr.Tab("对话"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
|
@ -283,10 +286,11 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
||||||
inputs=[mode, chatbot],
|
inputs=[mode, chatbot],
|
||||||
outputs=[vs_setting, knowledge_set, chatbot])
|
outputs=[vs_setting, knowledge_set, chatbot])
|
||||||
with vs_setting:
|
with vs_setting:
|
||||||
select_vs = gr.Dropdown(vs_list.value,
|
vs_refresh = gr.Button("更新已有知识库选项")
|
||||||
|
select_vs = gr.Dropdown(get_vs_list(),
|
||||||
label="请选择要加载的知识库",
|
label="请选择要加载的知识库",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
value=vs_list.value[0] if len(vs_list.value) > 0 else None
|
value=get_vs_list()[0] if len(get_vs_list()) > 0 else None
|
||||||
)
|
)
|
||||||
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
|
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
|
||||||
lines=1,
|
lines=1,
|
||||||
|
|
@ -302,19 +306,21 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
||||||
interactive=True, visible=True)
|
interactive=True, visible=True)
|
||||||
with gr.Tab("上传文件"):
|
with gr.Tab("上传文件"):
|
||||||
files = gr.File(label="添加文件",
|
files = gr.File(label="添加文件",
|
||||||
file_types=['.txt', '.md', '.docx', '.pdf'],
|
file_types=['.txt', '.md', '.docx', '.pdf', '.png', '.jpg'],
|
||||||
file_count="multiple",
|
file_count="multiple",
|
||||||
show_label=False)
|
show_label=False)
|
||||||
load_file_button = gr.Button("上传文件并加载知识库")
|
load_file_button = gr.Button("上传文件并加载知识库")
|
||||||
with gr.Tab("上传文件夹"):
|
with gr.Tab("上传文件夹"):
|
||||||
folder_files = gr.File(label="添加文件",
|
folder_files = gr.File(label="添加文件",
|
||||||
# file_types=['.txt', '.md', '.docx', '.pdf'],
|
|
||||||
file_count="directory",
|
file_count="directory",
|
||||||
show_label=False)
|
show_label=False)
|
||||||
load_folder_button = gr.Button("上传文件夹并加载知识库")
|
load_folder_button = gr.Button("上传文件夹并加载知识库")
|
||||||
|
vs_refresh.click(fn=refresh_vs_list,
|
||||||
|
inputs=[],
|
||||||
|
outputs=select_vs)
|
||||||
vs_add.click(fn=add_vs_name,
|
vs_add.click(fn=add_vs_name,
|
||||||
inputs=[vs_name, vs_list, chatbot],
|
inputs=[vs_name, chatbot],
|
||||||
outputs=[select_vs, vs_list, vs_name, vs_add, file2vs, chatbot])
|
outputs=[select_vs, vs_name, vs_add, file2vs, chatbot])
|
||||||
select_vs.change(fn=change_vs_name_input,
|
select_vs.change(fn=change_vs_name_input,
|
||||||
inputs=[select_vs, chatbot],
|
inputs=[select_vs, chatbot],
|
||||||
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
||||||
|
|
@ -366,10 +372,11 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
||||||
inputs=[chunk_conent, gr.Textbox(value="chunk_conent", visible=False), chatbot],
|
inputs=[chunk_conent, gr.Textbox(value="chunk_conent", visible=False), chatbot],
|
||||||
outputs=[chunk_sizes, chatbot])
|
outputs=[chunk_sizes, chatbot])
|
||||||
with vs_setting:
|
with vs_setting:
|
||||||
select_vs = gr.Dropdown(vs_list.value,
|
vs_refresh = gr.Button("更新已有知识库选项")
|
||||||
|
select_vs = gr.Dropdown(get_vs_list(),
|
||||||
label="请选择要加载的知识库",
|
label="请选择要加载的知识库",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
value=vs_list.value[0] if len(vs_list.value) > 0 else None)
|
value=get_vs_list()[0] if len(get_vs_list()) > 0 else None)
|
||||||
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
|
vs_name = gr.Textbox(label="请输入新建知识库名称,当前知识库命名暂不支持中文",
|
||||||
lines=1,
|
lines=1,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
|
|
@ -402,9 +409,12 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
||||||
interactive=True)
|
interactive=True)
|
||||||
load_conent_button = gr.Button("添加内容并加载知识库")
|
load_conent_button = gr.Button("添加内容并加载知识库")
|
||||||
# 将上传的文件保存到content文件夹下,并更新下拉框
|
# 将上传的文件保存到content文件夹下,并更新下拉框
|
||||||
|
vs_refresh.click(fn=refresh_vs_list,
|
||||||
|
inputs=[],
|
||||||
|
outputs=select_vs)
|
||||||
vs_add.click(fn=add_vs_name,
|
vs_add.click(fn=add_vs_name,
|
||||||
inputs=[vs_name, vs_list, chatbot],
|
inputs=[vs_name, chatbot],
|
||||||
outputs=[select_vs, vs_list, vs_name, vs_add, file2vs, chatbot])
|
outputs=[select_vs, vs_name, vs_add, file2vs, chatbot])
|
||||||
select_vs.change(fn=change_vs_name_input,
|
select_vs.change(fn=change_vs_name_input,
|
||||||
inputs=[select_vs, chatbot],
|
inputs=[select_vs, chatbot],
|
||||||
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
outputs=[vs_name, vs_add, file2vs, vs_path, chatbot])
|
||||||
|
|
@ -455,8 +465,8 @@ with gr.Blocks(css=block_css, theme=gr.themes.Default(**default_theme_args)) as
|
||||||
label="向量匹配 top k", interactive=True)
|
label="向量匹配 top k", interactive=True)
|
||||||
load_model_button = gr.Button("重新加载模型")
|
load_model_button = gr.Button("重新加载模型")
|
||||||
load_model_button.click(reinit_model, show_progress=True,
|
load_model_button.click(reinit_model, show_progress=True,
|
||||||
inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2, use_lora,
|
inputs=[llm_model, embedding_model, llm_history_len, no_remote_model, use_ptuning_v2,
|
||||||
top_k, chatbot], outputs=chatbot)
|
use_lora, top_k, chatbot], outputs=chatbot)
|
||||||
|
|
||||||
(demo
|
(demo
|
||||||
.queue(concurrency_count=3)
|
.queue(concurrency_count=3)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue