diff --git a/spec/draft/API_specification/statistical_functions.rst b/spec/draft/API_specification/statistical_functions.rst index 20e02b3f9..eb5e1a5d6 100644 --- a/spec/draft/API_specification/statistical_functions.rst +++ b/spec/draft/API_specification/statistical_functions.rst @@ -18,6 +18,7 @@ Objects in API :toctree: generated :template: method.rst + cumulative_prod cumulative_sum max mean diff --git a/src/array_api_stubs/_draft/statistical_functions.py b/src/array_api_stubs/_draft/statistical_functions.py index 4cb6de0a8..347a97fff 100644 --- a/src/array_api_stubs/_draft/statistical_functions.py +++ b/src/array_api_stubs/_draft/statistical_functions.py @@ -1,9 +1,71 @@ -__all__ = ["cumulative_sum", "max", "mean", "min", "prod", "std", "sum", "var"] +__all__ = [ + "cumulative_sum", + "cumulative_prod", + "max", + "mean", + "min", + "prod", + "std", + "sum", + "var", +] from ._types import Optional, Tuple, Union, array, dtype +def cumulative_prod( + x: array, + /, + *, + axis: Optional[int] = None, + dtype: Optional[dtype] = None, + include_initial: bool = False, +) -> array: + """ + Calculates the cumulative product of elements in the input array ``x``. + + Parameters + ---------- + x: array + input array. Should have one or more dimensions (axes). Should have a numeric data type. + axis: Optional[int] + axis along which a cumulative product must be computed. If ``axis`` is negative, the function must determine the axis along which to compute a cumulative product by counting from the last dimension. + + If ``x`` is a one-dimensional array, providing an ``axis`` is optional; however, if ``x`` has more than one dimension, providing an ``axis`` is required. + + dtype: Optional[dtype] + data type of the returned array. If ``None``, the returned array must have the same data type as ``x``, unless ``x`` has an integer data type supporting a smaller range of values than the default integer data type (e.g., ``x`` has an ``int16`` or ``uint32`` data type and the default integer data type is ``int64``). In those latter cases: + + - if ``x`` has a signed integer data type (e.g., ``int16``), the returned array must have the default integer data type. + - if ``x`` has an unsigned integer data type (e.g., ``uint16``), the returned array must have an unsigned integer data type having the same number of bits as the default integer data type (e.g., if the default integer data type is ``int32``, the returned array must have a ``uint32`` data type). + + If the data type (either specified or resolved) differs from the data type of ``x``, the input array should be cast to the specified data type before computing the product (rationale: the ``dtype`` keyword argument is intended to help prevent overflows). Default: ``None``. + + include_initial: bool + boolean indicating whether to include the initial value as the first value in the output. By convention, the initial value must be the multiplicative identity (i.e., one). Default: ``False``. + + Returns + ------- + out: array + an array containing the cumulative products. The returned array must have a data type as described by the ``dtype`` parameter above. + + Let ``N`` be the size of the axis along which to compute the cumulative product. The returned array must have a shape determined according to the following rules: + + - if ``include_initial`` is ``True``, the returned array must have the same shape as ``x``, except the size of the axis along which to compute the cumulative product must be ``N+1``. + - if ``include_initial`` is ``False``, the returned array must have the same shape as ``x``. + + Notes + ----- + + - When ``x`` is a zero-dimensional array, behavior is unspecified and thus implementation-defined. + + **Special Cases** + + For both real-valued and complex floating-point operands, special cases must be handled as if the operation is implemented by successive application of :func:`~array_api.multiply`. + """ + + def cumulative_sum( x: array, /,