Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(gh-13291): Add exponential distribution functions: cdf, logcdf, sf, logsf, and ppf #26259

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions docs/jax.scipy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ jax.scipy.stats.expon

logpdf
pdf
logcdf
cdf
logsf
sf
ppf

jax.scipy.stats.gamma
~~~~~~~~~~~~~~~~~~~~~
Expand Down
199 changes: 198 additions & 1 deletion jax/_src/scipy/stats/expon.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from jax import lax

import jax.numpy as jnp
from jax import lax
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike

Expand Down Expand Up @@ -41,7 +42,13 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
array of logpdf values.

See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale)
log_scale = lax.log(scale)
Expand Down Expand Up @@ -73,6 +80,196 @@ def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
array of pdf values.

See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.exp(logpdf(x, loc, scale))


def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential cumulative density function.

JAX implementation of :obj:`scipy.stats.expon` ``cdf``.

The cdf is defined as

.. math::

f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y

where :math:`f_{pdf}` is the exponential distribution probability density function,
:func:`jax.scipy.stats.expon.pdf`.

Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter

Returns:
array of pdf values.

See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.cdf", x, loc, scale)
neg_scaled_x = lax.div(lax.sub(loc, x), scale)
return jnp.where(
lax.lt(x, loc),
jnp.zeros_like(neg_scaled_x),
lax.neg(lax.expm1(neg_scaled_x)),
)


def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential log cumulative density function.

JAX implementation of :obj:`scipy.stats.expon` ``logcdf``.

The cdf is defined as

.. math::

f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y

where :math:`f_{pdf}` is the exponential distribution probability density function,
:func:`jax.scipy.stats.expon.pdf`.

Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter

Returns:
array of pdf values.

See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.logcdf", x, loc, scale)
neg_scaled_x = lax.div(lax.sub(loc, x), scale)
return jnp.where(
lax.lt(x, loc),
jnp.full_like(neg_scaled_x, -jnp.inf),
lax.log1p(lax.neg(lax.exp(neg_scaled_x))),
)
jakevdp marked this conversation as resolved.
Show resolved Hide resolved


def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential log survival function.

JAX implementation of :obj:`scipy.stats.expon` ``logsf``.

The survival function is defined as

.. math::

f_{sf}(x) = 1 - f_{cdf}(x)

where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
:func:`jax.scipy.stats.expon.cdf`.

Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter

Returns:
array of pdf values.

See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.sf", x, loc, scale)
neg_scaled_x = lax.div(lax.sub(loc, x), scale)
return jnp.where(lax.lt(x, loc), jnp.zeros_like(neg_scaled_x), neg_scaled_x)


def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential survival function.

JAX implementation of :obj:`scipy.stats.expon` ``sf``.

The survival function is defined as

.. math::

f_{sf}(x) = 1 - f_{cdf}(x)

where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
:func:`jax.scipy.stats.expon.cdf`.

Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter

Returns:
array of pdf values.

See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.exp(logsf(x, loc, scale))


def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential survival function.

JAX implementation of :obj:`scipy.stats.expon` ``ppf``.

The percent point function is defined as the inverse of the
cumulative distribution function, :func:`jax.scipy.stats.expon.cdf`.

Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter

Returns:
array of pdf values.

See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
q, loc, scale = promote_args_inexact("expon.ppf", q, loc, scale)
neg_scaled_q = lax.div(lax.sub(loc, q), scale)
return jnp.where(
jnp.isnan(q) | (q < 0) | (q > 1),
jnp.nan,
lax.neg(lax.log1p(neg_scaled_q)),
)
5 changes: 5 additions & 0 deletions jax/scipy/stats/expon.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570

from jax._src.scipy.stats.expon import (
cdf as cdf,
logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
ppf as ppf,
sf as sf,
)
80 changes: 80 additions & 0 deletions tests/scipy_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,86 @@ def args_maker():
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(3)
def testExponLogCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.logcdf
lax_fun = lsp_stats.expon.logcdf

def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(3)
def testExponCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.cdf
lax_fun = lsp_stats.expon.cdf

def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(3)
def testExponSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.sf
lax_fun = lsp_stats.expon.sf

def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(3)
def testExponLogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.logsf
lax_fun = lsp_stats.expon.logsf

def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(3)
def testExponPpf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.ppf
lax_fun = lsp_stats.expon.ppf

def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
return [q, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(4)
def testGammaLogPdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
Expand Down