Langchain-Chatchat/embeddings/add_embedding_keywords.py

56 lines
2.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

'''
该功能是为了将关键词加入到embedding模型中以便于在embedding模型中进行关键词的embedding
该功能的实现是通过修改embedding模型的tokenizer来实现的
该功能仅仅对EMBEDDING_MODEL参数对应的的模型有效输出后的模型保存在原本模型
该功能的Idea由社区贡献感谢@CharlesJu1
保存的模型的位置位于原本嵌入模型的目录下,模型的名称为原模型名称+Merge_Keywords_时间戳
'''
import sys
sys.path.append("..")
import os
from safetensors.torch import save_model
from sentence_transformers import SentenceTransformer
from datetime import datetime
from configs import (
MODEL_PATH,
EMBEDDING_MODEL,
EMBEDDING_KEYWORD_FILE,
)
def add_keyword_to_model(model_name: str = EMBEDDING_MODEL, keyword_file: str = "", output_model_path: str = None):
key_words = []
with open(keyword_file, "r") as f:
for line in f:
key_words.append(line.strip())
model = SentenceTransformer(model_name)
word_embedding_model = model._first_module()
tokenizer = word_embedding_model.tokenizer
tokenizer.add_tokens(key_words)
word_embedding_model.auto_model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=32)
if output_model_path:
os.makedirs(output_model_path, exist_ok=True)
tokenizer.save_pretrained(output_model_path)
model.save(output_model_path)
safetensors_file = os.path.join(output_model_path, "model.safetensors")
metadata = {'format': 'pt'}
save_model(model, safetensors_file, metadata)
def add_keyword_to_embedding_model(path: str = EMBEDDING_KEYWORD_FILE):
keyword_file = os.path.join(path)
model_name = MODEL_PATH["embed_model"][EMBEDDING_MODEL]
model_parent_directory = os.path.dirname(model_name)
current_time = datetime.now().strftime('%Y%m%d_%H%M%S')
output_model_name = "{}_Merge_Keywords_{}".format(EMBEDDING_MODEL, current_time)
output_model_path = os.path.join(model_parent_directory, output_model_name)
add_keyword_to_model(model_name, keyword_file, output_model_path)
print("save model to {}".format(output_model_path))
if __name__ == '__main__':
add_keyword_to_embedding_model(EMBEDDING_KEYWORD_FILE)