Skip to content

Conversation

@cijinsama
Copy link

Main change features:

  • enable deepcopy by add __deepcopy__ method for QConv2d_NoAct
  • enable vmap by change the implementation for GradientCancellation and SignFunction

Previous error example

example for deepcopy error:

from bitorch.models import Resnet18V2
model = Resnet18V2(input_shape=(3, 224, 224), num_classes=10)
import copy
copy.deepcopy(model)

the stderr is

Traceback (most recent call last):
  File "/home/cijin/Code/bitorch/test.py", line 4, in <module>
    copy.deepcopy(model)
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/copy.py", line 271, in _reconstruct
    state = deepcopy(state, memo)
            ^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/copy.py", line 146, in deepcopy
    y = copier(x, memo)
        ^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/copy.py", line 231, in _deepcopy_dict
    y[deepcopy(key, memo)] = deepcopy(value, memo)
                             ^^^^^^^^^^^^^^^^^^^^^
......
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/copy.py", line 172, in deepcopy
    y = _reconstruct(x, memo, *rv)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/copy.py", line 272, in _reconstruct
    if hasattr(y, '__setstate__'):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Code/bitorch/bitorch/layers/extensions/layer_container.py", line 42, in __getattr__
    attr_value = getattr(self._layer_implementation, item)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Code/bitorch/bitorch/layers/extensions/layer_container.py", line 41, in __getattr__
    return self.__dict__[item]
           ~~~~~~~~~~~~~^^^^^^
KeyError: '_layer_implementation'

example for vmap error:

from bitorch.models import Resnet18V2
import torch
import torch.nn.functional as F
from opacus.grad_sample import GradSampleModule
from opacus.validators import ModuleValidator
model = Resnet18V2(input_shape=(3, 224, 224), num_classes=10)
model = GradSampleModule(ModuleValidator.fix(model))
batch_input = torch.randn(10, 224, 224, 3)
pred = model(batch_input)
F.nll_loss(pred, torch.randint(0, 10, (10,))).backward()

the stderr is:

Traceback (most recent call last):
  File "/home/cijin/Code/bitorch/test.py", line 10, in <module>
    F.nll_loss(pred, torch.randint(0, 10, (10,))).backward()
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/nn/modules/module.py", line 98, in __call__
    return self.hook(module, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/opacus/grad_sample/grad_sample_module.py", line 338, in capture_backprops_hook
    grad_samples = grad_sampler_fn(module, activations, backprops)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/opacus/grad_sample/functorch.py", line 108, in ft_compute_per_sample_gradient
    per_sample_grads = layer.ft_compute_sample_grad(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/_functorch/apis.py", line 203, in wrapped
    return vmap_impl(
           ^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 331, in vmap_impl
    return _flat_vmap(
           ^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 479, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/_functorch/apis.py", line 399, in wrapper
    return eager_transforms.grad_impl(func, argnums, has_aux, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/_functorch/eager_transforms.py", line 1449, in grad_impl
    results = grad_and_value_impl(func, argnums, has_aux, args, kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 48, in fn
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/_functorch/eager_transforms.py", line 1407, in grad_and_value_impl
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/opacus/grad_sample/functorch.py", line 85, in compute_loss_stateless_model
    output = flayer(params, batched_activations)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/opacus/grad_sample/functorch.py", line 50, in fmodel
    return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/_functorch/functional_call.py", line 148, in functional_call
    return nn.utils.stateless._functional_call(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/nn/utils/stateless.py", line 298, in _functional_call
    return module(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Code/bitorch/bitorch/layers/qconv2d.py", line 106, in forward
    return super().forward(self.activation(input_tensor))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Code/bitorch/bitorch/layers/qactivation.py", line 84, in forward
    input_tensor = GradientCancellation.apply(input_tensor, self.gradient_cancellation_threshold)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/cijin/Application/miniconda3/envs/paraopt/lib/python3.11/site-packages/torch/autograd/function.py", line 578, in apply
    raise RuntimeError(
RuntimeError: In order to use an autograd.Function with functorch transforms (vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod. For more details, please see https://pytorch.org/docs/main/notes/extending.func.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant