import os
import numpy as np
import jax
import jax.numpy as jnp
from functools import partial
from jax.tree_util import register_pytree_node_class
from hmfast.download import _get_default_data_path
from hmfast.halos.massdef import MassDefinition
from hmfast.halos.profiles import HaloProfile
[docs]
class MatterProfile(HaloProfile):
"""
Parent matter density profile class from which matter profile classes inherit.
Child profile classes must implement :meth:`real` and :meth:`fourier`.
"""
pass
[docs]
class NFWMatterProfile(MatterProfile):
"""
Matter density profile from `Navarro, Frenk & White (1997) <https://ui.adsabs.harvard.edu/abs/1997ApJ...490..493N/abstract>`_.
The real-space mass-weighted matter profile is written as
.. math::
u_r(r, M, z) = \\frac{1}{\\bar{\\rho}_{m,0}} \\,
\\frac{\\rho_s}{(r/r_s) \\left(1+r/r_s\\right)^2}
\\tag{1}
.. math::
\\rho_s = \\frac{M}{4\\pi r_s^3}
\\left[\\ln(1+c) - \\frac{c}{1+c}\\right]^{-1}
\\tag{2}
with :math:`r_s = r_\\Delta / c`.
The Fourier-space mass-weighted matter profile is written as
.. math::
u(k, M, z)
= \\frac{M}{\\bar{\\rho}_{m,0}}
\\left[\\ln(1+c) - \\frac{c}{1+c}\\right]^{-1}
\\Bigg[
\\cos(q) \\left(\\mathrm{Ci}[(1+c)q] - \\mathrm{Ci}(q)\\right)
+ \\sin(q) \\left(\\mathrm{Si}[(1+c)q] - \\mathrm{Si}(q)\\right)
- \\frac{\\sin(cq)}{(1+c)q}
\\Bigg]
\\tag{3}
with :math:`q = k \\, r_s \\, (1+z)`.
"""
def __init__(self):
pass
[docs]
@partial(jax.jit, static_argnums=(0,))
def real(self, halo_model, r, m, z):
"""
Compute the real-space mass-weighted NFW matter profile.
This evaluates Eqs. (1) and (2).
Parameters
----------
halo_model : HaloModel
Halo model providing the cosmology, concentration relation, and halo
radius.
r : float or jnp.ndarray
Radius or radii in :math:`\\mathrm{Mpc}`.
m : float or jnp.ndarray
Halo mass(es) in physical :math:`M_\\odot`.
z : float or jnp.ndarray
Redshift(s).
Returns
-------
jnp.ndarray
Real-space profile with shape :math:`(N_r, N_m, N_z)`, where
singleton dimensions get squeezed before return.
"""
cparams = halo_model.cosmology._cosmo_params()
r = jnp.atleast_1d(r)
m = jnp.atleast_1d(m)
z = jnp.atleast_1d(z)
#m_internal = m * cparams["h"]
rho_mean_0 = cparams["Rho_crit_0"] * cparams["Omega0_m"] #/ cparams["h"]**2
# Normalized real-space profile (unit mass)
u_r_norm = jnp.reshape(self._u_r_nfw(halo_model, r, m, z), (len(r), len(m), len(z)))
# Mass-weighted profile
return jnp.squeeze((m[:, None] / rho_mean_0)[None, :, :] * u_r_norm)
[docs]
@partial(jax.jit, static_argnums=(0,))
def fourier(self, halo_model, k, m, z):
"""
Compute the mass-weighted NFW matter profile in Fourier space.
This evaluates Eq. (3).
Parameters
----------
halo_model : HaloModel
Halo model providing the cosmology, concentration relation, and halo
radius.
k : float or jnp.ndarray
Comoving wavenumber(s) in :math:`\\mathrm{Mpc}^{-1}`.
m : float or jnp.ndarray
Halo mass(es) in physical :math:`M_\\odot`.
z : float or jnp.ndarray
Redshift(s).
Returns
-------
jnp.ndarray
Fourier-space profile with shape :math:`(N_k, N_m, N_z)`, where
singleton dimensions get squeezed before return.
"""
cparams = halo_model.cosmology._cosmo_params()
k, m, z = jnp.atleast_1d(k), jnp.atleast_1d(m), jnp.atleast_1d(z)
# Compute u_m_k from Tracer
_, u_m = self._u_k_nfw(halo_model, k, m, z)
u_m = jnp.reshape(u_m, (len(k), len(m), len(z)))
rho_mean_0 = cparams["Rho_crit_0"] * cparams["Omega0_m"] #/ cparams["h"]**2
m_over_rho_mean = (m / rho_mean_0)[:, None] # shape (N_m, 1)
m_over_rho_mean = jnp.broadcast_to(m_over_rho_mean, u_m.shape)
u_m *= m_over_rho_mean
return jnp.squeeze(u_m)