增加了自定义分词器适配 (#1462)
* 添加了自定义分词器适配和测试文件 --------- Co-authored-by: zR <zRzRzRzRzRzRzR>
This commit is contained in:
parent
c4cb4e19e5
commit
bfdbe69fa1
36
README.md
36
README.md
|
|
@ -144,6 +144,28 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
|
|||
|
||||
---
|
||||
|
||||
### 知识库Text Splitter类型支持
|
||||
|
||||
本项目支持调用 [Langchain](https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.text_splitter) 的Text Splitter分词器以及基于此改进的自定义分词器,已支持的Text Splitter类型如下:
|
||||
|
||||
- CharacterTextSplitter
|
||||
- LatexTextSplitter
|
||||
- MarkdownHeaderTextSplitter
|
||||
- MarkdownTextSplitter
|
||||
- NLTKTextSplitter
|
||||
- PythonCodeTextSplitter
|
||||
- RecursiveCharacterTextSplitter
|
||||
- SentenceTransformersTokenTextSplitter
|
||||
- SpacyTextSplitter
|
||||
|
||||
已经支持的定制分词器如下:
|
||||
|
||||
- AliTextSplitter
|
||||
- ChineseRecursiveTextSplitter
|
||||
- ChineseTextSplitter
|
||||
|
||||
关于如何使用自定义分词器和贡献自己的分词器,可以参考[这里](docs/splitter.md)
|
||||
|
||||
## Docker 部署
|
||||
|
||||
🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3)`
|
||||
|
|
@ -215,6 +237,20 @@ embedding_model_dict = {
|
|||
}
|
||||
```
|
||||
|
||||
- 请确认本地分词器路径是否已经填写,如:
|
||||
|
||||
```python
|
||||
text_splitter_dict = {
|
||||
"ChineseRecursiveTextSplitter": {
|
||||
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
||||
"tokenizer_name_or_path": "", ## 空格不填则默认使用大模型的分词器。
|
||||
"chunk_size": 250, # 知识库中单段文本长度
|
||||
"overlap_size": 50, # 知识库中相邻文本重合长度
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
如果你选择使用OpenAI的Embedding模型,请将模型的 ``key``写入 `embedding_model_dict`中。使用该模型,你需要能够访问OpenAI官的API,或设置代理。
|
||||
|
||||
### 4. 知识库初始化与迁移
|
||||
|
|
|
|||
|
|
@ -88,14 +88,14 @@ llm_model_dict = {
|
|||
"provider": "ChatGLMWorker",
|
||||
"version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro"
|
||||
},
|
||||
"minimax-api": {
|
||||
"minimax-api": {
|
||||
"api_base_url": "http://127.0.0.1:8888/v1",
|
||||
"group_id": "",
|
||||
"api_key": "",
|
||||
"is_pro": False,
|
||||
"provider": "MiniMaxWorker",
|
||||
},
|
||||
"xinghuo-api": {
|
||||
"xinghuo-api": {
|
||||
"api_base_url": "http://127.0.0.1:8888/v1",
|
||||
"APPID": "",
|
||||
"APISecret": "",
|
||||
|
|
@ -115,9 +115,49 @@ HISTORY_LEN = 3
|
|||
TEMPERATURE = 0.7
|
||||
# TOP_P = 0.95 # ChatOpenAI暂不支持该参数
|
||||
|
||||
|
||||
# LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。
|
||||
LLM_DEVICE = "auto"
|
||||
|
||||
# TextSplitter
|
||||
|
||||
text_splitter_dict = {
|
||||
"ChineseRecursiveTextSplitter": {
|
||||
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
||||
"tokenizer_name_or_path": "gpt2",
|
||||
},
|
||||
"SpacyTextSplitter": {
|
||||
"source": "huggingface",
|
||||
"tokenizer_name_or_path": "",
|
||||
},
|
||||
"RecursiveCharacterTextSplitter": {
|
||||
"source": "tiktoken",
|
||||
"tokenizer_name_or_path": "cl100k_base",
|
||||
},
|
||||
|
||||
"MarkdownHeaderTextSplitter": {
|
||||
"headers_to_split_on":
|
||||
[
|
||||
("#", "head1"),
|
||||
("##", "head2"),
|
||||
("###", "head3"),
|
||||
("####", "head4"),
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
# TEXT_SPLITTER 名称
|
||||
TEXT_SPLITTER_NAME = "SpacyTextSplitter"
|
||||
|
||||
|
||||
|
||||
# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter)
|
||||
CHUNK_SIZE = 250
|
||||
|
||||
# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter)
|
||||
OVERLAP_SIZE = 50
|
||||
|
||||
|
||||
# 日志存储路径
|
||||
LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs")
|
||||
if not os.path.exists(LOG_PATH):
|
||||
|
|
@ -132,6 +172,7 @@ if not os.path.exists(KB_ROOT_PATH):
|
|||
DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db")
|
||||
SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}"
|
||||
|
||||
|
||||
# 可选向量库类型及对应配置
|
||||
kbs_config = {
|
||||
"faiss": {
|
||||
|
|
@ -154,12 +195,6 @@ DEFAULT_VS_TYPE = "faiss"
|
|||
# 缓存向量库数量
|
||||
CACHED_VS_NUM = 1
|
||||
|
||||
# 知识库中单段文本长度
|
||||
CHUNK_SIZE = 250
|
||||
|
||||
# 知识库中相邻文本重合长度
|
||||
OVERLAP_SIZE = 50
|
||||
|
||||
# 知识库匹配向量数量
|
||||
VECTOR_SEARCH_TOP_K = 5
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,6 @@ FSCHAT_MODEL_WORKERS = {
|
|||
# "limit_worker_concurrency": 5,
|
||||
# "stream_interval": 2,
|
||||
# "no_register": False,
|
||||
# "embed_in_truncate": False,
|
||||
},
|
||||
"baichuan-7b": { # 使用default中的IP和端口
|
||||
"device": "cpu",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,23 @@
|
|||
## 自定义分词器
|
||||
|
||||
### 在哪里写,哪些文件要改
|
||||
1. 在```text_splitter```文件夹下新建一个文件,文件名为您的分词器名字,比如`my_splitter.py`,然后在`__init__.py`中导入您的分词器,如下所示:
|
||||
```python
|
||||
from .my_splitter import MySplitter
|
||||
```
|
||||
2. 修改```config/model_config.py```文件,将您的分词器名字添加到```text_splitter_dict```中,如下所示:
|
||||
```python
|
||||
MySplitter": {
|
||||
"source": "huggingface", ## 选择tiktoken则使用openai的方法
|
||||
"tokenizer_name_or_path": "your tokenizer",
|
||||
"chunk_size": 250,
|
||||
"overlap_size": 50,
|
||||
},
|
||||
TEXT_SPLITTER_NAME = "MySplitter"
|
||||
```
|
||||
完成上述步骤后,就能使用自己的分词器了。
|
||||
|
||||
### 如何贡献您的分词器
|
||||
1. 将您的分词器所在的代码文件放在```text_splitter```文件夹下,文件名为您的分词器名字,比如`my_splitter.py`,然后在`__init__.py`中导入您的分词器。
|
||||
2. 发起PR,并说明您的分词器面向的场景或者改进之处。我们非常期待您能举例一个具体的应用场景。
|
||||
3. 在Readme.md中添加您的分词器的使用方法和支持说明。
|
||||
|
|
@ -50,6 +50,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
model_name: str = LLM_MODEL,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
|
||||
model = ChatOpenAI(
|
||||
streaming=True,
|
||||
verbose=True,
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE,
|
||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE,
|
||||
from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE,
|
||||
logger, log_verbose)
|
||||
from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder,
|
||||
list_files_from_folder,files2docs_in_thread,
|
||||
|
|
@ -38,8 +37,8 @@ def folder2db(
|
|||
mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"],
|
||||
vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE,
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
chunk_size: int = -1,
|
||||
chunk_overlap: int = -1,
|
||||
zh_title_enhance: bool = ZH_TITLE_ENHANCE,
|
||||
):
|
||||
'''
|
||||
|
|
|
|||
|
|
@ -1,15 +1,14 @@
|
|||
import os
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from configs.model_config import (
|
||||
embedding_model_dict,
|
||||
EMBEDDING_MODEL,
|
||||
KB_ROOT_PATH,
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
ZH_TITLE_ENHANCE,
|
||||
logger, log_verbose,
|
||||
logger, log_verbose, text_splitter_dict, llm_model_dict, LLM_MODEL,TEXT_SPLITTER_NAME
|
||||
)
|
||||
import importlib
|
||||
from text_splitter import zh_title_enhance
|
||||
|
|
@ -178,7 +177,7 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri
|
|||
|
||||
|
||||
def make_text_splitter(
|
||||
splitter_name: str = "SpacyTextSplitter",
|
||||
splitter_name: str = TEXT_SPLITTER_NAME,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE,
|
||||
):
|
||||
|
|
@ -186,23 +185,57 @@ def make_text_splitter(
|
|||
根据参数获取特定的分词器
|
||||
'''
|
||||
splitter_name = splitter_name or "SpacyTextSplitter"
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
try:
|
||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||
text_splitter = TextSplitter(
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定
|
||||
headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on']
|
||||
text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter(
|
||||
headers_to_split_on=headers_to_split_on)
|
||||
else:
|
||||
|
||||
try: ## 优先使用用户自定义的text_splitter
|
||||
text_splitter_module = importlib.import_module('text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||
except: ## 否则使用langchain的text_splitter
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||
|
||||
if text_splitter_dict[splitter_name]["source"] == "tiktoken":
|
||||
try:
|
||||
text_splitter = TextSplitter.from_tiktoken_encoder(
|
||||
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
except:
|
||||
text_splitter = TextSplitter.from_tiktoken_encoder(
|
||||
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
elif text_splitter_dict[splitter_name]["source"] == "huggingface":
|
||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
|
||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
|
||||
llm_model_dict[LLM_MODEL]["local_model_path"]
|
||||
|
||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
||||
from transformers import GPT2TokenizerFast
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") ## 这里选择你用的tokenizer
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||
trust_remote_code=True)
|
||||
text_splitter = TextSplitter.from_huggingface_tokenizer(
|
||||
tokenizer=tokenizer,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
except Exception as e:
|
||||
msg = f"查找分词器 {splitter_name} 时出错:{e}"
|
||||
logger.error(f'{e.__class__.__name__}: {msg}',
|
||||
exc_info=e if log_verbose else None)
|
||||
print(e)
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter")
|
||||
text_splitter = TextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
)
|
||||
text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50)
|
||||
return text_splitter
|
||||
|
||||
class KnowledgeFile:
|
||||
|
|
@ -223,9 +256,7 @@ class KnowledgeFile:
|
|||
self.docs = None
|
||||
self.splited_docs = None
|
||||
self.document_loader_name = get_LoaderClass(self.ext)
|
||||
|
||||
# TODO: 增加依据文件格式匹配text_splitter
|
||||
self.text_splitter_name = None
|
||||
self.text_splitter_name = TEXT_SPLITTER_NAME
|
||||
|
||||
def file2docs(self, refresh: bool=False):
|
||||
if self.docs is None or refresh:
|
||||
|
|
@ -249,7 +280,14 @@ class KnowledgeFile:
|
|||
if self.ext not in [".csv"]:
|
||||
if text_splitter is None:
|
||||
text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
||||
docs = text_splitter.split_documents(docs)
|
||||
if self.text_splitter_name == "MarkdownHeaderTextSplitter":
|
||||
docs = text_splitter.split_text(docs[0].page_content)
|
||||
for doc in docs:
|
||||
# 如果文档有元数据
|
||||
if doc.metadata:
|
||||
doc.metadata["source"] = os.path.basename(self.filepath)
|
||||
else:
|
||||
docs = text_splitter.split_documents(docs)
|
||||
|
||||
print(f"文档切分示例:{docs[0]}")
|
||||
if zh_title_enhance:
|
||||
|
|
@ -335,4 +373,4 @@ if __name__ == "__main__":
|
|||
pprint(docs[-1])
|
||||
|
||||
docs = kb_file.file2text()
|
||||
pprint(docs[-1])
|
||||
pprint(docs[-1])
|
||||
|
|
@ -18,7 +18,7 @@ except:
|
|||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||||
from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \
|
||||
logger, log_verbose
|
||||
logger, log_verbose, TEXT_SPLITTER_NAME
|
||||
from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER,
|
||||
FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT)
|
||||
from server.utils import (fschat_controller_address, fschat_model_worker_address,
|
||||
|
|
@ -464,7 +464,10 @@ def dump_server_info(after_start=False, args=None):
|
|||
models = [LLM_MODEL]
|
||||
if args and args.model_name:
|
||||
models = args.model_name
|
||||
|
||||
print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}")
|
||||
print(f"当前启动的LLM模型:{models} @ {llm_device()}")
|
||||
|
||||
for model in models:
|
||||
pprint(llm_model_dict[model])
|
||||
print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,89 @@
|
|||
import os
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
import sys
|
||||
|
||||
sys.path.append("../..")
|
||||
from configs.model_config import (
|
||||
CHUNK_SIZE,
|
||||
OVERLAP_SIZE,
|
||||
text_splitter_dict, llm_model_dict, LLM_MODEL, TEXT_SPLITTER_NAME
|
||||
)
|
||||
import langchain.document_loaders
|
||||
import importlib
|
||||
|
||||
|
||||
def text_different_splitter(splitter_name, chunk_size: int = CHUNK_SIZE,
|
||||
chunk_overlap: int = OVERLAP_SIZE, ):
|
||||
if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定
|
||||
headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on']
|
||||
text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter(
|
||||
headers_to_split_on=headers_to_split_on)
|
||||
|
||||
else:
|
||||
|
||||
try: ## 优先使用用户自定义的text_splitter
|
||||
text_splitter_module = importlib.import_module('text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||
except: ## 否则使用langchain的text_splitter
|
||||
text_splitter_module = importlib.import_module('langchain.text_splitter')
|
||||
TextSplitter = getattr(text_splitter_module, splitter_name)
|
||||
|
||||
if text_splitter_dict[splitter_name]["source"] == "tiktoken":
|
||||
try:
|
||||
text_splitter = TextSplitter.from_tiktoken_encoder(
|
||||
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||
pipeline="zh_core_web_sm",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
except:
|
||||
text_splitter = TextSplitter.from_tiktoken_encoder(
|
||||
encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
elif text_splitter_dict[splitter_name]["source"] == "huggingface":
|
||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "":
|
||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \
|
||||
llm_model_dict[LLM_MODEL]["local_model_path"]
|
||||
|
||||
if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2":
|
||||
from transformers import GPT2TokenizerFast
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") ## 这里选择你用的tokenizer
|
||||
else:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
text_splitter_dict[splitter_name]["tokenizer_name_or_path"],
|
||||
trust_remote_code=True)
|
||||
text_splitter = TextSplitter.from_huggingface_tokenizer(
|
||||
tokenizer=tokenizer,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap
|
||||
)
|
||||
|
||||
return text_splitter
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from langchain import document_loaders
|
||||
|
||||
# 使用DocumentLoader读取文件
|
||||
filepath = "../../knowledge_base/samples/content/test.txt"
|
||||
loader = document_loaders.UnstructuredFileLoader(filepath, autodetect_encoding=True)
|
||||
docs = loader.load()
|
||||
text_splitter = text_different_splitter(TEXT_SPLITTER_NAME, CHUNK_SIZE, OVERLAP_SIZE)
|
||||
# 使用text_splitter进行分词
|
||||
|
||||
if TEXT_SPLITTER_NAME == "MarkdownHeaderTextSplitter":
|
||||
split_docs = text_splitter.split_text(docs[0].page_content)
|
||||
for doc in docs:
|
||||
# 如果文档有元数据
|
||||
if doc.metadata:
|
||||
doc.metadata["source"] = os.path.basename(filepath)
|
||||
else:
|
||||
split_docs = text_splitter.split_documents(docs)
|
||||
|
||||
# 打印分词结果
|
||||
for doc in split_docs:
|
||||
print(doc)
|
||||
|
|
@ -1,11 +1,10 @@
|
|||
from langchain.text_splitter import CharacterTextSplitter
|
||||
import re
|
||||
from typing import List
|
||||
from configs.model_config import CHUNK_SIZE
|
||||
|
||||
|
||||
class ChineseTextSplitter(CharacterTextSplitter):
|
||||
def __init__(self, pdf: bool = False, sentence_size: int = CHUNK_SIZE, **kwargs):
|
||||
def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.pdf = pdf
|
||||
self.sentence_size = sentence_size
|
||||
|
|
|
|||
Loading…
Reference in New Issue