xarray_einstats.linalg.get_default_dims#

xarray_einstats.linalg.get_default_dims(da1_dims, d2_dims=None)[source]#

Get the dimensions corresponding to the matrices.

Parameters
da1_dimslist of str
da2_dimslist of str, optional

Used only in case of multiple inputs, otherwise it will keep its default value of None

Returns
list of str

The dimensions indicating the matrix dimensions. Must be an iterable containing two strings.

Warning

dims is required for functions in the linalg module. This function acts as a placeholder and only raises an error indicating that dims is a required argument unless this function is monkeypatched.

It is documented here to show how to write and configure a substitute function.

Examples

The xarray_einstats default behaviour is requiring the dims argument for functions in the linalg module. Not providing it raises a TypeError

from xarray_einstats import linalg, tutorial
da = tutorial.generate_matrices_dataarray(5)
linalg.inv(da)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [1], in <cell line: 3>()
      1 from xarray_einstats import linalg, tutorial
      2 da = tutorial.generate_matrices_dataarray(5)
----> 3 linalg.inv(da)

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.2.2/lib/python3.9/site-packages/xarray_einstats/linalg.py:690, in inv(da, dims, **kwargs)
    685 """Wrap :func:`numpy.linalg.inv`.
    686 
    687 Usage examples of all arguments is available at the :ref:`linalg_tutorial` page.
    688 """
    689 if dims is None:
--> 690     dims = _attempt_default_dims("inv", da.dims)
    691 return xr.apply_ufunc(
    692     np.linalg.inv, da, input_core_dims=[dims], output_core_dims=[dims], **kwargs
    693 )

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/v0.2.2/lib/python3.9/site-packages/xarray_einstats/linalg.py:106, in _attempt_default_dims(func, da1_dims, da2_dims)
    104     aux = get_default_dims(da1_dims, da2_dims)
    105 except MissingMonkeypatchError:
--> 106     raise TypeError(
    107         f"{func} missing required argument dims. You must monkeypatch "
    108         "xarray_einstats.linalg.get_default_dims for dims=None to be supported"
    109     ) from None
    110 return aux

TypeError: inv missing required argument dims. You must monkeypatch xarray_einstats.linalg.get_default_dims for dims=None to be supported

You need to pass the dimensions corresponding the matrix axes explicitly

linalg.inv(da, dims=["dim", "dim2"])
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)>
0.5087 -0.6454 -0.4175 0.1449 -0.1026 ... -0.1792 -0.6749 -0.2189 0.2564 1.403
Dimensions without coordinates: batch, experiment, dim, dim2

However, in many cases it will be possible to identify those dimensions from the list of all dimension names in the input.

Here we show how to monkeypatch get_default_dims to get a different default behaviour. If you follow a convention to label the dimensions corresponding to the matrix axes, you can integrate this logic into xarray_einstats, which will avoid unnecessary repetition, especially if performing several chained linear algebra operations:

def get_default_dims(dims1, dims2):
    if dims2 is not None:
        raise TypeError("Default dims only valid for single input functions")
    matrix_dims = [dim for dim in dims1 if f"{dim}2" in dims1]
    if len(matrix_dims) != 1:
        raise TypeError("Unable to guess default matrix dims")
    dim = matrix_dims[0]
    return [dim, f"{dim}2"]

linalg.get_default_dims = get_default_dims
linalg.inv(da)
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)>
0.5087 -0.6454 -0.4175 0.1449 -0.1026 ... -0.1792 -0.6749 -0.2189 0.2564 1.403
Dimensions without coordinates: batch, experiment, dim, dim2

You can still use dims explicitly to override those defaults.