Skip to main content

Forecasting timeseries with PyTorch - dataloaders, normalizers, metrics and models

Project description

Pytorch Forecasting aims to ease timeseries forecasting with neural networks. It specificially provides a class to wrap timeseries datasets and a number of PyTorch models.

Installation

If you are working windows, you need to first install PyTorch with

pip install torch -f https://download.pytorch.org/whl/torch_stable.html.

Otherwise, you can proceed with

pip install pytorch-forecasting

Visit the documentation at https://pytorch-forecasting.readthedocs.io.

Available models

Usage

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer

# load data
data = ...

# define dataset
max_encode_length = 36
max_prediction_length = 6
training_cutoff = "YYYY-MM-DD"  # day for cutoff

training = TimeSeriesDataSet(
    data[lambda x: x.date <= training_cutoff],
    time_idx= ...,
    target= ...,
    group_ids=[ ... ],
    max_encode_length=max_encode_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=[ ... ],
    static_reals=[ ... ],
    time_varying_known_categoricals=[ ... ],
    time_varying_known_reals=[ ... ],
    time_varying_unknown_categoricals=[ ... ],
    time_varying_unknown_reals=[ ... ],
)


validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training.index.time.max() + 1, stop_randomization=True)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=2)


early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
lr_logger = LearningRateLogger()
trainer = pl.Trainer(
    max_epochs=100,
    gpus=0,
    gradient_clip_val=0.1,
    early_stop_callback=early_stop_callback,
    limit_train_batches=30,
    callbacks=[lr_logger],
)


tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=32,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=16,
    output_size=7,
    loss=QuantileLoss(),
    log_interval=2,
    reduce_on_plateau_patience=4
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# find optimal learning rate
res = trainer.lr_find(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, early_stop_threshold=1000.0, max_lr=0.3,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

trainer.fit(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader,
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_forecasting-0.3.0.tar.gz (54.2 kB view details)

Uploaded Source

Built Distribution

pytorch_forecasting-0.3.0-py3-none-any.whl (59.8 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_forecasting-0.3.0.tar.gz.

File metadata

  • Download URL: pytorch_forecasting-0.3.0.tar.gz
  • Upload date:
  • Size: 54.2 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.0.10 CPython/3.7.9 Linux/5.4.0-1022-azure

File hashes

Hashes for pytorch_forecasting-0.3.0.tar.gz
Algorithm Hash digest
SHA256 f2a59ca742f6ba8defeab7be792de1a6fcbf470b913ed56dcb9bdb725ea630c7
MD5 da9b8ec67104e2fff058b16f60988fcb
BLAKE2b-256 d11048945ec0228e3403fdfe74888cb172ebdcadddd7aa234bbab90b9535c7af

See more details on using hashes here.

File details

Details for the file pytorch_forecasting-0.3.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_forecasting-0.3.0-py3-none-any.whl
Algorithm Hash digest
SHA256 9f2445d643efc35de17de53bcda328316aa6cabb7558dff278dbeeb6f4146fb3
MD5 36371478692ffa553c961f17ac59612f
BLAKE2b-256 08b05fc55160fcb06daa10fc19c1db6fa082fe2b7d2694a898f086655e6439f8

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page