Skip to main content

Wrapper around the pytorch-fid package to compute Frechet InceptionDistance (FID) using PyTorch in-memory given tensors of images.

Project description

pytorch-fid-wrapper

A simple wrapper around @mseitzer's great pytorch-fid work.

The goal is to compute the Fréchet Inception Distance between two sets of images in-memory using PyTorch.

Installation

PyPI

pip install pytorch-fid-wrapper

Requires (and will install) (as pytorch-fid):

  • Python >= 3.5
  • Pillow
  • Numpy
  • Scipy
  • Torch
  • Torchvision

Usage

import  pytorch_fid_wrapper as pfw

# Optional: set pfw's configuration with your parameters once and for all
pfw.set_config(batch_size=BATCH_SIZE, dims=DIMS, device=DEVICE)

# compute real_m and real_s only once, they will not change during training
real_images = my_validation_data # N x C x H x W tensor
real_m, real_s = pfw.get_stats(real_images)

# get the fake images your model currently generates
fake_images = my_model.compute_fake_images() # N x C x H x W tensor

# compute the fid score
val_fid = pfw.fid(fake_images, real_m, real_s)
# OR
new_real_data = some_other_validation_data # N x C x H x W tensor
val_fid = pfw.fid(fake_images, new_real_data)

Please refer to pytorch-fid for any documentation on the InceptionV3 implementation or FID calculations.

Config

pfw.get_stats(...) and pfw.fid(...) need to know what block of the InceptionV3 model to use (dims), on what device to compute inference (device) and with what batch size (batch_size).

Default values are in pfw.params: batch_size = 50, dims = 2048 and device = "cpu". If you want to override those, you have to options:

1/ override any of these parameters in the function calls. For instance:

pfw.fid(fake_images, new_real_data, device="cuda:0")

2/ override the params globally with pfw.set_config and set them for all future calls without passing parameters again. For instance:

pfw.set_config(batch_size=100, dims=768, device="cuda:0")
...
pfw.fid(fake_images, new_real_data)

Recognition

Remember to cite their work if using pytorch-fid-wrapper or pytorch-fid:

@misc{Seitzer2020FID,
  author={Maximilian Seitzer},
  title={{pytorch-fid: FID Score for PyTorch}},
  month={August},
  year={2020},
  note={Version 0.1.1},
  howpublished={\url{https://github.com/mseitzer/pytorch-fid}},
}

License

This implementation is licensed under the Apache License 2.0.

FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see https://arxiv.org/abs/1706.08500

The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. See https://github.com/bioinf-jku/TTUR.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch-fid-wrapper-0.0.3.tar.gz (10.9 kB view details)

Uploaded Source

Built Distribution

pytorch_fid_wrapper-0.0.3-py3-none-any.whl (15.2 kB view details)

Uploaded Python 3

File details

Details for the file pytorch-fid-wrapper-0.0.3.tar.gz.

File metadata

  • Download URL: pytorch-fid-wrapper-0.0.3.tar.gz
  • Upload date:
  • Size: 10.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.7.4

File hashes

Hashes for pytorch-fid-wrapper-0.0.3.tar.gz
Algorithm Hash digest
SHA256 dc87d9f73a9608712c96dc9aed2cc776aab449c71ef7de4240749c9ad8453787
MD5 020b5e7bd4453a8932d4096069b29727
BLAKE2b-256 a667f38b5093bf3eb58307febbf3f87891d46358ba07a9628ca54d8f9f18a910

See more details on using hashes here.

File details

Details for the file pytorch_fid_wrapper-0.0.3-py3-none-any.whl.

File metadata

  • Download URL: pytorch_fid_wrapper-0.0.3-py3-none-any.whl
  • Upload date:
  • Size: 15.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.7.4

File hashes

Hashes for pytorch_fid_wrapper-0.0.3-py3-none-any.whl
Algorithm Hash digest
SHA256 456fe140b7ac0614bf95b7f5eda3f4afeb363e28a58ff771121eaad3844c766d
MD5 25cc5887e452cd8ce70ba3c539572f1d
BLAKE2b-256 c4a1d2178de9405f275b77e9308240bc57a6d7e5acb162c477b3dca7b4d36344

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page