增加标题增强文档功能

This commit is contained in:
weiweiw 2025-01-13 11:30:14 +08:00
parent d096443b03
commit 74f4f8174d
1 changed files with 73 additions and 21 deletions

View File

@ -1,4 +1,4 @@
import logging #ChineseRecursiveTextSplitter
import re import re
from typing import Any, List, Optional from typing import Any, List, Optional
@ -9,8 +9,12 @@ from chatchat.utils import build_logger
logger = build_logger() logger = build_logger()
First_SEPARATOE = "\n\n\n\n\n\n\n\n\n\n"
Second_SEPARATOE = "\n\n\n\n\n\n\n\n"
Third_SEPARATOE = "\n\n\n\n\n\n"
Fourth_SEPARATOE = "\n\n\n\n"
def _split_text_with_regex_from_end( def _customer_split_text_with_regex_from_end(
text: str, separator: str, keep_separator: bool text: str, separator: str, keep_separator: bool
) -> List[str]: ) -> List[str]:
# Now that we have the separator, split the text # Now that we have the separator, split the text
@ -28,6 +32,9 @@ def _split_text_with_regex_from_end(
splits = list(text) splits = list(text)
return [s for s in splits if s != ""] return [s for s in splits if s != ""]
def customerLen(text:str)->int:
length = len(re.sub(r'[\s\n]+', '', text))
return length
class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter): class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
def __init__( def __init__(
@ -40,21 +47,45 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
"""Create a new TextSplitter.""" """Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs) super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or [ self._separators = separators or [
"\n\n", First_SEPARATOE,
"\n", Second_SEPARATOE,
"。||", Third_SEPARATOE,
"\.\s|\!\s|\?\s", Fourth_SEPARATOE
"|;\s",
"|,\s",
] ]
self._is_separator_regex = is_separator_regex self._is_separator_regex = is_separator_regex
self.is_recursive = False
self._length_function = customerLen
def _split_text(self, text: str, separators: List[str]) -> List[str]: def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks.""" """Split incoming text and return chunks."""
#print(f"***********************************ChineseRecursiveTextSplitter***********************************")
final_chunks = [] final_chunks = []
# Get appropriate separator to use # Get appropriate separator to use
separator = separators[-1] separator = separators[-1]
new_separators = [] new_separators = []
if self.is_recursive == False:
#一级目录
text = re.sub(r'(\n+前\s+言\n+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) #通过前言分块
text = re.sub(r'(\n+\d+[^\S\n]+[^\s\.]+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) #通过1 这样的
text = re.sub(r'(手工分段\*\*\s*)', r"\n\n\n\n\n\n\n\n\n\n", text) # 将“手工分段**”替换
text = re.sub(r'(\n+第\s*\S+\s*章\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过第 章
#二级目录
text = re.sub(r'(\n+表\s*[A-Za-z0-9]+(\s*\.\s*[A-Za-z0-9]+)*\s+)', r"\n\n\n\n\n\n\n\n\1", text) # 通过表 A.2
text = re.sub(r'(\n+(?<!\.|[a-zA-Z0-9])[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9]))', r"\n\n\n\n\n\n\n\n\1", text) # 通过\n1.2 这样的章和节来分块
text = re.sub(r'(\n+第\s*\S+\s*条\s+)', r"\n\n\n\n\n\n\n\n\1", text) # 通过第 条
text = re.sub(r'(\n+第\s*\S+\s*条(:|))', r"\n\n\n\n\n\n\n\n\1", text) # 通过第 条
text = re.sub(r'(\n+(一、|二、|三、|四、|五、|六、|七、|八、|九、|十、|十一、|十二、|十三、|十四、|十五、|十六、|十七、|十八、|十九、|二十、))', r"\n\n\n\n\n\n\n\n\1", text) # 通过第 条
#三级目录
text = re.sub(r'(\n+(?<!\.|[a-zA-Z0-9])[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9]))', r"\n\n\n\n\n\n\1", text) # 再通过 1.2.3
text = re.sub(r'(\n+((一)|(二)|(三)|(四)|(五)|(六)|(七)|(八)|(九)|(十)|(十一)|(十二)|(十三)|(十四)|(十五)|(十六)|(十七)|(十八)|(十九)|(二十)))', r"\n\n\n\n\n\n\1", text)
text = re.sub(r'(\n+(\(一\)|\(二\)|\(三\)|\(四\)|\(五\)|\(六\)|\(七\)|\(八\)|\(九\)|\(十\)|\(十一\)|\(十二\)|\(十三\)|\(十四\)|\(十五\)|\(十六\)|\(十七\)|\(十八\)|\(十九\)|\(二十\)))', r"\n\n\n\n\n\n\1", text)
# 不支持对四级目录分块,如果需要通过手工分段来实现
# text = re.sub(r'(\n+(?<!\.|[a-zA-Z0-9])[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9]))', r"\n\n\n\n\1", text) # 再通过 1.2.3
text = text.rstrip() # 段尾如果有多余的\n就去掉它
self.is_recursive = True
for i, _s in enumerate(separators): for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s) _separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "": if _s == "":
@ -62,37 +93,58 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
break break
if re.search(_separator, text): if re.search(_separator, text):
separator = _s separator = _s
new_separators = separators[i + 1 :] new_separators = separators[i + 1:]
break break
_separator = separator if self._is_separator_regex else re.escape(separator) _separator = separator if self._is_separator_regex else re.escape(separator)
splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator) splits = _customer_split_text_with_regex_from_end(text, _separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts. # Now go merging things, recursively splitting longer texts.
_good_splits = [] _good_splits = []
_separator = "" if self._keep_separator else separator _separator = "" if self._keep_separator else separator
for s in splits: for s in splits:
#print(f"***s:{s},len:{self._length_function(s)}")
if self._length_function(s) < self._chunk_size: if self._length_function(s) < self._chunk_size:
_good_splits.append(s) _good_splits.append(s)
#print(f"***_good_splits.append(s):{s}")
else: else:
if _good_splits: if _good_splits:
#print(f"***_merge_splits(s):{s}")
merged_text = self._merge_splits(_good_splits, _separator) merged_text = self._merge_splits(_good_splits, _separator)
#print(f"***after _merge_splits,merged_text:{merged_text}")
final_chunks.extend(merged_text) final_chunks.extend(merged_text)
_good_splits = [] _good_splits = []
if not new_separators: if not new_separators:
final_chunks.append(s) final_chunks.append(s)
#print(f"***final_chunks.append(s)")
else: else:
#print(f"***下一级_split_text(s)")
other_info = self._split_text(s, new_separators) other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info) final_chunks.extend(other_info)
if _good_splits: if _good_splits:
#print(f"***22_merge_splits(s):{s}")
merged_text = self._merge_splits(_good_splits, _separator) merged_text = self._merge_splits(_good_splits, _separator)
#print(f"***22after _merge_splits,merged_text:{merged_text}")
final_chunks.extend(merged_text) final_chunks.extend(merged_text)
return [
re.sub(r"\n{2,}", "\n", chunk.strip())
for chunk in final_chunks
if chunk.strip() != ""
]
final_chunks = [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip()!=""]
#将两行以内并且字数小于25和下面的分块合并
return_chunks = []
temp_sencond = ""
for chunk in final_chunks:
if temp_sencond =="":
if len(chunk.splitlines()) <= 2 and len(chunk) <= 25:
temp_sencond = chunk
else:
return_chunks.append(chunk)
else:
return_chunks.append(temp_sencond + "\n" + chunk)
temp_sencond = ""
if temp_sencond !="":
return_chunks.append(temp_sencond)
return return_chunks
if __name__ == "__main__": if __name__ == "__main__":
text_splitter = ChineseRecursiveTextSplitter( text_splitter = ChineseRecursiveTextSplitter(