diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index abdf5069ee08..dcbb673997ad 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -303,6 +303,11 @@ jax.scipy.stats.expon logpdf pdf + logcdf + cdf + logsf + sf + ppf jax.scipy.stats.gamma ~~~~~~~~~~~~~~~~~~~~~ diff --git a/jax/_src/scipy/stats/expon.py b/jax/_src/scipy/stats/expon.py index b09c52e97272..ba80fa6fbcb1 100644 --- a/jax/_src/scipy/stats/expon.py +++ b/jax/_src/scipy/stats/expon.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax import lax + import jax.numpy as jnp +from jax import lax from jax._src.numpy.util import promote_args_inexact from jax._src.typing import Array, ArrayLike @@ -41,7 +42,13 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: array of logpdf values. See Also: + :func:`jax.scipy.stats.expon.cdf` :func:`jax.scipy.stats.expon.pdf` + :func:`jax.scipy.stats.expon.ppf` + :func:`jax.scipy.stats.expon.sf` + :func:`jax.scipy.stats.expon.logcdf` + :func:`jax.scipy.stats.expon.logpdf` + :func:`jax.scipy.stats.expon.logsf` """ x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale) log_scale = lax.log(scale) @@ -73,6 +80,190 @@ def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: array of pdf values. See Also: + :func:`jax.scipy.stats.expon.cdf` + :func:`jax.scipy.stats.expon.pdf` + :func:`jax.scipy.stats.expon.ppf` + :func:`jax.scipy.stats.expon.sf` + :func:`jax.scipy.stats.expon.logcdf` :func:`jax.scipy.stats.expon.logpdf` + :func:`jax.scipy.stats.expon.logsf` """ return lax.exp(logpdf(x, loc, scale)) + + +def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Exponential cumulative density function. + + JAX implementation of :obj:`scipy.stats.expon` ``cdf``. + + The cdf is defined as + + .. math:: + + f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y + + where :math:`f_{pdf}` is the exponential distribution probability density function, + :func:`jax.scipy.stats.expon.pdf`. + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + :func:`jax.scipy.stats.expon.cdf` + :func:`jax.scipy.stats.expon.pdf` + :func:`jax.scipy.stats.expon.ppf` + :func:`jax.scipy.stats.expon.sf` + :func:`jax.scipy.stats.expon.logcdf` + :func:`jax.scipy.stats.expon.logpdf` + :func:`jax.scipy.stats.expon.logsf` + """ + x, loc, scale = promote_args_inexact("expon.cdf", x, loc, scale) + neg_scaled_x = lax.div(lax.sub(loc, x), scale) + return jnp.where( + lax.lt(x, loc), + jnp.zeros_like(neg_scaled_x), + lax.neg(lax.expm1(neg_scaled_x)), + ) + + +def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Exponential log cumulative density function. + + JAX implementation of :obj:`scipy.stats.expon` ``logcdf``. + + The cdf is defined as + + .. math:: + + f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y + + where :math:`f_{pdf}` is the exponential distribution probability density function, + :func:`jax.scipy.stats.expon.pdf`. + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + :func:`jax.scipy.stats.expon.cdf` + :func:`jax.scipy.stats.expon.pdf` + :func:`jax.scipy.stats.expon.ppf` + :func:`jax.scipy.stats.expon.sf` + :func:`jax.scipy.stats.expon.logcdf` + :func:`jax.scipy.stats.expon.logpdf` + :func:`jax.scipy.stats.expon.logsf` + """ + return lax.log1p(lax.neg(sf(x, loc, scale))) + + +def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Exponential log survival function. + + JAX implementation of :obj:`scipy.stats.expon` ``logsf``. + + The survival function is defined as + + .. math:: + + f_{sf}(x) = 1 - f_{cdf}(x) + + where :math:`f_{cdf}(x)` is the exponential cumulative distribution function, + :func:`jax.scipy.stats.expon.cdf`. + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + :func:`jax.scipy.stats.expon.cdf` + :func:`jax.scipy.stats.expon.pdf` + :func:`jax.scipy.stats.expon.ppf` + :func:`jax.scipy.stats.expon.sf` + :func:`jax.scipy.stats.expon.logcdf` + :func:`jax.scipy.stats.expon.logpdf` + :func:`jax.scipy.stats.expon.logsf` + """ + x, loc, scale = promote_args_inexact("expon.sf", x, loc, scale) + neg_scaled_x = lax.div(lax.sub(loc, x), scale) + return jnp.where(lax.lt(x, loc), jnp.zeros_like(neg_scaled_x), neg_scaled_x) + + +def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Exponential survival function. + + JAX implementation of :obj:`scipy.stats.expon` ``sf``. + + The survival function is defined as + + .. math:: + + f_{sf}(x) = 1 - f_{cdf}(x) + + where :math:`f_{cdf}(x)` is the exponential cumulative distribution function, + :func:`jax.scipy.stats.expon.cdf`. + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + :func:`jax.scipy.stats.expon.cdf` + :func:`jax.scipy.stats.expon.pdf` + :func:`jax.scipy.stats.expon.ppf` + :func:`jax.scipy.stats.expon.sf` + :func:`jax.scipy.stats.expon.logcdf` + :func:`jax.scipy.stats.expon.logpdf` + :func:`jax.scipy.stats.expon.logsf` + """ + return lax.exp(logsf(x, loc, scale)) + + +def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array: + r"""Exponential survival function. + + JAX implementation of :obj:`scipy.stats.expon` ``ppf``. + + The percent point function is defined as the inverse of the + cumulative distribution function, :func:`jax.scipy.stats.expon.cdf`. + + Args: + x: arraylike, value at which to evaluate the PDF + loc: arraylike, distribution offset parameter + scale: arraylike, distribution scale parameter + + Returns: + array of pdf values. + + See Also: + :func:`jax.scipy.stats.expon.cdf` + :func:`jax.scipy.stats.expon.pdf` + :func:`jax.scipy.stats.expon.ppf` + :func:`jax.scipy.stats.expon.sf` + :func:`jax.scipy.stats.expon.logcdf` + :func:`jax.scipy.stats.expon.logpdf` + :func:`jax.scipy.stats.expon.logsf` + """ + q, loc, scale = promote_args_inexact("expon.ppf", q, loc, scale) + neg_scaled_q = lax.div(lax.sub(loc, q), scale) + return jnp.where( + jnp.isnan(q) | (q < 0) | (q > 1), + jnp.nan, + lax.neg(lax.log1p(neg_scaled_q)), + ) diff --git a/jax/scipy/stats/expon.py b/jax/scipy/stats/expon.py index 8f5c0a0680ce..40a1d081b900 100644 --- a/jax/scipy/stats/expon.py +++ b/jax/scipy/stats/expon.py @@ -16,6 +16,11 @@ # See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.expon import ( + cdf as cdf, + logcdf as logcdf, logpdf as logpdf, + logsf as logsf, pdf as pdf, + ppf as ppf, + sf as sf, ) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index 88a126c284a7..796d4490daea 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -523,6 +523,86 @@ def args_maker(): tol=1e-4) self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(3) + def testExponLogCdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.expon.logcdf + lax_fun = lsp_stats.expon.logcdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy( + scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4 + ) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testExponCdf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.expon.cdf + lax_fun = lsp_stats.expon.cdf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy( + scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4 + ) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testExponSf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.expon.sf + lax_fun = lsp_stats.expon.sf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy( + scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4 + ) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testExponLogSf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.expon.logsf + lax_fun = lsp_stats.expon.logsf + + def args_maker(): + x, loc, scale = map(rng, shapes, dtypes) + return [x, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy( + scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4 + ) + self._CompileAndCheck(lax_fun, args_maker) + + @genNamedParametersNArgs(3) + def testExponPpf(self, shapes, dtypes): + rng = jtu.rand_positive(self.rng()) + scipy_fun = osp_stats.expon.ppf + lax_fun = lsp_stats.expon.ppf + + def args_maker(): + q, loc, scale = map(rng, shapes, dtypes) + return [q, loc, scale] + + with jtu.strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy( + scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4 + ) + self._CompileAndCheck(lax_fun, args_maker) + @genNamedParametersNArgs(4) def testGammaLogPdf(self, shapes, dtypes): rng = jtu.rand_positive(self.rng())