Source code for hmfast.tracers.cib

import os
import numpy as np
import jax
import jax.numpy as jnp

from hmfast.tracers.base_tracer import Tracer
from hmfast.halos.profiles import CIBProfile, S12CIBProfile
from hmfast.utils import lambertw, Const
from hmfast.download import _get_default_data_path

[docs] class CIBTracer(Tracer): """ Cosmic infrared background tracer. Attributes ---------- profile : CIBProfile Infrared emissivity profile used to model the cosmic infrared background signal. """ _required_profile_type = CIBProfile def __init__(self, profile=None): super().__init__(profile=profile or S12CIBProfile(nu=100)) # --- JAX PyTree Registration --- def _tree_flatten(self): # The Tracer's only dynamic component is the Profile PyTree leaves = (self.profile,) aux_data = None return (leaves, aux_data) @classmethod def _tree_unflatten(cls, aux_data, leaves): profile, = leaves obj = cls.__new__(cls) obj.profile = profile return obj
[docs] def update(self, profile=None): """ Return a new CIBTracer instance with updated attributes using PyTree logic. Parameters ---------- profile : CIBProfile, optional New CIB profile to use for the tracer. If None, the profile is unchanged. Returns ------- CIBTracer New tracer instance with updated attributes. """ flat, aux = self._tree_flatten() if profile is not None: flat = (profile,) return self._tree_unflatten(aux, flat)
[docs] def kernel(self, cosmology, z): chi = cosmology.angular_diameter_distance(z) * (1 + z) return 1.0 / ((1.0 + z) * chi**2)
jax.tree_util.register_pytree_node( CIBTracer, lambda obj: obj._tree_flatten(), lambda aux_data, children: CIBTracer._tree_unflatten(aux_data, children) )