diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index 8b9cff9..f9b7207 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -9,6 +9,7 @@ import os from configs.model_config import * import datetime from typing import List +from textsplitter import ChineseTextSplitter # return top-k text chunk from vector store VECTOR_SEARCH_TOP_K = 6 @@ -17,6 +18,18 @@ VECTOR_SEARCH_TOP_K = 6 LLM_HISTORY_LEN = 3 +def load_file(filepath): + if filepath.lower().endswith(".pdf"): + loader = UnstructuredFileLoader(filepath) + textsplitter = ChineseTextSplitter(pdf=True) + docs = loader.load_and_split(textsplitter) + else: + loader = UnstructuredFileLoader(filepath, mode="elements") + textsplitter = ChineseTextSplitter(pdf=False) + docs = loader.load_and_split(text_splitter=textsplitter) + return docs + + class LocalDocQA: llm: object = None embeddings: object = None @@ -48,10 +61,10 @@ class LocalDocQA: elif os.path.isfile(filepath): file = os.path.split(filepath)[-1] try: - loader = UnstructuredFileLoader(filepath, mode="elements") - docs = loader.load() + docs = load_file(filepath) print(f"{file} 已成功加载") - except: + except Exception as e: + print(e) print(f"{file} 未能成功加载") return None elif os.path.isdir(filepath): @@ -59,25 +72,25 @@ class LocalDocQA: for file in os.listdir(filepath): fullfilepath = os.path.join(filepath, file) try: - loader = UnstructuredFileLoader(fullfilepath, mode="elements") - docs += loader.load() + docs += load_file(fullfilepath) print(f"{file} 已成功加载") - except: + except Exception as e: + print(e) print(f"{file} 未能成功加载") else: docs = [] for file in filepath: try: - loader = UnstructuredFileLoader(file, mode="elements") - docs += loader.load() + docs += load_file(file) print(f"{file} 已成功加载") - except: + except Exception as e: + print(e) print(f"{file} 未能成功加载") vector_store = FAISS.from_documents(docs, self.embeddings) vs_path = f"""./vector_store/{os.path.splitext(file)[0]}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""" vector_store.save_local(vs_path) - return vs_path if len(docs)>0 else None + return vs_path if len(docs) > 0 else None def get_knowledge_based_answer(self, query, diff --git a/cli_demo.py b/cli_demo.py index cda072d..c9f2703 100644 --- a/cli_demo.py +++ b/cli_demo.py @@ -2,7 +2,7 @@ from configs.model_config import * from chains.local_doc_qa import LocalDocQA # return top-k text chunk from vector store -VECTOR_SEARCH_TOP_K = 10 +VECTOR_SEARCH_TOP_K = 6 # LLM input history length LLM_HISTORY_LEN = 3 diff --git a/content/1.pdf b/content/1.pdf new file mode 100644 index 0000000..504a1df Binary files /dev/null and b/content/1.pdf differ diff --git a/textsplitter/__init__.py b/textsplitter/__init__.py new file mode 100644 index 0000000..38a0587 --- /dev/null +++ b/textsplitter/__init__.py @@ -0,0 +1,2 @@ + +from .chinese_text_splitter import * \ No newline at end of file diff --git a/textsplitter/chinese_text_splitter.py b/textsplitter/chinese_text_splitter.py new file mode 100644 index 0000000..72b1903 --- /dev/null +++ b/textsplitter/chinese_text_splitter.py @@ -0,0 +1,25 @@ +from langchain.text_splitter import CharacterTextSplitter +import re +from typing import List + + +class ChineseTextSplitter(CharacterTextSplitter): + def __init__(self, pdf: bool = False, **kwargs): + super().__init__(**kwargs) + self.pdf = pdf + + def split_text(self, text: str) -> List[str]: + if self.pdf: + text = re.sub(r"\n{3,}", "\n", text) + text = re.sub('\s', ' ', text) + text = text.replace("\n\n", "") + sent_sep_pattern = re.compile('([﹒﹔﹖﹗.。!?]["’”」』]{0,2}|(?=["‘“「『]{1,2}|$))') # del :; + sent_list = [] + for ele in sent_sep_pattern.split(text): + if sent_sep_pattern.match(ele) and sent_list: + sent_list[-1] += ele + elif ele: + sent_list.append(ele) + return sent_list + +