Skip to main content

Tools for loading, augmenting and writing 3D medical images on PyTorch.

Project description

TorchIO

PyPI version DOI

torchio is a Python package containing a set of tools to efficiently read, sample and write 3D medical images in deep learning applications written in PyTorch, including intensity and spatial transforms for data augmentation and preprocessing. Transforms include typical computer vision operations such as random affine transformations and also domain specific ones such as simulation of intensity artifacts due to MRI magnetic field inhomogeneity or k-space motion artifacts.

This package has been greatly inspired by NiftyNet.

Index

Installation

$ pip install torchio

Features

Data handling

ImagesDataset

ImagesDataset is a reader of medical images that directly inherits from torch.utils.Dataset. It can be used with a torch.utils.DataLoader for efficient reading and data augmentation.

It receives a list of subjects, where each subject is composed of a list of torchio.Image instances. The paths suffix must be .nii, .nii.gz or .nrrd.

import torchio

subject_a = [
    Image('t1', '~/Dropbox/MRI/t1.nii.gz', torchio.INTENSITY),
    Image('label', '~/Dropbox/MRI/t1_seg.nii.gz', torchio.LABEL),
]
subject_b = [
    Image('t1', '/tmp/colin27_t1_tal_lin.nii.gz', torchio.INTENSITY),
    Image('label', '/tmp/colin27_seg1.nii.gz', torchio.LABEL),
]
subjects_list = [subject_a, subject_b]
subjects_dataset = torchio.ImagesDataset(subjects_list)
subject_sample = subjects_dataset[0]

Samplers

torchio includes grid, uniform and label patch samplers. There is also an aggregator used for dense predictions. For more information about patch-based training, see NiftyNet docs.

import torch
import torchio

CHANNELS_DIMENSION = 1
patch_overlap = 4
grid_sampler = torchio.inference.GridSampler(
    input_array,  # some NumPy array
    patch_size=128,
    patch_overlap=patch_overlap,
)
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)
aggregator = torchio.inference.GridAggregator(
    input_array,
    patch_overlap=patch_overlap,
)

with torch.no_grad():
    for patches_batch in patch_loader:
        input_tensor = patches_batch['image']
        locations = patches_batch['location']
        logits = model(input_tensor)  # some torch.nn.Module
        labels = logits.argmax(dim=CHANNELS_DIMENSION, keepdim=True)
        outputs = labels
        aggregator.add_batch(outputs, locations)

output_array = aggregator.output_array

Queue

A patches Queue (or buffer) can be used for randomized patch-based sampling during training. This interactive animation can be used to understand how the queue works.

import torch
import torchio

patches_queue = torchio.Queue(
    subjects_dataset=subjects_dataset,  # instance of torchio.ImagesDataset
    queue_length=300,
    samples_per_volume=10,
    patch_size=96,
    sampler_class=torchio.sampler.ImageSampler,
    num_workers=4,
    shuffle_subjects=True,
    shuffle_patches=True,
)
patches_loader = DataLoader(patches_queue, batch_size=4)

num_epochs = 20
for epoch_index in range(num_epochs):
    for patches_batch in patches_loader:
        logits = model(patches_batch)  # model is some torch.nn.Module

Transforms

The transforms package should remind users of torchvision.transforms. They take as input the samples generated by an ImagesDataset.

Intensity

MRI k-space motion artifacts

Magnetic resonance images suffer from motion artifacts when the subject moves during image acquisition. This transform follows Shaw et al., 2019 to simulate motion artifacts for data augmentation.

MRI k-space motion artifacts

MRI magnetic field inhomogeneity

MRI magnetic field inhomogeneity creates slow frequency intensity variations. This transform is very similar to the one in NiftyNet.

MRI bias field artifacts

Gaussian noise

Adds noise sampled from a normal distribution with mean 0 and standard deviation sampled from a uniform distribution in the range std_range. It is often used after ZNormalization, as the output of this transform has zero-mean.

Random Gaussian noise

Normalization
Histogram standardization

Implementation of New variants of a method of MRI scale standardization adapted from NiftyNet.

Histogram standardization

Z-normalization

This transform first extracts the values with intensity greater than the mean, which is an approximation of the foreground voxels. Then the foreground mean is subtracted from the image and it is divided by the foreground standard deviation.

Z-normalization

Rescale

Spatial

Flip

Reverse the order of elements in an image along the given axes.

Affine transform
B-spline dense elastic deformation

Random elastic deformation

Example

This example shows the improvement in performance when multiple workers are used to load and preprocess the volumes using multiple workers.

import time
import multiprocessing as mp

from tqdm import trange

import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.transforms import Compose

from torchio import ImagesDataset, Queue
from torchio.sampler import ImageSampler
from torchio.utils import create_dummy_dataset
from torchio.transforms import (
    ZNormalization,
    RandomNoise,
    RandomFlip,
    RandomAffine,
)


# Define training and patches sampling parameters
num_epochs = 4
patch_size = 128
queue_length = 400
samples_per_volume = 10
batch_size = 4

class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv3d(
            in_channels=1,
            out_channels=3,
            kernel_size=3,
        )
    def forward(self, x):
        return self.conv(x)

model = Network()

# Create a dummy dataset in the temporary directory, for this example
subjects_list = create_dummy_dataset(
    num_images=100,
    size_range=(193, 229),
    force=False,
)

# Each element of subjects_list is a dictionary:
# subject_images = [
#     torchio.Image('one_image', path_to_one_image, torchio.INTENSITY),
#     torchio.Image('another_image', path_to_another_image, torchio.INTENSITY),
#     torchio.Image('a_label', path_to_a_label, torchio.LABEL),
# ]

# Define transforms for data normalization and augmentation
transforms = (
    ZNormalization(),
    RandomNoise(std_range=(0, 0.25)),
    RandomAffine(scales=(0.9, 1.1), degrees=10),
    RandomFlip(axes=(0,)),
)
transform = Compose(transforms)
subjects_dataset = ImagesDataset(subjects_list, transform)


# Run a benchmark for different numbers of workers
workers = range(mp.cpu_count() + 1)
for num_workers in workers:
    print('Number of workers:', num_workers)

    # Define the dataset as a queue of patches
    queue_dataset = Queue(
        subjects_dataset,
        queue_length,
        samples_per_volume,
        patch_size,
        ImageSampler,
        num_workers=num_workers,
    )
    batch_loader = DataLoader(queue_dataset, batch_size=batch_size)

    start = time.time()
    for epoch_index in trange(num_epochs, leave=False):
        for batch in batch_loader:
            # The keys of batch have been defined in create_dummy_dataset()
            inputs = batch['one_modality']['data']
            targets = batch['segmentation']['data']
            logits = model(inputs)
    print('Time:', int(time.time() - start), 'seconds')
    print()

Output:

Number of workers: 0
Time: 394 seconds

Number of workers: 1
Time: 372 seconds

Number of workers: 2
Time: 278 seconds

Number of workers: 3
Time: 259 seconds

Number of workers: 4
Time: 242 seconds

======= History

0.2.0 (2019-12-06)

  • First release on PyPI.

0.3.0 (21-12-2019)

  • Add Rescale transform
  • Add support for multimodal data and missing modalities

0.4.0 (29-12-2019)

  • Add MRI k-space motion artefact augmentation

0.5.0 (01-01-2020)

  • Add bias field transform

0.6.0 (02-01-2020)

  • Add support for NRRD

0.7.0 (02-01-2020)

  • Make transforms use PyTorch tensors consistently

Project details


Release history Release notifications | RSS feed

This version

0.8.2

Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchio-0.8.2.tar.gz (32.3 kB view details)

Uploaded Source

Built Distribution

torchio-0.8.2-py2.py3-none-any.whl (30.1 kB view details)

Uploaded Python 2 Python 3

File details

Details for the file torchio-0.8.2.tar.gz.

File metadata

  • Download URL: torchio-0.8.2.tar.gz
  • Upload date:
  • Size: 32.3 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/44.0.0 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.1

File hashes

Hashes for torchio-0.8.2.tar.gz
Algorithm Hash digest
SHA256 4ae96f13c3c9f717683629bc1daa09d46aa591857491f5e6d676c549aad20a7a
MD5 a92826f3b869b9d9b8719fa0f7fd8861
BLAKE2b-256 d058a3865f7e020697acf46e1174b6e5b3b0aedfcf499bc65ce83fe763ed7a07

See more details on using hashes here.

File details

Details for the file torchio-0.8.2-py2.py3-none-any.whl.

File metadata

  • Download URL: torchio-0.8.2-py2.py3-none-any.whl
  • Upload date:
  • Size: 30.1 kB
  • Tags: Python 2, Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.1.1 pkginfo/1.5.0.1 requests/2.22.0 setuptools/44.0.0 requests-toolbelt/0.9.1 tqdm/4.41.1 CPython/3.7.1

File hashes

Hashes for torchio-0.8.2-py2.py3-none-any.whl
Algorithm Hash digest
SHA256 c16e06c706998d91699fa1e660951ebf456a057ccc19bf62f8725a9f068c642b
MD5 bb34f0f120fc6f887a67465ef82cb6ef
BLAKE2b-256 5b6a07557e4d225cb9f9e1997241d5ed6df6b66ae8297eb925e4fafabb84384d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page