diff --git a/webui_st.py b/webui_st.py index 1584a55..8309210 100644 --- a/webui_st.py +++ b/webui_st.py @@ -1,6 +1,8 @@ import streamlit as st from streamlit_chatbox import st_chatbox import tempfile +from pathlib import Path + ###### 从webui借用的代码 ##### ###### 做了少量修改 ##### import os @@ -101,23 +103,23 @@ def get_answer(query, vs_path, history, mode, score_threshold=VECTOR_SEARCH_SCOR def get_vector_store(vs_id, files, sentence_size, history, one_conent, one_content_segmentation): - vs_path = os.path.join(KB_ROOT_PATH, vs_id, "vector_store") - filelist = [] - if not os.path.exists(os.path.join(KB_ROOT_PATH, vs_id, "content")): - os.makedirs(os.path.join(KB_ROOT_PATH, vs_id, "content")) + vs_path = Path(KB_ROOT_PATH) / vs_id / "vector_store" + con_path = Path(KB_ROOT_PATH) / vs_id / "content" + con_path.mkdir(parents=True, exist_ok=True) + qa = st.session_state.local_doc_qa if qa.llm_model_chain and qa.embeddings: + filelist = [] if isinstance(files, list): for file in files: filename = os.path.split(file.name)[-1] - shutil.move(file.name, os.path.join( - KB_ROOT_PATH, vs_id, "content", filename)) - filelist.append(os.path.join( - KB_ROOT_PATH, vs_id, "content", filename)) + target = con_path / filename + shutil.move(file.name, target) + filelist.append(str(target)) vs_path, loaded_files = qa.init_knowledge_vector_store( - filelist, vs_path, sentence_size) + filelist, str(vs_path), sentence_size) else: - vs_path, loaded_files = qa.one_knowledge_add(vs_path, files, one_conent, one_content_segmentation, + vs_path, loaded_files = qa.one_knowledge_add(str(vs_path), files, one_conent, one_content_segmentation, sentence_size) if len(loaded_files): file_status = f"已添加 {'、'.join([os.path.split(i)[-1] for i in loaded_files if i])} 内容至知识库,并已加载知识库,请开始提问" @@ -322,7 +324,8 @@ with st.sidebar: sentence_size = st.slider('文本入库分句长度限制', 1, 1000, SENTENCE_SIZE) files = st.file_uploader('上传知识文件', ['docx', 'txt', 'md', 'csv', 'xlsx', 'pdf'], - accept_multiple_files=True) + accept_multiple_files=True, + ) if st.button('添加文件到知识库'): temp_dir = tempfile.mkdtemp() file_list = [] @@ -331,8 +334,8 @@ with st.sidebar: with open(file, 'wb') as fp: fp.write(f.getvalue()) file_list.append(TempFile(file)) - _, _, history = get_vector_store( - vs_path, file_list, sentence_size, [], None, None) + _, _, history = get_vector_store( + vs_path, file_list, sentence_size, [], None, None) st.session_state.files = []