UP | HOME

numpy einsum

1. examples

  • tensor product
  • Given an array of shape [n,d,d] – which is a list of \(n\) \(d\times d\) matrices, find the \((n,n,d,d)\) table of all pairwise matrix products
tensorprod = np.einsum("ijk,lkm->iljm", matrices, matrices).reshape(n,n,d,d) #take a look at the wikipedia to see why this makes sense. Basically, i needs to change the slowest, l the next slowest, and so on. Check to see that this is true in the wikipedia matrix
  • This is equivalent to
matrices.repeat(n, axis=0).reshape((n,n,d,d)) @ matrices
  • sum over last two axis
diff_sum = np.einsum("ijklm->ijk", diff) #sum over the last two axes

2. helpful links

Created: 2024-07-15 Mon 01:28