增加自定义分词器和自定义标题增强

This commit is contained in:
weiweiw 2025-01-07 16:36:02 +08:00
parent fd1a46ffd9
commit 773f9f275a
5 changed files with 456 additions and 3 deletions

View File

@ -8,6 +8,9 @@ from uuid import UUID
from langchain.callbacks import AsyncIteratorCallbackHandler from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.schema import AgentAction, AgentFinish from langchain.schema import AgentAction, AgentFinish
from langchain_core.outputs import LLMResult from langchain_core.outputs import LLMResult
from chatchat.utils import build_logger
logger = build_logger()
def dumps(obj: Dict) -> str: def dumps(obj: Dict) -> str:
@ -31,6 +34,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.done = asyncio.Event() self.done = asyncio.Event()
self.out = True self.out = True
logger.info(f"init....")
async def on_llm_start( async def on_llm_start(
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
@ -41,6 +45,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
} }
self.done.clear() self.done.clear()
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
logger.info(f"prompts:{prompts}")
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
special_tokens = ["\nAction:", "\nObservation:", "<|observation|>"] special_tokens = ["\nAction:", "\nObservation:", "<|observation|>"]
@ -79,6 +84,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
} }
self.done.clear() self.done.clear()
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
logger.info(f"messages:{messages}")
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
data = { data = {
@ -86,6 +92,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
"text": response.generations[0][0].message.content, "text": response.generations[0][0].message.content,
} }
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
logger.info(f"response:{response.json}")
async def on_llm_error( async def on_llm_error(
self, error: Exception | KeyboardInterrupt, **kwargs: Any self, error: Exception | KeyboardInterrupt, **kwargs: Any
@ -114,6 +121,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
"tool_input": input_str, "tool_input": input_str,
} }
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
logger.info(f"input_str:{input_str}")
async def on_tool_end( async def on_tool_end(
self, self,
@ -132,6 +140,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
} }
# self.done.clear() # self.done.clear()
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
logger.info(f"output:{output}")
async def on_tool_error( async def on_tool_error(
self, self,
@ -151,6 +160,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
} }
# self.done.clear() # self.done.clear()
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
logger.error(f"error:{error.__class__}")
async def on_agent_action( async def on_agent_action(
self, self,
@ -168,6 +178,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
"text": action.log, "text": action.log,
} }
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
logger.error(f"tool_name:{action.tool},tool_input:{ action.tool_input}")
async def on_agent_finish( async def on_agent_finish(
self, self,
@ -188,6 +199,7 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
"text": finish.return_values["output"], "text": finish.return_values["output"],
} }
self.queue.put_nowait(dumps(data)) self.queue.put_nowait(dumps(data))
logger.error(f"data:{data}")
async def on_chain_end( async def on_chain_end(
self, self,
@ -200,3 +212,4 @@ class AgentExecutorAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
) -> None: ) -> None:
self.done.set() self.done.set()
self.out = True self.out = True
logger.info(f"outputs:{outputs}")

View File

@ -5,10 +5,12 @@ import numpy as np
import tqdm import tqdm
from langchain_community.document_loaders.unstructured import UnstructuredFileLoader from langchain_community.document_loaders.unstructured import UnstructuredFileLoader
from PIL import Image from PIL import Image
import re
from chatchat.settings import Settings from chatchat.settings import Settings
from chatchat.server.file_rag.document_loaders.ocr import get_ocr from chatchat.server.file_rag.document_loaders.ocr import get_ocr
from chatchat.utils import build_logger
logger = build_logger()
class RapidOCRPDFLoader(UnstructuredFileLoader): class RapidOCRPDFLoader(UnstructuredFileLoader):
def _get_elements(self) -> List: def _get_elements(self) -> List:
@ -53,9 +55,11 @@ class RapidOCRPDFLoader(UnstructuredFileLoader):
) )
b_unit.refresh() b_unit.refresh()
text = page.get_text("") text = page.get_text("")
resp += text + "\n" # resp += text + "\n"
text_lines = text.strip().split("\n")
logger.info(f"****page:{i+1}****,文字内容:{text_lines}")
img_list = page.get_image_info(xrefs=True) img_list = page.get_image_info(xrefs=True)
ocr_result = []
for img in img_list: for img in img_list:
if xref := img.get("xref"): if xref := img.get("xref"):
bbox = img["bbox"] bbox = img["bbox"]
@ -86,8 +90,20 @@ class RapidOCRPDFLoader(UnstructuredFileLoader):
ocr_result = [line[1] for line in result] ocr_result = [line[1] for line in result]
resp += "\n".join(ocr_result) resp += "\n".join(ocr_result)
if (len(ocr_result)>0):
resp += "\n".join(ocr_result)
else:
if text_lines:
# 假设页码在最后一行
if text_lines[-1].isdigit():
text = "\n".join(text_lines[:-1])
print(f"******去除了页码")
resp += text + "\n"
# 更新进度 # 更新进度
b_unit.update(1) b_unit.update(1)
resp = re.sub(r'((?<!.)\d+(?!\.|[a-zA-Z0-9]))', r"\1 ", resp)
resp = re.sub(r'((?<!.)[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+(?!\.|[a-zA-Z0-9]))', r"\1 ", resp)
resp = re.sub(r'((?<!.)[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+(?!\.|[a-zA-Z0-9]))', r"\1 ", resp)
return resp return resp
text = pdf2text(self.file_path) text = pdf2text(self.file_path)

View File

@ -0,0 +1,192 @@
import logging
import re
from typing import Any, List, Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from chatchat.utils import build_logger
logger = build_logger()
First_SEPARATOE = "\n\n\n\n\n\n\n\n\n\n"
Second_SEPARATOE = "\n\n\n\n\n\n\n\n"
Third_SEPARATOE = "\n\n\n\n\n\n"
Fourth_SEPARATOE = "\n\n\n\n"
def _customer_split_text_with_regex_from_end(
text: str, separator: str, keep_separator: bool
) -> List[str]:
# Now that we have the separator, split the text
if separator:
if keep_separator:
# The parentheses in the pattern keep the delimiters in the result.
_splits = re.split(f"({separator})", text)
splits = ["".join(i) for i in zip(_splits[0::2], _splits[1::2])]
if len(_splits) % 2 == 1:
splits += _splits[-1:]
# splits = [_splits[0]] + splits
else:
splits = re.split(separator, text)
else:
splits = list(text)
return [s for s in splits if s != ""]
def customerLen(text:str)->int:
length = len(re.sub(r'[\s\n]+', '', text))
return length
class CustomerChineseRecursiveTextSplitter(RecursiveCharacterTextSplitter):
def __init__(
self,
separators: Optional[List[str]] = None,
keep_separator: bool = True,
is_separator_regex: bool = True,
**kwargs: Any,
) -> None:
"""Create a new TextSplitter."""
super().__init__(keep_separator=keep_separator, **kwargs)
self._separators = separators or [
First_SEPARATOE,
Second_SEPARATOE,
Third_SEPARATOE,
Fourth_SEPARATOE
#"\n\n",
#"\n",
# "。||",
# "\.\s|\!\s|\?\s",
# "|;\s",
# "|,\s"
]
self._is_separator_regex = is_separator_regex
self.is_recursive = False
self._length_function = customerLen
def _split_text(self, text: str, separators: List[str]) -> List[str]:
"""Split incoming text and return chunks."""
#print(f"***********************************ChineseRecursiveTextSplitter***********************************")
final_chunks = []
# Get appropriate separator to use
separator = separators[-1]
new_separators = []
if self.is_recursive == False:
#一级目录
text = re.sub(r'(\n+前\s+言\n+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) #通过前言分块
text = re.sub(r'(\n+\d+[^\S\n]+[^\s\.]+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) #通过1 这样的
text = re.sub(r'(手工分段\*\*\s*)', r"\n\n\n\n\n\n\n\n\n\n", text) # 将“手工分段**”替换
text = re.sub(r'(\n+第\s*\S+\s*章\s+)', r"\n\n\n\n\n\n\n\n\n\n\1", text) # 通过第 章
#二级目录
text = re.sub(r'(\n+表\s*[A-Za-z0-9]+(\s*\.\s*[A-Za-z0-9]+)*\s+)', r"\n\n\n\n\n\n\n\n\1", text) # 通过表 A.2
text = re.sub(r'(\n+(?<!\.|[a-zA-Z0-9])[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9]))', r"\n\n\n\n\n\n\n\n\1", text) # 通过\n1.2 这样的章和节来分块
text = re.sub(r'(\n+第\s*\S+\s*条\s+)', r"\n\n\n\n\n\n\n\n\1", text) # 通过第 条
text = re.sub(r'(\n+第\s*\S+\s*条(:|))', r"\n\n\n\n\n\n\n\n\1", text) # 通过第 条
text = re.sub(r'(\n+(一、|二、|三、|四、|五、|六、|七、|八、|九、|十、|十一、|十二、|十三、|十四、|十五、|十六、|十七、|十八、|十九、|二十、))', r"\n\n\n\n\n\n\n\n\1", text) # 通过第 条
#三级目录
text = re.sub(r'(\n+(?<!\.|[a-zA-Z0-9])[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9]))', r"\n\n\n\n\n\n\1", text) # 再通过 1.2.3
text = re.sub(r'(\n+((一)|(二)|(三)|(四)|(五)|(六)|(七)|(八)|(九)|(十)|(十一)|(十二)|(十三)|(十四)|(十五)|(十六)|(十七)|(十八)|(十九)|(二十)))', r"\n\n\n\n\n\n\1", text)
text = re.sub(r'(\n+(\(一\)|\(二\)|\(三\)|\(四\)|\(五\)|\(六\)|\(七\)|\(八\)|\(九\)|\(十\)|\(十一\)|\(十二\)|\(十三\)|\(十四\)|\(十五\)|\(十六\)|\(十七\)|\(十八\)|\(十九\)|\(二十\)))', r"\n\n\n\n\n\n\1", text)
# 不支持对四级目录分块,如果需要通过手工分段来实现
# text = re.sub(r'(\n+(?<!\.|[a-zA-Z0-9])[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9]))', r"\n\n\n\n\1", text) # 再通过 1.2.3
text = text.rstrip() # 段尾如果有多余的\n就去掉它
self.is_recursive = True
for i, _s in enumerate(separators):
_separator = _s if self._is_separator_regex else re.escape(_s)
if _s == "":
separator = _s
break
if re.search(_separator, text):
separator = _s
new_separators = separators[i + 1:]
break
_separator = separator if self._is_separator_regex else re.escape(separator)
splits = _customer_split_text_with_regex_from_end(text, _separator, self._keep_separator)
# Now go merging things, recursively splitting longer texts.
_good_splits = []
_separator = "" if self._keep_separator else separator
for s in splits:
#print(f"***s:{s},len:{self._length_function(s)}")
if self._length_function(s) < self._chunk_size:
_good_splits.append(s)
#print(f"***_good_splits.append(s):{s}")
else:
if _good_splits:
#print(f"***_merge_splits(s):{s}")
merged_text = self._merge_splits(_good_splits, _separator)
#print(f"***after _merge_splits,merged_text:{merged_text}")
final_chunks.extend(merged_text)
_good_splits = []
if not new_separators:
final_chunks.append(s)
#print(f"***final_chunks.append(s)")
else:
#print(f"***下一级_split_text(s)")
other_info = self._split_text(s, new_separators)
final_chunks.extend(other_info)
if _good_splits:
#print(f"***22_merge_splits(s):{s}")
merged_text = self._merge_splits(_good_splits, _separator)
#print(f"***22after _merge_splits,merged_text:{merged_text}")
final_chunks.extend(merged_text)
final_chunks = [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip()!=""]
#将两行以内并且字数小于25和下面的分块合并
return_chunks = []
temp_sencond = ""
for chunk in final_chunks:
if temp_sencond =="":
if len(chunk.splitlines()) <= 2 and len(chunk) <= 25:
temp_sencond = chunk
else:
return_chunks.append(chunk)
else:
return_chunks.append(temp_sencond + "\n" + chunk)
temp_sencond = ""
if temp_sencond !="":
return_chunks.append(temp_sencond)
return return_chunks
#return [re.sub(r"\n{2,}", "\n", chunk.strip()) for chunk in final_chunks if chunk.strip()!=""]
if __name__ == "__main__":
text_splitter = CustomerChineseRecursiveTextSplitter(
keep_separator=True, is_separator_regex=True, chunk_size=50, chunk_overlap=0
)
# ls = [
# """中国对外贸易形势报告75页。前 10 个月,一般贸易进出口 19.5 万亿元,增长 25.1% 比整体进出口增速高出 2.9 个百分点,占进出口总额的 61.7%,较去年同期提升 1.6 个百分点。其中,一般贸易出口 10.6 万亿元,增长 25.3%,占出口总额的 60.9%,提升 1.5 个百分点进口8.9万亿元增长24.9%占进口总额的62.7% 提升 1.8 个百分点。加工贸易进出口 6.8 万亿元,增长 11.8% 占进出口总额的 21.5%,减少 2.0 个百分点。其中,出口增 长 10.4%,占出口总额的 24.3%,减少 2.6 个百分点;进口增 长 14.2%,占进口总额的 18.0%,减少 1.2 个百分点。此外, 以保税物流方式进出口 3.96 万亿元,增长 27.9%。其中,出 口 1.47 万亿元,增长 38.9%;进口 2.49 万亿元,增长 22.2%。前三季度,中国服务贸易继续保持快速增长态势。服务 进出口总额 37834.3 亿元,增长 11.6%;其中服务出口 17820.9 亿元,增长 27.3%;进口 20013.4 亿元,增长 0.5%,进口增 速实现了疫情以来的首次转正。服务出口增幅大于进口 26.8 个百分点,带动服务贸易逆差下降 62.9%至 2192.5 亿元。服 务贸易结构持续优化,知识密集型服务进出口 16917.7 亿元, 增长 13.3%,占服务进出口总额的比重达到 44.7%,提升 0.7 个百分点。 二、中国对外贸易发展环境分析和展望 全球疫情起伏反复,经济复苏分化加剧,大宗商品价格 上涨、能源紧缺、运力紧张及发达经济体政策调整外溢等风 险交织叠加。同时也要看到,我国经济长期向好的趋势没有 改变,外贸企业韧性和活力不断增强,新业态新模式加快发 展,创新转型步伐提速。产业链供应链面临挑战。美欧等加快出台制造业回迁计 划,加速产业链供应链本土布局,跨国公司调整产业链供应 链,全球双链面临新一轮重构,区域化、近岸化、本土化、 短链化趋势凸显。疫苗供应不足,制造业“缺芯”、物流受限、 运价高企,全球产业链供应链面临压力。 全球通胀持续高位运行。能源价格上涨加大主要经济体 的通胀压力,增加全球经济复苏的不确定性。世界银行今年 10 月发布《大宗商品市场展望》指出,能源价格在 2021 年 大涨逾 80%,并且仍将在 2022 年小幅上涨。IMF 指出,全 球通胀上行风险加剧,通胀前景存在巨大不确定性。""",
# ]
# # text = """"""
# for inum, text in enumerate(ls):
# print(inum)
# chunks = text_splitter.split_text(text)
# for chunk in chunks:
# print(chunk)
ls=["""
5 技术要求
5.1 一般要求
5.1.1 智能安全帽应符合GB 2811 中的基本性能要求
5.1.2 智能模块与安全帽本体之间的连接应牢固可靠且不得影响安全帽佩戴的稳定性及 正常防护功能
5.1.3 智能模块的外壳防护等级应符合 GB/T 4208 IP54 的要求
5.1.4 智能模块重量基础型不宜超过300g
5.1.5 智能模块应能存储不低于8h 的采集回传的位置信息文件
5.1.6 智能模块应具有低电量报警功能电量低于20%应能给出清晰的报警提示
5.1.7 电池应符合GB 31241中的相关要求支持智能模块持续工作时间不得小于8h
5.1.8 无线通信应符合 GB 21288 中电磁辐射局部暴露限值的规定
5.1.9 智能安全帽应配合管理系统使用管理系统的功能应符合附录B的要求
"""]
#
for inum, temptext in enumerate(ls):
print(f"**************分段:{inum}")
chunks = text_splitter.split_text(temptext)
i = 0
for chunk in chunks:
print(f"**************:chunk {i}:{chunk}")
i = i+1

View File

@ -0,0 +1,231 @@
from langchain.docstore.document import Document
import re
from chatchat.utils import build_logger
logger = build_logger()
def get_fist_level_title(
text: str,
) -> bool:
# 文本长度为0,肯定不是title
if len(text) == 0:
print("Not a title. Text is empty or longer than 25.")
return ""
splitlines = text.splitlines()
first_line = splitlines[0]
# 文本中有标点符号就不是title
ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z"
ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN)
if ENDS_IN_PUNCT_RE.search(first_line) is not None:
return ""
FIRST_TITLE = r'((?<!.)\d+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9])|((?<!.)第\s*\S+\s*章\s+\S+))'
TITLE_PUNCT_RE = re.compile(FIRST_TITLE)
if TITLE_PUNCT_RE.search(first_line) is not None:
return first_line
return ""
#return the 2nd level title
def get_second_level_title(
text: str,
) -> str:
# 文本长度为0的话肯定不是title
lenght = len(text)
if lenght == 0:
print("Not a title. Text is empty or longer than 25.")
return ""
splitlines = text.splitlines()
first_line = splitlines[0]
# 文本中有标点符号就不是title
# ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z"
# ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN)
# if ENDS_IN_PUNCT_RE.search(first_line) is not None:
# return ""
#3 ****
#3.1 *****
#3.1.1 *****
#另一个分块
#3.1.2 ***** 所以二级目录可能在第二行 和第一行
Second_TITLE = r'((?<!.)[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9])|(?<!.)第\s*\S+\s*条\s+|(?<!.)第\s*\S+\s*条(:|)|(?<!.)(一、|二、|三、|四、|五、|六、|七、|八、|九、|十、|十一、|十二、|十三、|十四、|十五、|十六、|十七、|十八、|十九、|二十、))'
TITLE_PUNCT_RE = re.compile(Second_TITLE)
if TITLE_PUNCT_RE.search(first_line) is not None:
return first_line
else:
if len(splitlines)>1:
Second_line = splitlines[1]
if TITLE_PUNCT_RE.search(Second_line) is not None:
return Second_line
return ""
#judge if it is 2nd level content
def is_second_level_content(
text: str,
) -> bool:
# 文本长度为0的话肯定不是title
if len(text) == 0:
print("Not a title. Text is empty.")
return False
splitlines = text.splitlines()
first_line = splitlines[0]
Second_TITLE = r'((?<!.)[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9]))|(?<!.)(表\s*[A-Za-z0-9]+(\s*\.\s*[A-Za-z0-9]+)*\s+)|(?<!.)第\s*\S+\s*条\s+|(?<!.)第\s*\S+\s*条(:|)|(?<!.)(一、|二、|三、|四、|五、|六、|七、|八、|九、|十、|十一、|十二、|十三、|十四、|十五、|十六、|十七、|十八、|十九、|二十、)'
TITLE_PUNCT_RE = re.compile(Second_TITLE)
if TITLE_PUNCT_RE.search(first_line) is not None:
return True
return False
#judge if it is 3rd level content
def is_third_level_content(
text: str,
) -> bool:
# 文本长度为0的话肯定不是title
if len(text) == 0:
print("Not a title. Text is empty.")
return False
splitlines = text.splitlines()
first_line = splitlines[0]
Third_TITLE = r'((?<!.)[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9]))|((?<!.)表\s*[A-Za-z0-9]+(\s*\.\s*[A-Za-z0-9]+)*\s+)|((?<!.)(一)|(二)|(三)|(四)|(五)|(六)|(七)|(八)|(九)|(十)|(十一)|(十二)|(十三)|(十四)|(十五)|(十六)|(十七)|(十八)|(十九)|(二十))|((?<!.)(\(一\)|\(二\)|\(三\)|\(四\)|\(五\)|\(六\)|\(七\)|\(八\)|\(九\)|\(十\)|\(十一\)|\(十二\)|\(十三\)|\(十四\)|\(十五\)|\(十六\)|\(十七\)|\(十八\)|\(十九\)|\(二十\)))'
TITLE_PUNCT_RE = re.compile(Third_TITLE)
if TITLE_PUNCT_RE.search(first_line) is not None:
return True
return False
def get_third_level_title(
text: str,
) -> str:
# 文本长度为0的话肯定不是title
if len(text) == 0:
print("Not a title. Text is empty or longer than 25.")
return ""
splitlines = text.splitlines()
first_line = splitlines[0]
# 文本中有标点符号就不是title
# ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z"
# ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN)
# if ENDS_IN_PUNCT_RE.search(first_line) is not None:
# return ""
#3 ****
#3.1 *****
#3.1.1 *****
#3.1.1.1 *****
#另一个分块
#3.1.1.2 ***** 所以三级级目录可能在第三行 和第二行及第一行
Third_TITLE = r'((?<!.)[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9]))'
TITLE_PUNCT_RE = re.compile(Third_TITLE)
if TITLE_PUNCT_RE.search(first_line) is not None:
return first_line
else:
if len(splitlines)>1:
Second_line = splitlines[1]
if TITLE_PUNCT_RE.search(Second_line) is not None:
return Second_line
else:
if len(splitlines)>2:
Second_line = splitlines[2]
if TITLE_PUNCT_RE.search(Second_line) is not None:
return Second_line
return ""
#judge if it is 4th level content
def is_fourth_level_content(
text: str,
) -> bool:
# 文本长度为0的话肯定不是title
if len(text) == 0:
print("Not a title. Text is empty.")
return False
splitlines = text.splitlines()
first_line = splitlines[0]
Third_TITLE = r'((?<!.)[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+\s*\.\s*[a-zA-Z0-9]+[^\S\n]+[^\s\.]+(?!\.|[a-zA-Z0-9]))'
TITLE_PUNCT_RE = re.compile(Third_TITLE)
if TITLE_PUNCT_RE.search(first_line) is not None:
return True
return False
#给四级被分开的内容 增加三级标题
def zh_third_title_enhance(docs: Document) -> Document:
title = None
#print(f"zh_third_title_enhance ....")
if len(docs) > 0:
for doc in docs:
#print(f"zh_third_title_enhance: {doc}")
third_title = get_third_level_title(doc.page_content)
if third_title:
title = third_title
#print(f"title: {title}")
elif title:
#print(f"title is not none")
temp_fourth_content = is_fourth_level_content(doc.page_content)
if temp_fourth_content:
#print(f"is_fourth_level_content : {temp_fourth_content}")
doc.page_content = f"{title} {doc.page_content}"
else:
title = None
#print(f"final title: {title}")
return docs
else:
print("zh_third_title_enhance 文件不存在")
#给三级被分开的内容 增加二级标题
def zh_second_title_enhance(docs: Document) -> Document:
title = None
if len(docs) > 0:
for doc in docs:
logger.debug(f"zh_second_title_enhance: {doc}")
second_title = get_second_level_title(doc.page_content)
if second_title:
title = second_title
logger.debug(f"title: {title}")
elif title:
#print(f"title is not none")
temp_third_content = is_third_level_content(doc.page_content)
if temp_third_content:
#print(f"is_third_level_content : {temp_third_content}")
doc.page_content = f"{title} {doc.page_content}"
else:
title = None
logger.debug(f"final title: {title}")
return docs
else:
print("zh_second_title_enhance 文件不存在")
#给二级被分开的内容 增加一级标题
def zh_first_title_enhance(docs: Document) -> Document:
title = None
if len(docs) > 0:
for doc in docs:
logger.debug(f"zh_first_title_enhance: {doc}")
first_title = get_fist_level_title(doc.page_content)
if first_title:
title = first_title
logger.debug(f"title: {title}")
elif title:
temp_second_content = is_second_level_content(doc.page_content)
if temp_second_content:
logger.debug(f"is_second_level_content : {temp_second_content}")
doc.page_content = f"{title} {doc.page_content}"
else:
title = None
logger.debug(f"final title: {title}")
return docs
else:
print("zh_first_title_enhance 文件不存在")
if __name__ == "__main__":
str = """1 总 则\n1.1 本导则是编制和审查城市电力网(以下简称城网)规划的指导性文件,其 适用范围为国家电网公司所属的各网省公司、城市供电公司。\n1.2 城网是城市行政区划内为城市供电的各级电压电网的总称。城网是电力系 统的主要负荷中心,作为城市的重要基础设施之一,与城市的社会经济发展密切 相关。各城市应根据《中华人民共和国城市规划法》和《中华人民共和国电力法》 的相关规定,编制城网规划,并纳入相应的城市总体规划和各地区详细规划中。\n1.3 城网规划是城市总体规划的重要组成部分,应与城市的各项发展规划相互 配合、同步实施,做到与城市规划相协调,落实规划中所确定的线路走廊和地下 通道、变电站和配电室站址等供电设施用地。\n1.4 城网规划的目的是通过科学的规划,建设网络坚强、结构合理、安全可靠、 运行灵活、节能环保、经济高效的城市电网,不断提高城网供电能力和电能质量, 以满足城市经济增长和社会发展的需要。 ' metadata={'source': '/home/bns001/Langchain-Chatchat_0.2.9/knowledge_base/test/content/资产全寿命周期管理体系实施指南.docx'}"""
title = get_fist_level_title(str)
print(title)

View File

@ -71,6 +71,7 @@ def search_docs(
if kb is not None: if kb is not None:
if query: if query:
docs = kb.search_docs(query, top_k, score_threshold) docs = kb.search_docs(query, top_k, score_threshold)
logger.info(f"search_docs, query:{query},top_k:{top_k},score_threshold:{score_threshold}")
# data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs] # data = [DocumentWithVSId(**x[0].dict(), score=x[1], id=x[0].metadata.get("id")) for x in docs]
data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs] data = [DocumentWithVSId(**{"id": x.metadata.get("id"), **x.dict()}) for x in docs]
elif file_name or metadata: elif file_name or metadata: