fix: streamlit ui在上传多个文件到知识库时出错

This commit is contained in:
liunux4odoo 2023-07-25 22:36:49 +08:00
parent 466f0c9c97
commit 5f74f70515
1 changed files with 16 additions and 13 deletions

View File

@ -1,6 +1,8 @@
import streamlit as st
from streamlit_chatbox import st_chatbox
import tempfile
from pathlib import Path
###### 从webui借用的代码 #####
###### 做了少量修改 #####
import os
@ -101,23 +103,23 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store")
filelist = []
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")):
os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
vs_path = Path(KB_ROOT_PATH) / vs_id / "vector_store"
con_path = Path(KB_ROOT_PATH) / vs_id / "content"
con_path.mkdir(parents=True, exist_ok=True)
qa = st.session_state.local_doc_qa
if qa.llm_model_chain and qa.embeddings:
filelist = []
if isinstance(files, list):
for file in files:
filename = os.path.split(file.name)[-1]
shutil.move(file.name, os.path.join(
KB_ROOT_PATH, vs_id, "content", filename))
filelist.append(os.path.join(
KB_ROOT_PATH, vs_id, "content", filename))
target = con_path / filename
shutil.move(file.name, target)
filelist.append(str(target))
vs_path, loaded_files = qa.init_knowledge_vector_store(
filelist, vs_path, sentence_size)
filelist, str(vs_path), sentence_size)
else:
vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation,
vs_path, loaded_files = qa.one_knowledge_add(str(vs_path), files, one_conent, one_content_segmentation,
sentence_size)
if len(loaded_files):
file_status = f"已添加 {''.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
@ -322,7 +324,8 @@ with st.sidebar:
sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE)
files = st.file_uploader('上传知识文件',
['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'],
accept_multiple_files=True)
accept_multiple_files=True,
)
if st.button('添加文件到知识库'):
temp_dir = tempfile.mkdtemp()
file_list = []
@ -331,8 +334,8 @@ with st.sidebar:
with open(file, 'wb') as fp:
fp.write(f.getvalue())
file_list.append(TempFile(file))
_, _, history = get_vector_store(
vs_path, file_list, sentence_size, [], None, None)
_, _, history = get_vector_store(
vs_path, file_list, sentence_size, [], None, None)
st.session_state.files = []