Source code for BaryonForge.utils.Cache


import numpy as np
import pyccl as ccl
from collections import OrderedDict
from ..Profiles.Base import BaseBFGProfiles

__all__ = ['SimpleArrayCache', 'CachedProfile']

[docs]class SimpleArrayCache: """ A lightweight LRU-style cache designed for functions whose inputs include NumPy arrays. Unlike ``functools.lru_cache``, this cache supports unhashable arguments (e.g. ``numpy.ndarray``) by constructing a stable byte-based key from the array contents. The cache stores results keyed by a tuple that is dynamically generated according to the function being cached: - floats, ints, and strings are saved as single values - For each array argument, we store: * its shape, * its dtype, * its raw byte buffer from ``.tobytes()``. - Any lists and tuples are converted to arrays and follow the above - Other objects (custom classes) are converted to string representations When used as a decorator, the cache wraps a function of the form ``func(*args)`` and automatically caches its return value based on these arguments. Repeated calls with identical inputs return the cached result without re-evaluating the function. Parameters ---------- maxsize : int, optional Maximum number of cached entries to store. The cache evicts the least recently used (LRU) entry when the limit is exceeded. Default is 64 Notes ----- - This cache treats the *contents* of arrays as part of the key, so even small differences in floating-point values produce distinct cache entries. - Custom classes are converted to a string and used verbatim. Two custom objects that print identically will collide. - The cache is implemented using ``collections.OrderedDict`` and maintains LRU behavior manually. Examples -------- >>> cached_func = SimpleArrayCache(maxsize=64)(func) """ def __init__(self, maxsize = 32): self.maxsize = maxsize self._store = OrderedDict() def _key(self, *args): key = [] for a in args: if isinstance(a, (int, float, str)): key.append(a) elif isinstance(a, (list, tuple)): a = np.array(a) key.append(a.shape) key.append(a.dtype.str) key.append(a.tobytes()) elif isinstance(a, (np.ndarray)): key.append(a.shape) key.append(a.dtype.str) key.append(a.tobytes()) else: key.append(str(a)) return tuple(key)
[docs] def get(self, *args): k = self._key(*args) if k in self._store: self._store.move_to_end(k) return self._store[k] return None
[docs] def set(self, value, *args): k = self._key(*args) self._store[k] = value self._store.move_to_end(k) if len(self._store) > self.maxsize: self._store.popitem(last=False)
def __call__(self, func): def cached_func(*args): cached = self.get(*args) if cached is not None: return cached val = func(*args) self.set(val, *args) return val return cached_func
[docs]class CachedProfile(BaseBFGProfiles): """ A class that will cache the profile evaluations for the real, projected, and fourier methods. This class will take in a BaryonForge (BFG) class and cache its results. It is useful for halo model P(k) calculations, where the same masses, redshifts, wavenumbers/radii are evaluated many times. See also `TabulatedProfile` if you want to only store a sparser grid. Parameters ---------- Profile : object A profile that we want to cache. Can either be a vanilla CCL profile or a BaryonForge Profile. """ def __init__(self, Profile, maxsize = 64, methods = ['real', 'projected', 'fourier']): assert isinstance(methods, list), f"You passed methods = {methods}, but we need a list of strings" self.Profile = Profile self.maxsize = maxsize self.methods = methods for m in self.methods: setattr(self, m, SimpleArrayCache(self.maxsize)(getattr(self.Profile, m))) #We just set this to the same as the inputted profile. super().__init__(mass_def = self.Profile.mass_def) self.update_precision_fftlog(**self.Profile.precision_fftlog.to_dict()) def __getattr__(self, key): safe_keys = self.methods + ['Profile', 'maxsize'] if key in safe_keys: return object.__getattribute__(self, key) else: return getattr(object.__getattribute__(self, 'Profile'), key) def __str_prf__(self): return f"Cached[{self.Profile.__str_prf__()}]" def __str_par__(self): return self.Profile.__str_par__()
class CachedHODProfile(CachedProfile, ccl.halos.profiles.hod.HaloProfileHOD): def __init__(self, Profile, maxsize = 64, methods = ['get_normalization', '_fourier_variance', '_fourier', 'fourier', 'real']): self.Profile = Profile self.maxsize = maxsize self.methods = methods for m in self.methods: setattr(self, m, SimpleArrayCache(self.maxsize)(getattr(self.Profile, m))) #We just set this to the same as the inputted profile. ccl.halos.profiles.hod.HaloProfileHOD.__init__(self, mass_def = self.Profile.mass_def) self.update_precision_fftlog(**self.Profile.precision_fftlog.to_dict())