diff --git a/README.md b/README.md index 703c084..de1240a 100644 --- a/README.md +++ b/README.md @@ -144,6 +144,28 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch --- +### 知识库Text Splitter类型支持 + +本项目支持调用 [Langchain](https://api.python.langchain.com/en/latest/api_reference.html#module-langchain.text_splitter) 的Text Splitter分词器以及基于此改进的自定义分词器,已支持的Text Splitter类型如下: + +- CharacterTextSplitter +- LatexTextSplitter +- MarkdownHeaderTextSplitter +- MarkdownTextSplitter +- NLTKTextSplitter +- PythonCodeTextSplitter +- RecursiveCharacterTextSplitter +- SentenceTransformersTokenTextSplitter +- SpacyTextSplitter + +已经支持的定制分词器如下: + +- AliTextSplitter +- ChineseRecursiveTextSplitter +- ChineseTextSplitter + +关于如何使用自定义分词器和贡献自己的分词器,可以参考[这里](docs/splitter.md) + ## Docker 部署 🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3)` @@ -215,6 +237,20 @@ embedding_model_dict = { } ``` +- 请确认本地分词器路径是否已经填写,如: + +```python +text_splitter_dict = { + "ChineseRecursiveTextSplitter": { + "source": "huggingface", ## 选择tiktoken则使用openai的方法 + "tokenizer_name_or_path": "", ## 空格不填则默认使用大模型的分词器。 + "chunk_size": 250, # 知识库中单段文本长度 + "overlap_size": 50, # 知识库中相邻文本重合长度 + } +} + ``` + + 如果你选择使用OpenAI的Embedding模型,请将模型的 ``key``写入 `embedding_model_dict`中。使用该模型,你需要能够访问OpenAI官的API,或设置代理。 ### 4. 知识库初始化与迁移 diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 977c25f..f501dca 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -88,14 +88,14 @@ llm_model_dict = { "provider": "ChatGLMWorker", "version": "chatglm_pro", # 可选包括 "chatglm_lite", "chatglm_std", "chatglm_pro" }, -"minimax-api": { + "minimax-api": { "api_base_url": "http://127.0.0.1:8888/v1", "group_id": "", "api_key": "", "is_pro": False, "provider": "MiniMaxWorker", }, -"xinghuo-api": { + "xinghuo-api": { "api_base_url": "http://127.0.0.1:8888/v1", "APPID": "", "APISecret": "", @@ -115,9 +115,49 @@ HISTORY_LEN = 3 TEMPERATURE = 0.7 # TOP_P = 0.95 # ChatOpenAI暂不支持该参数 + # LLM 运行设备。设为"auto"会自动检测,也可手动设定为"cuda","mps","cpu"其中之一。 LLM_DEVICE = "auto" +# TextSplitter + +text_splitter_dict = { + "ChineseRecursiveTextSplitter": { + "source": "huggingface", ## 选择tiktoken则使用openai的方法 + "tokenizer_name_or_path": "gpt2", + }, + "SpacyTextSplitter": { + "source": "huggingface", + "tokenizer_name_or_path": "", + }, + "RecursiveCharacterTextSplitter": { + "source": "tiktoken", + "tokenizer_name_or_path": "cl100k_base", + }, + + "MarkdownHeaderTextSplitter": { + "headers_to_split_on": + [ + ("#", "head1"), + ("##", "head2"), + ("###", "head3"), + ("####", "head4"), + ] + }, +} + +# TEXT_SPLITTER 名称 +TEXT_SPLITTER_NAME = "SpacyTextSplitter" + + + +# 知识库中单段文本长度(不适用MarkdownHeaderTextSplitter) +CHUNK_SIZE = 250 + +# 知识库中相邻文本重合长度(不适用MarkdownHeaderTextSplitter) +OVERLAP_SIZE = 50 + + # 日志存储路径 LOG_PATH = os.path.join(os.path.dirname(os.path.dirname(__file__)), "logs") if not os.path.exists(LOG_PATH): @@ -132,6 +172,7 @@ if not os.path.exists(KB_ROOT_PATH): DB_ROOT_PATH = os.path.join(KB_ROOT_PATH, "info.db") SQLALCHEMY_DATABASE_URI = f"sqlite:///{DB_ROOT_PATH}" + # 可选向量库类型及对应配置 kbs_config = { "faiss": { @@ -154,12 +195,6 @@ DEFAULT_VS_TYPE = "faiss" # 缓存向量库数量 CACHED_VS_NUM = 1 -# 知识库中单段文本长度 -CHUNK_SIZE = 250 - -# 知识库中相邻文本重合长度 -OVERLAP_SIZE = 50 - # 知识库匹配向量数量 VECTOR_SEARCH_TOP_K = 5 diff --git a/configs/server_config.py.example b/configs/server_config.py.example index bc21d9a..3107c1d 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -59,7 +59,6 @@ FSCHAT_MODEL_WORKERS = { # "limit_worker_concurrency": 5, # "stream_interval": 2, # "no_register": False, - # "embed_in_truncate": False, }, "baichuan-7b": { # 使用default中的IP和端口 "device": "cpu", diff --git a/docs/splitter.md b/docs/splitter.md new file mode 100644 index 0000000..144d827 --- /dev/null +++ b/docs/splitter.md @@ -0,0 +1,23 @@ +## 自定义分词器 + +### 在哪里写,哪些文件要改 +1. 在```text_splitter```文件夹下新建一个文件,文件名为您的分词器名字,比如`my_splitter.py`,然后在`__init__.py`中导入您的分词器,如下所示: +```python +from .my_splitter import MySplitter +``` +2. 修改```config/model_config.py```文件,将您的分词器名字添加到```text_splitter_dict```中,如下所示: +```python +MySplitter": { + "source": "huggingface", ## 选择tiktoken则使用openai的方法 + "tokenizer_name_or_path": "your tokenizer", + "chunk_size": 250, + "overlap_size": 50, + }, +TEXT_SPLITTER_NAME = "MySplitter" +``` +完成上述步骤后,就能使用自己的分词器了。 + +### 如何贡献您的分词器 +1. 将您的分词器所在的代码文件放在```text_splitter```文件夹下,文件名为您的分词器名字,比如`my_splitter.py`,然后在`__init__.py`中导入您的分词器。 +2. 发起PR,并说明您的分词器面向的场景或者改进之处。我们非常期待您能举例一个具体的应用场景。 +3. 在Readme.md中添加您的分词器的使用方法和支持说明。 diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 5d1aac8..b26f2cb 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -50,6 +50,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", model_name: str = LLM_MODEL, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() + model = ChatOpenAI( streaming=True, verbose=True, diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 896d9b4..b2073a1 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -1,5 +1,4 @@ -from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, - CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, +from configs.model_config import (EMBEDDING_MODEL, DEFAULT_VS_TYPE, ZH_TITLE_ENHANCE, logger, log_verbose) from server.knowledge_base.utils import (get_file_path, list_kbs_from_folder, list_files_from_folder,files2docs_in_thread, @@ -38,8 +37,8 @@ def folder2db( mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"], vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, embed_model: str = EMBEDDING_MODEL, - chunk_size: int = CHUNK_SIZE, - chunk_overlap: int = OVERLAP_SIZE, + chunk_size: int = -1, + chunk_overlap: int = -1, zh_title_enhance: bool = ZH_TITLE_ENHANCE, ): ''' diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index e8b9103..bda727c 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,15 +1,14 @@ import os -from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.embeddings import HuggingFaceBgeEmbeddings + +from transformers import AutoTokenizer + from configs.model_config import ( - embedding_model_dict, EMBEDDING_MODEL, KB_ROOT_PATH, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE, - logger, log_verbose, + logger, log_verbose, text_splitter_dict, llm_model_dict, LLM_MODEL,TEXT_SPLITTER_NAME ) import importlib from text_splitter import zh_title_enhance @@ -178,7 +177,7 @@ def get_loader(loader_name: str, file_path_or_content: Union[str, bytes, io.Stri def make_text_splitter( - splitter_name: str = "SpacyTextSplitter", + splitter_name: str = TEXT_SPLITTER_NAME, chunk_size: int = CHUNK_SIZE, chunk_overlap: int = OVERLAP_SIZE, ): @@ -186,23 +185,57 @@ def make_text_splitter( 根据参数获取特定的分词器 ''' splitter_name = splitter_name or "SpacyTextSplitter" - text_splitter_module = importlib.import_module('langchain.text_splitter') try: - TextSplitter = getattr(text_splitter_module, splitter_name) - text_splitter = TextSplitter( - pipeline="zh_core_web_sm", - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - ) + if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定 + headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on'] + text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter( + headers_to_split_on=headers_to_split_on) + else: + + try: ## 优先使用用户自定义的text_splitter + text_splitter_module = importlib.import_module('text_splitter') + TextSplitter = getattr(text_splitter_module, splitter_name) + except: ## 否则使用langchain的text_splitter + text_splitter_module = importlib.import_module('langchain.text_splitter') + TextSplitter = getattr(text_splitter_module, splitter_name) + + if text_splitter_dict[splitter_name]["source"] == "tiktoken": + try: + text_splitter = TextSplitter.from_tiktoken_encoder( + encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], + pipeline="zh_core_web_sm", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + except: + text_splitter = TextSplitter.from_tiktoken_encoder( + encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + elif text_splitter_dict[splitter_name]["source"] == "huggingface": + if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "": + text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \ + llm_model_dict[LLM_MODEL]["local_model_path"] + + if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2": + from transformers import GPT2TokenizerFast + from langchain.text_splitter import CharacterTextSplitter + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") ## 这里选择你用的tokenizer + else: + tokenizer = AutoTokenizer.from_pretrained( + text_splitter_dict[splitter_name]["tokenizer_name_or_path"], + trust_remote_code=True) + text_splitter = TextSplitter.from_huggingface_tokenizer( + tokenizer=tokenizer, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) except Exception as e: - msg = f"查找分词器 {splitter_name} 时出错:{e}" - logger.error(f'{e.__class__.__name__}: {msg}', - exc_info=e if log_verbose else None) + print(e) + text_splitter_module = importlib.import_module('langchain.text_splitter') TextSplitter = getattr(text_splitter_module, "RecursiveCharacterTextSplitter") - text_splitter = TextSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - ) + text_splitter = TextSplitter(chunk_size=250, chunk_overlap=50) return text_splitter class KnowledgeFile: @@ -223,9 +256,7 @@ class KnowledgeFile: self.docs = None self.splited_docs = None self.document_loader_name = get_LoaderClass(self.ext) - - # TODO: 增加依据文件格式匹配text_splitter - self.text_splitter_name = None + self.text_splitter_name = TEXT_SPLITTER_NAME def file2docs(self, refresh: bool=False): if self.docs is None or refresh: @@ -249,7 +280,14 @@ class KnowledgeFile: if self.ext not in [".csv"]: if text_splitter is None: text_splitter = make_text_splitter(splitter_name=self.text_splitter_name, chunk_size=chunk_size, chunk_overlap=chunk_overlap) - docs = text_splitter.split_documents(docs) + if self.text_splitter_name == "MarkdownHeaderTextSplitter": + docs = text_splitter.split_text(docs[0].page_content) + for doc in docs: + # 如果文档有元数据 + if doc.metadata: + doc.metadata["source"] = os.path.basename(self.filepath) + else: + docs = text_splitter.split_documents(docs) print(f"文档切分示例:{docs[0]}") if zh_title_enhance: @@ -335,4 +373,4 @@ if __name__ == "__main__": pprint(docs[-1]) docs = kb_file.file2text() - pprint(docs[-1]) + pprint(docs[-1]) \ No newline at end of file diff --git a/startup.py b/startup.py index d045849..c3f1536 100644 --- a/startup.py +++ b/startup.py @@ -18,7 +18,7 @@ except: sys.path.append(os.path.dirname(os.path.dirname(__file__))) from configs.model_config import EMBEDDING_MODEL, llm_model_dict, LLM_MODEL, LOG_PATH, \ - logger, log_verbose + logger, log_verbose, TEXT_SPLITTER_NAME from configs.server_config import (WEBUI_SERVER, API_SERVER, FSCHAT_CONTROLLER, FSCHAT_OPENAI_API, HTTPX_DEFAULT_TIMEOUT) from server.utils import (fschat_controller_address, fschat_model_worker_address, @@ -464,7 +464,10 @@ def dump_server_info(after_start=False, args=None): models = [LLM_MODEL] if args and args.model_name: models = args.model_name + + print(f"当前使用的分词器:{TEXT_SPLITTER_NAME}") print(f"当前启动的LLM模型:{models} @ {llm_device()}") + for model in models: pprint(llm_model_dict[model]) print(f"当前Embbedings模型: {EMBEDDING_MODEL} @ {embedding_device()}") diff --git a/tests/custom_splitter/text_different_splitter.py b/tests/custom_splitter/text_different_splitter.py new file mode 100644 index 0000000..8cda776 --- /dev/null +++ b/tests/custom_splitter/text_different_splitter.py @@ -0,0 +1,89 @@ +import os + +from transformers import AutoTokenizer +import sys + +sys.path.append("../..") +from configs.model_config import ( + CHUNK_SIZE, + OVERLAP_SIZE, + text_splitter_dict, llm_model_dict, LLM_MODEL, TEXT_SPLITTER_NAME +) +import langchain.document_loaders +import importlib + + +def text_different_splitter(splitter_name, chunk_size: int = CHUNK_SIZE, + chunk_overlap: int = OVERLAP_SIZE, ): + if splitter_name == "MarkdownHeaderTextSplitter": # MarkdownHeaderTextSplitter特殊判定 + headers_to_split_on = text_splitter_dict[splitter_name]['headers_to_split_on'] + text_splitter = langchain.text_splitter.MarkdownHeaderTextSplitter( + headers_to_split_on=headers_to_split_on) + + else: + + try: ## 优先使用用户自定义的text_splitter + text_splitter_module = importlib.import_module('text_splitter') + TextSplitter = getattr(text_splitter_module, splitter_name) + except: ## 否则使用langchain的text_splitter + text_splitter_module = importlib.import_module('langchain.text_splitter') + TextSplitter = getattr(text_splitter_module, splitter_name) + + if text_splitter_dict[splitter_name]["source"] == "tiktoken": + try: + text_splitter = TextSplitter.from_tiktoken_encoder( + encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], + pipeline="zh_core_web_sm", + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + except: + text_splitter = TextSplitter.from_tiktoken_encoder( + encoding_name=text_splitter_dict[splitter_name]["tokenizer_name_or_path"], + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + elif text_splitter_dict[splitter_name]["source"] == "huggingface": + if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "": + text_splitter_dict[splitter_name]["tokenizer_name_or_path"] = \ + llm_model_dict[LLM_MODEL]["local_model_path"] + + if text_splitter_dict[splitter_name]["tokenizer_name_or_path"] == "gpt2": + from transformers import GPT2TokenizerFast + from langchain.text_splitter import CharacterTextSplitter + tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") ## 这里选择你用的tokenizer + else: + tokenizer = AutoTokenizer.from_pretrained( + text_splitter_dict[splitter_name]["tokenizer_name_or_path"], + trust_remote_code=True) + text_splitter = TextSplitter.from_huggingface_tokenizer( + tokenizer=tokenizer, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap + ) + + return text_splitter + + +if __name__ == "__main__": + from langchain import document_loaders + + # 使用DocumentLoader读取文件 + filepath = "../../knowledge_base/samples/content/test.txt" + loader = document_loaders.UnstructuredFileLoader(filepath, autodetect_encoding=True) + docs = loader.load() + text_splitter = text_different_splitter(TEXT_SPLITTER_NAME, CHUNK_SIZE, OVERLAP_SIZE) + # 使用text_splitter进行分词 + + if TEXT_SPLITTER_NAME == "MarkdownHeaderTextSplitter": + split_docs = text_splitter.split_text(docs[0].page_content) + for doc in docs: + # 如果文档有元数据 + if doc.metadata: + doc.metadata["source"] = os.path.basename(filepath) + else: + split_docs = text_splitter.split_documents(docs) + + # 打印分词结果 + for doc in split_docs: + print(doc) diff --git a/text_splitter/chinese_text_splitter.py b/text_splitter/chinese_text_splitter.py index d6294ae..4107b25 100644 --- a/text_splitter/chinese_text_splitter.py +++ b/text_splitter/chinese_text_splitter.py @@ -1,11 +1,10 @@ from langchain.text_splitter import CharacterTextSplitter import re from typing import List -from configs.model_config import CHUNK_SIZE class ChineseTextSplitter(CharacterTextSplitter): - def __init__(self, pdf: bool = False, sentence_size: int = CHUNK_SIZE, **kwargs): + def __init__(self, pdf: bool = False, sentence_size: int = 250, **kwargs): super().__init__(**kwargs) self.pdf = pdf self.sentence_size = sentence_size