fix: streamlit ui在上传多个文件到知识库时出错

This commit is contained in:
liunux4odoo 2023-07-25 22:36:49 +08:00
parent 466f0c9c97
commit 5f74f70515
1 changed files with 16 additions and 13 deletions

View File

@ -1,6 +1,8 @@
import streamlit as st import streamlit as st
from streamlit_chatbox import st_chatbox from streamlit_chatbox import st_chatbox
import tempfile import tempfile
from pathlib import Path
###### 从webui借用的代码 ##### ###### 从webui借用的代码 #####
###### 做了少量修改 ##### ###### 做了少量修改 #####
import os import os
@ -101,23 +103,23 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR
def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation):
vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") vs_path = Path(KB_ROOT_PATH) / vs_id / "vector_store"
filelist = [] con_path = Path(KB_ROOT_PATH) / vs_id / "content"
if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")): con_path.mkdir(parents=True, exist_ok=True)
os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content"))
qa = st.session_state.local_doc_qa qa = st.session_state.local_doc_qa
if qa.llm_model_chain and qa.embeddings: if qa.llm_model_chain and qa.embeddings:
filelist = []
if isinstance(files, list): if isinstance(files, list):
for file in files: for file in files:
filename = os.path.split(file.name)[-1] filename = os.path.split(file.name)[-1]
shutil.move(file.name, os.path.join( target = con_path / filename
KB_ROOT_PATH, vs_id, "content", filename)) shutil.move(file.name, target)
filelist.append(os.path.join( filelist.append(str(target))
KB_ROOT_PATH, vs_id, "content", filename))
vs_path, loaded_files = qa.init_knowledge_vector_store( vs_path, loaded_files = qa.init_knowledge_vector_store(
filelist, vs_path, sentence_size) filelist, str(vs_path), sentence_size)
else: else:
vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation, vs_path, loaded_files = qa.one_knowledge_add(str(vs_path), files, one_conent, one_content_segmentation,
sentence_size) sentence_size)
if len(loaded_files): if len(loaded_files):
file_status = f"已添加 {''.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问" file_status = f"已添加 {''.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问"
@ -322,7 +324,8 @@ with st.sidebar:
sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE) sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE)
files = st.file_uploader('上传知识文件', files = st.file_uploader('上传知识文件',
['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'], ['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'],
accept_multiple_files=True) accept_multiple_files=True,
)
if st.button('添加文件到知识库'): if st.button('添加文件到知识库'):
temp_dir = tempfile.mkdtemp() temp_dir = tempfile.mkdtemp()
file_list = [] file_list = []