30 lines
1.0 KiB
Python
30 lines
1.0 KiB
Python
import os
|
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
|
from configs.model_config import (embedding_model_dict, KB_ROOT_PATH)
|
|
from functools import lru_cache
|
|
|
|
|
|
def validate_kb_name(knowledge_base_id: str) -> bool:
|
|
# 检查是否包含预期外的字符或路径攻击关键字
|
|
if "../" in knowledge_base_id:
|
|
return False
|
|
return True
|
|
|
|
def get_kb_path(knowledge_base_name: str):
|
|
return os.path.join(KB_ROOT_PATH, knowledge_base_name)
|
|
|
|
def get_doc_path(knowledge_base_name: str):
|
|
return os.path.join(get_kb_path(knowledge_base_name), "content")
|
|
|
|
def get_vs_path(knowledge_base_name: str):
|
|
return os.path.join(get_kb_path(knowledge_base_name), "vector_store")
|
|
|
|
def get_file_path(knowledge_base_name: str, doc_name: str):
|
|
return os.path.join(get_doc_path(knowledge_base_name), doc_name)
|
|
|
|
@lru_cache(1)
|
|
def load_embeddings(model: str, device: str):
|
|
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
|
model_kwargs={'device': device})
|
|
return embeddings
|