Customizable Fastai+PyTorch implementation of sparse model training methods (SET, SNFS, RigL).

Warning: this repo is undergoing active development

Getting Started

Install

pip install fastsparse

Sparse Algorithms

This network implements the following sparse algorithms:

Abbr. Sparse Algorithm in FastSparse Notes
static sparsity baseline omit DynamicSparseTrainingCallback
SET Sparse Evolutionary Training (Jan 2019) DynamicSparseTrainingCallback(**SET_presets)
SNFS Sparse Networks From Scratch (Jul 2019) DynamicSparseTrainingCallback(**SNFS_presets) *redistribution not implemented
RigL Rigged Lottery (Nov 2019) DynamicSparseTrainingCallback(**RigL_presets)

*Authors of the RigL paper demonstrate that using SNFS + Erdos-Renyi-Kernel distribution - redistribution outperforms SNFS + uniform sparsity + redistribution (at least on the measured benchmarks).

Fastai demo

With just 4 additional lines of code, you can train your model using the latest dynamic sparse training techniques. This example achieves >99% accuracy on MNIST using a ResNet34 with only 1% of the weights.

# (0) install the library
# ! pip install fastsparse 

from fastai.vision.all import *

# (1) import this package
import fastsparse as sparse

path = untar_data(URLs.MNIST)
dls = ImageDataLoaders.from_folder(path, 'training', 'testing')
learn = cnn_learner(dls, resnet34, metrics=error_rate, pretrained=False)

# (2) sparsify initial model + enforce masks
sparse_hooks = sparse.sparsify_model(learn.model, 
                                     model_sparsity=0.99,
                                     sparse_f=sparse.erdos_renyi_sparsity)

# (3) schedule dynamic mask updates
cbs = [sparse.DynamicSparseTrainingCallback(**sparse.SNFS_presets, 
                                            batches_per_update=32)]

learn.fit_one_cycle(5, cbs=cbs)

# (4) remove hooks that enforce masks
sparse_hooks.remove()

Simply omit the DynamicSparseTrainingCallback to train a fixed-sparsity model as a baseline.

PyTorch demo (not implemented yet)

import torch
from torchvision import models

data = ...
model = ...
opt = ...
opt = DynamicSparseTrainingOptimizerWrapper(model, opt, **RigL_kwargs)

### Modified training step
# sparse_opt.step(...) will determine whether to:
#  (A) take a regular opt step, or
#  (B) update network connectivity
def sparse_train_step(model, xb, yb, loss_func, sparse_opt, step, pct_train):
    preds = model(xb)
    loss = loss_func(preds, yb)
    loss.backward()
    sparse_opt.step(step, pct_train)
    sparse_opt.zero_grad()

Save/Reload demo

Here is an example of saving a model and reloading it to resume training.

from fastai.vision.all import *
from fastsparse import *

path = untar_data(URLs.MNIST_TINY)
dls = ImageDataLoaders.from_folder(untar_data(URLs.MNIST_TINY))
learn = cnn_learner(dls, resnet18, metrics=accuracy, pretrained=False)
sparse_hooks = sparsify_model(learn.model, model_sparsity=0.9, sparse_f=erdos_renyi_sparsity)
dst_kwargs = {**SNFS_presets, **{'batches_per_update': 8}}
cbs = DynamicSparseTrainingCallback(**dst_kwargs)

learn.fit_flat_cos(5, cbs=cbs)

# (0) save model as usual (masks are stored automatically)
save_model('sparse_tiny_mnist', learn.model, learn.opt)
epoch train_loss valid_loss accuracy time
0 0.334119 0.679619 0.505007 00:03
1 0.271645 0.555170 0.848355 00:02
2 0.237115 0.072088 0.978541 00:02
3 0.220553 0.044927 0.987124 00:02
4 0.174585 0.006496 1.000000 00:02
# (1) then recreate learner as usual

from fastai.vision.all import *
from fastsparse import *

path = untar_data(URLs.MNIST_TINY)
dls = ImageDataLoaders.from_folder(untar_data(URLs.MNIST_TINY))
learn = cnn_learner(dls, resnet18, metrics=accuracy, pretrained=False)

# (2) re-sparsify model (this adds the masks to the parameters)
sparse_hooks = sparsify_model(learn.model, model_sparsity=0.9, sparse_f=erdos_renyi_sparsity) # <-- initial sparsity + enforce masks

# (3) load model as usual
load_model('sparse_tiny_mnist', learn.model, learn.opt)

# (5) check validation loss & accuracy to verify we've loaded it successfully
val_loss, val_acc = learn.validate()
print(f'validation loss: {val_loss}, validation accuracy: {val_acc}')

# (4) optionally, continue training; otherwise remove sparsity-preserving hooks
sparse_hooks.remove()
/home/dc/anaconda3/envs/fastai/lib/python3.8/site-packages/fastai/learner.py:53: UserWarning: Could not load the optimizer state.
  if with_opt: warn("Could not load the optimizer state.")
validation loss: 0.006496043410152197, validation accuracy: 1.0

Training with Large Batch Sizes

Authors of the Rigged Lottery paper hypothesize that the effectiveness of using the gradient magnitude for determining which connections to grow is partly due to their large batch size (4096 for ImageNet). Those without access to multi-gpu clusters can achieve effective batch sizes of this size by using fastai's GradientAccumulation callback, which has been tested to be compatible with this package's DynamicSparseTrainingCallback.

Training with Small # of Epochs

Dynamic sparse training algorithms work by modifying the network connectivity during training, dropping some weights and allowing others to regrow. By default, network connectivity is modified at the end of each epoch. When training with few epochs, however, there will be few chances to explore which weights to connect. To update more frequently, in DynamicSparseTrainingCallback, set batches_per_update to a smaller # of batches than occur in one training epoch. Varying the number of batches per update trades off the frequency of updates with stability in making good updates.

Customization

There are many ways to implement and test your own dynamic sparse algorithms using FastSparse.

Custom Initial Sparsity Distribution:

Define your own initial sparsity distribution by setting sparsify_method in sparsify_model to a custom function. For example, this function (included in library) will keep the first layer dense and set the remaining layers to a fixed sparsity.

def first_layer_dense_uniform(params:list, model_sparsity:float):
    sparsities = [1.] + [model_sparsity] * (len(params) - 1)
    return sparsities

Custom Drop Criterion

While published papers like SNFS and RigL refer to 'drop criterion', this library implements the reverse, a 'keep criterion'. This is a function that returns a score for each weight, where the largest M scores will be and M is determined by the decay schedule. For example, both Sparse Networks From Scratch and Rigged Lottery both use the magnitude of the weights (in FastSparse: weight_magnitude).

This can easily be customized in FastSparse by defining your own keep score function:

def custom_keep_scoring_function(param, opt):
    score = ...
    assert param.shape == score.shape
    return score

Then pass your custom function into the sparse training callback:

DynamicSparseTrainingCallback(..., keep_score_f=custom_keep_scoring_function)

Custom Grow Criterion

The grow criterion is a function that returns a score for each weight, where the largest N scores will be and N is determined by the decay schedule. For example, Sparse Networks From Scrath grows weights according to the momentum of the gradient, while Rigged Lottery uses the magnitude of the gradient (in FastSparse, gradient_momentum and gradient_magnitude respectively).

def custom_grow_scoring_function(param, opt):
    score = ...
    assert param.shape == score.shape
    return score

Then pass your custom function into the sparse training callback:

DynamicSparseTrainingCallback(..., grow_score_f=custom_grow_scoring_function)

Replication Results

In machine learning, is very easy for seemingly insignificant differences in algorithmic implementation to have a noticeable impact on final results. Therefore, this section compares results from this implementation to results reported in published papers.

TODO...

Under-The-Hood Details

Here's what's going on.

When you run sparsify_model(learn.model, 0.9), this adds sparse masks and add pre_forward hooks to enforce masks on weights during forward pass.

By default, a uniform sparsity distribution is used. Change the sparsity distribution to Erdos-Renyi with sparsify_model(learn.model, 0.9, sparse_init_f=erdos_renyi), or pass in your custom function (see Customization

To avoid adding pre_forward hooks, use sparsify_model(learn.model, 0.9, enforce_masks=False).

When you add the DynamicSparseTrainingCallback callback, ... TODO complete section