diff --git a/README.md b/README.md index de1240a..75c0536 100644 --- a/README.md +++ b/README.md @@ -144,7 +144,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch --- -### 知识库Text Splitter类型支持 +### Text Splitter 个性化支持 本项目支持调用 [Langchain](https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.text_splitter) 的Text Splitter分词器以及基于此改进的自定义分词器,已支持的Text Splitter类型如下: @@ -242,10 +242,8 @@ embedding_model_dict = { ```python text_splitter_dict = { "ChineseRecursiveTextSplitter": { - "source": "huggingface", ## 选择tiktoken则使用openai的方法 + "source": "huggingface", ## 选择tiktoken则使用openai的方法,不填写则默认为字符长度切割方法。 "tokenizer_name_or_path": "", ## 空格不填则默认使用大模型的分词器。 - "chunk_size": 250, # 知识库中单段文本长度 - "overlap_size": 50, # 知识库中相邻文本重合长度 } } ``` diff --git a/configs/model_config.py.example b/configs/model_config.py.example index f501dca..50806bf 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -123,11 +123,11 @@ LLM_DEVICE = "auto" text_splitter_dict = { "ChineseRecursiveTextSplitter": { - "source": "huggingface", ## 选择tiktoken则使用openai的方法 + "source": "huggingface", "tokenizer_name_or_path": "gpt2", }, "SpacyTextSplitter": { - "source": "huggingface", + "source": "", "tokenizer_name_or_path": "", }, "RecursiveCharacterTextSplitter": { @@ -149,8 +149,6 @@ text_splitter_dict = { # TEXT_SPLITTER 名称 TEXT_SPLITTER_NAME = "SpacyTextSplitter" - - # 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) CHUNK_SIZE = 250 diff --git a/docs/splitter.md b/docs/splitter.md index 144d827..f44a7e8 100644 --- a/docs/splitter.md +++ b/docs/splitter.md @@ -7,12 +7,10 @@ from .my_splitter import MySplitter ``` 2. 修改```config/model_config.py```文件,将您的分词器名字添加到```text_splitter_dict```中,如下所示: ```python -MySplitter": { +MySplitter: { "source": "huggingface", ## 选择tiktoken则使用openai的方法 - "tokenizer_name_or_path": "your tokenizer", - "chunk_size": 250, - "overlap_size": 50, - }, + "tokenizer_name_or_path": "your tokenizer", #如果选择huggingface则使用huggingface的方法,部分tokenizer需要从Huggingface下载 + } TEXT_SPLITTER_NAME = "MySplitter" ``` 完成上述步骤后,就能使用自己的分词器了。 diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index bda727c..f8b102a 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -199,7 +199,7 @@ def make_text_splitter( text_splitter_module = importlib.import_module('langchain.text_splitter') TextSplitter = getattr(text_splitter_module, splitter_name) - if text_splitter_dict[splitter_name]["source"] == "tiktoken": + if text_splitter_dict[splitter_name]["source"] == "tiktoken": ## 从tiktoken加载 try: text_splitter = TextSplitter.from_tiktoken_encoder( encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], @@ -213,7 +213,7 @@ def make_text_splitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap ) - elif text_splitter_dict[splitter_name]["source"] == "huggingface": + elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载 if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "": text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \ llm_model_dict[LLM_MODEL]["local_model_path"] @@ -221,8 +221,8 @@ def make_text_splitter( if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2": from transformers import GPT2TokenizerFast from langchain.text_splitter import CharacterTextSplitter - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") ## 这里选择你用的tokenizer - else: + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") + else: ## 字符长度加载 tokenizer = AutoTokenizer.from_pretrained( text_splitter_dict[splitter_name]["tokenizer_name_or_path"], trust_remote_code=True) @@ -231,6 +231,18 @@ def make_text_splitter( chunk_size=chunk_size, chunk_overlap=chunk_overlap ) + else: + try: + text_splitter = TextSplitter( + pipeline="zh_core_web_sm", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + except: + text_splitter = TextSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) except Exception as e: print(e) text_splitter_module = importlib.import_module('langchain.text_splitter') diff --git a/tests/custom_splitter/test_different_splitter.py b/tests/custom_splitter/test_different_splitter.py index 5c80cc1..9993a9c 100644 --- a/tests/custom_splitter/test_different_splitter.py +++ b/tests/custom_splitter/test_different_splitter.py @@ -6,87 +6,36 @@ import sys sys.path.append("../..") from configs.model_config import ( CHUNK_SIZE, - OVERLAP_SIZE, - text_splitter_dict, llm_model_dict, LLM_MODEL, TEXT_SPLITTER_NAME + OVERLAP_SIZE ) -import langchain.document_loaders -import importlib +from server.knowledge_base.utils import make_text_splitter -def test_different_splitter( - splitter_name, - chunk_size: int = CHUNK_SIZE, - chunk_overlap: int = OVERLAP_SIZE, -): - if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定 - headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on'] - text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter( - headers_to_split_on=headers_to_split_on) - - else: - - try: ## 优先使用用户自定义的text_splitter - text_splitter_module = importlib.import_module('text_splitter') - TextSplitter = getattr(text_splitter_module, splitter_name) - except: ## 否则使用langchain的text_splitter - text_splitter_module = importlib.import_module('langchain.text_splitter') - TextSplitter = getattr(text_splitter_module, splitter_name) - - if text_splitter_dict[splitter_name]["source"] == "tiktoken": - try: - text_splitter = TextSplitter.from_tiktoken_encoder( - encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], - pipeline="zh_core_web_sm", - chunk_size=chunk_size, - chunk_overlap=chunk_overlap - ) - except: - text_splitter = TextSplitter.from_tiktoken_encoder( - encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], - chunk_size=chunk_size, - chunk_overlap=chunk_overlap - ) - elif text_splitter_dict[splitter_name]["source"] == "huggingface": - if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "": - text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \ - llm_model_dict[LLM_MODEL]["local_model_path"] - - if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2": - from transformers import GPT2TokenizerFast - from langchain.text_splitter import CharacterTextSplitter - tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") ## 这里选择你用的tokenizer - else: - tokenizer = AutoTokenizer.from_pretrained( - text_splitter_dict[splitter_name]["tokenizer_name_or_path"], - trust_remote_code=True) - text_splitter = TextSplitter.from_huggingface_tokenizer( - tokenizer=tokenizer, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap - ) - - return text_splitter - - -if __name__ == "__main__": +def text(splitter_name): from langchain import document_loaders # 使用DocumentLoader读取文件 filepath = "../../knowledge_base/samples/content/test.txt" loader = document_loaders.UnstructuredFileLoader(filepath, autodetect_encoding=True) docs = loader.load() - text_splitter = test_different_splitter(TEXT_SPLITTER_NAME, CHUNK_SIZE, OVERLAP_SIZE) - # 使用text_splitter进行分词 - - if TEXT_SPLITTER_NAME == "MarkdownHeaderTextSplitter": + text_splitter = make_text_splitter(splitter_name, CHUNK_SIZE, OVERLAP_SIZE) + if splitter_name == "MarkdownHeaderTextSplitter": split_docs = text_splitter.split_text(docs[0].page_content) for doc in docs: - # 如果文档有元数据 if doc.metadata: doc.metadata["source"] = os.path.basename(filepath) else: split_docs = text_splitter.split_documents(docs) + return docs - # 打印分词结果 - for doc in split_docs: - print(doc) + + + +import pytest +@pytest.mark.parametrize("splitter_name", ["ChineseRecursiveTextSplitter", "SpacyTextSplitter", "RecursiveCharacterTextSplitter","MarkdownHeaderTextSplitter"]) +def test_different_splitter(splitter_name): + try: + docs = text(splitter_name) + assert docs is not None + except Exception as e: + pytest.fail(f"test_different_splitter failed with {splitter_name}, error: {str(e)}")