Core functionality for sparsifying dense modules & models.
When a parameter and buffer in a module follow the naming convention: {p_name}
, {p_name}_mask
, respectively, the buffer is assumed to be the mask for the parameter. For example, masked Linear and ConvNd layers will typically have a parameter named weight
, a buffer named weight_mask
. Additionally, parameters optionally also contain a sparsity buffer (e.g. for ConvNd, named weight_sparsity
), which is used by the DynamicSparseTrainingCallback.
mask = sparse_mask((10,5), 0.8)
test_eq(10, int(mask.sum()))
s, m = 0.8, nn.Linear(5,10)
m.register_buffer('weight_mask', sparse_mask_like(m.weight, s))
m.register_buffer('weight_sparsity', tensor(s))
m.register_buffer('bias_mask', sparse_mask_like(m.bias, s))
param_mask_sparsity = sparse_params(m)
test_eq(2, len(param_mask_sparsity))
apply_masks(m)
test_eq(10, m.weight.abs().gt(0).sum())
def test_model():
return nn.Sequential(
nn.Conv2d(3,32,3), nn.ReLU(),
nn.Conv2d(32,128,3), nn.ReLU(),
nn.Conv2d(128,512,3), nn.ReLU(), AdaptiveAvgPool(), Flatten(),
nn.Linear(512, 10))
model = test_model()
s_mods = sparseable_modules(model)
test_eq(4, len(s_mods))
model = test_model()
s_params = L(sparseable_modules(model)).map(lambda m: m.weight)
sparsities = erdos_renyi_sparsity(s_params, 0.9)
n_nonzeros = sum([(1-s) * p.numel() for p, s in zip(s_params, sparsities)])
test_close(n_nonzeros, 0.1 * sum([p.numel() for p in s_params]), eps=len(s_params))
# test_eq([0., 0., 0., 0.], sparsities) # TODO: calc sparsities by hand and compare
model = test_model()
s_mods = sparseable_modules(model)
n_params = sum(m.weight.numel() for m in s_mods)
sparsify_model(model, 0.9, sparse_f=uniform_sparsity)
n_nonzeros = sum(m.weight.abs().gt(0).sum() for m in s_mods)
# increase `eps` to account for rounding to nearest whole weight
test_close(n_nonzeros, 0.1 * n_params, eps=len(s_mods))
p = s_mods[0].weight
test_close(p.abs().gt(0).sum(), 0.1 * p.numel(), eps=1)
model = nn.Sequential(nn.Linear(1,50), nn.ReLU(), nn.Linear(50,1))
hooks = sparsify_model(model, 0.9)
model(torch.rand(10,1))
test_eq(10, sum([model[i].weight.abs().gt(0).sum() for i in (0,2)]))
hooks.remove()
for i in (0,2): model[i].weight.data = torch.ones_like(model[i].weight)
model(torch.rand(10,1))
test_eq(100, sum([model[i].weight.abs().gt(0).sum() for i in (0,2)]))
t = torch.linspace(-0.9, 0.9, 20).reshape(4,5)
mask = top_k_mask(t, 5)
test_eq(0, mask[:3].sum())
test_eq(5, mask[3:].sum())
First, let's test the callback on a toy model:
from fastai.test_utils import *
model = nn.Sequential(nn.Linear(1,32), nn.ReLU(), nn.Linear(32,32), nn.ReLU(), nn.Linear(32,1))
learn = synth_learner(data=synth_dbunch(bs=100), model=model)
sparse_hooks = sparsify_model(learn.model, 0.8, sparse_f=first_layer_dense_uniform)
cbs = DynamicSparseTrainingCallback(batches_per_update=None, stop_pct=0.5, grow_score_f=gradient_momentum)
learn.fit(10, lr=1e-2, cbs=cbs)
Now, let's test a slightly more realistic use case: MNIST_TINY on ResNet18.
from fastai.vision.all import *
dls = ImageDataLoaders.from_folder(untar_data(URLs.MNIST_TINY))
learn = cnn_learner(dls, resnet18, metrics=accuracy, pretrained=False)
sparse_hooks = sparsify_model(learn.model, 0.9, erdos_renyi_sparsity)
cbs = DynamicSparseTrainingCallback(batches_per_update=8, stop_pct=0.5, grow_score_f=gradient_momentum)
learn.fit_one_cycle(5, 1e-2, cbs=cbs)
test_close(1, learn.final_record[-1], eps=0.02) # better than 98% accuracy
for m in sparseable_modules(learn.model):
for p, mask, s in sparse_params(m):
n_alive = p.abs().gt(0).sum()
n_total = p.numel()
test_close(s, 1 - n_alive / n_total, eps=0.01) # layer sparsity = target sparsity