Getting started#
Welcome to xarray-einstats
!#
xarray-einstats
is an open source Python library part of the
ArviZ project.
It acts as a bridge between the xarray
library for labelled arrays and libraries for raw arrays
such as NumPy or SciPy.
Xarray has as “Compatibility with the broader ecosystem” as
one of its main goals.
Which is what allows xarray-einstats
to perform this
bridge role with minimal code and duplication.
Overview#
xarray-einstats
provides wrappers for:
Most of the functions in
numpy.linalg
A subset of
scipy.stats
rearrange
andreduce
from einops
These wrappers have the same names and functionality as the original functions.
The difference in behaviour is that the wrappers will not make assumptions
about the meaning of a dimension based on its position
nor they have arguments like axis
or axes
.
They will have dims
argument that take dimension names instead of
integers indicating the positions of the dimensions on which to act.
It also provides a handful of re-implemented functions:
These are partially reimplemented because the original function doesn’t yet support multidimensional and/or batched computations. They also share the name with a function in NumPy or SciPy, but they only implement a subset of the features. Moreover, the goal is for those to eventually be wrappers too.
Using xarray-einstats
#
DataArray inputs#
Functions in xarray-einstats
are designed to work on DataArray
objects.
Let’s load some example data:
from xarray_einstats import linalg, stats, tutorial
da = tutorial.generate_matrices_dataarray(4)
da
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)> 3.799 0.4308 3.24 0.1412 0.9402 0.7951 ... 0.6156 1.124 0.8559 2.108 0.7637 Dimensions without coordinates: batch, experiment, dim, dim2
and show an example:
stats.skew(da, dims=["batch", "dim2"])
<xarray.DataArray (experiment: 3, dim: 4)> 1.256 1.432 0.9728 1.762 1.612 1.188 1.033 2.388 2.196 1.455 1.631 1.373 Dimensions without coordinates: experiment, dim
xarray-einstats
uses dims
as argument throughout the codebase
as an alternative to both axis
or axes
indistinctively,
also as alternative to the (..., M, M)
convention used by NumPy.
The use of dims
follows dot
, instead of the singular
dim
argument used for example in mean
.
Both a single dimension or multiple are valid inputs,
and using dims
emphasizes the fact that operations
and reductions can be performed over multiple dimensions at the same time.
Moreover, in linear algebra functions, dims
is often restricted
to a 2 element list as it indicates which dimensions define the matrices,
interpreting all the others as batch dimensions.
That means that the two calls below are equivalent, even if the dimension names of the inputs are not, because their dimension names are the same. Thus,
linalg.det(da, dims=["dim", "dim2"])
<xarray.DataArray (batch: 10, experiment: 3)> 23.55 2.033 0.3923 -7.374 0.06645 ... 1.804 -0.1599 8.875 -0.04935 -8.428 Dimensions without coordinates: batch, experiment
returns the same as:
linalg.det(da.transpose("dim2", "experiment", "dim", "batch"), dims=["dim", "dim2"])
<xarray.DataArray (experiment: 3, batch: 10)> 23.55 -7.374 -5.617 -12.29 1.77 -0.6289 ... -11.07 -0.5096 -28.77 -0.1599 -8.428 Dimensions without coordinates: experiment, batch
Important
In xarray_einstats
only the dimension names matter, not their order.
Dataset and GroupBy inputs#
While the DataArray
is the base xarray object, there are also
other xarray objects that are key while using the library.
These other objects such as Dataset
are implemented as
a collection of DataArray
objects, and all include a .map
method in order to apply the same function to all its child DataArrays
.
ds = tutorial.generate_mcmc_like_dataset(9438)
ds
<xarray.Dataset> Dimensions: (plot_dim: 20, chain: 4, draw: 10, team: 6, match: 12) Coordinates: * team (team) <U1 'a' 'b' 'c' 'd' 'e' 'f' * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 9 Dimensions without coordinates: plot_dim, match Data variables: x_plot (plot_dim) float64 0.0 0.5263 1.053 1.579 ... 8.947 9.474 10.0 mu (chain, draw, team) float64 0.2691 0.1617 0.4371 ... 0.4673 1.844 sigma (chain, draw) float64 1.939 1.435 0.5109 ... 0.594 1.54 1.257 score (chain, draw, match) int64 0 2 3 0 0 0 0 0 2 ... 1 0 1 1 1 2 4 0 2
We can use map
to apply the same function to
all the 4 child DataArray
s in ds
, but this will not always be possible.
When using .map
, the function provided is applied to all child DataArray
s
with the same **kwargs
.
If we try doing:
ds.map(stats.circmean, dims=("chain", "draw"))
Show code cell output
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[6], line 1
----> 1 ds.map(stats.circmean, dims=("chain", "draw"))
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataset.py:5949, in Dataset.map(self, func, keep_attrs, args, **kwargs)
5947 if keep_attrs is None:
5948 keep_attrs = _get_keep_attrs(default=False)
-> 5949 variables = {
5950 k: maybe_wrap_array(v, func(v, *args, **kwargs))
5951 for k, v in self.data_vars.items()
5952 }
5953 if keep_attrs:
5954 for k, v in variables.items():
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataset.py:5950, in <dictcomp>(.0)
5947 if keep_attrs is None:
5948 keep_attrs = _get_keep_attrs(default=False)
5949 variables = {
-> 5950 k: maybe_wrap_array(v, func(v, *args, **kwargs))
5951 for k, v in self.data_vars.items()
5952 }
5953 if keep_attrs:
5954 for k, v in variables.items():
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray_einstats/stats.py:489, in circmean(da, dims, high, low, nan_policy, **kwargs)
487 if nan_policy is not None:
488 circmean_kwargs["nan_policy"] = nan_policy
--> 489 return _apply_reduce_func(stats.circmean, da, dims, kwargs, circmean_kwargs)
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray_einstats/stats.py:432, in _apply_reduce_func(func, da, dims, kwargs, func_kwargs)
430 if not isinstance(dims, str):
431 aux_dim = f"__aux_dim__:{','.join(dims)}"
--> 432 da = da.stack({aux_dim: dims})
433 core_dims = [aux_dim]
434 else:
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataarray.py:2739, in DataArray.stack(self, dimensions, create_index, index_cls, **dimensions_kwargs)
2674 def stack(
2675 self: T_DataArray,
2676 dimensions: Mapping[Any, Sequence[Hashable]] | None = None,
(...)
2679 **dimensions_kwargs: Sequence[Hashable],
2680 ) -> T_DataArray:
2681 """
2682 Stack any number of existing dimensions into a single new dimension.
2683
(...)
2737 DataArray.unstack
2738 """
-> 2739 ds = self._to_temp_dataset().stack(
2740 dimensions,
2741 create_index=create_index,
2742 index_cls=index_cls,
2743 **dimensions_kwargs,
2744 )
2745 return self._from_temp_dataset(ds)
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataset.py:4593, in Dataset.stack(self, dimensions, create_index, index_cls, **dimensions_kwargs)
4591 result = self
4592 for new_dim, dims in dimensions.items():
-> 4593 result = result._stack_once(dims, new_dim, index_cls, create_index)
4594 return result
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataset.py:4524, in Dataset._stack_once(self, dims, new_dim, index_cls, create_index)
4522 product_vars: dict[Any, Variable] = {}
4523 for dim in dims:
-> 4524 idx, idx_vars = self._get_stack_index(dim, create_index=create_index)
4525 if idx is not None:
4526 product_vars.update(idx_vars)
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataset.py:4480, in Dataset._get_stack_index(self, dim, multi, create_index)
4478 var = self._variables[dim]
4479 else:
-> 4480 _, _, var = _get_virtual_variable(self._variables, dim, self.dims)
4481 # dummy index (only `stack_coords` will be used to construct the multi-index)
4482 stack_index = PandasIndex([0], dim)
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.4.0/lib/python3.10/site-packages/xarray/core/dataset.py:178, in _get_virtual_variable(variables, key, dim_sizes)
176 split_key = key.split(".", 1)
177 if len(split_key) != 2:
--> 178 raise KeyError(key)
180 ref_name, var_name = split_key
181 ref_var = variables[ref_name]
KeyError: 'chain'
we get an exception. The chain
and draw
dimensions are not present in all
child DataArrays
. Instead, we could apply it only to the variables
that have both chain
and dim
dimensions.
ds_samples = ds[["mu", "sigma", "score"]]
ds_samples.map(stats.circmean, dims=("chain", "draw"))
<xarray.Dataset> Dimensions: (team: 6, match: 12) Coordinates: * team (team) <U1 'a' 'b' 'c' 'd' 'e' 'f' Dimensions without coordinates: match Data variables: mu (team) float64 0.8221 0.7376 0.6195 0.7485 0.7439 0.7818 sigma float64 0.8134 score (match) float64 0.7441 0.3923 0.9316 0.6107 ... 0.5814 0.9538 0.94
Attention
In general, you should prefer using .map
attribute over using non-DataArray
objects as
input to the xarray_einstats
directly.
.map
will ensure no unexpected broadcasting between the multiple child DataArray
s takes place.
See the examples below for some examples.
However, if you are using functions that reduce dimensions on non-DataArray
inputs
whose child DataArray
s all have all the dimensions to reduce you will
not trigger any such broadcasting,
and we have included that behaviour on our test suite to ensure it stays this way.
It is also possible to do
stats.circmean(ds_samples, dims=("chain", "draw"))
<xarray.Dataset> Dimensions: (team: 6, match: 12) Coordinates: * team (team) <U1 'a' 'b' 'c' 'd' 'e' 'f' Dimensions without coordinates: match Data variables: mu (team) float64 0.8221 0.7376 0.6195 0.7485 0.7439 0.7818 sigma float64 0.8134 score (match) float64 0.7441 0.3923 0.9316 0.6107 ... 0.5814 0.9538 0.94
Here, all child DataArray
s have both chain
and draw
dimension,
so as expected, the result is the same.
There are some cases however, in which not using .map
triggers
some broadcasting operations which will generally not be the desired
output.
If we use the .map
attribute, the function is applied to each
child DataArray
independently from the others:
ds.map(stats.rankdata)
<xarray.Dataset> Dimensions: (plot_dim: 20, chain: 4, draw: 10, team: 6, match: 12) Coordinates: * plot_dim (plot_dim) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 9 * team (team) <U1 'a' 'b' 'c' 'd' 'e' 'f' * match (match) int64 0 1 2 3 4 5 6 7 8 9 10 11 Data variables: x_plot (plot_dim) float64 1.0 2.0 3.0 4.0 5.0 ... 17.0 18.0 19.0 20.0 mu (chain, draw, team) float64 65.0 41.0 89.0 ... 55.0 97.0 205.0 sigma (chain, draw) float64 33.0 30.0 15.0 5.0 ... 4.0 18.0 31.0 29.0 score (chain, draw, match) float64 105.0 401.0 457.0 ... 105.0 401.0
whereas without using the .map
attribute, extra broadcasting can happen:
stats.rankdata(ds)
<xarray.Dataset> Dimensions: (plot_dim: 20, chain: 4, draw: 10, team: 6, match: 12) Coordinates: * plot_dim (plot_dim) int64 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 9 * team (team) <U1 'a' 'b' 'c' 'd' 'e' 'f' * match (match) int64 0 1 2 3 4 5 6 7 8 9 10 11 Data variables: x_plot (plot_dim, chain, draw, team, match) float64 1.44e+03 ... 5.616... mu (plot_dim, chain, draw, team, match) float64 1.548e+04 ... 4.90... sigma (plot_dim, chain, draw, team, match) float64 4.68e+04 ... 4.104... score (plot_dim, chain, draw, team, match) float64 1.254e+04 ... 4.80...
The behaviour on DataArrayGroupBy
for example is very similar
to the examples we have shown for Dataset
s:
da = ds["mu"].assign_coords(team=["a", "b", "b", "a", "c", "b"])
da
<xarray.DataArray 'mu' (chain: 4, draw: 10, team: 6)> 0.2691 0.1617 0.4371 0.4885 0.1836 2.149 ... 2.037 0.09032 0.2221 0.4673 1.844 Coordinates: * team (team) <U1 'a' 'b' 'b' 'a' 'c' 'b' * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 9
when we apply a “group by” operation over the team
dimension, we generate a
DataArrayGroupBy
with 3 groups.
gb = da.groupby("team")
gb
DataArrayGroupBy, grouped over 'team'
3 groups with labels 'a', 'b', 'c'.
on which we can use .map
to apply a function from xarray-einstats
over
all groups independently:
gb.map(stats.median_abs_deviation, dims=["draw", "team"])
<xarray.DataArray 'mu' (chain: 4, team: 3)> 0.3436 0.3758 0.2351 0.5221 0.5937 0.4158 ... 0.4314 0.212 0.3479 0.5708 0.2288 Coordinates: * chain (chain) int64 0 1 2 3 * team (team) object 'a' 'b' 'c'
which as expected has performed the operation group-wise, yielding a different result than either
stats.median_abs_deviation(da, dims=["draw", "team"])
<xarray.DataArray 'mu' (chain: 4)> 0.3444 0.5968 0.4553 0.4069 Coordinates: * chain (chain) int64 0 1 2 3
or
stats.median_abs_deviation(da, dims="draw")
<xarray.DataArray 'mu' (chain: 4, team: 6)> 0.3452 0.3788 0.09536 0.3892 0.2351 ... 0.6554 0.2451 0.3832 0.2288 0.5281 Coordinates: * team (team) <U1 'a' 'b' 'b' 'a' 'c' 'b' * chain (chain) int64 0 1 2 3
See also
Check out the GroupBy: Group and Bin Data page on xarray’s documentation.