修复测试文件 (#1467)

Co-authored-by: zR <zRzRzRzRzRzRzR>
This commit is contained in:
zR 2023-09-13 17:12:05 +08:00 committed by GitHub
parent dc413120e2
commit 769d75d784
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 40 additions and 85 deletions

View File

@ -144,7 +144,7 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
--- ---
### 知识库Text Splitter类型支持 ### Text Splitter 个性化支持
本项目支持调用 [Langchain](https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.text_splitter) 的Text Splitter分词器以及基于此改进的自定义分词器已支持的Text Splitter类型如下 本项目支持调用 [Langchain](https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.text_splitter) 的Text Splitter分词器以及基于此改进的自定义分词器已支持的Text Splitter类型如下
@ -242,10 +242,8 @@ embedding_model_dict = {
```python ```python
text_splitter_dict = { text_splitter_dict = {
"ChineseRecursiveTextSplitter": { "ChineseRecursiveTextSplitter": {
"source": "huggingface", ## 选择tiktoken则使用openai的方法 "source": "huggingface", ## 选择tiktoken则使用openai的方法,不填写则默认为字符长度切割方法。
"tokenizer_name_or_path": "", ## 空格不填则默认使用大模型的分词器。 "tokenizer_name_or_path": "", ## 空格不填则默认使用大模型的分词器。
"chunk_size": 250, # 知识库中单段文本长度
"overlap_size": 50, # 知识库中相邻文本重合长度
} }
} }
``` ```

View File

@ -123,11 +123,11 @@ LLM_DEVICE = "auto"
text_splitter_dict = { text_splitter_dict = {
"ChineseRecursiveTextSplitter": { "ChineseRecursiveTextSplitter": {
"source": "huggingface", ## 选择tiktoken则使用openai的方法 "source": "huggingface",
"tokenizer_name_or_path": "gpt2", "tokenizer_name_or_path": "gpt2",
}, },
"SpacyTextSplitter": { "SpacyTextSplitter": {
"source": "huggingface", "source": "",
"tokenizer_name_or_path": "", "tokenizer_name_or_path": "",
}, },
"RecursiveCharacterTextSplitter": { "RecursiveCharacterTextSplitter": {
@ -149,8 +149,6 @@ text_splitter_dict = {
# TEXT_SPLITTER 名称 # TEXT_SPLITTER 名称
TEXT_SPLITTER_NAME = "SpacyTextSplitter" TEXT_SPLITTER_NAME = "SpacyTextSplitter"
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) # 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
CHUNK_SIZE = 250 CHUNK_SIZE = 250

View File

@ -7,12 +7,10 @@ from .my_splitter import MySplitter
``` ```
2. 修改```config/model_config.py```文件,将您的分词器名字添加到```text_splitter_dict```中,如下所示: 2. 修改```config/model_config.py```文件,将您的分词器名字添加到```text_splitter_dict```中,如下所示:
```python ```python
MySplitter": { MySplitter: {
"source": "huggingface", ## 选择tiktoken则使用openai的方法 "source": "huggingface", ## 选择tiktoken则使用openai的方法
"tokenizer_name_or_path": "your tokenizer", "tokenizer_name_or_path": "your tokenizer", #如果选择huggingface则使用huggingface的方法部分tokenizer需要从Huggingface下载
"chunk_size": 250, }
"overlap_size": 50,
},
TEXT_SPLITTER_NAME = "MySplitter" TEXT_SPLITTER_NAME = "MySplitter"
``` ```
完成上述步骤后,就能使用自己的分词器了。 完成上述步骤后,就能使用自己的分词器了。

View File

@ -199,7 +199,7 @@ def make_text_splitter(
text_splitter_module = importlib.import_module('langchain.text_splitter') text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, splitter_name) TextSplitter = getattr(text_splitter_module, splitter_name)
if text_splitter_dict[splitter_name]["source"] == "tiktoken": if text_splitter_dict[splitter_name]["source"] == "tiktoken": ## 从tiktoken加载
try: try:
text_splitter = TextSplitter.from_tiktoken_encoder( text_splitter = TextSplitter.from_tiktoken_encoder(
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
@ -213,7 +213,7 @@ def make_text_splitter(
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap chunk_overlap=chunk_overlap
) )
elif text_splitter_dict[splitter_name]["source"] == "huggingface": elif text_splitter_dict[splitter_name]["source"] == "huggingface": ## 从huggingface加载
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "": if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \ text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
llm_model_dict[LLM_MODEL]["local_model_path"] llm_model_dict[LLM_MODEL]["local_model_path"]
@ -221,8 +221,8 @@ def make_text_splitter(
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2": if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
from transformers import GPT2TokenizerFast from transformers import GPT2TokenizerFast
from langchain.text_splitter import CharacterTextSplitter from langchain.text_splitter import CharacterTextSplitter
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") ## 这里选择你用的tokenizer tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
else: else: ## 字符长度加载
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
text_splitter_dict[splitter_name]["tokenizer_name_or_path"], text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
trust_remote_code=True) trust_remote_code=True)
@ -231,6 +231,18 @@ def make_text_splitter(
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap chunk_overlap=chunk_overlap
) )
else:
try:
text_splitter = TextSplitter(
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
except:
text_splitter = TextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
except Exception as e: except Exception as e:
print(e) print(e)
text_splitter_module = importlib.import_module('langchain.text_splitter') text_splitter_module = importlib.import_module('langchain.text_splitter')

View File

@ -6,87 +6,36 @@ import sys
sys.path.append("../..") sys.path.append("../..")
from configs.model_config import ( from configs.model_config import (
CHUNK_SIZE, CHUNK_SIZE,
OVERLAP_SIZE, OVERLAP_SIZE
text_splitter_dict, llm_model_dict, LLM_MODEL, TEXT_SPLITTER_NAME
) )
import langchain.document_loaders
import importlib
from server.knowledge_base.utils import make_text_splitter
def test_different_splitter( def text(splitter_name):
splitter_name,
chunk_size: int = CHUNK_SIZE,
chunk_overlap: int = OVERLAP_SIZE,
):
if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定
headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on']
text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter(
headers_to_split_on=headers_to_split_on)
else:
try: ## 优先使用用户自定义的text_splitter
text_splitter_module = importlib.import_module('text_splitter')
TextSplitter = getattr(text_splitter_module, splitter_name)
except: ## 否则使用langchain的text_splitter
text_splitter_module = importlib.import_module('langchain.text_splitter')
TextSplitter = getattr(text_splitter_module, splitter_name)
if text_splitter_dict[splitter_name]["source"] == "tiktoken":
try:
text_splitter = TextSplitter.from_tiktoken_encoder(
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
pipeline="zh_core_web_sm",
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
except:
text_splitter = TextSplitter.from_tiktoken_encoder(
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
elif text_splitter_dict[splitter_name]["source"] == "huggingface":
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
llm_model_dict[LLM_MODEL]["local_model_path"]
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
from transformers import GPT2TokenizerFast
from langchain.text_splitter import CharacterTextSplitter
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") ## 这里选择你用的tokenizer
else:
tokenizer = AutoTokenizer.from_pretrained(
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
trust_remote_code=True)
text_splitter = TextSplitter.from_huggingface_tokenizer(
tokenizer=tokenizer,
chunk_size=chunk_size,
chunk_overlap=chunk_overlap
)
return text_splitter
if __name__ == "__main__":
from langchain import document_loaders from langchain import document_loaders
# 使用DocumentLoader读取文件 # 使用DocumentLoader读取文件
filepath = "../../knowledge_base/samples/content/test.txt" filepath = "../../knowledge_base/samples/content/test.txt"
loader = document_loaders.UnstructuredFileLoader(filepath, autodetect_encoding=True) loader = document_loaders.UnstructuredFileLoader(filepath, autodetect_encoding=True)
docs = loader.load() docs = loader.load()
text_splitter = test_different_splitter(TEXT_SPLITTER_NAME, CHUNK_SIZE, OVERLAP_SIZE) text_splitter = make_text_splitter(splitter_name, CHUNK_SIZE, OVERLAP_SIZE)
# 使用text_splitter进行分词 if splitter_name == "MarkdownHeaderTextSplitter":
if TEXT_SPLITTER_NAME == "MarkdownHeaderTextSplitter":
split_docs = text_splitter.split_text(docs[0].page_content) split_docs = text_splitter.split_text(docs[0].page_content)
for doc in docs: for doc in docs:
# 如果文档有元数据
if doc.metadata: if doc.metadata:
doc.metadata["source"] = os.path.basename(filepath) doc.metadata["source"] = os.path.basename(filepath)
else: else:
split_docs = text_splitter.split_documents(docs) split_docs = text_splitter.split_documents(docs)
return docs
# 打印分词结果
for doc in split_docs:
print(doc)
import pytest
@pytest.mark.parametrize("splitter_name", ["ChineseRecursiveTextSplitter", "SpacyTextSplitter", "RecursiveCharacterTextSplitter","MarkdownHeaderTextSplitter"])
def test_different_splitter(splitter_name):
try:
docs = text(splitter_name)
assert docs is not None
except Exception as e:
pytest.fail(f"test_different_splitter failed with {splitter_name}, error: {str(e)}")