parent
dc413120e2
commit
769d75d784
|
|
@ -144,7 +144,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### 知识库Text Splitter类型支持
|
### Text Splitter 个性化支持
|
||||||
|
|
||||||
本项目支持调用 [Langchain](https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.text_splitter) 的Text Splitter分词器以及基于此改进的自定义分词器,已支持的Text Splitter类型如下:
|
本项目支持调用 [Langchain](https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.text_splitter) 的Text Splitter分词器以及基于此改进的自定义分词器,已支持的Text Splitter类型如下:
|
||||||
|
|
||||||
|
|
@ -242,10 +242,8 @@ embedding_model_dict = {
|
||||||
```python
|
```python
|
||||||
text_splitter_dict = {
|
text_splitter_dict = {
|
||||||
"ChineseRecursiveTextSplitter": {
|
"ChineseRecursiveTextSplitter": {
|
||||||
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
"source": "huggingface", ## 选择tiktoken则使用openai的方法,不填写则默认为字符长度切割方法。
|
||||||
"tokenizer_name_or_path": "", ## 空格不填则默认使用大模型的分词器。
|
"tokenizer_name_or_path": "", ## 空格不填则默认使用大模型的分词器。
|
||||||
"chunk_size": 250, # 知识库中单段文本长度
|
|
||||||
"overlap_size": 50, # 知识库中相邻文本重合长度
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -123,11 +123,11 @@ LLM_DEVICE = "auto"
|
||||||
|
|
||||||
text_splitter_dict = {
|
text_splitter_dict = {
|
||||||
"ChineseRecursiveTextSplitter": {
|
"ChineseRecursiveTextSplitter": {
|
||||||
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
"source": "huggingface",
|
||||||
"tokenizer_name_or_path": "gpt2",
|
"tokenizer_name_or_path": "gpt2",
|
||||||
},
|
},
|
||||||
"SpacyTextSplitter": {
|
"SpacyTextSplitter": {
|
||||||
"source": "huggingface",
|
"source": "",
|
||||||
"tokenizer_name_or_path": "",
|
"tokenizer_name_or_path": "",
|
||||||
},
|
},
|
||||||
"RecursiveCharacterTextSplitter": {
|
"RecursiveCharacterTextSplitter": {
|
||||||
|
|
@ -149,8 +149,6 @@ text_splitter_dict = {
|
||||||
# TEXT_SPLITTER 名称
|
# TEXT_SPLITTER 名称
|
||||||
TEXT_SPLITTER_NAME = "SpacyTextSplitter"
|
TEXT_SPLITTER_NAME = "SpacyTextSplitter"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
|
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
|
||||||
CHUNK_SIZE = 250
|
CHUNK_SIZE = 250
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,10 @@ from .my_splitter import MySplitter
|
||||||
```
|
```
|
||||||
2. 修改```config/model_config.py```文件,将您的分词器名字添加到```text_splitter_dict```中,如下所示:
|
2. 修改```config/model_config.py```文件,将您的分词器名字添加到```text_splitter_dict```中,如下所示:
|
||||||
```python
|
```python
|
||||||
MySplitter": {
|
MySplitter: {
|
||||||
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
||||||
"tokenizer_name_or_path": "your tokenizer",
|
"tokenizer_name_or_path": "your tokenizer", #如果选择huggingface则使用huggingface的方法,部分tokenizer需要从Huggingface下载
|
||||||
"chunk_size": 250,
|
}
|
||||||
"overlap_size": 50,
|
|
||||||
},
|
|
||||||
TEXT_SPLITTER_NAME = "MySplitter"
|
TEXT_SPLITTER_NAME = "MySplitter"
|
||||||
```
|
```
|
||||||
完成上述步骤后,就能使用自己的分词器了。
|
完成上述步骤后,就能使用自己的分词器了。
|
||||||
|
|
|
||||||
|
|
@ -199,7 +199,7 @@ def make_text_splitter(
|
||||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||||
|
|
||||||
if text_splitter_dict[splitter_name]["source"] == "tiktoken":
|
if text_splitter_dict[splitter_name]["source"] == "tiktoken": ## 从tiktoken加载
|
||||||
try:
|
try:
|
||||||
text_splitter = TextSplitter.from_tiktoken_encoder(
|
text_splitter = TextSplitter.from_tiktoken_encoder(
|
||||||
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||||
|
|
@ -213,7 +213,7 @@ def make_text_splitter(
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap
|
chunk_overlap=chunk_overlap
|
||||||
)
|
)
|
||||||
elif text_splitter_dict[splitter_name]["source"] == "huggingface":
|
elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
|
||||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
|
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
|
||||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
|
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
|
||||||
llm_model_dict[LLM_MODEL]["local_model_path"]
|
llm_model_dict[LLM_MODEL]["local_model_path"]
|
||||||
|
|
@ -221,8 +221,8 @@ def make_text_splitter(
|
||||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
||||||
from transformers import GPT2TokenizerFast
|
from transformers import GPT2TokenizerFast
|
||||||
from langchain.text_splitter import CharacterTextSplitter
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") ## 这里选择你用的tokenizer
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
else:
|
else: ## 字符长度加载
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||||
trust_remote_code=True)
|
trust_remote_code=True)
|
||||||
|
|
@ -231,6 +231,18 @@ def make_text_splitter(
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap
|
chunk_overlap=chunk_overlap
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
text_splitter = TextSplitter(
|
||||||
|
pipeline="zh_core_web_sm",
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
text_splitter = TextSplitter(
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||||
|
|
|
||||||
|
|
@ -6,87 +6,36 @@ import sys
|
||||||
sys.path.append("../..")
|
sys.path.append("../..")
|
||||||
from configs.model_config import (
|
from configs.model_config import (
|
||||||
CHUNK_SIZE,
|
CHUNK_SIZE,
|
||||||
OVERLAP_SIZE,
|
OVERLAP_SIZE
|
||||||
text_splitter_dict, llm_model_dict, LLM_MODEL, TEXT_SPLITTER_NAME
|
|
||||||
)
|
)
|
||||||
import langchain.document_loaders
|
|
||||||
import importlib
|
|
||||||
|
|
||||||
|
from server.knowledge_base.utils import make_text_splitter
|
||||||
|
|
||||||
def test_different_splitter(
|
def text(splitter_name):
|
||||||
splitter_name,
|
|
||||||
chunk_size: int = CHUNK_SIZE,
|
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
|
||||||
):
|
|
||||||
if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定
|
|
||||||
headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on']
|
|
||||||
text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter(
|
|
||||||
headers_to_split_on=headers_to_split_on)
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
try: ## 优先使用用户自定义的text_splitter
|
|
||||||
text_splitter_module = importlib.import_module('text_splitter')
|
|
||||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
|
||||||
except: ## 否则使用langchain的text_splitter
|
|
||||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
|
||||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
|
||||||
|
|
||||||
if text_splitter_dict[splitter_name]["source"] == "tiktoken":
|
|
||||||
try:
|
|
||||||
text_splitter = TextSplitter.from_tiktoken_encoder(
|
|
||||||
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
|
||||||
pipeline="zh_core_web_sm",
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
text_splitter = TextSplitter.from_tiktoken_encoder(
|
|
||||||
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap
|
|
||||||
)
|
|
||||||
elif text_splitter_dict[splitter_name]["source"] == "huggingface":
|
|
||||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
|
|
||||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
|
|
||||||
llm_model_dict[LLM_MODEL]["local_model_path"]
|
|
||||||
|
|
||||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
|
||||||
from transformers import GPT2TokenizerFast
|
|
||||||
from langchain.text_splitter import CharacterTextSplitter
|
|
||||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") ## 这里选择你用的tokenizer
|
|
||||||
else:
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
|
||||||
trust_remote_code=True)
|
|
||||||
text_splitter = TextSplitter.from_huggingface_tokenizer(
|
|
||||||
tokenizer=tokenizer,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap
|
|
||||||
)
|
|
||||||
|
|
||||||
return text_splitter
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
from langchain import document_loaders
|
from langchain import document_loaders
|
||||||
|
|
||||||
# 使用DocumentLoader读取文件
|
# 使用DocumentLoader读取文件
|
||||||
filepath = "../../knowledge_base/samples/content/test.txt"
|
filepath = "../../knowledge_base/samples/content/test.txt"
|
||||||
loader = document_loaders.UnstructuredFileLoader(filepath, autodetect_encoding=True)
|
loader = document_loaders.UnstructuredFileLoader(filepath, autodetect_encoding=True)
|
||||||
docs = loader.load()
|
docs = loader.load()
|
||||||
text_splitter = test_different_splitter(TEXT_SPLITTER_NAME, CHUNK_SIZE, OVERLAP_SIZE)
|
text_splitter = make_text_splitter(splitter_name, CHUNK_SIZE, OVERLAP_SIZE)
|
||||||
# 使用text_splitter进行分词
|
if splitter_name == "MarkdownHeaderTextSplitter":
|
||||||
|
|
||||||
if TEXT_SPLITTER_NAME == "MarkdownHeaderTextSplitter":
|
|
||||||
split_docs = text_splitter.split_text(docs[0].page_content)
|
split_docs = text_splitter.split_text(docs[0].page_content)
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
# 如果文档有元数据
|
|
||||||
if doc.metadata:
|
if doc.metadata:
|
||||||
doc.metadata["source"] = os.path.basename(filepath)
|
doc.metadata["source"] = os.path.basename(filepath)
|
||||||
else:
|
else:
|
||||||
split_docs = text_splitter.split_documents(docs)
|
split_docs = text_splitter.split_documents(docs)
|
||||||
|
return docs
|
||||||
|
|
||||||
# 打印分词结果
|
|
||||||
for doc in split_docs:
|
|
||||||
print(doc)
|
|
||||||
|
import pytest
|
||||||
|
@pytest.mark.parametrize("splitter_name", ["ChineseRecursiveTextSplitter", "SpacyTextSplitter", "RecursiveCharacterTextSplitter","MarkdownHeaderTextSplitter"])
|
||||||
|
def test_different_splitter(splitter_name):
|
||||||
|
try:
|
||||||
|
docs = text(splitter_name)
|
||||||
|
assert docs is not None
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"test_different_splitter failed with {splitter_name}, error: {str(e)}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue