from langchain.embeddings.huggingface import HuggingFaceEmbeddings from langchain.embeddings.openai import OpenAIEmbeddings from langchain.embeddings import HuggingFaceBgeEmbeddings from langchain.embeddings.base import Embeddings from langchain.schema import Document import threading from configs import (EMBEDDING_MODEL, CHUNK_SIZE, CACHED_VS_NUM, logger, log_verbose) from server.utils import embedding_device, get_model_path from contextlib import contextmanager from collections import OrderedDict from typing import List, Any, Union, Tuple class ThreadSafeObject: def __init__(self, key: Union[str, Tuple], obj: Any = None, pool: "CachePool" = None): self._obj = obj self._key = key self._pool = pool self._lock = threading.RLock() self._loaded = threading.Event() def __repr__(self) -> str: cls = type(self).__name__ return f"<{cls}: key: {self.key}, obj: {self._obj}>" @property def key(self): return self._key @contextmanager def acquire(self, owner: str = "", msg: str = ""): owner = owner or f"thread {threading.get_native_id()}" try: self._lock.acquire() if self._pool is not None: self._pool._cache.move_to_end(self.key) if log_verbose: logger.info(f"{owner} 开始操作:{self.key}。{msg}") yield self._obj finally: if log_verbose: logger.info(f"{owner} 结束操作:{self.key}。{msg}") self._lock.release() def start_loading(self): self._loaded.clear() def finish_loading(self): self._loaded.set() def wait_for_loading(self): self._loaded.wait() @property def obj(self): return self._obj @obj.setter def obj(self, val: Any): self._obj = val class CachePool: def __init__(self, cache_num: int = -1): self._cache_num = cache_num self._cache = OrderedDict() self.atomic = threading.RLock() def keys(self) -> List[str]: return list(self._cache.keys()) def _check_count(self): if isinstance(self._cache_num, int) and self._cache_num > 0: while len(self._cache) > self._cache_num: self._cache.popitem(last=False) def get(self, key: str) -> ThreadSafeObject: if cache := self._cache.get(key): cache.wait_for_loading() return cache def set(self, key: str, obj: ThreadSafeObject) -> ThreadSafeObject: self._cache[key] = obj self._check_count() return obj def pop(self, key: str = None) -> ThreadSafeObject: if key is None: return self._cache.popitem(last=False) else: return self._cache.pop(key, None) def acquire(self, key: Union[str, Tuple], owner: str = "", msg: str = ""): cache = self.get(key) if cache is None: raise RuntimeError(f"请求的资源 {key} 不存在") elif isinstance(cache, ThreadSafeObject): self._cache.move_to_end(key) return cache.acquire(owner=owner, msg=msg) else: return cache def load_kb_embeddings(self, kb_name: str=None, embed_device: str = embedding_device()) -> Embeddings: from server.db.repository.knowledge_base_repository import get_kb_detail kb_detail = get_kb_detail(kb_name=kb_name) print(kb_detail) embed_model = kb_detail.get("embed_model", EMBEDDING_MODEL) return embeddings_pool.load_embeddings(model=embed_model, device=embed_device) class EmbeddingsPool(CachePool): def load_embeddings(self, model: str, device: str) -> Embeddings: self.atomic.acquire() model = model or EMBEDDING_MODEL device = device or embedding_device() key = (model, device) if not self.get(key): item = ThreadSafeObject(key, pool=self) self.set(key, item) with item.acquire(msg="初始化"): self.atomic.release() if model == "text-embedding-ada-002": # openai text-embedding-ada-002 embeddings = OpenAIEmbeddings(openai_api_key=get_model_path(model), chunk_size=CHUNK_SIZE) elif 'bge-' in model: embeddings = HuggingFaceBgeEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device}, query_instruction="为这个句子生成表示以用于检索相关文章:") if model == "bge-large-zh-noinstruct": # bge large -noinstruct embedding embeddings.query_instruction = "" else: embeddings = HuggingFaceEmbeddings(model_name=get_model_path(model), model_kwargs={'device': device}) item.obj = embeddings item.finish_loading() else: self.atomic.release() return self.get(key).obj embeddings_pool = EmbeddingsPool(cache_num=1)