Getting started#
Welcome to xarray-einstats!#
xarray-einstats is an open source Python library part of the
ArviZ project.
It acts as a bridge between the xarray
library for labelled arrays and libraries for raw arrays
such as NumPy or SciPy.
Xarray has as “Compatibility with the broader ecosystem” as
one of its main goals.
Which is what allows xarray-einstats to perform this
bridge role with minimal code and duplication.
Overview#
xarray-einstats provides wrappers for:
Most of the functions in
numpy.linalgA subset of
scipy.statsrearrangeandreducefrom einops
These wrappers have the same names and functionality as the original functions.
The difference in behaviour is that the wrappers will not make assumptions
about the meaning of a dimension based on its position
nor they have arguments like axis or axes.
They will have dims argument that take dimension names instead of
integers indicating the positions of the dimensions on which to act.
It also provides a handful of re-implemented functions:
These are partially reimplemented because the original function doesn’t yet support multidimensional and/or batched computations. They also share the name with a function in NumPy or SciPy, but they only implement a subset of the features. Moreover, the goal is for those to eventually be wrappers too.
Using xarray-einstats#
DataArray inputs#
Functions in xarray-einstats are designed to work on DataArray objects.
Let’s load some example data:
from xarray_einstats import linalg, stats, tutorial
da = tutorial.generate_matrices_dataarray(4)
da
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)> 3.799 0.4308 3.24 0.1412 0.9402 0.7951 ... 0.6156 1.124 0.8559 2.108 0.7637 Dimensions without coordinates: batch, experiment, dim, dim2
and show an example:
stats.skew(da, dims=["batch", "dim2"])
<xarray.DataArray (experiment: 3, dim: 4)> 1.256 1.432 0.9728 1.762 1.612 1.188 1.033 2.388 2.196 1.455 1.631 1.373 Dimensions without coordinates: experiment, dim
xarray-einstats uses dims as argument throughout the codebase
as an alternative to both axis or axes indistinctively,
also as alternative to the (..., M, M) convention used by NumPy.
The use of dims follows dot, instead of the singular
dim argument used for example in mean.
Both a single dimension or multiple are valid inputs,
and using dims emphasizes the fact that operations
and reductions can be performed over multiple dimensions at the same time.
Moreover, in linear algebra functions, dims is often restricted
to a 2 element list as it indicates which dimensions define the matrices,
interpreting all the others as batch dimensions.
That means that the two calls below are equivalent, even if the dimension names of the inputs are not, because their dimension names are the same. Thus,
linalg.det(da, dims=["dim", "dim2"])
<xarray.DataArray (batch: 10, experiment: 3)> 23.55 2.033 0.3923 -7.374 0.06645 ... 1.804 -0.1599 8.875 -0.04935 -8.428 Dimensions without coordinates: batch, experiment
returns the same as:
linalg.det(da.transpose("dim2", "experiment", "dim", "batch"), dims=["dim", "dim2"])
<xarray.DataArray (experiment: 3, batch: 10)> 23.55 -7.374 -5.617 -12.29 1.77 -0.6289 ... -11.07 -0.5096 -28.77 -0.1599 -8.428 Dimensions without coordinates: experiment, batch
Important
In xarray_einstats only the dimension names matter, not their order.
Dataset and GroupBy inputs#
While the DataArray is the base xarray object, there are also
other xarray objects that are key while using the library.
These other objects such as Dataset are implemented as
a collection of DataArray objects, and all include a .map
method in order to apply the same function to all its child DataArrays.
ds = tutorial.generate_mcmc_like_dataset(9438)
ds
<xarray.Dataset>
Dimensions: (plot_dim: 20, chain: 4, draw: 10, team: 6, match: 12)
Coordinates:
* team (team) <U1 'a' 'b' 'c' 'd' 'e' 'f'
* chain (chain) int64 0 1 2 3
* draw (draw) int64 0 1 2 3 4 5 6 7 8 9
Dimensions without coordinates: plot_dim, match
Data variables:
x_plot (plot_dim) float64 0.0 0.5263 1.053 1.579 ... 8.947 9.474 10.0
mu (chain, draw, team) float64 0.2691 0.1617 0.4371 ... 0.4673 1.844
sigma (chain, draw) float64 1.939 1.435 0.5109 ... 0.594 1.54 1.257
score (chain, draw, match) int64 0 2 3 0 0 0 0 0 2 ... 1 0 1 1 1 2 4 0 2We can use map to apply the same function to
all the 4 child DataArrays in ds, but this will not always be possible.
When using .map, the function provided is applied to all child DataArrays
with the same **kwargs.
If we try doing:
ds.map(stats.circmean, dims=("chain", "draw"))
Show code cell output
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[6], line 1
----> 1 ds.map(stats.circmean, dims=("chain", "draw"))
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.6.0/lib/python3.10/site-packages/xarray/core/dataset.py:6037, in Dataset.map(self, func, keep_attrs, args, **kwargs)
6035 if keep_attrs is None:
6036 keep_attrs = _get_keep_attrs(default=False)
-> 6037 variables = {
6038 k: maybe_wrap_array(v, func(v, *args, **kwargs))
6039 for k, v in self.data_vars.items()
6040 }
6041 if keep_attrs:
6042 for k, v in variables.items():
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.6.0/lib/python3.10/site-packages/xarray/core/dataset.py:6038, in <dictcomp>(.0)
6035 if keep_attrs is None:
6036 keep_attrs = _get_keep_attrs(default=False)
6037 variables = {
-> 6038 k: maybe_wrap_array(v, func(v, *args, **kwargs))
6039 for k, v in self.data_vars.items()
6040 }
6041 if keep_attrs:
6042 for k, v in variables.items():
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.6.0/lib/python3.10/site-packages/xarray_einstats/stats.py:498, in circmean(da, dims, high, low, nan_policy, **kwargs)
496 if nan_policy is not None:
497 circmean_kwargs["nan_policy"] = nan_policy
--> 498 return _apply_reduce_func(stats.circmean, da, dims, kwargs, circmean_kwargs)
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.6.0/lib/python3.10/site-packages/xarray_einstats/stats.py:445, in _apply_reduce_func(func, da, dims, kwargs, func_kwargs)
443 else:
444 core_dims = [dims]
--> 445 out_da = xr.apply_ufunc(
446 func, da, input_core_dims=[core_dims], output_core_dims=[[]], kwargs=func_kwargs, **kwargs
447 )
448 return out_da
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.6.0/lib/python3.10/site-packages/xarray/core/computation.py:1197, in apply_ufunc(func, input_core_dims, output_core_dims, exclude_dims, vectorize, join, dataset_join, dataset_fill_value, keep_attrs, kwargs, dask, output_dtypes, output_sizes, meta, dask_gufunc_kwargs, *args)
1195 # feed DataArray apply_variable_ufunc through apply_dataarray_vfunc
1196 elif any(isinstance(a, DataArray) for a in args):
-> 1197 return apply_dataarray_vfunc(
1198 variables_vfunc,
1199 *args,
1200 signature=signature,
1201 join=join,
1202 exclude_dims=exclude_dims,
1203 keep_attrs=keep_attrs,
1204 )
1205 # feed Variables directly through apply_variable_ufunc
1206 elif any(isinstance(a, Variable) for a in args):
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.6.0/lib/python3.10/site-packages/xarray/core/computation.py:304, in apply_dataarray_vfunc(func, signature, join, exclude_dims, keep_attrs, *args)
299 result_coords, result_indexes = build_output_coords_and_indexes(
300 args, signature, exclude_dims, combine_attrs=keep_attrs
301 )
303 data_vars = [getattr(a, "variable", a) for a in args]
--> 304 result_var = func(*data_vars)
306 out: tuple[DataArray, ...] | DataArray
307 if signature.num_outputs > 1:
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.6.0/lib/python3.10/site-packages/xarray/core/computation.py:672, in apply_variable_ufunc(func, signature, exclude_dims, dask, output_dtypes, vectorize, keep_attrs, dask_gufunc_kwargs, *args)
667 broadcast_dims = tuple(
668 dim for dim in dim_sizes if dim not in signature.all_core_dims
669 )
670 output_dims = [broadcast_dims + out for out in signature.output_core_dims]
--> 672 input_data = [
673 broadcast_compat_data(arg, broadcast_dims, core_dims)
674 if isinstance(arg, Variable)
675 else arg
676 for arg, core_dims in zip(args, signature.input_core_dims)
677 ]
679 if any(is_chunked_array(array) for array in input_data):
680 if dask == "forbidden":
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.6.0/lib/python3.10/site-packages/xarray/core/computation.py:673, in <listcomp>(.0)
667 broadcast_dims = tuple(
668 dim for dim in dim_sizes if dim not in signature.all_core_dims
669 )
670 output_dims = [broadcast_dims + out for out in signature.output_core_dims]
672 input_data = [
--> 673 broadcast_compat_data(arg, broadcast_dims, core_dims)
674 if isinstance(arg, Variable)
675 else arg
676 for arg, core_dims in zip(args, signature.input_core_dims)
677 ]
679 if any(is_chunked_array(array) for array in input_data):
680 if dask == "forbidden":
File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.6.0/lib/python3.10/site-packages/xarray/core/computation.py:600, in broadcast_compat_data(variable, broadcast_dims, core_dims)
598 missing_core_dims = [d for d in core_dims if d not in set_old_dims]
599 if missing_core_dims:
--> 600 raise ValueError(
601 "operand to apply_ufunc has required core dimensions {}, but "
602 "some of these dimensions are absent on an input variable: {}".format(
603 list(core_dims), missing_core_dims
604 )
605 )
607 set_new_dims = set(new_dims)
608 unexpected_dims = [d for d in old_dims if d not in set_new_dims]
ValueError: operand to apply_ufunc has required core dimensions ['__aux_dim__:chain,draw'], but some of these dimensions are absent on an input variable: ['__aux_dim__:chain,draw']
we get an exception. The chain and draw dimensions are not present in all
child DataArrays. Instead, we could apply it only to the variables
that have both chain and dim dimensions.
ds_samples = ds[["mu", "sigma", "score"]]
ds_samples.map(stats.circmean, dims=("chain", "draw"))
<xarray.Dataset>
Dimensions: (team: 6, match: 12)
Coordinates:
* team (team) <U1 'a' 'b' 'c' 'd' 'e' 'f'
Dimensions without coordinates: match
Data variables:
mu (team) float64 0.8221 0.7376 0.6195 0.7485 0.7439 0.7818
sigma float64 0.8134
score (match) float64 0.7441 0.3923 0.9316 0.6107 ... 0.5814 0.9538 0.94Attention
In general, you should prefer using .map attribute over using non-DataArray objects as
input to the xarray_einstats directly.
.map will ensure no unexpected broadcasting between the multiple child DataArrays takes place.
See the examples below for some examples.
However, if you are using functions that reduce dimensions on non-DataArray inputs
whose child DataArrays all have all the dimensions to reduce you will
not trigger any such broadcasting,
and we have included that behaviour on our test suite to ensure it stays this way.
It is also possible to do
stats.circmean(ds_samples, dims=("chain", "draw"))
<xarray.Dataset>
Dimensions: (team: 6, match: 12)
Coordinates:
* team (team) <U1 'a' 'b' 'c' 'd' 'e' 'f'
Dimensions without coordinates: match
Data variables:
mu (team) float64 0.8221 0.7376 0.6195 0.7485 0.7439 0.7818
sigma float64 0.8134
score (match) float64 0.7441 0.3923 0.9316 0.6107 ... 0.5814 0.9538 0.94Here, all child DataArrays have both chain and draw dimension,
so as expected, the result is the same.
There are some cases however, in which not using .map triggers
some broadcasting operations which will generally not be the desired
output.
If we use the .map attribute, the function is applied to each
child DataArray independently from the others:
ds.map(stats.rankdata)
<xarray.Dataset>
Dimensions: (plot_dim: 20, chain: 4, draw: 10, team: 6, match: 12)
Dimensions without coordinates: plot_dim, chain, draw, team, match
Data variables:
x_plot (plot_dim) float64 1.0 2.0 3.0 4.0 5.0 ... 16.0 17.0 18.0 19.0 20.0
mu (chain, draw, team) float64 65.0 41.0 89.0 ... 55.0 97.0 205.0
sigma (chain, draw) float64 33.0 30.0 15.0 5.0 ... 4.0 18.0 31.0 29.0
score (chain, draw, match) float64 105.0 401.0 457.0 ... 105.0 401.0whereas without using the .map attribute, extra broadcasting can happen:
stats.rankdata(ds)
<xarray.Dataset>
Dimensions: (plot_dim: 20, chain: 4, draw: 10, team: 6, match: 12)
Dimensions without coordinates: plot_dim, chain, draw, team, match
Data variables:
x_plot (plot_dim, chain, draw, team, match) float64 1.44e+03 ... 5.616e+04
mu (plot_dim, chain, draw, team, match) float64 1.548e+04 ... 4.908...
sigma (plot_dim, chain, draw, team, match) float64 4.68e+04 ... 4.104e+04
score (plot_dim, chain, draw, team, match) float64 1.254e+04 ... 4.806...The behaviour on DataArrayGroupBy for example is very similar
to the examples we have shown for Datasets:
da = ds["mu"].assign_coords(team=["a", "b", "b", "a", "c", "b"])
da
<xarray.DataArray 'mu' (chain: 4, draw: 10, team: 6)> 0.2691 0.1617 0.4371 0.4885 0.1836 2.149 ... 2.037 0.09032 0.2221 0.4673 1.844 Coordinates: * team (team) <U1 'a' 'b' 'b' 'a' 'c' 'b' * chain (chain) int64 0 1 2 3 * draw (draw) int64 0 1 2 3 4 5 6 7 8 9
when we apply a “group by” operation over the team dimension, we generate a
DataArrayGroupBy with 3 groups.
gb = da.groupby("team")
gb
DataArrayGroupBy, grouped over 'team'
3 groups with labels 'a', 'b', 'c'.
on which we can use .map to apply a function from xarray-einstats over
all groups independently:
gb.map(stats.median_abs_deviation, dims=["draw", "team"])
<xarray.DataArray 'mu' (chain: 4, team: 3)> 0.3436 0.3758 0.2351 0.5221 0.5937 0.4158 ... 0.4314 0.212 0.3479 0.5708 0.2288 Coordinates: * chain (chain) int64 0 1 2 3 * team (team) object 'a' 'b' 'c'
which as expected has performed the operation group-wise, yielding a different result than either
stats.median_abs_deviation(da, dims=["draw", "team"])
<xarray.DataArray 'mu' (chain: 4)> 0.3444 0.5968 0.4553 0.4069 Coordinates: * chain (chain) int64 0 1 2 3
or
stats.median_abs_deviation(da, dims="draw")
<xarray.DataArray 'mu' (chain: 4, team: 6)> 0.3452 0.3788 0.09536 0.3892 0.2351 ... 0.6554 0.2451 0.3832 0.2288 0.5281 Coordinates: * team (team) <U1 'a' 'b' 'b' 'a' 'c' 'b' * chain (chain) int64 0 1 2 3
See also
Check out the GroupBy: Group and Bin Data page on xarray’s documentation.