Utility functions that prints a summary of a model.
Project description
torch-inspect
torch-inspect – collection of utility functions to inspect low level information of neural network for PyTorch
Features
Provides helper function summary that prints Keras style model summary.
Provides helper function inspect that returns object with network summary information for programmatic access.
RNN/LSTM support.
Library has tests and reasonable code coverage.
Simple example
import torch.nn as nn
import torch.nn.functional as F
import torch_inspect as ti
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
self.conv1 = nn.Conv2d(1, 6, 3)
self.conv2 = nn.Conv2d(6, 16, 3)
self.fc1 = nn.Linear(16 * 6 * 6, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
x = F.max_pool2d(F.relu(self.conv2(x)), 2)
x = x.view(-1, self.num_flat_features(x))
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
def num_flat_features(self, x):
size = x.size()[1:]
num_features = 1
for s in size:
num_features *= s
return num_features
net = SimpleNet()
ti.summary(net, (1, 32, 32))
Will produce following output:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Conv2d-1 [100, 6, 30, 30] 60
Conv2d-2 [100, 16, 13, 13] 880
Linear-3 [100, 120] 69,240
Linear-4 [100, 84] 10,164
Linear-5 [100, 10] 850
================================================================
Total params: 81,194
Trainable params: 81,194
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.39
Forward/backward pass size (MB): 6.35
Params size (MB): 0.31
Estimated Total Size (MB): 7.05
----------------------------------------------------------------
For programmatic access to network information there is inspect function:
info = ti.inspect(net, (1, 32, 32))
print(info)
[LayerInfo(name='Conv2d-1', input_shape=[100, 1, 32, 32], output_shape=[100, 6, 30, 30], trainable_params=60, non_trainable_params=0),
LayerInfo(name='Conv2d-2', input_shape=[100, 6, 15, 15], output_shape=[100, 16, 13, 13], trainable_params=880, non_trainable_params=0),
LayerInfo(name='Linear-3', input_shape=[100, 576], output_shape=[100, 120], trainable_params=69240, non_trainable_params=0),
LayerInfo(name='Linear-4', input_shape=[100, 120], output_shape=[100, 84], trainable_params=10164, non_trainable_params=0),
LayerInfo(name='Linear-5', input_shape=[100, 84], output_shape=[100, 10], trainable_params=850, non_trainable_params=0)]
Installation
Installation process is simple, just:
$ pip install torch-inspect
Requirements
References and Thanks
This package is based on pytorch-summary and PyTorch issue . Compared to pytorch-summary, pytorch-inspect has support of RNN/LSTMs, also provides programmatic access to the network summary information. With a bit more modular structure and presence of tests it is easier to extend and support more features.
Changes
0.0.3 (2019-09-22)
Added LSTM support
Fixed multi input/output support
Added more network test cases
Batch size no longer -1 by default
0.0.2 (2019-09-22)
Added batch norm support
Removed device parameter
0.0.1 (2019-09-1)
Initial release.
Project details
Release history Release notifications | RSS feed
Download files
Download the file for your platform. If you're not sure which to choose, learn more about installing packages.
Source Distribution
File details
Details for the file torch-inspect-0.0.3.tar.gz
.
File metadata
- Download URL: torch-inspect-0.0.3.tar.gz
- Upload date:
- Size: 14.5 kB
- Tags: Source
- Uploaded using Trusted Publishing? No
- Uploaded via: Python-urllib/3.7
File hashes
Algorithm | Hash digest | |
---|---|---|
SHA256 | 12a06812109cb2bad0a46eeb3216c65dc6eca063dade1327c563876c6a8fc59b |
|
MD5 | d7687d16a98c09fd772880fa12fe2b3a |
|
BLAKE2b-256 | 0f5ed1df8aaeb433c32ece64b9e7df82f0f0abd5ec18b441a7bf18e2d60fd5bb |