import jax
import jax.numpy as jnp
import jax.scipy as jscipy
from functools import partial
from hmfast.utils import newton_root
[docs]
class MassDefinition:
"""
Mass definition for halos specified by an overdensity threshold and a reference density.
For example, :math:`M_{200c}` corresponds to
``MassDefinition(delta=200, reference="critical")``, while
:math:`M_{200m}` corresponds to
``MassDefinition(delta=200, reference="mean")``. The special value
``delta='vir'`` denotes the redshift-dependent virial overdensity and can
only be used with ``reference='critical'``.
Attributes
----------
delta : int, float, or str
Overdensity threshold used to define the halo boundary. This can be a numeric value such as ``200`` or ``500``, or the string ``'vir'`` for the redshift-dependent virial overdensity. The value ``'vir'`` is only valid with ``reference='critical'``.
reference : str
Reference density associated with ``delta``, either ``'critical'`` or ``'mean'``.
Raises
------
ValueError
If an invalid combination of `delta` and `reference` is provided, or if either
parameter is set to an unsupported value.
"""
def __init__(self, delta=200, reference="critical"):
self._delta = None
self._reference = None
self.reference = reference
self.delta = delta
# Ensure that reference is only ever critical or mean
@property
def reference(self):
return self._reference
@reference.setter
def reference(self, value):
value = str(value).lower()
if value not in ("critical", "mean"):
raise ValueError("reference must be either 'critical' or 'mean'")
# Prevent changing reference if delta == "vir"
if getattr(self, "_delta", None) == "vir" and value != "critical":
raise ValueError("'vir' is only allowed with 'critical' reference")
self._reference = value
@property
def delta(self):
return self._delta
@delta.setter
def delta(self, value):
if isinstance(value, str):
value = value.lower()
# If 'vir', reference must be 'critical'
if value == "vir":
if getattr(self, "_reference", None) != "critical":
raise ValueError("'vir' is only allowed with 'critical' reference")
self._delta = value
return
# Otherwise, it must be numeric
if isinstance(value, (int, float)):
self._delta = value
return
raise ValueError("delta must be numeric or 'vir'")
def _tree_flatten(self):
# delta can be a tracer (numeric) or a static string ('vir')
# reference is always a static string for critical/mean
children = ()
aux_data = (self.delta, self.reference)
return (children, aux_data)
@classmethod
def _tree_unflatten(cls, aux_data, children):
return cls(*aux_data)
def _delta_vir_to_crit(self, cosmology, z):
"""
Compute the virial overdensity with respect to the critical density.
"""
omega_m = cosmology.omega_m(z)
x = omega_m - 1.0
return 18.0 * jnp.pi**2 + 82.0 * x - 39.0 * x**2
def _delta_numeric(self, cosmology, z):
"""
Return the numeric overdensity threshold at redshift ``z``.
"""
if self.delta == "vir":
if self.reference != "critical":
raise ValueError("virial overdensity only defined w.r.t. critical density")
return self._delta_vir_to_crit(cosmology, z)
return self.delta
def _convert_reference(self, cosmology, z, delta, from_ref="critical", to_ref="mean"):
"""
Convert an overdensity threshold between critical and mean references.
"""
z = jnp.asarray(z)
delta = jnp.asarray(delta)
if from_ref == to_ref:
return jnp.broadcast_to(delta, z.shape)
omega_m = cosmology.omega_m(z)
if from_ref == "critical" and to_ref == "mean":
return delta / omega_m
if from_ref == "mean" and to_ref == "critical":
return delta * omega_m
raise ValueError("from_ref and to_ref must be 'critical' or 'mean'")
[docs]
@partial(jax.jit, static_argnums=(0,))
def r_delta(self, cosmology, m, z):
"""
Compute the halo radius :math:`r_\\Delta` associated with a halo mass.
.. math::
r_\\Delta = \\left[\\frac{3M}{4\\pi \\Delta \\rho_{\\mathrm{ref}}(z)}\\right]^{1/3}
Parameters
----------
cosmology : Cosmology
Cosmology object used to evaluate the reference density.
m : float or array-like
Halo mass enclosed within the overdensity radius, in
:math:`M_\\odot`.
z : float or array-like
Redshift at which to compute the radius.
Returns
-------
float or array-like
Radius :math:`r_\\Delta` within which the mean enclosed density is
:math:`\\Delta \\rho_{\\mathrm{ref}}(z)`, in physical :math:`\\mathrm{Mpc}`.
With shape :math:`(N_m, N_z)`, where singleton dimensions get
squeezed before return.
"""
delta, reference = self.delta, self.reference
m = jnp.asarray(m)
z = jnp.asarray(z)
if m.ndim <= 1 and z.ndim <= 1:
m, z = jnp.atleast_1d(m)[:, None], jnp.atleast_1d(z)[None, :]
rho_ref = cosmology.critical_density(z)
if delta == "vir":
delta = self._delta_vir_to_crit(cosmology, z)
if reference == "mean":
rho_ref *= cosmology.omega_m(z)
return jnp.squeeze((3.0 * m / (4.0 * jnp.pi * delta * rho_ref)) ** (1.0 / 3.0))
jax.tree_util.register_pytree_node(
MassDefinition,
lambda obj: obj._tree_flatten(),
lambda aux_data, children: MassDefinition._tree_unflatten(aux_data, children)
)
@partial(jax.jit, static_argnums=(3, 4, 6))
def _solve_m_delta_nfw(cosmology, m, z, mass_def_old, mass_def_new, c_old, max_iter=20):
m, z = jnp.atleast_1d(m), jnp.atleast_1d(z)
nm, nz = len(m), len(z)
c_old = jnp.broadcast_to(jnp.reshape(jnp.squeeze(jnp.asarray(c_old)), (-1, nz)), (nm, nz))
d_old = mass_def_old._delta_numeric(cosmology, z)
d_new = mass_def_new._delta_numeric(cosmology, z)
d_old_z = mass_def_old._convert_reference(cosmology, z, d_old, from_ref=mass_def_old.reference, to_ref="critical")
d_new_z = mass_def_new._convert_reference(cosmology, z, d_new, from_ref=mass_def_new.reference, to_ref="critical")
is_same_z = jnp.isclose(d_old_z, d_new_z) & (mass_def_old.reference == mass_def_new.reference)
mm, _ = jnp.meshgrid(m, z, indexing='ij')
x0 = m[:, None] * (d_old_z / d_new_z)[None, :] ** 0.2
def solve_single(m_i, c_i, x0_i, d_o, d_n, same_flag):
f_nfw = lambda x: jnp.log1p(x) - x / (1.0 + x)
obj = lambda m_new: m_i / m_new - f_nfw(c_i) / f_nfw(c_i * (m_new / m_i * d_o / d_n) ** (1 / 3))
return jax.lax.cond(
same_flag,
lambda _: m_i,
lambda _: newton_root(obj, x0=x0_i, max_iter=max_iter),
None,
)
d_o_flat = jnp.broadcast_to(d_old_z[None, :], mm.shape).flatten()
d_n_flat = jnp.broadcast_to(d_new_z[None, :], mm.shape).flatten()
same_flat = jnp.broadcast_to(is_same_z[None, :], mm.shape).flatten()
results = jax.vmap(solve_single)(mm.flatten(), c_old.flatten(), x0.flatten(), d_o_flat, d_n_flat, same_flat)
return jnp.squeeze(results.reshape(mm.shape))
@partial(jax.jit, static_argnums=(3, 4, 6))
def _convert_m_delta(cosmology, m, z, mass_def_old, mass_def_new, c_old, max_iter=20):
"""
Convert halo masses between two spherical-overdensity definitions.
The conversion assumes an NFW profile and requires the input concentration
:math:`c_{\\Delta}` for the original mass definition.
Parameters
----------
cosmology : Cosmology
Cosmology object used to evaluate :math:`\\Omega_m(z)` for reference-density
conversions and virial overdensities.
m : array-like
Halo mass in the original definition, :math:`M_{\\Delta}`, in
physical :math:`M_{\\odot}`.
z : array-like
Redshift(s).
mass_def_old : MassDefinition
Original mass definition specifying :math:`\\Delta` and its reference density.
mass_def_new : MassDefinition
Target mass definition specifying :math:`\\Delta'` and its reference density.
c_old : array-like
Halo concentration :math:`c_{\\Delta}` in the original definition,
evaluated for the input halo masses in physical :math:`M_{\\odot}`.
max_iter : int, optional
Maximum number of root-finder iterations.
Returns
-------
float or array-like
Halo mass in the target definition, :math:`M_{\\Delta'}`, in physical
:math:`M_{\\odot}`, with shape :math:`(N_m, N_z)`, where singleton
dimensions get squeezed before return.
"""
m, z = jnp.atleast_1d(m), jnp.atleast_1d(z)
nm, nz = len(m), len(z)
c_old = jnp.broadcast_to(jnp.reshape(jnp.squeeze(jnp.asarray(c_old)), (-1, nz)), (nm, nz))
def get_delta_crit(mdef, z_val):
delta = mdef._delta_numeric(cosmology, z_val)
return mdef._convert_reference(cosmology, z_val, delta, from_ref=mdef.reference, to_ref="critical")
d_old_z = get_delta_crit(mass_def_old, z)
d_new_z = get_delta_crit(mass_def_new, z)
is_same_z = jnp.isclose(d_old_z, d_new_z) & (mass_def_old.reference == mass_def_new.reference)
mm, zz = jnp.meshgrid(m, z, indexing='ij')
x0 = m[:, None] * (d_old_z / d_new_z)[None, :] ** 0.2
def solve_single(m_i, c_i, x0_i, d_o, d_n, same_flag):
f_nfw = lambda x: jnp.log1p(x) - x / (1.0 + x)
obj = lambda m_new: m_i / m_new - f_nfw(c_i) / f_nfw(c_i * (m_new / m_i * d_o / d_n) ** (1 / 3))
return jax.lax.cond(
same_flag,
lambda _: m_i,
lambda _: newton_root(obj, x0=x0_i, max_iter=max_iter),
None,
)
d_o_flat = jnp.broadcast_to(d_old_z[None, :], mm.shape).flatten()
d_n_flat = jnp.broadcast_to(d_new_z[None, :], mm.shape).flatten()
same_flat = jnp.broadcast_to(is_same_z[None, :], mm.shape).flatten()
results = jax.vmap(solve_single)(mm.flatten(), c_old.flatten(), x0.flatten(), d_o_flat, d_n_flat, same_flat)
return jnp.squeeze(results.reshape(mm.shape))
[docs]
def mass_translator(mass_def_old, mass_def_new, concentration, max_iter=20):
"""
Build a mass-conversion callable for fixed source and target definitions.
Conversions between overdensity thresholds :math:`\\Delta` are performed
by solving
.. math::
\\frac{M_{\\Delta}}{M_{\\Delta'}} =
\\frac{f(c_{\\Delta})}
{f\\!\\left[c_{\\Delta}
\\frac{r_{\\Delta'}}{r_{\\Delta}}\\right]},
where
.. math::
f(x) = \\ln(1+x) - \\frac{x}{1+x}.
Reference-only conversions use
.. math::
\\Delta_{\\mathrm{m}}(z) = \\frac{\\Delta_{\\mathrm{c}}(z)}{\\Omega_m(z)},
while virial overdensities are computed from
.. math::
\\Delta_{\\mathrm{vir,c}}(z) = 18\\pi^2 + 82x - 39x^2,
\\qquad x = \\Omega_m(z) - 1.
Parameters
----------
mass_def_old : MassDefinition
Source mass definition.
mass_def_new : MassDefinition
Target mass definition.
concentration : Concentration
Concentration relation calibrated for the source mass definition.
When the returned callable is evaluated, this object is used to
compute :math:`c_{\\Delta}` internally for ``mass_def_old``.
max_iter : int, optional
Maximum number of root-finder iterations.
Returns
-------
callable
Function ``f(cosmology, m, z)`` that converts masses from
``mass_def_old`` to ``mass_def_new``.
"""
if (mass_def_old.delta == mass_def_new.delta) and (mass_def_old.reference == mass_def_new.reference):
@jax.jit
def f(cosmology, m, z):
m, z = jnp.atleast_1d(m), jnp.atleast_1d(z)
return jnp.squeeze(jnp.broadcast_to(m[:, None], (len(m), len(z))))
return f
@jax.jit
def f(cosmology, m, z):
c_old = concentration.c_delta(cosmology, m, z, mass_definition=mass_def_old)
return _solve_m_delta_nfw(cosmology, m, z, mass_def_old=mass_def_old, mass_def_new=mass_def_new, c_old=c_old, max_iter=max_iter)
return f
__all__ = ["MassDefinition", "mass_translator", "_convert_m_delta", "_solve_m_delta_nfw"]