From 565a94c1bb2b262bc4e57953e9d384335184aa3f Mon Sep 17 00:00:00 2001 From: wvivi2023 Date: Wed, 10 Jan 2024 10:45:47 +0800 Subject: [PATCH] customize word loader --- .DS_Store | Bin 10244 -> 10244 bytes configs/model_config.py.example | 3 + document_loaders/__init__.py | 3 +- document_loaders/mywordload.py | 77 ++++++++++++++++++ requirements.txt | 3 +- server/knowledge_base/kb_doc_api.py | 42 +++++++++- server/knowledge_base/utils.py | 5 +- .../chinese_recursive_text_splitter.py | 5 +- 8 files changed, 131 insertions(+), 7 deletions(-) create mode 100644 document_loaders/mywordload.py diff --git a/.DS_Store b/.DS_Store index 5276a037243498ee07863db78023769616a2cc84..a261c3058273757a0c55a8a824398f30f10cb55e 100644 GIT binary patch delta 155 zcmZn(XbG6$&&abeU^hP_&tx8f@A3i+DGd1x$qc0oxeTccc?=~C@eDaYkwl<)Dnrp^ z8G$AdQ%h4F1yiG19ffK`6Eh1P1q)N-$(sa>32EN^Nx+J6vZg>OyM>vNj)IBV)&Kwi delta 45 zcmZn(XbG6$&&ahgU^hP_*JK`n@0+ECG#Mv53+>$`!p^g?L7Z_jyTUJ)%@ShF%m7c* B4W9r2 diff --git a/configs/model_config.py.example b/configs/model_config.py.example index b203e93..5bcdee4 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -17,6 +17,9 @@ RERANKER_MODEL = "bge-reranker-large" USE_RERANKER = False RERANKER_MAX_LENGTH = 1024 +# 是否启用精排 +USE_RANKING = False + # 如果需要在 EMBEDDING_MODEL 中增加自定义的关键字时配置 EMBEDDING_KEYWORD_FILE = "keywords.txt" EMBEDDING_MODEL_OUTPUT_PATH = "output" diff --git a/document_loaders/__init__.py b/document_loaders/__init__.py index 22340ae..b1de210 100644 --- a/document_loaders/__init__.py +++ b/document_loaders/__init__.py @@ -1,3 +1,4 @@ from .mypdfloader import RapidOCRPDFLoader from .myimgloader import RapidOCRLoader -from .customiedpdfloader import CustomizedPDFLoader \ No newline at end of file +from .customiedpdfloader import CustomizedPDFLoader +from.mywordload import RapidWordLoader \ No newline at end of file diff --git a/document_loaders/mywordload.py b/document_loaders/mywordload.py new file mode 100644 index 0000000..9980d35 --- /dev/null +++ b/document_loaders/mywordload.py @@ -0,0 +1,77 @@ +from typing import List +from langchain.document_loaders.unstructured import UnstructuredFileLoader +from docx import Document as docxDocument +from docx.document import Document as _Document +from docx.table import _Cell +from docx.oxml.text.paragraph import CT_P +from docx.oxml.table import CT_Tbl +from docx.table import _Cell, Table +from docx.text.paragraph import Paragraph + +class RapidWordLoader(UnstructuredFileLoader): + def _get_elements(self) -> List: + def iter_block_items(parent): + """ + Yield each paragraph and table child within *parent*, in document order. + Each returned value is an instance of either Table or Paragraph. + """ + #Document + if isinstance(parent, _Document): + parent_elm = parent._element.body + elif isinstance(parent, _Cell): + parent_elm = parent._element + else: + raise ValueError("something's not right") + + for child in parent_elm.iterchildren(): + if isinstance(child, CT_P): + yield Paragraph(child, parent) + elif isinstance(child, CT_Tbl): + yield Table(child, parent) + + def read_table(table): + # 获取表格列标题 + headers = [cell.text.strip() for cell in table.rows[0].cells] + # 存储表格数据的字符串 + table_string = "" + + # 遍历表格行 + for row_index, row in enumerate(table.rows[1:], 2): # 从第二行开始遍历,因为第一行是标题 + row_data = [] + + # 遍历行中的单元格 + for cell_index, cell in enumerate(row.cells, 1): + cell_text = cell.text.strip() + row_data.append(f'"{headers[cell_index - 1]}": "{cell_text}"') + + # 将每一行的数据连接为字符串,用逗号分隔 + row_string = ", ".join(row_data) + # 将每一行的字符串添加到总的表格字符串中 + table_string += f"{{{row_string}}}\n" + + return table_string + + def word2text(filepath): + resp = "" + try: + doc = docxDocument(filepath) + for block in iter_block_items(doc): + if isinstance(block,Paragraph): + resp += (block.text + "\n\n") + elif isinstance(block, Table): + resp += read_table(block) + "\n" + except ValueError: + print(f"Error:input invalid parameter") + except Exception as e: + print(f"word2text error:{e}") + return resp + + text = word2text(self.file_path) + from unstructured.partition.text import partition_text + return partition_text(text=text, **self.unstructured_kwargs) + + +if __name__ == "__main__": + loader = RapidWordLoader(file_path="/Users/wangvivi/Desktop/MySelf/AI/Test/国家电网公司供电企业组织机构规范标准.docx") + docs = loader.load() + print(docs) diff --git a/requirements.txt b/requirements.txt index 16a6a62..6c9d4df 100644 --- a/requirements.txt +++ b/requirements.txt @@ -79,4 +79,5 @@ streamlit-aggrid>=0.3.4.post3 watchdog>=3.0.0 docx2txt elasticsearch -PyPDF2 \ No newline at end of file +PyPDF2 +jieba \ No newline at end of file diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 702b2b3..16be378 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -17,6 +17,11 @@ from server.db.repository.knowledge_file_repository import get_file_detail from langchain.docstore.document import Document from server.knowledge_base.model.kb_document_model import DocumentWithVSId from typing import List, Dict +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.metrics.pairwise import cosine_similarity +from configs import USE_RANKING +import jieba + def search_docs( @@ -38,7 +43,42 @@ def search_docs( print(f"search_docs, query:{query}") docs = kb.search_docs(query, top_k, score_threshold) print(f"search_docs, docs:{docs}") - data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] + + if USE_RANKING: + queryList = [] + queryList.append(query) + doc_contents = [doc[0].page_content for doc in docs] + + doc_contents = [" ".join(jieba.cut(doc)) for doc in doc_contents] + queryList = [" ".join(jieba.cut(doc)) for doc in queryList] + + #print(f"****** search_docs, doc_contents:{doc_contents}") + #print(f"****** search_docs, queryList:{queryList}") + + vectorizer = TfidfVectorizer() + tfidf_matrix = vectorizer.fit_transform(doc_contents) + print(f"****** search_docs, tfidf_matrix:{tfidf_matrix}") + query_vector = vectorizer.transform(queryList) + print(f"****** search_docs, query_vector:{query_vector}") + cosine_similarities = cosine_similarity(query_vector, tfidf_matrix).flatten() + print(f"****** search_docs, cosine_similarities:{cosine_similarities}") + + # 将相似度分数与文档结合 + docs_with_scores = [(doc, score) for doc, score in zip(docs, cosine_similarities)] + sorted_docs = sorted(docs_with_scores, key=lambda x: x[1], reverse=True) + print(f"****** search_docs, sorted_docs:{sorted_docs}") + data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in sorted_docs] + + docs_with_scores = [(doc, score) for doc, score in zip(docs, cosine_similarities)] + sorted_docs = sorted(docs_with_scores, key=lambda x: x[1], reverse=True) + print(f"****** search_docs, sorted_docs:{sorted_docs}") + data = [DocumentWithVSId(*x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in sorted_docs] + + else: + #data = [DocumentWithScore(**doc[0].dict(), score=score) for doc, score in sorted_docs] + #data = [DocumentWithScore(**x[0].dict(), score=x[1]) for x in docs] + data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] + elif file_name or metadata: print(f"search_docs, kb:{knowledge_base_name}, filename:{file_name}") data = kb.list_docs(file_name=file_name, metadata=metadata) diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 9ae42f1..9e024f3 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -122,7 +122,8 @@ LOADER_DICT = {"UnstructuredHTMLLoader": ['.html'], "UnstructuredPowerPointLoader": ['.ppt', '.pptx'], "EverNoteLoader": ['.enex'], "UnstructuredFileLoader": ['.txt'], - "Docx2txtLoader":['.docx','.doc'], + "Docx2txtLoader":['.doc'], + "RapidWordLoader":['.docx'] } SUPPORTED_EXTS = [ext for sublist in LOADER_DICT.values() for ext in sublist] @@ -162,7 +163,7 @@ def get_loader(loader_name: str, file_path: str, loader_kwargs: Dict = None): ''' loader_kwargs = loader_kwargs or {} try: - if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader"]: + if loader_name in ["RapidOCRPDFLoader", "RapidOCRLoader","FilteredCSVLoader","RapidWordLoader"]: document_loaders_module = importlib.import_module('document_loaders') else: document_loaders_module = importlib.import_module('langchain.document_loaders') diff --git a/text_splitter/chinese_recursive_text_splitter.py b/text_splitter/chinese_recursive_text_splitter.py index 7c66321..66ab041 100644 --- a/text_splitter/chinese_recursive_text_splitter.py +++ b/text_splitter/chinese_recursive_text_splitter.py @@ -64,9 +64,10 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): text = re.sub(r'(\n+\d+[^\S\n]+[^\s\.]+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) #通过1 这样的 text = re.sub(r'(手工分段\*\*\s*)', r"\n\n\n\n\n\n\n\n\n\n", text) # 将“手工分段**”替换 text = re.sub(r'(\n+第\s*\S+\s*章\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过第 章 - + #text = re.sub(r'(\n+表\s*[A-Za-z0-9]+(\s*\.\s*[A-Za-z0-9]+)*\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过表 A.2 + + text = re.sub(r'(\n+表\s*[A-Za-z0-9]+(\s*\.\s*[A-Za-z0-9]+)*\s+)', r"\n\n\n\n\n\n\n\n\1", text) # 通过表 A.2 text = re.sub(r'(\n+(?