Source code for hmfast.tracers.cmb_lensing

import os
import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from jax.scipy.special import sici, erf 

from hmfast.tracers.base_tracer import Tracer
from hmfast.download import _get_default_data_path
from hmfast.utils import Const
from hmfast.halos.profiles import MatterProfile, NFWMatterProfile
jax.config.update("jax_enable_x64", True)


[docs] class CMBLensingTracer(Tracer): """ CMB weak lensing tracer. Attributes ---------- profile : MatterProfile Matter profile used to model the CMB lensing convergence signal. """ _required_profile_type = MatterProfile def __init__(self, profile=None): super().__init__(profile=profile or NFWMatterProfile()) def _tree_flatten(self): # We treat the profile as the only leaf. leaves = (self.profile,) aux_data = None return (leaves, aux_data) @classmethod def _tree_unflatten(cls, aux_data, leaves): profile, = leaves obj = cls.__new__(cls) obj.profile = profile return obj
[docs] def update(self, profile=None): """ Return a new CMBLensingTracer instance with updated attributes using PyTree logic. Parameters ---------- profile : MatterProfile, optional New matter profile to use for the tracer. If None, the profile is unchanged. Returns ------- CMBLensingTracer New tracer instance with updated attributes. """ flat, aux = self._tree_flatten() if profile is not None: flat = (profile,) return self._tree_unflatten(aux, flat)
[docs] def kernel(self, cosmology, z): """ Compute the CMB lensing kernel :math:`W_{\\kappa,\\mathrm{cmb}}(z)` at redshift :math:`z`. The kernel is given by: .. math:: W_{\\kappa_{\\mathrm{CMB}}}(z) = \\frac{3}{2} \\Omega_m \\left(\\frac{H_0}{c}\\right)^2 \\frac{(1+z)}{\\chi(z)} \\frac{\\chi_* - \\chi(z)}{\\chi_*} where :math:`\\chi_*` is the comoving distance to the last scattering surface. Parameters ---------- cosmology : Cosmology Cosmology object with required methods and parameters. z : float or array_like Redshift(s) at which to compute the kernel. Returns ------- W_kappa_cmb : array_like CMB lensing kernel evaluated at redshift(s) :math:`z`. """ # Merge default parameters with input cparams = cosmology._cosmo_params() z = jnp.atleast_1d(z) # Ensure z is an array # Cosmological constants H0 = cosmology.H0 # Hubble constant in km/s/Mpc Omega_m = cparams["Omega0_m"] # Matter density parameter c_km_s = Const._c_ / 1e3 # Speed of light in km/s # Compute comoving distances in physical Mpc. chi_z = cosmology.angular_diameter_distance(z) * (1 + z) # Comoving distance to the last scattering surface (z ~ 1090) in physical Mpc. chi_z_cmb = cosmology.derived_parameters()["chi_star"] # Compute the CMB lensing kernel W_kappa_cmb = ( (3.0 / 2.0) * Omega_m * (H0/c_km_s)**2 * (1 + z) / chi_z * ((chi_z_cmb - chi_z) / chi_z_cmb) ) return W_kappa_cmb
jax.tree_util.register_pytree_node( CMBLensingTracer, lambda obj: obj._tree_flatten(), lambda aux_data, children: CMBLensingTracer._tree_unflatten(aux_data, children) )