Skip to main content

Tools for loading, augmenting and writing 3D medical images on PyTorch.

Project description

TorchIO

Downloads PyPI version Google Colab DOI Build status Documentation status Coverage status Code quality Slack

torchio is a Python package containing a set of tools to efficiently read, sample and write 3D medical images in deep learning applications written in PyTorch, including intensity and spatial transforms for data augmentation and preprocessing. Transforms include typical computer vision operations such as random affine transformations and also domain-specific ones such as simulation of intensity artifacts due to MRI magnetic field inhomogeneity or k-space motion artifacts.

This package has been greatly inspired by NiftyNet.

Jupyter notebook

The best way to quickly understand and try the library is the Jupyter notebook hosted by Google Colab. It includes many examples and visualization of most of the classes and even training of a 3D U-Net for brain segmentation of T1-weighted MRI with whole images and patch-based sampling.

Credits

If you like this repository, please click on Star!

If you used this package for your research, please cite this repository using the information available on its Zenodo entry or use this text:

Pérez-García, Fernando. (2020, January 15). fepegar/torchio: TorchIO: Tools for loading, augmenting and writing 3D medical images on PyTorch. Zenodo. http://doi.org/10.5281/zenodo.3598622

BibTeX entry:

@software{perez_garcia_fernando_2020_3598622,
  author       = {Pérez-García, Fernando},
  title        = {{fepegar/torchio: TorchIO: Tools for loading,
                   augmenting and writing 3D medical images on
                   PyTorch}},
  month        = jan,
  year         = 2020,
  publisher    = {Zenodo},
  doi          = {10.5281/zenodo.3598622},
  url          = {https://doi.org/10.5281/zenodo.3598622}
}

Installation

This package is on the Python Package Index (PyPI). To install the latest published version, just run the following command in a terminal:

$ pip install --upgrade torchio

Index

Features

Medical image datasets

IXI

The Information eXtraction from Images (IXI) dataset contains "nearly 600 MR images from normal, healthy subjects", including "T1, T2 and PD-weighted images, MRA images and Diffusion-weighted images (15 directions)".

The usage is very similar to torchvision.datasets:

import torchio
import torchvision

transforms = [
    torchio.ToCanonical(),  # to RAS
    torchio.Resample((1, 1, 1)),  # to 1 mm iso
]

ixi_dataset = torchio.datasets.IXI(
    'path/to/ixi_root/',
    modalities=('T1', 'T2'),
    transform=torchvision.transforms.Compose(transforms),
    download=True,
)
print('Number of subjects in dataset:', len(ixi_dataset))  # 577

sample_subject = ixi_dataset[0]
print('Keys in subject sample:', tuple(sample_subject.keys()))  # ('T1', 'T2')
print('Shape of T1 data:', sample_subject['T1'][torchio.DATA].shape)  # [1, 180, 268, 268]
print('Shape of T2 data:', sample_subject['T2'][torchio.DATA].shape)  # [1, 241, 257, 188]

Tiny IXI

This is the dataset used in the notebook. It is a tiny version of IXI, containing 566 T1-weighted brain MR images and their corresponding brain segmentations, all with size (83 x 44 x 55).

Data handling

ImagesDataset

ImagesDataset is a reader of 3D medical images that directly inherits from torch.utils.Dataset. It can be used with a torch.utils.DataLoader for efficient loading and data augmentation.

It receives a list of subjects, where each subject is an instance of torchio.Subject containing instances of torchio.Image. The file format must be compatible with NiBabel or SimpleITK readers. It can also be a directory containing DICOM files.

import torchio
from torchio import ImagesDataset, Image, Subject

subject_a = Subject([
    Image('t1', '~/Dropbox/MRI/t1.nrrd', torchio.INTENSITY),
    Image('label', '~/Dropbox/MRI/t1_seg.nii.gz', torchio.LABEL),
])
subject_b = Subject(
    Image('t1', '/tmp/colin27_t1_tal_lin.nii.gz', torchio.INTENSITY),
    Image('t2', '/tmp/colin27_t2_tal_lin.nii', torchio.INTENSITY),
    Image('label', '/tmp/colin27_seg1.nii.gz', torchio.LABEL),
)
subjects_list = [subject_a, subject_b]
subjects_dataset = ImagesDataset(subjects_list)
subject_sample = subjects_dataset[0]

Samplers and aggregators

TorchIO includes grid, uniform and label patch samplers. There is also an aggregator used for dense predictions. For more information about patch-based training, see NiftyNet docs.

import torch
import torch.nn as nn
import torchio

CHANNELS_DIMENSION = 1
patch_overlap = 4
patch_size = 128

grid_sampler = torchio.inference.GridSampler(
    input_data,  # some PyTorch tensor or NumPy array
    patch_size,
    patch_overlap,
)
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)
aggregator = torchio.inference.GridAggregator(
    input_data,  # some PyTorch tensor or NumPy array
    patch_overlap,
)

model = nn.Module()
model.to(device)
model.eval()
with torch.no_grad():
    for patches_batch in patch_loader:
        input_tensor = patches_batch['image'].to(device)
        locations = patches_batch['location']
        logits = model(input_tensor)
        labels = logits.argmax(dim=CHANNELS_DIMENSION, keepdim=True)
        outputs = labels
        aggregator.add_batch(outputs, locations)

output_tensor = aggregator.get_output_tensor()

Queue

A patches Queue (or buffer) can be used for randomized patch-based sampling during training. This interactive animation can be used to understand how the queue works.

import torch
import torchio

patches_queue = torchio.Queue(
    subjects_dataset=subjects_dataset,  # instance of torchio.ImagesDataset
    max_length=300,
    samples_per_volume=10,
    patch_size=96,
    sampler_class=torchio.sampler.ImageSampler,
    num_workers=4,
    shuffle_subjects=True,
    shuffle_patches=True,
)
patches_loader = DataLoader(patches_queue, batch_size=4)

num_epochs = 20
for epoch_index in range(num_epochs):
    for patches_batch in patches_loader:
        logits = model(patches_batch)  # model is some torch.nn.Module

Transforms

The transforms module should remind users of torchvision.transforms. TorchIO transforms take as input samples generated by an ImagesDataset.

A transform can be quickly applied to an image file using the command-line tool torchio-transform:

$ torchio-transform input.nii.gz RandomMotion output.nii.gz --kwargs "proportion_to_augment=1 num_transforms=4"

Augmentation

Intensity
MRI k-space motion artifacts

Magnetic resonance images suffer from motion artifacts when the subject moves during image acquisition. This transform follows Shaw et al., 2019 to simulate motion artifacts for data augmentation.

MRI k-space motion artifacts

MRI k-space ghosting artifacts

Discrete "ghost" artifacts may occur along the phase-encode direction whenever the position or signal intensity of imaged structures within the field-of-view vary or move in a regular (periodic) fashion. Pulsatile flow of blood or CSF, cardiac motion, and respiratory motion are the most important patient-related causes of ghost artifacts in clinical MR imaging (From mriquestions.com).

MRI k-space ghosting artifacts

MRI k-space spike artifacts

Also known as Herringbone artifact, crisscross artifact or corduroy artifact, it creates stripes in different directions in image space due to spikes in k-space.

MRI k-space spike artifacts

MRI magnetic field inhomogeneity

MRI magnetic field inhomogeneity creates slow frequency intensity variations. This transform is very similar to the one in NiftyNet.

MRI bias field artifacts

Patch swap

Randomly swaps patches in the image. This is typically used in context restoration for self-supervised learning.

Random patches swapping

Gaussian noise

Adds noise sampled from a normal distribution with mean 0 and standard deviation sampled from a uniform distribution in the range std_range. It is often used after ZNormalization, as the output of this transform has zero-mean.

Random Gaussian noise

Gaussian blurring

Blurs the image using a discrete Gaussian image filter.

Spatial
B-spline dense elastic deformation

Random elastic deformation

Flip

Reverse the order of elements in an image along the given axes.

Affine transform

Random affine transformation of the image keeping center invariant.

Preprocessing

Histogram standardization

Implementation of New variants of a method of MRI scale standardization adapted from NiftyNet.

Histogram standardization

Rescale

Rescale intensity values in an image to a certain range.

Z-normalization

This transform first extracts the values with intensity greater than the mean, which is an approximation of the foreground voxels. Then the foreground mean is subtracted from the image and it is divided by the foreground standard deviation.

Resample

Resample images to a new voxel spacing using nibabel.

Pad

Pad images, like in torchvision.transforms.Pad.

Crop

Crop images passing 1, 3, or 6 integers, as in Pad.

ToCanonical

Reorder the data so that it is closest to canonical NIfTI (RAS+) orientation.

CenterCropOrPad

Crops or pads image center to a target size, modifying the affine accordingly.

Others

Lambda

Applies a user-defined function as transform. For example, image intensity can be inverted with Lambda(lambda x: -x, types_to_apply=[torchio.INTENSITY]) and a mask can be negated with Lambda(lambda x: 1 - x, types_to_apply=[torchio.LABEL]).

Example

This example shows the improvement in performance when multiple workers are used to load and preprocess the volumes using multiple workers.

import time
import multiprocessing as mp

from tqdm import trange

import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from torchio import ImagesDataset, Queue, DATA
from torchio.data.sampler import ImageSampler
from torchio.utils import create_dummy_dataset
from torchio.transforms import (
    ZNormalization,
    RandomNoise,
    RandomFlip,
    RandomAffine,
)


# Define training and patches sampling parameters
num_epochs = 4
patch_size = 128
queue_length = 400
samples_per_volume = 10
batch_size = 4

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv3d(
            in_channels=1,
            out_channels=3,
            kernel_size=3,
        )
    def forward(self, x):
        return self.conv(x)

model = Network()

# Create a dummy dataset in the temporary directory, for this example
subjects_list = create_dummy_dataset(
    num_images=100,
    size_range=(193, 229),
    force=False,
)

# Each element of subjects_list is an instance of torchio.Subject:
# subject = Subject(
#     torchio.Image('one_image', path_to_one_image, torchio.INTENSITY),
#     torchio.Image('another_image', path_to_another_image, torchio.INTENSITY),
#     torchio.Image('a_label', path_to_a_label, torchio.LABEL),
# )

# Define transforms for data normalization and augmentation
transforms = (
    ZNormalization(),
    RandomNoise(std_range=(0, 0.25)),
    RandomAffine(scales=(0.9, 1.1), degrees=10),
    RandomFlip(axes=(0,)),
)
transform = Compose(transforms)
subjects_dataset = ImagesDataset(subjects_list, transform)


# Run a benchmark for different numbers of workers
workers = range(mp.cpu_count() + 1)
for num_workers in workers:
    print('Number of workers:', num_workers)

    # Define the dataset as a queue of patches
    queue_dataset = Queue(
        subjects_dataset,
        queue_length,
        samples_per_volume,
        patch_size,
        ImageSampler,
        num_workers=num_workers,
    )
    batch_loader = DataLoader(queue_dataset, batch_size=batch_size)

    start = time.time()
    for epoch_index in trange(num_epochs, leave=False):
        for batch in batch_loader:
            # The keys of batch have been defined in create_dummy_dataset()
            inputs = batch['one_modality'][DATA]
            targets = batch['segmentation'][DATA]
            logits = model(inputs)
    print('Time:', int(time.time() - start), 'seconds')
    print()

Output:

Number of workers: 0
Time: 394 seconds

Number of workers: 1
Time: 372 seconds

Number of workers: 2
Time: 278 seconds

Number of workers: 3
Time: 259 seconds

Number of workers: 4
Time: 242 seconds

Related projects

See also

======= History

0.13.0 (24-02-2020)

  • Add Subject class
  • Add random blur transform
  • Add lambda transform
  • Add random patches swapping transform
  • Add MRI k-space ghosting artefact augmentation

0.12.0 (21-01-2020)

  • Add ToCanonical transform
  • Add CenterCropOrPad transform

0.11.0 (15-01-2020)

  • Add Resample transform

0.10.0 (15-01-2020)

  • Add Pad transform
  • Add Crop transform

0.9.0 (14-01-2020)

  • Add CLI tool to transform an image from file

0.8.0 (11-01-2020)

  • Add Image class

0.7.0 (02-01-2020)

  • Make transforms use PyTorch tensors consistently

0.6.0 (02-01-2020)

  • Add support for NRRD

0.5.0 (01-01-2020)

  • Add bias field transform

0.4.0 (29-12-2019)

  • Add MRI k-space motion artefact augmentation

0.3.0 (21-12-2019)

  • Add Rescale transform
  • Add support for multimodal data and missing modalities

0.2.0 (2019-12-06)

  • First release on PyPI.

Project details


Release history Release notifications | RSS feed

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchio-0.13.11.tar.gz (56.9 kB view details)

Uploaded Source

Built Distribution

torchio-0.13.11-py2.py3-none-any.whl (58.2 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file torchio-0.13.11.tar.gz.

File metadata

  • Download URL: torchio-0.13.11.tar.gz
  • Upload date:
  • Size: 56.9 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.0.0 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.1

File hashes

Hashes for torchio-0.13.11.tar.gz
Algorithm Hash digest
SHA256 fb038c2cc8c12d3edc772c5da93c387d7e2ddd9dec1240cd8c81c18ced82014c
MD5 b1308e4cb6be8b7867265b33e1c37173
BLAKE2b-256 43e28fca0ac7025d03620d910d67073c62a7ec44da90ba4e415e7b9151f4a1a6

See more details on using hashes here.

File details

Details for the file torchio-0.13.11-py2.py3-none-any.whl.

File metadata

  • Download URL: torchio-0.13.11-py2.py3-none-any.whl
  • Upload date:
  • Size: 58.2 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.23.0 setuptools/46.0.0 requests-toolbelt/0.9.1 tqdm/4.43.0 CPython/3.7.1

File hashes

Hashes for torchio-0.13.11-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 335981ddebbe18398470c990605cd12de70308a3b4b80015a4d243095a59c0de
MD5 cce1cf4dc3a58de2cd11afe1be3bf088
BLAKE2b-256 a2485981b5ade34272baee059450ff82ce2bb54fe55809b5a434e8f75bc6ae73

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page