get_default_dims

xarray_einstats.linalg.get_default_dims(da1_dims, d2_dims=None)[source]

Get the dimensions corresponding to the matrices.

Parameters:
da1_dimslist of str
da2_dimslist of str, optional

Used only in case of multiple inputs, otherwise it will keep its default value of None

Returns:
list of str

The dimensions indicating the matrix dimensions. Must be an iterable containing two strings.

Warning

dims is required for functions in the linalg module. This function acts as a placeholder and only raises an error indicating that dims is a required argument unless this function is monkeypatched.

It is documented here to show how to write and configure a substitute function.

Examples

The xarray_einstats default behaviour is requiring the dims argument for functions in the linalg module. Not providing it raises a TypeError

from xarray_einstats import linalg, tutorial
da = tutorial.generate_matrices_dataarray(5)
linalg.inv(da)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 3
      1 from xarray_einstats import linalg, tutorial
      2 da = tutorial.generate_matrices_dataarray(5)
----> 3 linalg.inv(da)

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/stable/lib/python3.12/site-packages/xarray_einstats/linalg.py:984, in inv(da, dims, **kwargs)
    969 """Wrap :func:`numpy.linalg.inv`.
    970 
    971 Usage examples of all arguments is available at the :ref:`linalg_tutorial` page.
   (...)    981 DataArray
    982 """
    983 if dims is None:
--> 984     dims = _attempt_default_dims("inv", da.dims)
    985 return xr.apply_ufunc(
    986     np.linalg.inv, da, input_core_dims=[dims], output_core_dims=[dims], **kwargs
    987 )

File ~/checkouts/readthedocs.org/user_builds/xarray-einstats/envs/stable/lib/python3.12/site-packages/xarray_einstats/linalg.py:121, in _attempt_default_dims(func, da1_dims, da2_dims)
    119     aux = get_default_dims(da1_dims, da2_dims)
    120 except MissingMonkeypatchError:
--> 121     raise TypeError(
    122         f"{func} missing required argument dims. You must monkeypatch "
    123         "xarray_einstats.linalg.get_default_dims for dims=None to be supported"
    124     ) from None
    125 return aux

TypeError: inv missing required argument dims. You must monkeypatch xarray_einstats.linalg.get_default_dims for dims=None to be supported

You need to pass the dimensions corresponding the matrix axes explicitly

linalg.inv(da, dims=["dim", "dim2"])
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)> Size: 4kB
0.5087 -0.6454 -0.4175 0.1449 -0.1026 ... -0.1792 -0.6749 -0.2189 0.2564 1.403
Dimensions without coordinates: batch, experiment, dim, dim2

However, in many cases it will be possible to identify those dimensions from the list of all dimension names in the input.

Here we show how to monkeypatch get_default_dims to get a different default behaviour. If you follow a convention to label the dimensions corresponding to the matrix axes, you can integrate this logic into xarray_einstats, which will avoid unnecessary repetition, especially if performing several chained linear algebra operations:

def get_default_dims(dims1, dims2):
    if dims2 is not None:
        raise TypeError("Default dims only valid for single input functions")
    matrix_dims = [dim for dim in dims1 if f"{dim}2" in dims1]
    if len(matrix_dims) != 1:
        raise TypeError("Unable to guess default matrix dims")
    dim = matrix_dims[0]
    return [dim, f"{dim}2"]

linalg.get_default_dims = get_default_dims
linalg.inv(da)
<xarray.DataArray (batch: 10, experiment: 3, dim: 4, dim2: 4)> Size: 4kB
0.5087 -0.6454 -0.4175 0.1449 -0.1026 ... -0.1792 -0.6749 -0.2189 0.2564 1.403
Dimensions without coordinates: batch, experiment, dim, dim2

You can still use dims explicitly to override those defaults.