Skip to main content

Temporal fusion transformer for timeseries forecasting

Project description

Timeseries forecasting with Pytorch

Install with

pip install pytorch-forecasting

Available models

Usage

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping

from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer

# load data
data = ...

# define dataset
max_encode_length = 36
max_prediction_length = 6
training_cutoff = "YYYY-MM-DD"  # day for cutoff

training = TimeSeriesDataSet(
    data[lambda x: x.date < training_cutoff],
    time_idx= ...,
    target= ...,
    # weight="weight",
    group_ids=[ ... ],
    max_encode_length=max_encode_length,
    max_prediction_length=max_prediction_length,
    static_categoricals=[ ... ],
    static_reals=[ ... ],
    time_varying_known_categoricals=[ ... ],
    time_varying_known_reals=[ ... ],
    time_varying_unknown_categoricals=[ ... ],
    time_varying_unknown_reals=[ ... ],
)


validation = TimeSeriesDataSet.from_dataset(training, data, min_prediction_idx=training.index.time.max() + 1, stop_randomization=True)
batch_size = 128
train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=2)
val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size, num_workers=2)


early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=1, verbose=False, mode="min")
lr_logger = LearningRateLogger()
trainer = pl.Trainer(
    max_epochs=100,
    gpus=0,
    gradient_clip_val=0.1,
    early_stop_callback=early_stop_callback,
    limit_train_batches=30,
    callbacks=[lr_logger],
)


tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=0.03,
    hidden_size=32,
    attention_head_size=1,
    dropout=0.1,
    hidden_continuous_size=16,
    output_size=7,
    loss=QuantileLoss(log_space=True),
    log_interval=2,
    reduce_on_plateau_patience=4
)
print(f"Number of parameters in network: {tft.size()/1e3:.1f}k")

# find optimal learning rate
res = trainer.lr_find(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader, early_stop_threshold=1000.0, max_lr=0.3,
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

trainer.fit(
    tft, train_dataloader=train_dataloader, val_dataloaders=val_dataloader,
)

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

pytorch_forecasting-0.2.0.tar.gz (45.0 kB view details)

Uploaded Source

Built Distribution

pytorch_forecasting-0.2.0-py3-none-any.whl (48.5 kB view details)

Uploaded Python 3

File details

Details for the file pytorch_forecasting-0.2.0.tar.gz.

File metadata

  • Download URL: pytorch_forecasting-0.2.0.tar.gz
  • Upload date:
  • Size: 45.0 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: poetry/1.0.9 CPython/3.7.7 Darwin/19.5.0

File hashes

Hashes for pytorch_forecasting-0.2.0.tar.gz
Algorithm Hash digest
SHA256 a33a0d45caaf442d21166563c4ce129020c4a1f69ced14cbd4c4c46bb25734bc
MD5 0990b6df406069e90192a2b534a8d007
BLAKE2b-256 570a3aa6cfbd11930182b7a5309605817cb8a26a844d07d38eb0c59710ca1305

See more details on using hashes here.

File details

Details for the file pytorch_forecasting-0.2.0-py3-none-any.whl.

File metadata

File hashes

Hashes for pytorch_forecasting-0.2.0-py3-none-any.whl
Algorithm Hash digest
SHA256 b13352e093c303ff4c14cb9f980972e4b13e8de0ff30ba4011b72b099c156d3b
MD5 bf1ec6f620644a71b681bf2775b89d73
BLAKE2b-256 d4024383f1d3962c9f4b8f1d7196f7daaef12d7fd7377e08d337beb0e54ceb0d

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page