Wrapper around the pytorch-fid package to compute Frechet InceptionDistance (FID) using PyTorch in-memory given tensors of images.
Project description
pytorch-fid-wrapper
A simple wrapper around @mseitzer's great pytorch-fid work.
The goal is to compute the Fréchet Inception Distance between two sets of images in-memory using PyTorch.
Installation
pip install pytorch-fid-wrapper
Requires (and will install) (as pytorch-fid
):
- Python >= 3.5
- Pillow
- Numpy
- Scipy
- Torch
- Torchvision
Usage
import pytorch_fid_wrapper as pfw
# Optional: set pfw's configuration with your parameters once and for all
pfw.set_config(batch_size=BATCH_SIZE, dims=DIMS, device=DEVICE)
# compute real_m and real_s only once, they will not change during training
real_images = my_validation_data # N x C x H x W tensor
real_m, real_s = pfw.get_stats(real_images)
# get the fake images your model currently generates
fake_images = my_model.compute_fake_images() # N x C x H x W tensor
# compute the fid score
val_fid = pfw.fid(fake_images, real_m, real_s)
# OR
new_real_data = some_other_validation_data # N x C x H x W tensor
val_fid = pfw.fid(fake_images, new_real_data)
Please refer to pytorch-fid for any documentation on the InceptionV3 implementation or FID calculations.
Config
pfw.get_stats(...)
and pfw.fid(...)
need to know what block of the InceptionV3 model to use (dims
), on what device to compute inference (device
) and with what batch size (batch_size
).
Default values are in pfw.params
: batch_size = 50
, dims = 2048
and device = "cpu"
. If you want to override those, you have to options:
1/ override any of these parameters in the function calls. For instance:
pfw.fid(fake_images, new_real_data, device="cuda:0")
2/ override the params globally with pfw.set_config
and set them for all future calls without passing parameters again. For instance:
pfw.set_config(batch_size=100, dims=768, device="cuda:0")
...
pfw.fid(fake_images, new_real_data)
Recognition
Remember to cite their work if using pytorch-fid-wrapper
or pytorch-fid:
@misc{Seitzer2020FID,
author={Maximilian Seitzer},
title={{pytorch-fid: FID Score for PyTorch}},
month={August},
year={2020},
note={Version 0.1.1},
howpublished={\url{https://github.com/mseitzer/pytorch-fid}},
}
License
This implementation is licensed under the Apache License 2.0.
FID was introduced by Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler and Sepp Hochreiter in "GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium", see https://arxiv.org/abs/1706.08500
The original implementation is by the Institute of Bioinformatics, JKU Linz, licensed under the Apache License 2.0. See https://github.com/bioinf-jku/TTUR.
Project details
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
Built Distribution
File details
Details for the file pytorch-fid-wrapper-0.0.2.tar.gz
.
File metadata
- Download URL: pytorch-fid-wrapper-0.0.2.tar.gz
- Upload date:
- Size: 10.1 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.7.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 696139f4ac5a3caa297c1b1641a034266d8e0fd04e953a66977b4c5ae9900919 |
|
MD5 | 4c944f5f2cbab713afe34be533f1af14 |
|
BLAKE2b-256 | cd6813dad97490577354c88d152825cc43453a7bf85e2e2efadd4a74951740c9 |
File details
Details for the file pytorch_fid_wrapper-0.0.2-py3-none-any.whl
.
File metadata
- Download URL: pytorch_fid_wrapper-0.0.2-py3-none-any.whl
- Upload date:
- Size: 14.9 kB
- Tags: Python 3
- Uploaded using Trusted Publishing? No
- Uploaded via: twine/3.2.0 pkginfo/1.5.0.1 requests/2.24.0 setuptools/50.3.2 requests-toolbelt/0.9.1 tqdm/4.48.2 CPython/3.7.4
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 72380d575d1adcb8a7d6e6bfdbef0a33411a099831005d3ea640610725b7d0bb |
|
MD5 | d236ec77d504d69de129099fc1bc90fe |
|
BLAKE2b-256 | 763849c6ee4bef83b767ebe06a96735222e9ceabf05480d2e66045f6c37a71a3 |