增加了自定义分词器适配 (#1462)
* 添加了自定义分词器适配和测试文件 --------- Co-authored-by: zR <zRzRzRzRzRzRzR>
This commit is contained in:
parent
c4cb4e19e5
commit
bfdbe69fa1
36
README.md
36
README.md
|
|
@ -144,6 +144,28 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
### 知识库Text Splitter类型支持
|
||||||
|
|
||||||
|
本项目支持调用 [Langchain](https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.text_splitter) 的Text Splitter分词器以及基于此改进的自定义分词器,已支持的Text Splitter类型如下:
|
||||||
|
|
||||||
|
- CharacterTextSplitter
|
||||||
|
- LatexTextSplitter
|
||||||
|
- MarkdownHeaderTextSplitter
|
||||||
|
- MarkdownTextSplitter
|
||||||
|
- NLTKTextSplitter
|
||||||
|
- PythonCodeTextSplitter
|
||||||
|
- RecursiveCharacterTextSplitter
|
||||||
|
- SentenceTransformersTokenTextSplitter
|
||||||
|
- SpacyTextSplitter
|
||||||
|
|
||||||
|
已经支持的定制分词器如下:
|
||||||
|
|
||||||
|
- AliTextSplitter
|
||||||
|
- ChineseRecursiveTextSplitter
|
||||||
|
- ChineseTextSplitter
|
||||||
|
|
||||||
|
关于如何使用自定义分词器和贡献自己的分词器,可以参考[这里](docs/splitter.md)
|
||||||
|
|
||||||
## Docker 部署
|
## Docker 部署
|
||||||
|
|
||||||
🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3)`
|
🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3)`
|
||||||
|
|
@ -215,6 +237,20 @@ embedding_model_dict = {
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- 请确认本地分词器路径是否已经填写,如:
|
||||||
|
|
||||||
|
```python
|
||||||
|
text_splitter_dict = {
|
||||||
|
"ChineseRecursiveTextSplitter": {
|
||||||
|
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
||||||
|
"tokenizer_name_or_path": "", ## 空格不填则默认使用大模型的分词器。
|
||||||
|
"chunk_size": 250, # 知识库中单段文本长度
|
||||||
|
"overlap_size": 50, # 知识库中相邻文本重合长度
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
如果你选择使用OpenAI的Embedding模型,请将模型的 ``key``写入 `embedding_model_dict`中。使用该模型,你需要能够访问OpenAI官的API,或设置代理。
|
如果你选择使用OpenAI的Embedding模型,请将模型的 ``key``写入 `embedding_model_dict`中。使用该模型,你需要能够访问OpenAI官的API,或设置代理。
|
||||||
|
|
||||||
### 4. 知识库初始化与迁移
|
### 4. 知识库初始化与迁移
|
||||||
|
|
|
||||||
|
|
@ -115,9 +115,49 @@ HISTORY_LEN = 3
|
||||||
TEMPERATURE = 0.7
|
TEMPERATURE = 0.7
|
||||||
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
||||||
|
|
||||||
|
|
||||||
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||||
LLM_DEVICE = "auto"
|
LLM_DEVICE = "auto"
|
||||||
|
|
||||||
|
# TextSplitter
|
||||||
|
|
||||||
|
text_splitter_dict = {
|
||||||
|
"ChineseRecursiveTextSplitter": {
|
||||||
|
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
||||||
|
"tokenizer_name_or_path": "gpt2",
|
||||||
|
},
|
||||||
|
"SpacyTextSplitter": {
|
||||||
|
"source": "huggingface",
|
||||||
|
"tokenizer_name_or_path": "",
|
||||||
|
},
|
||||||
|
"RecursiveCharacterTextSplitter": {
|
||||||
|
"source": "tiktoken",
|
||||||
|
"tokenizer_name_or_path": "cl100k_base",
|
||||||
|
},
|
||||||
|
|
||||||
|
"MarkdownHeaderTextSplitter": {
|
||||||
|
"headers_to_split_on":
|
||||||
|
[
|
||||||
|
("#", "head1"),
|
||||||
|
("##", "head2"),
|
||||||
|
("###", "head3"),
|
||||||
|
("####", "head4"),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# TEXT_SPLITTER 名称
|
||||||
|
TEXT_SPLITTER_NAME = "SpacyTextSplitter"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
|
||||||
|
CHUNK_SIZE = 250
|
||||||
|
|
||||||
|
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
|
||||||
|
OVERLAP_SIZE = 50
|
||||||
|
|
||||||
|
|
||||||
# 日志存储路径
|
# 日志存储路径
|
||||||
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||||
if not os.path.exists(LOG_PATH):
|
if not os.path.exists(LOG_PATH):
|
||||||
|
|
@ -132,6 +172,7 @@ if not os.path.exists(KB_ROOT_PATH):
|
||||||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||||
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
|
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
|
||||||
|
|
||||||
|
|
||||||
# 可选向量库类型及对应配置
|
# 可选向量库类型及对应配置
|
||||||
kbs_config = {
|
kbs_config = {
|
||||||
"faiss": {
|
"faiss": {
|
||||||
|
|
@ -154,12 +195,6 @@ DEFAULT_VS_TYPE = "faiss"
|
||||||
# 缓存向量库数量
|
# 缓存向量库数量
|
||||||
CACHED_VS_NUM = 1
|
CACHED_VS_NUM = 1
|
||||||
|
|
||||||
# 知识库中单段文本长度
|
|
||||||
CHUNK_SIZE = 250
|
|
||||||
|
|
||||||
# 知识库中相邻文本重合长度
|
|
||||||
OVERLAP_SIZE = 50
|
|
||||||
|
|
||||||
# 知识库匹配向量数量
|
# 知识库匹配向量数量
|
||||||
VECTOR_SEARCH_TOP_K = 5
|
VECTOR_SEARCH_TOP_K = 5
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,6 @@ FSCHAT_MODEL_WORKERS = {
|
||||||
# "limit_worker_concurrency": 5,
|
# "limit_worker_concurrency": 5,
|
||||||
# "stream_interval": 2,
|
# "stream_interval": 2,
|
||||||
# "no_register": False,
|
# "no_register": False,
|
||||||
# "embed_in_truncate": False,
|
|
||||||
},
|
},
|
||||||
"baichuan-7b": { # 使用default中的IP和端口
|
"baichuan-7b": { # 使用default中的IP和端口
|
||||||
"device": "cpu",
|
"device": "cpu",
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
## 自定义分词器
|
||||||
|
|
||||||
|
### 在哪里写,哪些文件要改
|
||||||
|
1. 在```text_splitter```文件夹下新建一个文件,文件名为您的分词器名字,比如`my_splitter.py`,然后在`__init__.py`中导入您的分词器,如下所示:
|
||||||
|
```python
|
||||||
|
from .my_splitter import MySplitter
|
||||||
|
```
|
||||||
|
2. 修改```config/model_config.py```文件,将您的分词器名字添加到```text_splitter_dict```中,如下所示:
|
||||||
|
```python
|
||||||
|
MySplitter": {
|
||||||
|
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
||||||
|
"tokenizer_name_or_path": "your tokenizer",
|
||||||
|
"chunk_size": 250,
|
||||||
|
"overlap_size": 50,
|
||||||
|
},
|
||||||
|
TEXT_SPLITTER_NAME = "MySplitter"
|
||||||
|
```
|
||||||
|
完成上述步骤后,就能使用自己的分词器了。
|
||||||
|
|
||||||
|
### 如何贡献您的分词器
|
||||||
|
1. 将您的分词器所在的代码文件放在```text_splitter```文件夹下,文件名为您的分词器名字,比如`my_splitter.py`,然后在`__init__.py`中导入您的分词器。
|
||||||
|
2. 发起PR,并说明您的分词器面向的场景或者改进之处。我们非常期待您能举例一个具体的应用场景。
|
||||||
|
3. 在Readme.md中添加您的分词器的使用方法和支持说明。
|
||||||
|
|
@ -50,6 +50,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
model_name: str = LLM_MODEL,
|
model_name: str = LLM_MODEL,
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
|
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
streaming=True,
|
streaming=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
|
||||||
logger, log_verbose)
|
logger, log_verbose)
|
||||||
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
|
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
|
||||||
list_files_from_folder,files2docs_in_thread,
|
list_files_from_folder,files2docs_in_thread,
|
||||||
|
|
@ -38,8 +37,8 @@ def folder2db(
|
||||||
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
|
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
|
||||||
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
||||||
embed_model: str = EMBEDDING_MODEL,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
chunk_size: int = CHUNK_SIZE,
|
chunk_size: int = -1,
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
chunk_overlap: int = -1,
|
||||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
|
||||||
|
|
@ -1,15 +1,14 @@
|
||||||
import os
|
import os
|
||||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
||||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
from transformers import AutoTokenizer
|
||||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
|
||||||
from configs.model_config import (
|
from configs.model_config import (
|
||||||
embedding_model_dict,
|
|
||||||
EMBEDDING_MODEL,
|
EMBEDDING_MODEL,
|
||||||
KB_ROOT_PATH,
|
KB_ROOT_PATH,
|
||||||
CHUNK_SIZE,
|
CHUNK_SIZE,
|
||||||
OVERLAP_SIZE,
|
OVERLAP_SIZE,
|
||||||
ZH_TITLE_ENHANCE,
|
ZH_TITLE_ENHANCE,
|
||||||
logger, log_verbose,
|
logger, log_verbose, text_splitter_dict, llm_model_dict, LLM_MODEL,TEXT_SPLITTER_NAME
|
||||||
)
|
)
|
||||||
import importlib
|
import importlib
|
||||||
from text_splitter import zh_title_enhance
|
from text_splitter import zh_title_enhance
|
||||||
|
|
@ -178,7 +177,7 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
|
||||||
|
|
||||||
|
|
||||||
def make_text_splitter(
|
def make_text_splitter(
|
||||||
splitter_name: str = "SpacyTextSplitter",
|
splitter_name: str = TEXT_SPLITTER_NAME,
|
||||||
chunk_size: int = CHUNK_SIZE,
|
chunk_size: int = CHUNK_SIZE,
|
||||||
chunk_overlap: int = OVERLAP_SIZE,
|
chunk_overlap: int = OVERLAP_SIZE,
|
||||||
):
|
):
|
||||||
|
|
@ -186,23 +185,57 @@ def make_text_splitter(
|
||||||
根据参数获取特定的分词器
|
根据参数获取特定的分词器
|
||||||
'''
|
'''
|
||||||
splitter_name = splitter_name or "SpacyTextSplitter"
|
splitter_name = splitter_name or "SpacyTextSplitter"
|
||||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
|
||||||
try:
|
try:
|
||||||
|
if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定
|
||||||
|
headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on']
|
||||||
|
text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter(
|
||||||
|
headers_to_split_on=headers_to_split_on)
|
||||||
|
else:
|
||||||
|
|
||||||
|
try: ## 优先使用用户自定义的text_splitter
|
||||||
|
text_splitter_module = importlib.import_module('text_splitter')
|
||||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||||
text_splitter = TextSplitter(
|
except: ## 否则使用langchain的text_splitter
|
||||||
|
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||||
|
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||||
|
|
||||||
|
if text_splitter_dict[splitter_name]["source"] == "tiktoken":
|
||||||
|
try:
|
||||||
|
text_splitter = TextSplitter.from_tiktoken_encoder(
|
||||||
|
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||||
pipeline="zh_core_web_sm",
|
pipeline="zh_core_web_sm",
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap,
|
chunk_overlap=chunk_overlap
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
text_splitter = TextSplitter.from_tiktoken_encoder(
|
||||||
|
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap
|
||||||
|
)
|
||||||
|
elif text_splitter_dict[splitter_name]["source"] == "huggingface":
|
||||||
|
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
|
||||||
|
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
|
||||||
|
llm_model_dict[LLM_MODEL]["local_model_path"]
|
||||||
|
|
||||||
|
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") ## 这里选择你用的tokenizer
|
||||||
|
else:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||||
|
trust_remote_code=True)
|
||||||
|
text_splitter = TextSplitter.from_huggingface_tokenizer(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"查找分词器 {splitter_name} 时出错:{e}"
|
print(e)
|
||||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||||
exc_info=e if log_verbose else None)
|
|
||||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||||
text_splitter = TextSplitter(
|
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50)
|
||||||
chunk_size=chunk_size,
|
|
||||||
chunk_overlap=chunk_overlap,
|
|
||||||
)
|
|
||||||
return text_splitter
|
return text_splitter
|
||||||
|
|
||||||
class KnowledgeFile:
|
class KnowledgeFile:
|
||||||
|
|
@ -223,9 +256,7 @@ class KnowledgeFile:
|
||||||
self.docs = None
|
self.docs = None
|
||||||
self.splited_docs = None
|
self.splited_docs = None
|
||||||
self.document_loader_name = get_LoaderClass(self.ext)
|
self.document_loader_name = get_LoaderClass(self.ext)
|
||||||
|
self.text_splitter_name = TEXT_SPLITTER_NAME
|
||||||
# TODO: 增加依据文件格式匹配text_splitter
|
|
||||||
self.text_splitter_name = None
|
|
||||||
|
|
||||||
def file2docs(self, refresh: bool=False):
|
def file2docs(self, refresh: bool=False):
|
||||||
if self.docs is None or refresh:
|
if self.docs is None or refresh:
|
||||||
|
|
@ -249,6 +280,13 @@ class KnowledgeFile:
|
||||||
if self.ext not in [".csv"]:
|
if self.ext not in [".csv"]:
|
||||||
if text_splitter is None:
|
if text_splitter is None:
|
||||||
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||||
|
if self.text_splitter_name == "MarkdownHeaderTextSplitter":
|
||||||
|
docs = text_splitter.split_text(docs[0].page_content)
|
||||||
|
for doc in docs:
|
||||||
|
# 如果文档有元数据
|
||||||
|
if doc.metadata:
|
||||||
|
doc.metadata["source"] = os.path.basename(self.filepath)
|
||||||
|
else:
|
||||||
docs = text_splitter.split_documents(docs)
|
docs = text_splitter.split_documents(docs)
|
||||||
|
|
||||||
print(f"文档切分示例:{docs[0]}")
|
print(f"文档切分示例:{docs[0]}")
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ except:
|
||||||
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||||
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
|
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
|
||||||
logger, log_verbose
|
logger, log_verbose, TEXT_SPLITTER_NAME
|
||||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
|
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
|
||||||
FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT)
|
FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT)
|
||||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||||
|
|
@ -464,7 +464,10 @@ def dump_server_info(after_start=False, args=None):
|
||||||
models = [LLM_MODEL]
|
models = [LLM_MODEL]
|
||||||
if args and args.model_name:
|
if args and args.model_name:
|
||||||
models = args.model_name
|
models = args.model_name
|
||||||
|
|
||||||
|
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
|
||||||
print(f"当前启动的LLM模型:{models} @ {llm_device()}")
|
print(f"当前启动的LLM模型:{models} @ {llm_device()}")
|
||||||
|
|
||||||
for model in models:
|
for model in models:
|
||||||
pprint(llm_model_dict[model])
|
pprint(llm_model_dict[model])
|
||||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,89 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append("../..")
|
||||||
|
from configs.model_config import (
|
||||||
|
CHUNK_SIZE,
|
||||||
|
OVERLAP_SIZE,
|
||||||
|
text_splitter_dict, llm_model_dict, LLM_MODEL, TEXT_SPLITTER_NAME
|
||||||
|
)
|
||||||
|
import langchain.document_loaders
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
|
def text_different_splitter(splitter_name, chunk_size: int = CHUNK_SIZE,
|
||||||
|
chunk_overlap: int = OVERLAP_SIZE, ):
|
||||||
|
if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定
|
||||||
|
headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on']
|
||||||
|
text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter(
|
||||||
|
headers_to_split_on=headers_to_split_on)
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
try: ## 优先使用用户自定义的text_splitter
|
||||||
|
text_splitter_module = importlib.import_module('text_splitter')
|
||||||
|
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||||
|
except: ## 否则使用langchain的text_splitter
|
||||||
|
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||||
|
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||||
|
|
||||||
|
if text_splitter_dict[splitter_name]["source"] == "tiktoken":
|
||||||
|
try:
|
||||||
|
text_splitter = TextSplitter.from_tiktoken_encoder(
|
||||||
|
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||||
|
pipeline="zh_core_web_sm",
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
text_splitter = TextSplitter.from_tiktoken_encoder(
|
||||||
|
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap
|
||||||
|
)
|
||||||
|
elif text_splitter_dict[splitter_name]["source"] == "huggingface":
|
||||||
|
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
|
||||||
|
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
|
||||||
|
llm_model_dict[LLM_MODEL]["local_model_path"]
|
||||||
|
|
||||||
|
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
||||||
|
from transformers import GPT2TokenizerFast
|
||||||
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
|
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") ## 这里选择你用的tokenizer
|
||||||
|
else:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||||
|
trust_remote_code=True)
|
||||||
|
text_splitter = TextSplitter.from_huggingface_tokenizer(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
chunk_overlap=chunk_overlap
|
||||||
|
)
|
||||||
|
|
||||||
|
return text_splitter
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
from langchain import document_loaders
|
||||||
|
|
||||||
|
# 使用DocumentLoader读取文件
|
||||||
|
filepath = "../../knowledge_base/samples/content/test.txt"
|
||||||
|
loader = document_loaders.UnstructuredFileLoader(filepath, autodetect_encoding=True)
|
||||||
|
docs = loader.load()
|
||||||
|
text_splitter = text_different_splitter(TEXT_SPLITTER_NAME, CHUNK_SIZE, OVERLAP_SIZE)
|
||||||
|
# 使用text_splitter进行分词
|
||||||
|
|
||||||
|
if TEXT_SPLITTER_NAME == "MarkdownHeaderTextSplitter":
|
||||||
|
split_docs = text_splitter.split_text(docs[0].page_content)
|
||||||
|
for doc in docs:
|
||||||
|
# 如果文档有元数据
|
||||||
|
if doc.metadata:
|
||||||
|
doc.metadata["source"] = os.path.basename(filepath)
|
||||||
|
else:
|
||||||
|
split_docs = text_splitter.split_documents(docs)
|
||||||
|
|
||||||
|
# 打印分词结果
|
||||||
|
for doc in split_docs:
|
||||||
|
print(doc)
|
||||||
|
|
@ -1,11 +1,10 @@
|
||||||
from langchain.text_splitter import CharacterTextSplitter
|
from langchain.text_splitter import CharacterTextSplitter
|
||||||
import re
|
import re
|
||||||
from typing import List
|
from typing import List
|
||||||
from configs.model_config import CHUNK_SIZE
|
|
||||||
|
|
||||||
|
|
||||||
class ChineseTextSplitter(CharacterTextSplitter):
|
class ChineseTextSplitter(CharacterTextSplitter):
|
||||||
def __init__(self, pdf: bool = False, sentence_size: int = CHUNK_SIZE, **kwargs):
|
def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.pdf = pdf
|
self.pdf = pdf
|
||||||
self.sentence_size = sentence_size
|
self.sentence_size = sentence_size
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue