import os
import numpy as np
import jax
import jax.numpy as jnp
import mcfit
from functools import partial
from hmfast.download import _get_default_data_path
from hmfast.utils import lambertw, Const
from hmfast.halos.massdef import MassDefinition
from hmfast.halos.profiles import HaloProfile, HankelTransform
[docs]
class DensityProfile(HaloProfile):
"""
Parent ICM density profile class from which density profile classes inherit.
Child profile classes must implement :meth:`real` and :meth:`fourier`.
"""
pass
[docs]
class B16DensityProfile(DensityProfile):
"""
Electron density profile from `Battaglia et al. (2016) <https://ui.adsabs.harvard.edu/abs/2016JCAP...08..058B/abstract>`_.
The profile is evaluated as a function of the comoving radius
:math:`r`, while its shape is defined using the physical
:math:`200c` radius:
.. math::
\\rho_{\\mathrm{gas,free}}(r)
= f_b f_{\\mathrm{free}} \\rho_{\\mathrm{crit}}(z) \\, C
\\left(\\frac{x_{200c}}{x_c}\\right)^{\\gamma}
\\left[1 + \\left(\\frac{x_{200c}}{x_c}\\right)^{\\alpha}\\right]^{-\\frac{\\beta+\\gamma}{\\alpha}}
\\tag{1}
where :math:`x_{200c} = r / r_{200c}` and :math:`r_{200c}` has the same
units as :math:`r`. With :math:`x_c = 0.5` and
:math:`\\gamma = -0.2` fixed, the mass- and redshift-dependent parameters
obey
.. math::
X(M_{200c}, z) = A_X
\\left(\\frac{M_{200c} / h}{10^{14} M_\\odot}\\right)^{\\alpha_m^X}
(1 + z)^{\\alpha_z^X}
\\tag{2}
where :math:`X \\in \\{C, \\alpha, \\beta\\}`.
The Fourier-space density profile used by the halo model is evaluated as
.. math::
u_k(k, M, z) =
4 \\pi \\, r_\\Delta^3 \\, (1+z)^3
\\int dx \\, x^2 \\, \\rho(x, M, z)
\\, \\frac{\\sin\\!\\left[(k r_\\Delta) x\\right]}
{(k r_\\Delta) x}
\\tag{3}
where :math:`x = r / [(1+z) r_\\Delta]`.
Attributes
----------
x : jnp.ndarray
Dimensionless radial grid :math:`x = r / [(1+z) r_\\Delta]` used to tabulate the profile and define the Hankel transform.
A_rho0 : float
Amplitude :math:`A_C` controlling the normalization of the density profile.
A_alpha : float
Amplitude :math:`A_\\alpha` controlling the transition width.
A_beta : float
Amplitude :math:`A_\\beta` controlling the outer slope.
alpha_m_rho0 : float
Mass-scaling exponent :math:`\\alpha_m^C`.
alpha_m_alpha : float
Mass-scaling exponent :math:`\\alpha_m^\\alpha`.
alpha_m_beta : float
Mass-scaling exponent :math:`\\alpha_m^\\beta`.
alpha_z_rho0 : float
Redshift-scaling exponent :math:`\\alpha_z^C`.
alpha_z_alpha : float
Redshift-scaling exponent :math:`\\alpha_z^\\alpha`.
alpha_z_beta : float
Redshift-scaling exponent :math:`\\alpha_z^\\beta`.
"""
def __init__(self, x=None,
A_rho0=4000.0, A_alpha=0.88, A_beta=3.83,
alpha_m_rho0=0.29, alpha_m_alpha=-0.03, alpha_m_beta=0.04,
alpha_z_rho0=-0.66, alpha_z_alpha=0.19, alpha_z_beta=-0.025,
):
# Grid initialization (triggers the x.setter)
self.x = x if x is not None else jnp.logspace(-4, 1, 256)
self.A_rho0, self.A_alpha, self.A_beta = A_rho0, A_alpha, A_beta
self.alpha_m_rho0, self.alpha_m_alpha, self.alpha_m_beta = alpha_m_rho0, alpha_m_alpha, alpha_m_beta
self.alpha_z_rho0, self.alpha_z_alpha, self.alpha_z_beta = alpha_z_rho0, alpha_z_alpha, alpha_z_beta
@property
def x(self):
return self._x
@x.setter
def x(self, value):
self._x = value
self._hankel = HankelTransform(self._x, nu=0.5)
def _tree_flatten(self):
# Dynamic calibration parameters
leaves = (
self.A_rho0, self.A_alpha, self.A_beta,
self.alpha_m_rho0, self.alpha_m_alpha, self.alpha_m_beta,
self.alpha_z_rho0, self.alpha_z_alpha, self.alpha_z_beta
)
# Static metadata
aux_data = (self._x, self._hankel)
return (leaves, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, leaves):
x, hankel = aux_data
obj = cls.__new__(cls)
# Unpack leaves back into attributes
(obj.A_rho0, obj.A_alpha, obj.A_beta,
obj.alpha_m_rho0, obj.alpha_m_alpha, obj.alpha_m_beta,
obj.alpha_z_rho0, obj.alpha_z_alpha, obj.alpha_z_beta) = leaves
obj._x = x
obj._hankel = hankel
return obj
[docs]
def update(self, A_rho0=None, A_alpha=None, A_beta=None,
alpha_m_rho0=None, alpha_m_alpha=None, alpha_m_beta=None,
alpha_z_rho0=None, alpha_z_alpha=None, alpha_z_beta=None):
"""
Return a new profile instance with updated Battaglia density parameters.
Parameters
----------
A_rho0, A_alpha, A_beta, alpha_m_rho0, alpha_m_alpha, alpha_m_beta, alpha_z_rho0, alpha_z_alpha, alpha_z_beta : float, optional
Replacement values for the corresponding class attributes. Any argument left as ``None`` keeps its current value.
Returns
-------
B16DensityProfile
New profile instance with updated parameters.
"""
leaves, treedef = self._tree_flatten()
new_leaves = (
A_rho0 if A_rho0 is not None else self.A_rho0,
A_alpha if A_alpha is not None else self.A_alpha,
A_beta if A_beta is not None else self.A_beta,
alpha_m_rho0 if alpha_m_rho0 is not None else self.alpha_m_rho0,
alpha_m_alpha if alpha_m_alpha is not None else self.alpha_m_alpha,
alpha_m_beta if alpha_m_beta is not None else self.alpha_m_beta,
alpha_z_rho0 if alpha_z_rho0 is not None else self.alpha_z_rho0,
alpha_z_alpha if alpha_z_alpha is not None else self.alpha_z_alpha,
alpha_z_beta if alpha_z_beta is not None else self.alpha_z_beta,
)
return self._tree_unflatten(treedef, new_leaves)
[docs]
@staticmethod
def get_params(model_key="agn"):
"""Static helper to grab Table 2 values."""
presets = {
"agn": {
'A_rho0': 4000.0, 'A_alpha': 0.88, 'A_beta': 3.83,
'alpha_m_rho0': 0.29, 'alpha_m_alpha': -0.03, 'alpha_m_beta': 0.04,
'alpha_z_rho0': -0.66, 'alpha_z_alpha': 0.19, 'alpha_z_beta': -0.025
},
"shock": {
'A_rho0': 1.9e4, 'A_alpha': 0.70, 'A_beta': 4.43,
'alpha_m_rho0': 0.09, 'alpha_m_alpha': -0.017, 'alpha_m_beta': 0.005,
'alpha_z_rho0': -0.95, 'alpha_z_alpha': 0.27, 'alpha_z_beta': 0.037
}
}
key = model_key.lower()
if key not in presets:
raise ValueError(f"Model {model_key} not recognized. Choose 'agn' or 'shock'.")
return presets[key]
[docs]
@partial(jax.jit, static_argnums=(0,))
def real(self, halo_model, r, m, z):
"""
Compute the electron-density profile.
Parameters
----------
halo_model : HaloModel
Halo model providing the cosmology.
r : float or jnp.ndarray
Comoving radius or radii in :math:`\\mathrm{Mpc}`.
m : float or jnp.ndarray
Halo mass or masses in physical :math:`M_\\odot`.
z : float or jnp.ndarray
Redshift(s).
Returns
-------
jnp.ndarray
Electron-density profile with shape :math:`(N_r, N_m, N_z)`,
where singleton dimensions get squeezed before return.
"""
cparams = halo_model.cosmology._cosmo_params()
f_b = cparams["Omega_b"] / cparams["Omega0_m"]
h = cparams["h"]
f_free = 1.0
gamma = -0.2
xc = 0.5
# Ensure 1D and setup broadcasting shapes
r, m, z = jnp.atleast_1d(r), jnp.atleast_1d(m), jnp.atleast_1d(z)
r_b, m_b, z_b = r[:, None, None], m[None, :, None], z[None, None, :]
r_200c = jnp.reshape(halo_model.mass_definition.r_delta(halo_model.cosmology, m, z), (len(m), len(z)))
x_200c = r_b / ((1.0 + z_b) * r_200c[None, :, :])
# Critical density broadcast to (1, 1, Nz) in physical units.
rho_crit_z = jnp.atleast_1d(halo_model.cosmology.critical_density(z))[None, None, :]
# Mass scaling logic
m_200c_msun = m_b
mass_ratio = m_200c_msun / 1e14
# Compute Shape Parameters (Equations A1, A2 from B16)
rho0 = self.A_rho0 * mass_ratio**self.alpha_m_rho0 * (1 + z_b)**self.alpha_z_rho0
alpha = self.A_alpha * mass_ratio**self.alpha_m_alpha * (1 + z_b)**self.alpha_z_alpha
beta = self.A_beta * mass_ratio**self.alpha_m_beta * (1 + z_b)**self.alpha_z_beta
# Profile Shape Function (Nx, Nm, Nz)
p_x = (x_200c / xc)**gamma * (1 + (x_200c / xc)**alpha)**(-(beta + gamma) / alpha)
# Truncate at r_200c so the real-space and Fourier-space profiles
# describe the same finite halo.
rho_gas = rho0 * rho_crit_z * f_b * f_free * p_x
rho_gas = jnp.where(x_200c <= 1.0, rho_gas, 0.0)
return jnp.squeeze(rho_gas)
[docs]
@partial(jax.jit, static_argnums=(0,))
def fourier(self, halo_model, k, m, z):
"""
Compute the projected Fourier-space density profile for halo-model calculations.
Parameters
----------
halo_model : HaloModel
Halo model providing the cosmology and halo-radius relation.
k : float or jnp.ndarray
Comoving wavenumber(s) in :math:`\\mathrm{Mpc}^{-1}`.
m : float or jnp.ndarray
Halo mass or masses in physical :math:`M_\\odot`.
z : float or jnp.ndarray
Redshift(s).
Returns
-------
jnp.ndarray
Transformed profile with shape :math:`(N_k, N_m, N_z)`, where
singleton dimensions get squeezed before return.
"""
k, m, z = jnp.atleast_1d(k), jnp.atleast_1d(m), jnp.atleast_1d(z)
r_delta = jnp.reshape(halo_model.mass_definition.r_delta(halo_model.cosmology, m, z), (len(m), len(z)))
d_A_z = jnp.atleast_1d(halo_model.cosmology.angular_diameter_distance(z))
ell_delta = d_A_z[None, :] / r_delta
chi = d_A_z * (1 + z)
ell_target = k[:, None] * chi[None, :] - 0.5
prefactor = 4 * jnp.pi * r_delta**3 * (1 + z)[None, :]**3
r = self.x[:, None, None] * r_delta[None, :, :] * (1.0 + z[None, None, :])
k_native, u_k_native = self._u_k_hankel(halo_model, self.x, r, m, z)
u_k_native = jnp.reshape(u_k_native, (len(k_native), len(m), len(z)))
u_ell_native = u_k_native * jnp.sqrt(jnp.pi / (2 * k_native[:, None, None]))
ell_native = k_native[:, None, None] * ell_delta[None, :, :]
u_ell_val = prefactor[None, :, :] * u_ell_native
def interp_single_column(target_x, native_x, native_y):
return jnp.interp(target_x, native_x, native_y)
vmapped_interp = jax.vmap(
jax.vmap(interp_single_column, in_axes=(None, 1, 1), out_axes=1),
in_axes=(1, 2, 2), out_axes=2
)
return jnp.squeeze(vmapped_interp(ell_target, ell_native, u_ell_val))
jax.tree_util.register_pytree_node(
B16DensityProfile,
lambda obj: obj._tree_flatten(),
lambda aux_data, children: B16DensityProfile._tree_unflatten(aux_data, children)
)
class _NFWDensityProfile(DensityProfile):
"""
Electron density profile based on `Navarro, Frenk & White (1997) <https://ui.adsabs.harvard.edu/abs/1997ApJ...490..493N/abstract>`_.
The profile is evaluated as a function of the comoving radius
:math:`r` and is obtained by scaling the NFW matter density by the cosmic
baryon fraction,
.. math::
\\rho_e(r, M, z) = f_b \\, f_{\\mathrm{free}} \\, \\rho_{\\mathrm{NFW}}(r)
\\tag{1}
where
.. math::
\\rho_{\\mathrm{NFW}}(r)
= \\frac{\\rho_s}{x_s \\left(1+x_s\\right)^2}
\\tag{2}
.. math::
\\rho_s = \\frac{M}{4\\pi r_s^3}
\\left[\\ln(1+c_\\Delta) - \\frac{c_\\Delta}{1+c_\\Delta}\\right]^{-1}
\\tag{3}
with :math:`x_s = r / r_s`, where :math:`r_s` has the same units as
:math:`r`, and :math:`r_s = r_\\Delta / c_\\Delta`.
The Fourier-space density profile used by the halo model is evaluated as
.. math::
u_k(k, M, z) =
4 \\pi \\, r_s^3 \\, (1+z)^3
\\int dx \\, x^2 \\, \\rho(x, M, z)
\\, \\frac{\\sin\\!\\left[(k r_s) x\\right]}
{(k r_s) x}
\\tag{4}
where :math:`x = r / r_s`, :math:`r_s` has the same units as
:math:`r`, and :math:`f_{\\mathrm{free}} = 1`. Any kSZ-specific weighting
is applied by the tracer kernel rather than by this profile.
Attributes
----------
x : jnp.ndarray
Dimensionless radial grid :math:`x = r / r_s` used to tabulate the profile and define the Hankel transform, with :math:`r_s` expressed in the same units as :math:`r`.
"""
def __init__(self, x=None):
self.x = x if x is not None else jnp.logspace(jnp.log10(1e-4), jnp.log10(1.0), 256)
@property
def x(self):
return self._x
@x.setter
def x(self, value):
"""
Whenever x is modified, immediately rebuild the hankel transform object
"""
self._x = value
self._hankel = HankelTransform(self._x, nu=0.5)
@partial(jax.jit, static_argnums=(0,))
def real(self, halo_model, r, m, z):
"""
Compute the electron-density profile.
Parameters
----------
halo_model : HaloModel
Halo model providing the cosmology, halo radius, and concentration model.
r : float or jnp.ndarray
Comoving radius or radii in :math:`\\mathrm{Mpc}`.
m : float or jnp.ndarray
Halo mass or masses in physical :math:`M_\\odot`.
z : float or jnp.ndarray
Redshift(s).
Returns
-------
jnp.ndarray
Electron-density profile with shape :math:`(N_r, N_m, N_z)`,
where singleton dimensions get squeezed before return.
"""
cparams = halo_model.cosmology._cosmo_params()
r, m, z = jnp.atleast_1d(r), jnp.atleast_1d(m), jnp.atleast_1d(z)
m_internal = m * cparams["h"]
f_b = cparams["Omega_b"] / cparams["Omega0_m"]
# Get scale radius r_s
r_delta = jnp.reshape(halo_model.mass_definition.r_delta(halo_model.cosmology, m, z), (len(m), len(z))) * cparams["h"]
c_delta = jnp.reshape(
halo_model.concentration.c_delta(
halo_model.cosmology,
m,
z,
mass_definition=halo_model.mass_definition,
),
(len(m), len(z)),
)
r_s = r_delta / c_delta # (Nm, Nz)
x_s = r[:, None, None] * cparams["h"] / ((1.0 + z[None, None, :]) * r_s[None, :, :])
# Calculate rho_s
m_nfw = jnp.log(1 + c_delta) - c_delta / (1 + c_delta) # (Nm, Nz)
rho_s = m_internal[:, None] / (4 * jnp.pi * r_s**3 * m_nfw) # (Nm, Nz)
# Truncate at r_delta so the real-space and Fourier-space profiles
# use the same finite-mass NFW definition.
rho_gas = f_b * rho_s[None, :, :] / (x_s * (1 + x_s)**2)
rho_gas = jnp.where(x_s <= c_delta[None, :, :], rho_gas, 0.0)
return jnp.squeeze(rho_gas)
@partial(jax.jit, static_argnums=(0,))
def fourier(self, halo_model, k, m, z):
"""
Compute the projected Fourier-space density profile for halo-model calculations.
Parameters
----------
halo_model : HaloModel
Halo model providing the cosmology, halo radius, and concentration model.
k : float or jnp.ndarray
Comoving wavenumber(s) in :math:`\\mathrm{Mpc}^{-1}`.
m : float or jnp.ndarray
Halo mass or masses in physical :math:`M_\\odot`.
z : float or jnp.ndarray
Redshift(s).
Returns
-------
jnp.ndarray
Transformed profile with shape :math:`(N_k, N_m, N_z)`, where
singleton dimensions get squeezed before return.
"""
k, m, z = jnp.atleast_1d(k), jnp.atleast_1d(m), jnp.atleast_1d(z)
r_delta = jnp.reshape(halo_model.mass_definition.r_delta(halo_model.cosmology, m, z), (len(m), len(z)))
c_delta = jnp.reshape(
halo_model.concentration.c_delta(
halo_model.cosmology,
m,
z,
mass_definition=halo_model.mass_definition,
),
(len(m), len(z)),
)
r_s = r_delta / c_delta
d_A_z = jnp.atleast_1d(halo_model.cosmology.angular_diameter_distance(z))
ell_s = d_A_z[None, :] / r_s
chi = d_A_z * (1 + z)
ell_target = k[:, None] * chi[None, :] - 0.5
prefactor = 4 * jnp.pi * r_s**3 * (1 + z)[None, :]**3
r = self.x[:, None, None] * r_s[None, :, :] * (1.0 + z[None, None, :])
k_native, u_k_native = self._u_k_hankel(halo_model, self.x, r, m, z)
u_k_native = jnp.reshape(u_k_native, (len(k_native), len(m), len(z)))
u_ell_native = u_k_native * jnp.sqrt(jnp.pi / (2 * k_native[:, None, None]))
ell_native = k_native[:, None, None] * ell_s[None, :, :]
u_ell_val = prefactor[None, :, :] * u_ell_native
def interp_single_column(target_x, native_x, native_y):
return jnp.interp(target_x, native_x, native_y)
vmapped_interp = jax.vmap(
jax.vmap(interp_single_column, in_axes=(None, 1, 1), out_axes=1),
in_axes=(1, 2, 2), out_axes=2
)
return jnp.squeeze(vmapped_interp(ell_target, ell_native, u_ell_val))
class _BCMDensityProfile(DensityProfile):
"""
Electron density profile from `Schneider et al. (2019) <https://ui.adsabs.harvard.edu/abs/2019JCAP...03..020S/abstract>`_,
also known as the Baryon Correction Model (BCM).
The profile is evaluated as a function of the comoving radius
:math:`r`, with shape defined relative to the physical virial radius
through :math:`x_{\\mathrm{vir}} = r / r_{\\mathrm{vir}}`, where
:math:`r_{\\mathrm{vir}}` has the same units as :math:`r`:
.. math::
\\rho_{\\mathrm{gas}}(r, M, z)
= \\frac{f_b - f_\\star(M)}
{\\left(1 + 10 x_{\\mathrm{vir}}\\right)^{\\beta_M(M, z)}
\\left[1 + \\left(\\frac{x_{\\mathrm{vir}}}{\\theta_{\\mathrm{ej}}}\\right)^\\gamma\\right]^{(\\delta - \\beta_M(M, z))/\\gamma}}
\\tag{1}
where
.. math::
f_\\star(M) = f_{\\star, M_s}
\\left(\\frac{M}{M_s}\\right)^{-\\eta_\\star}
\\tag{2}
.. math::
\\beta_M(M, z) =
\\frac{3 (M / M_c(z))^\\mu}{1 + (M / M_c(z))^\\mu}
\\tag{3}
.. math::
\\log_{10} M_c(z) = \\log_{10} M_c \\, (1+z)^{\\nu_{\\log_{10} M_c}}
\\tag{4}
In the implementation, :math:`M_s = 2.5 \\times 10^{11} \\, M_\\odot / h`
and :math:`f_{\\star, M_s} = 0.055` are fixed constants.
The Fourier-space density profile used by the halo model is evaluated as
.. math::
u_k(k, M, z) =
4 \\pi \\, r_{\\mathrm{vir}}^3 \\, (1+z)^3
\\int dx \\, x^2 \\, \\rho(x, M, z)
\\, \\frac{\\sin\\!\\left[(k r_{\\mathrm{vir}}) x\\right]}
{(k r_{\\mathrm{vir}}) x}
\\tag{5}
where :math:`x = r / r_{\\mathrm{vir}}`, :math:`r_{\\mathrm{vir}}` has the
same units as :math:`r`, and :math:`f_{\\mathrm{free}} = 1`. Any
kSZ-specific weighting is applied by the tracer kernel rather than by this
profile.
Attributes
----------
x : jnp.ndarray
Dimensionless radial grid :math:`x = r / r_{\\mathrm{vir}}` used to tabulate the profile and define the Hankel transform, with :math:`r_{\\mathrm{vir}}` expressed in the same units as :math:`r`.
log10Mc : float
Characteristic mass scale :math:`\\log_{10} M_c` controlling the gas fraction suppression.
theta_ej : float
Ejection-radius parameter :math:`\\theta_{\\mathrm{ej}}` in units of the virial radius.
eta_star : float
Stellar-fraction parameter :math:`\\eta_\\star`.
delta : float
Inner-slope parameter :math:`\\delta` of the gas profile.
gamma : float
Outer-slope parameter :math:`\\gamma` of the gas profile.
mu : float
Transition-shape parameter :math:`\\mu` controlling the stellar component.
nu_log10Mc : float
Redshift exponent :math:`\\nu_{\\log_{10} M_c}` of the characteristic mass scale.
"""
def __init__(self, x=None,
log10Mc=13.25, theta_ej = 4.711, eta_star = 0.2,
delta = 7.0, gamma = 2.5, mu = 1.0, nu_log10Mc = -0.038,
):
# Grid initialization (triggers the x.setter)
self.x = x if x is not None else jnp.logspace(-4, 1, 256)
self.log10Mc, self.theta_ej, self.eta_star = log10Mc, theta_ej, eta_star
self.delta, self.gamma, self.mu, self.nu_log10Mc = delta, gamma, mu, nu_log10Mc
@property
def x(self):
return self._x
@x.setter
def x(self, value):
self._x = value
self._hankel = HankelTransform(self._x, nu=0.5)
def _tree_flatten(self):
# Dynamic calibration parameters
leaves = (
self.log10Mc, self.theta_ej, self.eta_star,
self.delta, self.gamma, self.mu, self.nu_log10Mc
)
# Static metadata
aux_data = (self._x, self._hankel)
return (leaves, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, leaves):
x, hankel = aux_data
obj = cls.__new__(cls)
# Unpack leaves back into attributes
(obj.log10Mc, obj.theta_ej, obj.eta_star,
obj.delta, obj.gamma, obj.mu, obj.nu_log10Mc) = leaves
obj._x = x
obj._hankel = hankel
return obj
def update(self, log10Mc=None, theta_ej=None, eta_star=None,
delta=None, gamma=None, mu=None, nu_log10Mc=None):
"""
Return a new profile instance with updated BCM parameters.
Parameters
----------
log10Mc, theta_ej, eta_star, delta, gamma, mu, nu_log10Mc : float, optional
Replacement values for the corresponding class attributes. Any argument left as ``None`` keeps its current value.
Returns
-------
BCMDensityProfile
New profile instance with updated parameters.
"""
leaves, treedef = self._tree_flatten()
new_leaves = (
log10Mc if log10Mc is not None else self.log10Mc,
theta_ej if theta_ej is not None else self.theta_ej,
eta_star if eta_star is not None else self.eta_star,
delta if delta is not None else self.delta,
gamma if gamma is not None else self.gamma,
mu if mu is not None else self.mu,
nu_log10Mc if nu_log10Mc is not None else self.nu_log10Mc,
)
return self._tree_unflatten(treedef, new_leaves)
@partial(jax.jit, static_argnums=(0,))
def real(self, halo_model, r, m, z):
"""
Compute the gas-density profile.
Parameters
----------
halo_model : HaloModel
Halo model providing the cosmology and virial radius.
r : float or jnp.ndarray
Comoving radius or radii in :math:`\\mathrm{Mpc}`.
m : float or jnp.ndarray
Halo mass or masses in physical :math:`M_\\odot`.
z : float or jnp.ndarray
Redshift(s).
Returns
-------
jnp.ndarray
Gas-density profile with shape :math:`(N_r, N_m, N_z)`, where
singleton dimensions get squeezed before return.
"""
cparams = halo_model.cosmology._cosmo_params()
f_b = cparams["Omega_b"] / cparams["Omega0_m"]
# Broadcasting shapes: (Nx, 1, 1), (1, Nm, 1), (1, 1, Nz)
r, m, z = jnp.atleast_1d(r), jnp.atleast_1d(m), jnp.atleast_1d(z)
m_internal = m * cparams["h"]
rb, mb, zb = r[:, None, None], m_internal[None, :, None], z[None, None, :]
# This model is calibrated for the virial radius
r_vir = jnp.reshape(MassDefinition("vir", "critical").r_delta(halo_model.cosmology, m, z), (len(m), len(z)))
x_vir = rb / ((1.0 + zb) * r_vir[None, :, :])
# Redshift Dependent Mc (Matching your C logic)
mc_z_log = self.log10Mc * (1. + zb)**self.nu_log10Mc
mc = 10.**mc_z_log
# Profile Components
ms = 2.5e11 # M_sun/h, fixed value
fstar_ms = 0.055 # Fixed value
f_star = fstar_ms * (m / ms)**(-self.eta_star)
num = f_b - f_star
# beta_m scaling (Mass dependent slope)
m_ratio_mu = (mb / mc)**self.mu
beta_m = 3. * m_ratio_mu / (1. + m_ratio_mu)
# Denominator 1: Large scale bound gas
denom1 = (1. + 10. * x_vir)**beta_m
# Denominator 2: Ejected gas / transition,
scaled_r = x_vir / self.theta_ej
denom2 = (1. + (scaled_r)**self.gamma)**((self.delta - beta_m) / self.gamma)
return jnp.squeeze(num / (denom1 * denom2))
@partial(jax.jit, static_argnums=(0,))
def fourier(self, halo_model, k, m, z):
"""
Compute the projected Fourier-space gas-density profile for halo-model calculations.
Parameters
----------
halo_model : HaloModel
Halo model providing the cosmology and virial radius.
k : float or jnp.ndarray
Comoving wavenumber(s) in :math:`\\mathrm{Mpc}^{-1}`.
m : float or jnp.ndarray
Halo mass or masses in physical :math:`M_\\odot`.
z : float or jnp.ndarray
Redshift(s).
Returns
-------
jnp.ndarray
Transformed profile with shape :math:`(N_k, N_m, N_z)`, where
singleton dimensions get squeezed before return.
"""
k, m, z = jnp.atleast_1d(k), jnp.atleast_1d(m), jnp.atleast_1d(z)
r_vir = jnp.reshape(
MassDefinition("vir", "critical").r_delta(
halo_model.cosmology,
m,
z,
),
(len(m), len(z)),
)
d_A_z = jnp.atleast_1d(halo_model.cosmology.angular_diameter_distance(z))
ell_vir = d_A_z[None, :] / r_vir
chi = d_A_z * (1 + z)
ell_target = k[:, None] * chi[None, :] - 0.5
prefactor = 4 * jnp.pi * r_vir**3 * (1 + z)[None, :]**3
r = self.x[:, None, None] * r_vir[None, :, :] * (1.0 + z[None, None, :])
k_native, u_k_native = self._u_k_hankel(halo_model, self.x, r, m, z)
u_k_native = jnp.reshape(u_k_native, (len(k_native), len(m), len(z)))
u_ell_native = u_k_native * jnp.sqrt(jnp.pi / (2 * k_native[:, None, None]))
ell_native = k_native[:, None, None] * ell_vir[None, :, :]
u_ell_val = prefactor[None, :, :] * u_ell_native
def interp_single_column(target_x, native_x, native_y):
return jnp.interp(target_x, native_x, native_y)
vmapped_interp = jax.vmap(
jax.vmap(interp_single_column, in_axes=(None, 1, 1), out_axes=1),
in_axes=(1, 2, 2), out_axes=2
)
return jnp.squeeze(vmapped_interp(ell_target, ell_native, u_ell_val))
jax.tree_util.register_pytree_node(
_BCMDensityProfile,
lambda obj: obj._tree_flatten(),
lambda aux_data, children: _BCMDensityProfile._tree_unflatten(aux_data, children)
)