Einops tutorial (ported)#
This tutorial is a port using
xarray-einstats of the einops basics tutorial
Einops meets xarray!#
We don’t write
y = x.transpose(0, 2, 3, 1)
nor we write the more comprehensible alternative
y = rearrange(x, 'b c h w -> b h w c')
we write comprehensible code and use labeled arrays
y_da = rearrange(x_da, 'batch height width channel')
x_da is a
xarray.DataArray whose dimensions are already labeled, thus,
we can skip the left side that defines the names of the dimensions.
einops functions to extend them to work on xarray objects.
What’s in this tutorial?#
fundamentals: reordering, composition and decomposition of axes
how much you can do with a single operation!
# Examples are given for numpy. This code also setups ipython/jupyter # so that numpy arrays in the output are displayed as images import numpy import xarray from utils import display_np_arrays_as_images display_np_arrays_as_images()
This cell above configures jupyter to display numpy arrays as images, which is a great visual help to understand
the operations performed by einops. To take advantage of this we use
.values. In some specific cases, we also omit it in order to show the values of the dimensions of the DataArray. If you are running this yourself we encourage you to try both views
Load a batch of images to play with#
ds = xarray.open_zarr("einops-image.zarr").load() ims = ds["ims"] # There are 6 images of shape 96x96 with 3 color channels packed into the ims DataArray ds
<xarray.Dataset> Dimensions: (batch: 6, height: 96, width: 96, channel: 3) Dimensions without coordinates: batch, height, width, channel Data variables: ims (batch, height, width, channel) float64 1.0 0.902 ... 1.0 0.8039
display the first image (whole 4d tensor can’t be rendered)
second image in a batch
we’ll use three operations
from xarray_einstats.einops import raw_rearrange, raw_reduce #, repeat
rearrange, as its name suggests, rearranges elements.
Below we swapped height and width. This can be seen as transposing the first two dimensions, but in xarray, by design, the order of the dimensions should not matter, and that exact code below should (and will work) if the input object has the same dimension names but different order (in which case the operation might be transposing 1st and 3rd dims) or doing nothing. Having said that, rearranging is still a valuable operation on xarray objects, especially if done right before accessing the underlying numpy or dask array.
By rearranging, we are enforcing the width, height, channel order.
raw_rearrange(ims.sel(batch=0), 'width height channel').values
Composition of axes#
transposition is very common and useful, but let’s move to other capabilities provided by einops
einops allows seamlessly composing batch and height to a new height dimension, something that tends
to be tricky in xarray once we move from a standard
We just rendered all images by collapsing to 3d tensor!
da = raw_rearrange(ims, '(batch height) width channel') da.values
but wait, dimensions must be named in xarray, so what happened with this stacked dimension we have just created?
xarray_einstats assigns it a dimension name based on the names of the parent dimensions:
('batch-height', 'width', 'channel')
we can also compose a new dimension of batch and width. And now we will name it manually (strongly recommended over relying on the automatic names)
raw_rearrange(ims, '(batch width)=batched_widths channel').values
note that here we have skipped the
height dimension not only from the input but also from the
xarray_einstats follows xarray convention of adding the new or modified (aka
present in the output expression) dimensions at the end.
As dimensions are already named, we can skip dimensions not only from the input as we have been doing but also from the output if we don’t mind the new dimensions being moved to the right of the omitted ones.
Resulting dimensions are computed very simply. The length of newly composed axis is a product of components:
[6, 96, 96, 3] -> [96, (6 * 96), 3]
raw_rearrange(ims, '(batch width) channel').shape
(96, 576, 3)
We can compose more than two axes. Let’s flatten 4d array into 1d, resulting array has as many elements as the original
raw_rearrange(ims, '(batch height width channel)').shape
Everything we have done so far could have been done with
transpose or with
stack, so choosing between those methods or einops is a matter of personal
choice. We’d recommend you stick with the original xarray methods, especially if working
with dask arrays, their defaults will be much more convenient than the lack of automatic dask handling
The rearrangements below this point however can’t be reproduced by a single xarray method.
Decomposition of axis#
Decomposition is the inverse process, we represent an axis as a combination of new axes.
There will always be several decompositions possible, so we specify
batch to two dimensions
b2 of lengths 2 and 3 respectively.
In addition, we also need to specify the name of the dimension we want to decompose.
raw_rearrange(ims, '(b1 b2)=batch -> b1 b2 height width channel ', b1=2).shape
(2, 3, 96, 96, 3)
Finally, combine composition and decomposition:
da = raw_rearrange(ims, '(b1 b2)=batch -> (b1 height) (b2 width) channel ', b1=2) da.values
Again, we skip naming the output dimensions so they are named by
('b1-height', 'b2-width', 'channel')
Slightly different composition: b1 is merged with width, b2 with height
… so letters are ordered by w then by h
raw_rearrange(ims, '(b1 b2)=batch -> (b2 height) (b1 width) channel ', b1=2).values
Wove part of width dimension to height. We should call this width-to-height as image width shrunk by 2 and height doubled.
but all pixels are the same!
Can you write reverse operation (height-to-width)?
raw_rearrange(ims, '(w1 w2)=width -> (height w2) (batch w1) channel', w2=2).values
Order of axes matters#
Compare with the next two examples
raw_rearrange(ims, '(batch width) channel').values
raw_rearrange(ims, '(width batch) channel').values
The order of axes in the composition is different.
The rule is just as for digits in the number: the leftmost digit is the most significant, while neighboring numbers differ in the rightmost axis. You can also think of this as lexicographic sort
And what if b1 and b2 are reordered before composing to width?
raw_rearrange(ims, '(b1 b2)=batch -> (b1 b2 width) channel ', b1=2).values # produces 'einops' raw_rearrange(ims, '(b1 b2)=batch -> (b2 b1 width) channel ', b1=2).values # produces 'eoipns'
From einops documentation:
In einops-land you don’t need to guess what happenedx.mean(-1)
Because you write what the operation doesreduce(x, 'b h w c -> b h w', 'mean')
if an axis is not present in the output — you guessed it — that axis is reduced.
using xarray objects, you already don’t have to guess what happened. Much like with
the first examples (three in this case) using reduce can be reproduced with a single xarray method. See the
operation above with pure xarray and with
da.mean("channel") raw_reduce(x, "batch width height", "mean")
However, again much like with
reduce also opens the door to many other operations
that go beyond what single xarray methods can do. These cases are worth making
and once it’s working the simple operations are possible automatically.
If you prefer thinking in terms of the dimensions that are kept instead of the ones that are reduced,
reduce can be more convenient even for simple operations.
# average over batch raw_reduce(ims, 'height width channel', 'mean').values
# the previous is identical to familiar: ims.values.mean(axis=0) # but is so much more readable
# Example of reducing of several axes # besides mean, there are also min, max, sum, prod raw_reduce(ims, 'height width', 'min').values
# this is mean-pooling with 2x2 kernel # image is split into 2x2 patches, each patch is averaged raw_reduce(ims, '(h h2)=height (w w2)=width -> h (batch w) channel', 'mean', h2=2, w2=2).values
# max-pooling is similar # result is not as smooth as for mean-pooling raw_reduce(ims, '(h h2)=height (w w2)=width -> h (batch w) channel', 'max', h2=2, w2=2).values
# yet another example. Can you compute result shape? raw_reduce(ims, '(b1 b2)=batch -> (b2 height) (b1 width)', 'mean', b1=2).values
We skip the section about numpy-like stacking and concatenating because they aren’t relevant to xarray objects.
We have also reimagined the next section. We are working with xarray objects so adding new axis of length 1 to ensure broadcastability is not relevant either. We have therefore modified the section to a showcase of xarray automatic broadcasting and alignment.
Broadcasting and alignment#
As we are using xarray objects, we can directly operate between original and reduced inputs without
the need for adding new dimensions, be it with
[np.newaxis, :] or with
in einops expressions. See for yourself:
# compute max in each image individually, then show a difference x = raw_reduce(ims, 'batch channel', 'max') - ims raw_rearrange(x, '(batch width) channel').values
coming soon, for now jump to Fancy examples in random order
Third operation we introduce is
# repeat along a new axis. New axis can be placed anywhere repeat(ims, 'h w c -> h new_axis w c', new_axis=5).shape
# shortcut repeat(ims, 'h w c -> h 5 w c').shape
# repeat along w (existing axis) repeat(ims, 'h w c -> h (repeat w) c', repeat=3)
# repeat along two existing axes repeat(ims, 'h w c -> (2 h) (2 w) c')
# order of axes matters as usual - you can repeat each element (pixel) 3 times # by changing order in parenthesis repeat(ims, 'h w c -> h (w repeat) c', repeat=3)
repeat operation covers functionality identical to
numpy.tile and actually more than that.
Reduce ⇆ repeat#
reduce and repeat are like opposite of each other: first one reduces amount of elements, second one increases.
In the following example each image is repeated first, then we reduce over new axis to get back original tensor. Notice that operation patterns are “reverse” of each other
repeated = repeat(ims, 'b h w c -> b h new_axis w c', new_axis=2) reduced = reduce(repeated, 'b h new_axis w c -> b h w c', 'min') assert numpy.array_equal(ims, reduced)
Fancy examples in random order#
(a.k.a. mad designer gallery)
# interweaving pixels of different pictures # all letters are observable raw_rearrange(ims, '(b1 b2)=batch -> (height b1) (width b2) channel ', b1=2).values
# interweaving along vertical for couples of images raw_rearrange(ims, '(b1 b2)=batch -> (height b1) (b2 width) channel', b1=2).values
# interweaving lines for couples of images # exercise: achieve the same result without einops in your favourite framework raw_reduce(ims, '(b1 b2)=batch -> height (b2 width) channel', 'max', b1=2).values
# color can be also composed into dimension # ... while image is downsampled raw_reduce(ims, '(h h2)=height (w w2)=width -> (channel h) (batch w)', 'mean', h2=2, w2=2).values
# disproportionate resize raw_reduce(ims, '(h h4)=height (w w3)=width -> h (batch w)', 'mean', h4=4, w3=3).values
# spilt each image in two halves, compute mean of the two raw_reduce(ims, '(h1 h2)=height -> h2 (batch width)', 'mean', h1=2).values
# split in small patches and transpose each patch raw_rearrange(ims, '(h1 h2)=height (w1 w2)=width -> (h1 w2) (batch w1 h2) channel', h2=8, w2=8).values
# stop me someone! raw_rearrange( ims, '(h1 h2 h3)=height (w1 w2 w3)=width -> (h1 w2 h3) (batch w1 h2 w3) channel', h2=2, w2=2, w3=2, h3=2 ).values
raw_rearrange( ims, '(b1 b2)=batch (h1 h2)=height (w1 w2)=width -> (h1 b1 h2) (w1 b2 w2) channel', h1=3, w1=3, b2=3 ).values
# patterns can be arbitrarily complicated raw_reduce( ims, '(b1 b2)=batch (h1 h2 h3)=height (w1 w2 w3)=width -> (h1 w1 h3) (b1 w2 h2 w3 b2) channel', 'mean', h2=2, w1=2, w3=2, h3=2, b2=2 ).values
# subtract background in each image individually and normalize # pay attention to () - this is composition of 0 axis, a dummy axis with 1 element. im2 = raw_reduce(ims, 'batch channel', 'max') - ims im2 /= raw_reduce(im2, 'batch channel', 'max') raw_rearrange(im2, '(batch width) channel').values
##### no repeat yet #### # pixelate: first downscale by averaging, then upscale back using the same pattern averaged = reduce(ims, 'b (h h2) (w w2) c -> b h w c', 'mean', h2=6, w2=8) repeat(averaged, 'b h w c -> (h h2) (b w w2) c', h2=6, w2=8)
raw_rearrange(ims, '(batch height) channel').values
# let's bring color dimension as part of horizontal axis # at the same time horizontal axis is downsampled by 2x raw_reduce(ims, '(h h2)=height (w w2)=width -> (h w2) (batch w channel)', 'mean', h2=3, w2=3).values
rearrangedoesn’t change number of elements and covers different numpy functions (like
reducecombines same reordering syntax with reductions (
prod, and any others)
repeatadditionally covers repeating and tiling
composition and decomposition of axes are a corner stone, they can and should be used together
%load_ext watermark %watermark -n -u -v -iv -w -p einops,xarray_einstats
Last updated: Sun Jun 19 2022 Python implementation: CPython Python version : 3.9.10 IPython version : 8.4.0 einops : 0.4.1 xarray_einstats: 0.3.0 xarray: 2022.3.0 numpy : 1.22.4 Watermark: 2.3.1