xarray_einstats.linalg.get_default_dims#

xarray_einstats.linalg.get_default_dims(da1_dims, d2_dims=None)[source]#

Get the dimensions corresponding to the matrices.

Parameters:
da1_dimslist of str
da2_dimslist of str, optional

Used only in case of multiple inputs, otherwise it will keep its default value of None

Returns:
list of str

The dimensions indicating the matrix dimensions. Must be an iterable containing two strings.

Warning

dims is required for functions in the linalg module. This function acts as a placeholder and only raises an error indicating that dims is a required argument unless this function is monkeypatched.

It is documented here to show how to write and configure a substitute function.

Examples

The xarray_einstats default behaviour is requiring the dims argument for functions in the linalg module. Not providing it raises a TypeError

from xarray_einstats import linalg, tutorial
da = tutorial.generate_matrices_dataarray(5)
linalg.inv(da)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [1], in <cell line: 3>()
      1 from xarray_einstats import linalg, tutorial
      2 da = tutorial.generate_matrices_dataarray(5)
----> 3 linalg.inv(da)

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/latest/lib/python3.9/site-packages/xarray_einstats/linalg.py:690, in inv(da, dims, **kwargs)
    685 """Wrap :func:`numpy.linalg.inv`.
    686 
    687 Usage examples of all arguments is available at the :ref:`linalg_tutorial` page.
    688 """
    689 if dims is None:
--> 690     dims = _attempt_default_dims("inv", da.dims)
    691 return xr.apply_ufunc(
    692     np.linalg.inv, da, input_core_dims=[dims], output_core_dims=[dims], **kwargs
    693 )

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/latest/lib/python3.9/site-packages/xarray_einstats/linalg.py:106, in _attempt_default_dims(func, da1_dims, da2_dims)
    104     aux = get_default_dims(da1_dims, da2_dims)
    105 except MissingMonkeypatchError:
--> 106     raise TypeError(
    107         f"{func} missing required argument dims. You must monkeypatch "
    108         "xarray_einstats.linalg.get_default_dims for dims=None to be supported"
    109     ) from None
    110 return aux

TypeError: inv missing required argument dims. You must monkeypatch xarray_einstats.linalg.get_default_dims for dims=None to be supported

You need to pass the dimensions corresponding the matrix axes explicitly

linalg.inv(da, dims=["dim", "dim2"])
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)>
0.5087 -0.6454 -0.4175 0.1449 -0.1026 ... -0.1792 -0.6749 -0.2189 0.2564 1.403
Dimensions without coordinates: batch, experiment, dim, dim2

However, in many cases it will be possible to identify those dimensions from the list of all dimension names in the input.

Here we show how to monkeypatch get_default_dims to get a different default behaviour. If you follow a convention to label the dimensions corresponding to the matrix axes, you can integrate this logic into xarray_einstats, which will avoid unnecessary repetition, especially if performing several chained linear algebra operations:

def get_default_dims(dims1, dims2):
    if dims2 is not None:
        raise TypeError("Default dims only valid for single input functions")
    matrix_dims = [dim for dim in dims1 if f"{dim}2" in dims1]
    if len(matrix_dims) != 1:
        raise TypeError("Unable to guess default matrix dims")
    dim = matrix_dims[0]
    return [dim, f"{dim}2"]

linalg.get_default_dims = get_default_dims
linalg.inv(da)
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)>
0.5087 -0.6454 -0.4175 0.1449 -0.1026 ... -0.1792 -0.6749 -0.2189 0.2564 1.403
Dimensions without coordinates: batch, experiment, dim, dim2

You can still use dims explicitly to override those defaults.