
xarray_einstats.linalg.get_default_dims(da1_dims, d2_dims=None)[source]#

Get the dimensions corresponding to the matrices.

da1_dimslist of str
da2_dimslist of str, optional

Used only in case of multiple inputs, otherwise it will keep its default value of None

list of str

The dimensions indicating the matrix dimensions. Must be an iterable containing two strings.


dims is required for functions in the linalg module. This function acts as a placeholder and only raises an error indicating that dims is a required argument unless this function is monkeypatched.

It is documented here to show how to write and configure a substitute function.


The xarray_einstats default behaviour is requiring the dims argument for functions in the linalg module. Not providing it raises a TypeError

from xarray_einstats import linalg, tutorial
da = tutorial.generate_matrices_dataarray(5)
TypeError                                 Traceback (most recent call last)
Cell In[1], line 3
      1 from xarray_einstats import linalg, tutorial
      2 da = tutorial.generate_matrices_dataarray(5)
----> 3 linalg.inv(da)

File ~/checkouts/, in inv(da, dims, **kwargs)
    704 """Wrap :func:`numpy.linalg.inv`.
    706 Usage examples of all arguments is available at the :ref:`linalg_tutorial` page.
    707 """
    708 if dims is None:
--> 709     dims = _attempt_default_dims("inv", da.dims)
    710 return xr.apply_ufunc(
    711     np.linalg.inv, da, input_core_dims=[dims], output_core_dims=[dims], **kwargs
    712 )

File ~/checkouts/, in _attempt_default_dims(func, da1_dims, da2_dims)
    105     aux = get_default_dims(da1_dims, da2_dims)
    106 except MissingMonkeypatchError:
--> 107     raise TypeError(
    108         f"{func} missing required argument dims. You must monkeypatch "
    109         "xarray_einstats.linalg.get_default_dims for dims=None to be supported"
    110     ) from None
    111 return aux

TypeError: inv missing required argument dims. You must monkeypatch xarray_einstats.linalg.get_default_dims for dims=None to be supported

You need to pass the dimensions corresponding the matrix axes explicitly

linalg.inv(da, dims=["dim", "dim2"])
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)>
0.5087 -0.6454 -0.4175 0.1449 -0.1026 ... -0.1792 -0.6749 -0.2189 0.2564 1.403
Dimensions without coordinates: batch, experiment, dim, dim2

However, in many cases it will be possible to identify those dimensions from the list of all dimension names in the input.

Here we show how to monkeypatch get_default_dims to get a different default behaviour. If you follow a convention to label the dimensions corresponding to the matrix axes, you can integrate this logic into xarray_einstats, which will avoid unnecessary repetition, especially if performing several chained linear algebra operations:

def get_default_dims(dims1, dims2):
    if dims2 is not None:
        raise TypeError("Default dims only valid for single input functions")
    matrix_dims = [dim for dim in dims1 if f"{dim}2" in dims1]
    if len(matrix_dims) != 1:
        raise TypeError("Unable to guess default matrix dims")
    dim = matrix_dims[0]
    return [dim, f"{dim}2"]

linalg.get_default_dims = get_default_dims
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)>
0.5087 -0.6454 -0.4175 0.1449 -0.1026 ... -0.1792 -0.6749 -0.2189 0.2564 1.403
Dimensions without coordinates: batch, experiment, dim, dim2

You can still use dims explicitly to override those defaults.