Skip to main content

Temporal fusion transformer for timeseries forecasting

Project description

Pytorch Forecasting aims to ease timeseries forecasting with neural networks. It specificially provides a class to wrap timeseries datasets and a number of PyTorch models.

Installation

If you are working windows, you need to first install PyTorch with

pip install torch -f https://download.pytorch.org/whl/torch_stable.html.

Otherwise, you can proceed with

pip install pytorch-forecasting

Visit the documentation at https://pytorch-forecasting.readthedocs.io.

Available models

Usage

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer

# load data
data = ...

# define dataset
max_encode_length = 36
max_prediction_length = 6
training_cutoff = "YYYY-MM-DD"  # day for cutoff

training = TimeSeriesDataSet(
    data[lambda x: x.date <= training_cutoff],
    time_idx= ...,
    target= ...,
    group_ids=[ ... ],
    max_encode_length=max_encode_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=[ ... ],
    static_reals=[ ... ],
    time_varying_known_categoricals=[ ... ],
    time_varying_known_reals=[ ... ],
    time_varying_unknown_categoricals=[ ... ],
    time_varying_unknown_reals=[ ... ],
)


validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training.index.time.max() + 1, stop_randomization=True)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=2)


early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
lr_logger = LearningRateLogger()
trainer = pl.Trainer(
    max_epochs=100,
    gpus=0,
    gradient_clip_val=0.1,
    early_stop_callback=early_stop_callback,
    limit_train_batches=30,
    callbacks=[lr_logger],
)


tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=32,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=16,
    output_size=7,
    loss=QuantileLoss(),
    log_interval=2,
    reduce_on_plateau_patience=4
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# find optimal learning rate
res = trainer.lr_find(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, early_stop_threshold=1000.0, max_lr=0.3,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

trainer.fit(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader,
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_forecasting-0.2.4.tar.gz (48.6 kB view details)

Uploaded Source

Built Distribution

pytorch_forecasting-0.2.4-py3-none-any.whl (52.4 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_forecasting-0.2.4.tar.gz.

File metadata

  • Download URL: pytorch_forecasting-0.2.4.tar.gz
  • Upload date:
  • Size: 48.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.0.9 CPython/3.7.7 Darwin/19.6.0

File hashes

Hashes for pytorch_forecasting-0.2.4.tar.gz
Algorithm Hash digest
SHA256 7c76c20ef16c61b88c880691c4c1b027d14a9712731d96622acdb277e6346036
MD5 42766838ba43ae62f634779c8035da08
BLAKE2b-256 f6c596962fdffd8145c8868a1e75da4f2a29d18b925946b922039811ec37136d

See more details on using hashes here.

File details

Details for the file pytorch_forecasting-0.2.4-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_forecasting-0.2.4-py3-none-any.whl
Algorithm Hash digest
SHA256 27c052e624a0950e951c179f8d5eb0b77fd749e78ccaaa91998478a566e04ce4
MD5 6d31215cced5c20ffc75bb4650a2f1c1
BLAKE2b-256 d6d83f229bb163f52656357e8de0092fc2721fbd183be9dcd28955addf672340

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page