Getting started#

Welcome to xarray-einstats!#

xarray-einstats is an open source Python library part of the ArviZ project. It acts as a bridge between the xarray library for labelled arrays and libraries for raw arrays such as NumPy or SciPy.

Xarray has as “Compatibility with the broader ecosystem” as one of its main goals. Which is what allows xarray-einstats to perform this bridge role with minimal code and duplication.

Overview#

xarray-einstats provides wrappers for:

These wrappers have the same names and functionality as the original functions. The difference in behaviour is that the wrappers will not make assumptions about the meaning of a dimension based on its position nor they have arguments like axis or axes. They will have dims argument that take dimension names instead of integers indicating the positions of the dimensions on which to act.

It also provides a handful of re-implemented functions:

These are partially reimplemented because the original function doesn’t yet support multidimensional and/or batched computations. They also share the name with a function in NumPy or SciPy, but they only implement a subset of the features. Moreover, the goal is for those to eventually be wrappers too.

Using xarray-einstats#

DataArray inputs#

Functions in xarray-einstats are designed to work on DataArray objects.

Let’s load some example data:

from xarray_einstats import linalg, stats, tutorial

da = tutorial.generate_matrices_dataarray(4)
da
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)>
3.799 0.4308 3.24 0.1412 0.9402 0.7951 ... 0.6156 1.124 0.8559 2.108 0.7637
Dimensions without coordinates: batch, experiment, dim, dim2

and show an example:

stats.skew(da, dims=["batch", "dim2"])
<xarray.DataArray (experiment: 3, dim: 4)>
1.256 1.432 0.9728 1.762 1.612 1.188 1.033 2.388 2.196 1.455 1.631 1.373
Dimensions without coordinates: experiment, dim

xarray-einstats uses dims as argument throughout the codebase as an alternative to both axis or axes indistinctively, also as alternative to the (..., M, M) convention used by NumPy.

The use of dims follows dot, instead of the singular dim argument used for example in mean. Both a single dimension or multiple are valid inputs, and using dims emphasizes the fact that operations and reductions can be performed over multiple dimensions at the same time. Moreover, in linear algebra functions, dims is often restricted to a 2 element list as it indicates which dimensions define the matrices, interpreting all the others as batch dimensions.

That means that the two calls below are equivalent, even if the dimension names of the inputs are not, because their dimension names are the same. Thus,

linalg.det(da, dims=["dim", "dim2"])
<xarray.DataArray (batch: 10, experiment: 3)>
23.55 2.033 0.3923 -7.374 0.06645 ... 1.804 -0.1599 8.875 -0.04935 -8.428
Dimensions without coordinates: batch, experiment

returns the same as:

linalg.det(da.transpose("dim2", "experiment", "dim", "batch"), dims=["dim", "dim2"])
<xarray.DataArray (experiment: 3, batch: 10)>
23.55 -7.374 -5.617 -12.29 1.77 -0.6289 ... -11.07 -0.5096 -28.77 -0.1599 -8.428
Dimensions without coordinates: experiment, batch

Important

In xarray_einstats only the dimension names matter, not their order.

Dataset and GroupBy inputs#

While the DataArray is the base xarray object, there are also other xarray objects that are key while using the library. These other objects such as Dataset are implemented as a collection of DataArray objects, and all include a .map method in order to apply the same function to all its child DataArrays.

ds = tutorial.generate_mcmc_like_dataset(9438)
ds
<xarray.Dataset>
Dimensions:  (plot_dim: 20, chain: 4, draw: 10, team: 6, match: 12)
Coordinates:
  * team     (team) <U1 'a' 'b' 'c' 'd' 'e' 'f'
  * chain    (chain) int64 0 1 2 3
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 9
Dimensions without coordinates: plot_dim, match
Data variables:
    x_plot   (plot_dim) float64 0.0 0.5263 1.053 1.579 ... 8.947 9.474 10.0
    mu       (chain, draw, team) float64 0.2691 0.1617 0.4371 ... 0.4673 1.844
    sigma    (chain, draw) float64 1.939 1.435 0.5109 ... 0.594 1.54 1.257
    score    (chain, draw, match) int64 0 2 3 0 0 0 0 0 2 ... 1 0 1 1 1 2 4 0 2

We can use map to apply the same function to all the 4 child DataArrays in ds, but this will not always be possible. When using .map, the function provided is applied to all child DataArrays with the same **kwargs.

If we try doing:

ds.map(stats.circmean, dims=("chain", "draw"))
Hide code cell output
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[6], line 1
----> 1 ds.map(stats.circmean, dims=("chain", "draw"))

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataset.py:5949, in Dataset.map(self, func, keep_attrs, args, **kwargs)
   5947 if keep_attrs is None:
   5948     keep_attrs = _get_keep_attrs(default=False)
-> 5949 variables = {
   5950     k: maybe_wrap_array(v, func(v, *args, **kwargs))
   5951     for k, v in self.data_vars.items()
   5952 }
   5953 if keep_attrs:
   5954     for k, v in variables.items():

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataset.py:5950, in <dictcomp>(.0)
   5947 if keep_attrs is None:
   5948     keep_attrs = _get_keep_attrs(default=False)
   5949 variables = {
-> 5950     k: maybe_wrap_array(v, func(v, *args, **kwargs))
   5951     for k, v in self.data_vars.items()
   5952 }
   5953 if keep_attrs:
   5954     for k, v in variables.items():

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray_einstats/stats.py:489, in circmean(da, dims, high, low, nan_policy, **kwargs)
    487 if nan_policy is not None:
    488     circmean_kwargs["nan_policy"] = nan_policy
--> 489 return _apply_reduce_func(stats.circmean, da, dims, kwargs, circmean_kwargs)

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray_einstats/stats.py:432, in _apply_reduce_func(func, da, dims, kwargs, func_kwargs)
    430 if not isinstance(dims, str):
    431     aux_dim = f"__aux_dim__:{','.join(dims)}"
--> 432     da = da.stack({aux_dim: dims})
    433     core_dims = [aux_dim]
    434 else:

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataarray.py:2739, in DataArray.stack(self, dimensions, create_index, index_cls, **dimensions_kwargs)
   2674 def stack(
   2675     self: T_DataArray,
   2676     dimensions: Mapping[Any, Sequence[Hashable]] | None = None,
   (...)
   2679     **dimensions_kwargs: Sequence[Hashable],
   2680 ) -> T_DataArray:
   2681     """
   2682     Stack any number of existing dimensions into a single new dimension.
   2683 
   (...)
   2737     DataArray.unstack
   2738     """
-> 2739     ds = self._to_temp_dataset().stack(
   2740         dimensions,
   2741         create_index=create_index,
   2742         index_cls=index_cls,
   2743         **dimensions_kwargs,
   2744     )
   2745     return self._from_temp_dataset(ds)

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataset.py:4593, in Dataset.stack(self, dimensions, create_index, index_cls, **dimensions_kwargs)
   4591 result = self
   4592 for new_dim, dims in dimensions.items():
-> 4593     result = result._stack_once(dims, new_dim, index_cls, create_index)
   4594 return result

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataset.py:4524, in Dataset._stack_once(self, dims, new_dim, index_cls, create_index)
   4522 product_vars: dict[Any, Variable] = {}
   4523 for dim in dims:
-> 4524     idx, idx_vars = self._get_stack_index(dim, create_index=create_index)
   4525     if idx is not None:
   4526         product_vars.update(idx_vars)

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataset.py:4480, in Dataset._get_stack_index(self, dim, multi, create_index)
   4478     var = self._variables[dim]
   4479 else:
-> 4480     _, _, var = _get_virtual_variable(self._variables, dim, self.dims)
   4481 # dummy index (only `stack_coords` will be used to construct the multi-index)
   4482 stack_index = PandasIndex([0], dim)

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataset.py:178, in _get_virtual_variable(variables, key, dim_sizes)
    176 split_key = key.split(".", 1)
    177 if len(split_key) != 2:
--> 178     raise KeyError(key)
    180 ref_name, var_name = split_key
    181 ref_var = variables[ref_name]

KeyError: 'chain'

we get an exception. The chain and draw dimensions are not present in all child DataArrays. Instead, we could apply it only to the variables that have both chain and dim dimensions.

ds_samples = ds[["mu", "sigma", "score"]]
ds_samples.map(stats.circmean, dims=("chain", "draw"))
<xarray.Dataset>
Dimensions:  (team: 6, match: 12)
Coordinates:
  * team     (team) <U1 'a' 'b' 'c' 'd' 'e' 'f'
Dimensions without coordinates: match
Data variables:
    mu       (team) float64 0.8221 0.7376 0.6195 0.7485 0.7439 0.7818
    sigma    float64 0.8134
    score    (match) float64 0.7441 0.3923 0.9316 0.6107 ... 0.5814 0.9538 0.94

Attention

In general, you should prefer using .map attribute over using non-DataArray objects as input to the xarray_einstats directly. .map will ensure no unexpected broadcasting between the multiple child DataArrays takes place. See the examples below for some examples.

However, if you are using functions that reduce dimensions on non-DataArray inputs whose child DataArrays all have all the dimensions to reduce you will not trigger any such broadcasting, and we have included that behaviour on our test suite to ensure it stays this way.

It is also possible to do

stats.circmean(ds_samples, dims=("chain", "draw"))
<xarray.Dataset>
Dimensions:  (team: 6, match: 12)
Coordinates:
  * team     (team) <U1 'a' 'b' 'c' 'd' 'e' 'f'
Dimensions without coordinates: match
Data variables:
    mu       (team) float64 0.8221 0.7376 0.6195 0.7485 0.7439 0.7818
    sigma    float64 0.8134
    score    (match) float64 0.7441 0.3923 0.9316 0.6107 ... 0.5814 0.9538 0.94

Here, all child DataArrays have both chain and draw dimension, so as expected, the result is the same. There are some cases however, in which not using .map triggers some broadcasting operations which will generally not be the desired output.

If we use the .map attribute, the function is applied to each child DataArray independently from the others:

ds.map(stats.rankdata)
<xarray.Dataset>
Dimensions:   (plot_dim: 20, chain: 4, draw: 10, team: 6, match: 12)
Coordinates:
  * plot_dim  (plot_dim) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
  * chain     (chain) int64 0 1 2 3
  * draw      (draw) int64 0 1 2 3 4 5 6 7 8 9
  * team      (team) <U1 'a' 'b' 'c' 'd' 'e' 'f'
  * match     (match) int64 0 1 2 3 4 5 6 7 8 9 10 11
Data variables:
    x_plot    (plot_dim) float64 1.0 2.0 3.0 4.0 5.0 ... 17.0 18.0 19.0 20.0
    mu        (chain, draw, team) float64 65.0 41.0 89.0 ... 55.0 97.0 205.0
    sigma     (chain, draw) float64 33.0 30.0 15.0 5.0 ... 4.0 18.0 31.0 29.0
    score     (chain, draw, match) float64 105.0 401.0 457.0 ... 105.0 401.0

whereas without using the .map attribute, extra broadcasting can happen:

stats.rankdata(ds)
<xarray.Dataset>
Dimensions:   (plot_dim: 20, chain: 4, draw: 10, team: 6, match: 12)
Coordinates:
  * plot_dim  (plot_dim) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
  * chain     (chain) int64 0 1 2 3
  * draw      (draw) int64 0 1 2 3 4 5 6 7 8 9
  * team      (team) <U1 'a' 'b' 'c' 'd' 'e' 'f'
  * match     (match) int64 0 1 2 3 4 5 6 7 8 9 10 11
Data variables:
    x_plot    (plot_dim, chain, draw, team, match) float64 1.44e+03 ... 5.616...
    mu        (plot_dim, chain, draw, team, match) float64 1.548e+04 ... 4.90...
    sigma     (plot_dim, chain, draw, team, match) float64 4.68e+04 ... 4.104...
    score     (plot_dim, chain, draw, team, match) float64 1.254e+04 ... 4.80...

The behaviour on DataArrayGroupBy for example is very similar to the examples we have shown for Datasets:

da = ds["mu"].assign_coords(team=["a", "b", "b", "a", "c", "b"])
da
<xarray.DataArray 'mu' (chain: 4, draw: 10, team: 6)>
0.2691 0.1617 0.4371 0.4885 0.1836 2.149 ... 2.037 0.09032 0.2221 0.4673 1.844
Coordinates:
  * team     (team) <U1 'a' 'b' 'b' 'a' 'c' 'b'
  * chain    (chain) int64 0 1 2 3
  * draw     (draw) int64 0 1 2 3 4 5 6 7 8 9

when we apply a “group by” operation over the team dimension, we generate a DataArrayGroupBy with 3 groups.

gb = da.groupby("team")
gb
DataArrayGroupBy, grouped over 'team'
3 groups with labels 'a', 'b', 'c'.

on which we can use .map to apply a function from xarray-einstats over all groups independently:

gb.map(stats.median_abs_deviation, dims=["draw", "team"])
<xarray.DataArray 'mu' (chain: 4, team: 3)>
0.3436 0.3758 0.2351 0.5221 0.5937 0.4158 ... 0.4314 0.212 0.3479 0.5708 0.2288
Coordinates:
  * chain    (chain) int64 0 1 2 3
  * team     (team) object 'a' 'b' 'c'

which as expected has performed the operation group-wise, yielding a different result than either

stats.median_abs_deviation(da, dims=["draw", "team"])
<xarray.DataArray 'mu' (chain: 4)>
0.3444 0.5968 0.4553 0.4069
Coordinates:
  * chain    (chain) int64 0 1 2 3

or

stats.median_abs_deviation(da, dims="draw")
<xarray.DataArray 'mu' (chain: 4, team: 6)>
0.3452 0.3788 0.09536 0.3892 0.2351 ... 0.6554 0.2451 0.3832 0.2288 0.5281
Coordinates:
  * team     (team) <U1 'a' 'b' 'b' 'a' 'c' 'b'
  * chain    (chain) int64 0 1 2 3

See also

Check out the GroupBy: Group and Bin Data page on xarray’s documentation.