From 24a280ce8cf860070d497880f7ecd09ad6930b5a Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Wed, 9 Aug 2023 23:09:24 +0800 Subject: [PATCH] re-add zh_title_enhance.py --- configs/__init__.py | 1 + configs/model_config.py.example | 5 ++ server/knowledge_base/utils.py | 10 ++-- text_splitter/__init__.py | 3 +- text_splitter/zh_title_enhance.py | 99 +++++++++++++++++++++++++++++++ 5 files changed, 113 insertions(+), 5 deletions(-) create mode 100644 configs/__init__.py create mode 100644 text_splitter/zh_title_enhance.py diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 0000000..0bed9b6 --- /dev/null +++ b/configs/__init__.py @@ -0,0 +1 @@ +from .model_config import * \ No newline at end of file diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 1183432..7454cab 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -291,3 +291,8 @@ kbs_config = { "secure": False, } } + +# 是否开启中文标题加强,以及标题增强的相关配置 +# 通过增加标题判断,判断哪些文本为标题,并在metadata中进行标记; +# 然后将文本与往上一级的标题进行拼合,实现文本信息的增强。 +ZH_TITLE_ENHANCE = False \ No newline at end of file diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 6d52360..a863acb 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,10 +1,9 @@ -from typing import Union import os from langchain.embeddings.huggingface import HuggingFaceEmbeddings from configs.model_config import (embedding_model_dict, KB_ROOT_PATH, EMBEDDING_MODEL, kbs_config) from functools import lru_cache -import langchain.document_loaders import sys +from text_splitter import zh_title_enhance def validate_kb_name(knowledge_base_id: str) -> bool: @@ -73,11 +72,14 @@ class KnowledgeFile: # TODO: 增加依据文件格式匹配text_splitter self.text_splitter_name = "CharacterTextSplitter" - def file2text(self): + def file2text(self, using_zh_title_enhance): DocumentLoader = getattr(sys.modules['langchain.document_loaders'], self.document_loader_name) loader = DocumentLoader(self.filepath) # TODO: 增加依据文件格式匹配text_splitter TextSplitter = getattr(sys.modules['langchain.text_splitter'], self.text_splitter_name) text_splitter = TextSplitter(chunk_size=250, chunk_overlap=200) - return loader.load_and_split(text_splitter) + docs = loader.load_and_split(text_splitter) + if using_zh_title_enhance: + docs = zh_title_enhance(docs) + return docs diff --git a/text_splitter/__init__.py b/text_splitter/__init__.py index 7dfd66a..1c4b665 100644 --- a/text_splitter/__init__.py +++ b/text_splitter/__init__.py @@ -1 +1,2 @@ -from .MyTextSplitter import MyTextSplitter \ No newline at end of file +from .MyTextSplitter import MyTextSplitter +from .zh_title_enhance import zh_title_enhance \ No newline at end of file diff --git a/text_splitter/zh_title_enhance.py b/text_splitter/zh_title_enhance.py new file mode 100644 index 0000000..7f8c548 --- /dev/null +++ b/text_splitter/zh_title_enhance.py @@ -0,0 +1,99 @@ +from langchain.docstore.document import Document +import re + + +def under_non_alpha_ratio(text: str, threshold: float = 0.5): + """Checks if the proportion of non-alpha characters in the text snippet exceeds a given + threshold. This helps prevent text like "-----------BREAK---------" from being tagged + as a title or narrative text. The ratio does not count spaces. + + Parameters + ---------- + text + The input string to test + threshold + If the proportion of non-alpha characters exceeds this threshold, the function + returns False + """ + if len(text) == 0: + return False + + alpha_count = len([char for char in text if char.strip() and char.isalpha()]) + total_count = len([char for char in text if char.strip()]) + try: + ratio = alpha_count / total_count + return ratio < threshold + except: + return False + + +def is_possible_title( + text: str, + title_max_word_length: int = 20, + non_alpha_threshold: float = 0.5, +) -> bool: + """Checks to see if the text passes all of the checks for a valid title. + + Parameters + ---------- + text + The input text to check + title_max_word_length + The maximum number of words a title can contain + non_alpha_threshold + The minimum number of alpha characters the text needs to be considered a title + """ + + # 文本长度为0的话,肯定不是title + if len(text) == 0: + print("Not a title. Text is empty.") + return False + + # 文本中有标点符号,就不是title + ENDS_IN_PUNCT_PATTERN = r"[^\w\s]\Z" + ENDS_IN_PUNCT_RE = re.compile(ENDS_IN_PUNCT_PATTERN) + if ENDS_IN_PUNCT_RE.search(text) is not None: + return False + + # 文本长度不能超过设定值,默认20 + # NOTE(robinson) - splitting on spaces here instead of word tokenizing because it + # is less expensive and actual tokenization doesn't add much value for the length check + if len(text) > title_max_word_length: + return False + + # 文本中数字的占比不能太高,否则不是title + if under_non_alpha_ratio(text, threshold=non_alpha_threshold): + return False + + # NOTE(robinson) - Prevent flagging salutations like "To My Dearest Friends," as titles + if text.endswith((",", ".", ",", "。")): + return False + + if text.isnumeric(): + print(f"Not a title. Text is all numeric:\n\n{text}") # type: ignore + return False + + # 开头的字符内应该有数字,默认5个字符内 + if len(text) < 5: + text_5 = text + else: + text_5 = text[:5] + alpha_in_text_5 = sum(list(map(lambda x: x.isnumeric(), list(text_5)))) + if not alpha_in_text_5: + return False + + return True + + +def zh_title_enhance(docs: Document) -> Document: + title = None + if len(docs) > 0: + for doc in docs: + if is_possible_title(doc.page_content): + doc.metadata['category'] = 'cn_Title' + title = doc.page_content + elif title: + doc.page_content = f"下文与({title})有关。{doc.page_content}" + return docs + else: + print("文件不存在")