Skip to main content

Optimizing numpys einsum function

Project description

Einsum is a very powerful function for contracting tensors of arbitrary dimension and index. However, it is only optimized to contract two terms at a time resulting in non-optimal scaling.

For example, consider the following index transformation: M_{pqrs} = C_{pi} C_{qj} I_{ijkl} C_{rk} C_{sl}

Consider two different algorithms:

import numpy as np
N = 10
C = np.random.rand(N, N)
I = np.random.rand(N, N, N, N)

def naive(I, C):
    # N^8 scaling
    return np.einsum('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C)

def optimized(I, C):
    # N^5 scaling
    K = np.einsum('pi,ijkl->pjkl', C, I)
    K = np.einsum('qj,pjkl->pqkl', C, K)
    K = np.einsum('rk,pqkl->pqrl', C, K)
    K = np.einsum('sl,pqrl->pqrs', C, K)
    return K

The einsum function does not consider building intermediate arrays; therefore, helping einsum out by building these intermediate arrays can result in a considerable cost savings even for small N (N=10):

>> np.allclose(naive(I, C), optimized(I, C))
True

%timeit naive(I, C)
1 loops, best of 3: 1.18 s per loop

%timeit optimized(I, C)
1000 loops, best of 3: 612 µs per loop

The index transformation is a well known contraction that leads to straightforward intermediates. This contraction can be further complicated by considering that the shape of the C matrices need not be the same, in this case the ordering in which the indices are transformed matters greatly. Logic can be built that optimizes the ordering; however, this is a lot of time and effort for a single expression.

The opt_einsum package is a drop in replacement for the np.einsum function and can handle all of the logic for you:

from opt_einsum import contract

contract('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C)

The above will automatically find the optimal contraction order, in this case identical to that of the optimized function above, and compute the products for you. In this case, it even uses np.dot under the hood to exploit any vendor BLAS functionality that your NumPy build has!

We can then view more details about the optimized contraction order:

>>> from opt_einsum import contract_path

>>> path_info = oe.contract_path('pi,qj,ijkl,rk,sl->pqrs', C, C, I, C, C)

>>> print(path_info[0])
[(0, 2), (0, 3), (0, 2), (0, 1)]

>>> print(path_info[1])
  Complete contraction:  pi,qj,ijkl,rk,sl->pqrs
         Naive scaling:  8
     Optimized scaling:  5
      Naive FLOP count:  8.000e+08
  Optimized FLOP count:  8.000e+05
   Theoretical speedup:  1000.000
  Largest intermediate:  1.000e+04 elements
--------------------------------------------------------------------------------
scaling   BLAS                  current                                remaining
--------------------------------------------------------------------------------
   5      GEMM            ijkl,pi->jklp                      qj,rk,sl,jklp->pqrs
   5      GEMM            jklp,qj->klpq                         rk,sl,klpq->pqrs
   5      GEMM            klpq,rk->lpqr                            sl,lpqr->pqrs
   5      GEMM            lpqr,sl->pqrs                               pqrs->pqrs

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

opt_einsum-2.0.0.tar.gz (26.8 kB view details)

Uploaded Source

Built Distribution

opt_einsum-2.0.0-py2.py3-none-any.whl (29.5 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file opt_einsum-2.0.0.tar.gz.

File metadata

  • Download URL: opt_einsum-2.0.0.tar.gz
  • Upload date:
  • Size: 26.8 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No

File hashes

Hashes for opt_einsum-2.0.0.tar.gz
Algorithm Hash digest
SHA256 b1a67039dae6a7d2aa47a522ac19bef381fff53f9b243a50af3827ed9539bfcc
MD5 b15085d33d80be8520da6eec1197fa0e
BLAKE2b-256 dea2ae05b691dbd14e2a9ac146ec8a43118f4f681cab09d60b528c3b9bb67b68

See more details on using hashes here.

File details

Details for the file opt_einsum-2.0.0-py2.py3-none-any.whl.

File metadata

File hashes

Hashes for opt_einsum-2.0.0-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 d991589c4cf82c3581b8c44a33edefe7fc1d245055cf9f3ffbb2299bb29f49dd
MD5 0c37f770823426d821e15c183b72d333
BLAKE2b-256 6bc8e4c0cd49d054073b704de1d9827386570183b8e91b619b0d7188acf6fbf6

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page