Core functionality for sparsifying dense modules & models.

Sparsify Module

For sparsifying a single module.

When a parameter and buffer in a module follow the naming convention: {p_name}, {p_name}_mask, respectively, the buffer is assumed to be the mask for the parameter. For example, masked Linear and ConvNd layers will typically have a parameter named weight, a buffer named weight_mask. Additionally, parameters optionally also contain a sparsity buffer (e.g. for ConvNd, named weight_sparsity), which is used by the DynamicSparseTrainingCallback.

sparse_mask[source]

sparse_mask(sizes, sparsity)

sparse_mask_like[source]

sparse_mask_like(param, sparsity)

mask = sparse_mask((10,5), 0.8)
test_eq(10, int(mask.sum()))

maybe_float[source]

maybe_float(num)

sparse_params[source]

sparse_params(module)

Returns list of all (param, mask, sparsity) tuples in a module.

s, m = 0.8, nn.Linear(5,10)
m.register_buffer('weight_mask', sparse_mask_like(m.weight, s))
m.register_buffer('weight_sparsity', tensor(s))
m.register_buffer('bias_mask', sparse_mask_like(m.bias, s))
param_mask_sparsity = sparse_params(m)
test_eq(2, len(param_mask_sparsity))

apply_masks[source]

apply_masks(module, *args, inplace=True)

apply_masks(m)
test_eq(10, m.weight.abs().gt(0).sum())

is_sparseable_module[source]

is_sparseable_module(m, additional_types=[])

sparseable_modules[source]

sparseable_modules(model, additional_types=[])

def test_model():
    return nn.Sequential(
        nn.Conv2d(3,32,3), nn.ReLU(), 
        nn.Conv2d(32,128,3), nn.ReLU(), 
        nn.Conv2d(128,512,3), nn.ReLU(), AdaptiveAvgPool(), Flatten(),
        nn.Linear(512, 10))

model = test_model()
s_mods = sparseable_modules(model)
test_eq(4, len(s_mods))

Sparse Distributions

For determining the layer-wise sparsity of a list of modules.

Uniform Distribution

All layers have a the same percentage of connection removed.

uniform_sparsity[source]

uniform_sparsity(params, model_sparsity)

First-Layer-Dense Uniform Distribution

Uniform sparsity except for the first layer, which is dense.

first_layer_dense_uniform[source]

first_layer_dense_uniform(params, model_sparsity)

Erdos-Renyi (Kernel) Distribution

For a fixed overall sparsity, the Erdos-Renyi sparsity distribution allocates more connections to smaller layers and fewer to large layers when compared to a uniform sparsity distribution.

erdos_renyi_sparsity[source]

erdos_renyi_sparsity(params, model_sparsity, include_kernel=True, erk_power_scale=1.0)

Returns a list of sparsities in the same order as params. Sparsities satisfy the Erdos-Renyi(Kernel) distribution, where the model has a total parameter count as one with uniform sparsities, that is, satisfying the following equation: $ eps * (p_1 * N_1 + p_2 * N_2) = (1 - model_sparsity) * (N_1 + N_2) $, for some float eps.

Args: params: list of all sparseable parameters model_sparsity: target overall sparsity between 0 and 1 include_kernel: if True, kernel dimensions are included in the scaling (e.g. for ConvNd layers) erk_power_scale: scale < 1 softens the erdos_renyi distribution (i.e. closer to uniform)

Returns a list of sparsities where values correspond to individual param sparsities.

model = test_model()
s_params = L(sparseable_modules(model)).map(lambda m: m.weight)

sparsities = erdos_renyi_sparsity(s_params, 0.9)
n_nonzeros = sum([(1-s) * p.numel() for p, s in zip(s_params, sparsities)])
test_close(n_nonzeros, 0.1 * sum([p.numel() for p in s_params]), eps=len(s_params))
# test_eq([0., 0., 0., 0.], sparsities) # TODO: calc sparsities by hand and compare

Sparsify Model

For sparsifying an entire model.

sparsify_model[source]

sparsify_model(model, model_sparsity, sparse_f=uniform_sparsity, enforce_mask=True)

Adds a sparse mask for each sparseable-module weight in model and applies mask to weights.

If enforce_mask is True, a forward_pre_hook will be registered to each module to apply the weight mask before every forward pass of the module.

sparsify_method: per RigL paper, uniform_sparsity has fewer FLOPs, erdos_renyi_sparsity results in better model.

Returns a fastai Hooks object. You can remove the hooks after training by calling hooks.remove().

model = test_model()
s_mods = sparseable_modules(model)
n_params = sum(m.weight.numel() for m in s_mods)
sparsify_model(model, 0.9, sparse_f=uniform_sparsity)
n_nonzeros = sum(m.weight.abs().gt(0).sum() for m in s_mods)
# increase `eps` to account for rounding to nearest whole weight
test_close(n_nonzeros, 0.1 * n_params, eps=len(s_mods))
p = s_mods[0].weight
test_close(p.abs().gt(0).sum(), 0.1 * p.numel(), eps=1)

model = nn.Sequential(nn.Linear(1,50), nn.ReLU(), nn.Linear(50,1))
hooks = sparsify_model(model, 0.9)
model(torch.rand(10,1))
test_eq(10, sum([model[i].weight.abs().gt(0).sum() for i in (0,2)]))
hooks.remove()
for i in (0,2): model[i].weight.data = torch.ones_like(model[i].weight)
model(torch.rand(10,1))
test_eq(100, sum([model[i].weight.abs().gt(0).sum() for i in (0,2)]))

Sparse Training

Drop/Grow Heuristics

random_score[source]

random_score(p, **kwargs)

weight_magnitude[source]

weight_magnitude(p, **kwargs)

gradient_magnitude[source]

gradient_magnitude(p, **kwargs)

gradient_momentum[source]

gradient_momentum(p, opt, **kwargs)

Calculates the momentum of the gradient for a parameter p from the opt state.

Dynamic Sparse Training Callback

top_k_mask[source]

top_k_mask(t, n_keep)

Returns a mask with n_keep ones cooresponding to the largest values in t

t = torch.linspace(-0.9, 0.9, 20).reshape(4,5)
mask = top_k_mask(t, 5)
test_eq(0, mask[:3].sum())
test_eq(5, mask[3:].sum())

class DynamicSparseTrainingCallback[source]

DynamicSparseTrainingCallback(sparse_modules=None, batches_per_update=None, initial_drop_grow_pct=0.3, stop_pct=0.75, keep_score_f=weight_magnitude, grow_score_f=gradient_magnitude) :: Callback

Dynamically updates the network connectivity during training.

First, let's test the callback on a toy model:

from fastai.test_utils import *
model = nn.Sequential(nn.Linear(1,32), nn.ReLU(), nn.Linear(32,32), nn.ReLU(), nn.Linear(32,1))
learn = synth_learner(data=synth_dbunch(bs=100), model=model)
sparse_hooks = sparsify_model(learn.model, 0.8, sparse_f=first_layer_dense_uniform)
cbs = DynamicSparseTrainingCallback(batches_per_update=None, stop_pct=0.5, grow_score_f=gradient_momentum)
learn.fit(10, lr=1e-2, cbs=cbs)
epoch train_loss valid_loss time
0 6.367298 5.362422 00:00
1 4.590237 3.676831 00:00
2 3.637887 2.383177 00:00
3 2.929754 1.341745 00:00
4 2.343258 0.674901 00:00
5 1.849479 0.260025 00:00
6 1.453780 0.182112 00:00
7 1.152167 0.085386 00:00
8 0.918357 0.049471 00:00
9 0.733873 0.029181 00:00

Now, let's test a slightly more realistic use case: MNIST_TINY on ResNet18.

from fastai.vision.all import *
dls = ImageDataLoaders.from_folder(untar_data(URLs.MNIST_TINY))
learn = cnn_learner(dls, resnet18, metrics=accuracy, pretrained=False)
sparse_hooks = sparsify_model(learn.model, 0.9, erdos_renyi_sparsity)
cbs = DynamicSparseTrainingCallback(batches_per_update=8, stop_pct=0.5, grow_score_f=gradient_momentum)
learn.fit_one_cycle(5, 1e-2, cbs=cbs)

test_close(1, learn.final_record[-1], eps=0.02) # better than 98% accuracy

for m in sparseable_modules(learn.model):
    for p, mask, s in sparse_params(m):
        n_alive = p.abs().gt(0).sum()
        n_total = p.numel()    
        test_close(s, 1 - n_alive / n_total, eps=0.01) # layer sparsity = target sparsity
epoch train_loss valid_loss accuracy time
0 0.779406 0.861908 0.505007 00:03
1 0.952576 0.216129 0.931330 00:02
2 0.668288 2.507498 0.746781 00:02
3 0.468921 0.129384 0.967096 00:02
4 0.348578 0.024292 0.991416 00:02

Preset Definitions

Sparse Evolutionary Training (SET)

Sparse Networks From Scratch (SNFS)

Rigged Lottery (RigL)

Export