Skip to main content

TorchGeo: datasets, transforms, and models for geospatial data

Project description

TorchGeo

TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

The goal of this library is to make it simple:

  1. for machine learning experts to use geospatial data in their workflows, and
  2. for remote sensing experts to use their data in machine learning workflows.

See our installation instructions, documentation, and examples to learn how to use TorchGeo.

External links: docs codecov pypi conda spack

Tests: style tests

Installation

The recommended way to install TorchGeo is with pip:

$ pip install torchgeo

For conda and spack installation instructions, see the documentation.

Documentation

You can find the documentation for TorchGeo on ReadTheDocs.

Example Usage

The following sections give basic examples of what you can do with TorchGeo. For more examples, check out our tutorials.

First we'll import various classes and functions used in the following sections:

from torch.utils.data import DataLoader
from torchgeo.datasets import CDL, COWCDetection, Landsat7, Landsat8, stack_samples
from torchgeo.samplers import RandomGeoSampler

Benchmark datasets

TorchGeo includes a number of benchmark datasets, datasets that include both input images and target labels. This includes datasets for tasks like image classification, regression, semantic segmentation, object detection, instance segmentation, change detection, and more.

If you've used torchvision before, these datasets should seem very familiar. In this example, we'll create a dataset for the Cars Overhead With Context (COWC) car detection dataset. This dataset can be automatically downloaded, checksummed, and extracted, just like with torchvision.

dataset = COWCDetection(root="...", split="train", download=True, checksum=True)

This dataset can then be passed to a PyTorch data loader.

dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4)

The only difference between a benchmark dataset in TorchGeo and a similar dataset in torchvision is that each dataset returns a dictionary with keys for each PyTorch Tensor.

for batch in dataloader:
    image = batch["image"]
    label = batch["label"]

    # train a model, or make predictions using a pre-trained model

Geospatial datasets

Many remote sensing applications involve working with generic geospatial data. This data can be challenging to work with due to the sheer variety of data. Geospatial imagery is often multispectral with a different number of spectral bands and spatial resolution for every satellite. In addition, each file may be in a different coordinate reference system (CRS), requiring the data to be reprojected into a matching CRS.

In this example, we show how easy it is to work with geospatial data and to sample small image patches from a combination of Landsat and Cropland Data Layer (CDL) data using TorchGeo. First, we assume that the user has Landsat 7 and 8 imagery downloaded. Since Landsat 8 has more spectral bands than Landsat 7, we'll only use the bands that both satellites have in common. We'll create a single dataset including all images from both Landsat 7 and 8 data by taking the union between these two datasets.

landsat7 = Landsat7(root="...")
landsat8 = Landsat8(root="...", bands=["B2", "B3", "B4", "B5", "B6", "B7", "B8", "B9"])
landsat = landsat7 | landsat8

Next, we take the intersection between this dataset and the Cropland Data Layer (CDL) dataset. We want to take the intersection instead of the union to ensure that we only sample from regions that have both Landsat and CDL data. Note that we can automatically download and checksum CDL data. Also note that each of these datasets may contain files in different coordinate reference systems (CRS) or resolutions, but TorchGeo automatically ensures that a matching CRS and resolution is used.

cdl = CDL(root="...", download=True, checksum=True)
dataset = landsat & cdl

This dataset can now be used with a PyTorch data loader. Unlike benchmark datasets, geospatial datasets often include very large images. For example, the CDL dataset consists of a single image covering the entire continental United States. In order to sample from these datasets using geospatial coordinates, TorchGeo defines a number of samplers. In this example, we'll use a random sampler that returns 256x256 pixel images and an epoch length of 10,000 images. We also use a custom collation function to combine each sample dictionary into a mini-batch of samples.

sampler = RandomGeoSampler(dataset, size=256, length=10000)
dataloader = DataLoader(dataset, batch_size=128, sampler=sampler, collate_fn=stack_samples)

This data loader can now be used in your normal training/evaluation pipeline.

for batch in dataloader:
    image = batch["image"]
    mask = batch["mask"]

    # train a model, or make predictions using a pre-trained model

Train and test models using our PyTorch Lightning-based training script

We provide a script, train.py for training models using a subset of the datasets. We do this with the PyTorch Lightning LightningModules and LightningDataModules implemented under the torchgeo.trainers namespace. The train.py script is configurable via the command line and/or via YAML configuration files. See the conf/ directory for example configuration files that can be customized for different training runs.

$ python train.py config_file=conf/landcoverai.yaml

Citation

If you use this software in your work, please cite our paper:

@article{Stewart_TorchGeo_deep_learning_2021,
    author = {Stewart, Adam J. and Robinson, Caleb and Corley, Isaac A. and Ortiz, Anthony and Lavista Ferres, Juan M. and Banerjee, Arindam},
    journal = {arXiv preprint arXiv:2111.08872},
    month = {11},
    title = {{TorchGeo: deep learning with geospatial data}},
    url = {https://github.com/microsoft/torchgeo},
    year = {2021}
}

Contributing

This project welcomes contributions and suggestions. If you would like to submit a pull request, see our Contribution Guide for more information.

This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torchgeo-0.2.0.tar.gz (129.2 kB view details)

Uploaded Source

Built Distribution

torchgeo-0.2.0-py3-none-any.whl (205.2 kB view details)

Uploaded Python 3

File details

Details for the file torchgeo-0.2.0.tar.gz.

File metadata

  • Download URL: torchgeo-0.2.0.tar.gz
  • Upload date:
  • Size: 129.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.9.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for torchgeo-0.2.0.tar.gz
Algorithm Hash digest
SHA256 968c4bf68c7e487bf495f2f306d8bb0f5824eb67e24b26772a510e753e04ba4c
MD5 1aeb1b8a5d82aef3a9e249a8c4ef3413
BLAKE2b-256 a4b9408b3935368b19fc9fa070ba9a54a3c706db81f0f7c4fa55051b30cf5383

See more details on using hashes here.

File details

Details for the file torchgeo-0.2.0-py3-none-any.whl.

File metadata

  • Download URL: torchgeo-0.2.0-py3-none-any.whl
  • Upload date:
  • Size: 205.2 kB
  • Tags: Python 3
  • Uploaded using Trusted Publishing? No
  • Uploaded via: twine/3.7.1 importlib_metadata/4.9.0 pkginfo/1.8.2 requests/2.26.0 requests-toolbelt/0.9.1 tqdm/4.62.3 CPython/3.9.7

File hashes

Hashes for torchgeo-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 31562339a20cb2f7ee36c2e8aa789bc960a167df6191a434aad317fc06209178
MD5 321e119db124f48cd9f39bac26bcd3fc
BLAKE2b-256 7cf02c5c94cb49a06a52387a81789fe95239b1771339f5c9f443aa1cd21696a8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page