增加标题增强文档功能
This commit is contained in:
parent
d096443b03
commit
74f4f8174d
|
|
@ -1,4 +1,4 @@
|
|||
import logging
|
||||
#ChineseRecursiveTextSplitter
|
||||
import re
|
||||
from typing import Any, List, Optional
|
||||
|
||||
|
|
@ -9,8 +9,12 @@ from chatchat.utils import build_logger
|
|||
|
||||
logger = build_logger()
|
||||
|
||||
First_SEPARATOE = "\n\n\n\n\n\n\n\n\n\n"
|
||||
Second_SEPARATOE = "\n\n\n\n\n\n\n\n"
|
||||
Third_SEPARATOE = "\n\n\n\n\n\n"
|
||||
Fourth_SEPARATOE = "\n\n\n\n"
|
||||
|
||||
def _split_text_with_regex_from_end(
|
||||
def _customer_split_text_with_regex_from_end(
|
||||
text: str, separator: str, keep_separator: bool
|
||||
) -> List[str]:
|
||||
# Now that we have the separator, split the text
|
||||
|
|
@ -28,6 +32,9 @@ def _split_text_with_regex_from_end(
|
|||
splits = list(text)
|
||||
return [s for s in splits if s != ""]
|
||||
|
||||
def customerLen(text:str)->int:
|
||||
length = len(re.sub(r'[\s\n]+', '', text))
|
||||
return length
|
||||
|
||||
class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
|
||||
def __init__(
|
||||
|
|
@ -40,21 +47,45 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
|
|||
"""Create a new TextSplitter."""
|
||||
super().__init__(keep_separator=keep_separator, **kwargs)
|
||||
self._separators = separators or [
|
||||
"\n\n",
|
||||
"\n",
|
||||
"。|!|?",
|
||||
"\.\s|\!\s|\?\s",
|
||||
";|;\s",
|
||||
",|,\s",
|
||||
First_SEPARATOE,
|
||||
Second_SEPARATOE,
|
||||
Third_SEPARATOE,
|
||||
Fourth_SEPARATOE
|
||||
]
|
||||
self._is_separator_regex = is_separator_regex
|
||||
self.is_recursive = False
|
||||
self._length_function = customerLen
|
||||
|
||||
def _split_text(self, text: str, separators: List[str]) -> List[str]:
|
||||
"""Split incoming text and return chunks."""
|
||||
#print(f"***********************************ChineseRecursiveTextSplitter***********************************")
|
||||
final_chunks = []
|
||||
# Get appropriate separator to use
|
||||
separator = separators[-1]
|
||||
new_separators = []
|
||||
if self.is_recursive == False:
|
||||
#一级目录
|
||||
text = re.sub(r'(\n+前\s+言\n+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) #通过前言分块
|
||||
text = re.sub(r'(\n+\d+[^\S\n]+[^\s\.]+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) #通过1 这样的
|
||||
text = re.sub(r'(手工分段\*\*\s*)', r"\n\n\n\n\n\n\n\n\n\n", text) # 将“手工分段**”替换
|
||||
text = re.sub(r'(\n+第\s*\S+\s*章\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过第 章
|
||||
|
||||
#二级目录
|
||||
text = re.sub(r'(\n+表\s*[A-Za-z0-9]+(\s*\.\s*[A-Za-z0-9]+)*\s+)', r"\n\n\n\n\n\n\n\n\1", text) # 通过表 A.2
|
||||
text = re.sub(r'(\n+(?<!\.|[a-zA-Z0-9])[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9]))', r"\n\n\n\n\n\n\n\n\1", text) # 通过\n1.2 这样的章和节来分块
|
||||
text = re.sub(r'(\n+第\s*\S+\s*条\s+)', r"\n\n\n\n\n\n\n\n\1", text) # 通过第 条
|
||||
text = re.sub(r'(\n+第\s*\S+\s*条(:|:))', r"\n\n\n\n\n\n\n\n\1", text) # 通过第 条
|
||||
text = re.sub(r'(\n+(一、|二、|三、|四、|五、|六、|七、|八、|九、|十、|十一、|十二、|十三、|十四、|十五、|十六、|十七、|十八、|十九、|二十、))', r"\n\n\n\n\n\n\n\n\1", text) # 通过第 条
|
||||
|
||||
#三级目录
|
||||
text = re.sub(r'(\n+(?<!\.|[a-zA-Z0-9])[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9]))', r"\n\n\n\n\n\n\1", text) # 再通过 1.2.3
|
||||
text = re.sub(r'(\n+((一)|(二)|(三)|(四)|(五)|(六)|(七)|(八)|(九)|(十)|(十一)|(十二)|(十三)|(十四)|(十五)|(十六)|(十七)|(十八)|(十九)|(二十)))', r"\n\n\n\n\n\n\1", text)
|
||||
text = re.sub(r'(\n+(\(一\)|\(二\)|\(三\)|\(四\)|\(五\)|\(六\)|\(七\)|\(八\)|\(九\)|\(十\)|\(十一\)|\(十二\)|\(十三\)|\(十四\)|\(十五\)|\(十六\)|\(十七\)|\(十八\)|\(十九\)|\(二十\)))', r"\n\n\n\n\n\n\1", text)
|
||||
|
||||
# 不支持对四级目录分块,如果需要通过手工分段来实现
|
||||
# text = re.sub(r'(\n+(?<!\.|[a-zA-Z0-9])[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9]))', r"\n\n\n\n\1", text) # 再通过 1.2.3
|
||||
text = text.rstrip() # 段尾如果有多余的\n就去掉它
|
||||
self.is_recursive = True
|
||||
for i, _s in enumerate(separators):
|
||||
_separator = _s if self._is_separator_regex else re.escape(_s)
|
||||
if _s == "":
|
||||
|
|
@ -62,37 +93,58 @@ class ChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
|
|||
break
|
||||
if re.search(_separator, text):
|
||||
separator = _s
|
||||
new_separators = separators[i + 1 :]
|
||||
new_separators = separators[i + 1:]
|
||||
break
|
||||
|
||||
_separator = separator if self._is_separator_regex else re.escape(separator)
|
||||
splits = _split_text_with_regex_from_end(text, _separator, self._keep_separator)
|
||||
splits = _customer_split_text_with_regex_from_end(text, _separator, self._keep_separator)
|
||||
|
||||
# Now go merging things, recursively splitting longer texts.
|
||||
_good_splits = []
|
||||
_separator = "" if self._keep_separator else separator
|
||||
for s in splits:
|
||||
#print(f"***s:{s},len:{self._length_function(s)}")
|
||||
if self._length_function(s) < self._chunk_size:
|
||||
_good_splits.append(s)
|
||||
#print(f"***_good_splits.append(s):{s}")
|
||||
else:
|
||||
if _good_splits:
|
||||
#print(f"***_merge_splits(s):{s}")
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
#print(f"***after _merge_splits,merged_text:{merged_text}")
|
||||
final_chunks.extend(merged_text)
|
||||
_good_splits = []
|
||||
if not new_separators:
|
||||
final_chunks.append(s)
|
||||
#print(f"***final_chunks.append(s)")
|
||||
else:
|
||||
#print(f"***下一级_split_text(s)")
|
||||
other_info = self._split_text(s, new_separators)
|
||||
final_chunks.extend(other_info)
|
||||
if _good_splits:
|
||||
#print(f"***22_merge_splits(s):{s}")
|
||||
merged_text = self._merge_splits(_good_splits, _separator)
|
||||
#print(f"***22after _merge_splits,merged_text:{merged_text}")
|
||||
final_chunks.extend(merged_text)
|
||||
return [
|
||||
re.sub(r"\n{2,}", "\n", chunk.strip())
|
||||
for chunk in final_chunks
|
||||
if chunk.strip() != ""
|
||||
]
|
||||
|
||||
final_chunks = [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip()!=""]
|
||||
#将两行以内并且字数小于25,和下面的分块合并
|
||||
return_chunks = []
|
||||
temp_sencond = ""
|
||||
for chunk in final_chunks:
|
||||
if temp_sencond =="":
|
||||
if len(chunk.splitlines()) <= 2 and len(chunk) <= 25:
|
||||
temp_sencond = chunk
|
||||
else:
|
||||
return_chunks.append(chunk)
|
||||
else:
|
||||
return_chunks.append(temp_sencond + "\n" + chunk)
|
||||
temp_sencond = ""
|
||||
|
||||
if temp_sencond !="":
|
||||
return_chunks.append(temp_sencond)
|
||||
|
||||
return return_chunks
|
||||
|
||||
if __name__ == "__main__":
|
||||
text_splitter = ChineseRecursiveTextSplitter(
|
||||
|
|
|
|||
Loading…
Reference in New Issue