Skip to content
This repository was archived by the owner on May 1, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pruning/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

from foundations.hparams import PruningHparams
from pruning import sparse_global
from pruning import sparse_layerwise

registered_strategies = {'sparse_global': sparse_global.Strategy}
registered_strategies = {'sparse_global': sparse_global.Strategy, 'sparse_layerwise': sparse_layerwise.Strategy}


def get(pruning_hparams: PruningHparams):
Expand Down
48 changes: 48 additions & 0 deletions pruning/sparse_layerwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses
import numpy as np

from foundations import hparams
import models.base
from pruning import base
from pruning import sparse_global
from pruning.mask import Mask


PruningHparams = sparse_global.PruningHparams


class Strategy(base.Strategy):
@staticmethod
def get_pruning_hparams() -> type:
return PruningHparams

@staticmethod
def prune(pruning_hparams: PruningHparams, trained_model: models.base.Model, current_mask: Mask = None):
current_mask = Mask.ones_like(trained_model).numpy() if current_mask is None else current_mask.numpy()

new_mask = Mask(current_mask)

for k,v in current_mask.items():

if pruning_hparams.pruning_layers_to_ignore and k in pruning_hparams.pruning_layers_to_ignore:
continue

# Determine the number of weights that need to be pruned.
number_of_remaining_weights = np.sum(v)
number_of_weights_to_prune = np.ceil(
pruning_hparams.pruning_fraction * number_of_remaining_weights).astype(int)


weights = trained_model.state_dict()[k].clone().cpu().detach().numpy()

weight_vector = weights[current_mask[k] == 1]
threshold = np.sort(np.abs(weight_vector))[number_of_weights_to_prune]

new_mask[k] = np.where(np.abs(weights) >= threshold, current_mask[k], np.zeros_like(v))

return new_mask
82 changes: 82 additions & 0 deletions pruning/test/test_sparse_layerwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Facebook, Inc. and its affiliates.

# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np

import models.registry
from pruning.sparse_layerwise import Strategy
from pruning.sparse_layerwise import PruningHparams
from testing import test_case


class TestSparseLayerWise(test_case.TestCase):
def setUp(self):
super(TestSparseLayerWise, self).setUp()
self.hparams = PruningHparams('sparse_global', 0.2)

model_hparams = models.registry.get_default_hparams('cifar_resnet_20').model_hparams
self.model = models.registry.get(model_hparams)

def test_get_pruning_hparams(self):
self.assertTrue(issubclass(Strategy.get_pruning_hparams(), PruningHparams))

def test_prune(self):
m = Strategy.prune(self.hparams, self.model)

# Check that the mask only contains entries for the prunable layers.
self.assertEqual(set(m.keys()), set(self.model.prunable_layer_names))

# Check that the masks are the same sizes as the tensors.
for k in self.model.prunable_layer_names:
self.assertEqual(list(m[k].shape), list(self.model.state_dict()[k].shape))

# Check that the right fraction of weights was pruned for all prunable layers.
m = m.numpy()
for k, v in m.items():
pruned_count = np.sum(1 - v)
should_be_pruned = v.size * self.hparams.pruning_fraction
self.assertTrue(should_be_pruned <= pruned_count <= should_be_pruned + 1)

# Ensure that there are no errors in thresholding for every layer
for k, v in m.items():
pruned_weights = self.model.state_dict()[k].numpy()[m[k] == 0]
threshold = np.max(np.abs(pruned_weights))
unpruned_weights = self.model.state_dict()[k].numpy()[m[k] == 1]
self.assertTrue(np.all(np.abs(unpruned_weights) > threshold))

def test_iterative_pruning(self):
m = Strategy.prune(self.hparams, self.model)
m2 = Strategy.prune(self.hparams, self.model, m)

# Ensure that all weights pruned before are still pruned here.
m, m2 = m.numpy(), m2.numpy()
self.assertEqual(set(m.keys()), set(m2.keys()))
for k in m:
self.assertTrue(np.all(m[k] >= m2[k]))

for k, v in m2.items():
pruned_count = np.sum(1 - v)
should_be_pruned = v.size * (1 - ((1 - self.hparams.pruning_fraction) ** 2))
self.assertTrue(should_be_pruned <= pruned_count <= should_be_pruned + 2)

def test_prune_layers_to_ignore(self):
layers_to_ignore = sorted(self.model.prunable_layer_names)[:5]
self.hparams.pruning_layers_to_ignore = ','.join(layers_to_ignore)

m = Strategy.prune(self.hparams, self.model).numpy()

# Ensure that the ignored layers were, indeed, ignored.
for k in layers_to_ignore:
self.assertTrue(np.all(m[k] == 1))

# Ensure that the expected fraction of parameters was still pruned.
total_pruned = np.sum([np.sum(1 - v) for v in m.values()])
total_weights = np.sum([v.size for v in m.values()])
actual_fraction = float(total_pruned) / total_weights
self.assertGreaterEqual(actual_fraction, self.hparams.pruning_fraction)
self.assertGreater(self.hparams.pruning_fraction + 0.0001, actual_fraction)


test_case.main()