Skip to main content

Utility functions that prints a summary of a model.

Project description

torch-inspect

https://travis-ci.com/jettify/pytorch-inspect.svg?branch=master https://codecov.io/gh/jettify/pytorch-inspect/branch/master/graph/badge.svg https://img.shields.io/pypi/pyversions/torch-inspect.svg

torch-inspect – collection of utility functions to inspect low level information of neural network for PyTorch

Features

  • Provides helper function summary that prints Keras style model summary.

  • Provides helper function inspect that returns object with network summary information for programmatic access.

  • Library has tests and reasonable code coverage.

Simple example

import torch.nn as nn
import torch.nn.functional as F
import torch_inspect as ti

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 6 * 6, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]
        num_features = 1
        for s in size:
            num_features *= s
        return num_features


  net = SimpleNet()
  ti.summary(net, (1, 32, 32), device='cpu')

Will produce following output:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1            [-1, 6, 30, 30]              60
            Conv2d-2           [-1, 16, 13, 13]             880
            Linear-3                  [-1, 120]          69,240
            Linear-4                   [-1, 84]          10,164
            Linear-5                   [-1, 10]             850
================================================================
Total params: 81,194
Trainable params: 81,194
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.06
Params size (MB): 0.31
Estimated Total Size (MB): 0.38
----------------------------------------------------------------

For programmatic access to network information there is inspect function:

info = ti.inspect(net, (1, 32, 32), device='cpu')
print(info)
[LayerInfo(name='Conv2d-1', input_shape=[-1, 1, 32, 32], output_shape=[-1, 6, 30, 30], trainable_params=60, non_trainable_params=0),
 LayerInfo(name='Conv2d-2', input_shape=[-1, 6, 15, 15], output_shape=[-1, 16, 13, 13], trainable_params=880, non_trainable_params=0),
 LayerInfo(name='Linear-3', input_shape=[-1, 576], output_shape=[-1, 120], trainable_params=69240, non_trainable_params=0),
 LayerInfo(name='Linear-4', input_shape=[-1, 120], output_shape=[-1, 84], trainable_params=10164, non_trainable_params=0),
 LayerInfo(name='Linear-5', input_shape=[-1, 84], output_shape=[-1, 10], trainable_params=850, non_trainable_params=0)]

Installation

Installation process is simple, just:

$ pip install torch-inspect

Requirements

References and Thanks

This package is based on pytorch-summary and PyTorch issue

Changes

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

torch-inspect-0.0.2.tar.gz (11.6 kB view details)

Uploaded Source

File details

Details for the file torch-inspect-0.0.2.tar.gz.

File metadata

  • Download URL: torch-inspect-0.0.2.tar.gz
  • Upload date:
  • Size: 11.6 kB
  • Tags: Source
  • Uploaded using Trusted Publishing? No
  • Uploaded via: Python-urllib/3.7

File hashes

Hashes for torch-inspect-0.0.2.tar.gz
Algorithm Hash digest
SHA256 d3549269bc8ef5a75d51d3436d0af3df550a8efc69e5044a75a1093de678e6e6
MD5 9c7e09631cf8d3c7a3bb89408a064149
BLAKE2b-256 9b7e2150c68c2f961d35e57d4fd281678d98ee3a5f90c31ae5a979f3ed2bedbd

See more details on using hashes here.

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page