2626import logging
2727import math
2828from packaging .version import parse
29- from typing import Any , TYPE_CHECKING
29+ from typing import Any , cast , TYPE_CHECKING
3030
3131import numpy as np
3232from tqdm .auto import trange
4141
4242 import torch
4343
44- from art .utils import CLASSIFIER_NEURALNETWORK_TYPE
44+ from art .utils import CLASSIFIER_NEURALNETWORK_TYPE , PYTORCH_OBJECT_DETECTOR_TYPE
4545
4646logger = logging .getLogger (__name__ )
4747
@@ -72,7 +72,7 @@ class AdversarialPatchPyTorch(EvasionAttack):
7272
7373 def __init__ (
7474 self ,
75- estimator : "CLASSIFIER_NEURALNETWORK_TYPE" ,
75+ estimator : "CLASSIFIER_NEURALNETWORK_TYPE | PYTORCH_OBJECT_DETECTOR_TYPE " ,
7676 rotation_max : float = 22.5 ,
7777 scale_min : float = 0.1 ,
7878 scale_max : float = 1.0 ,
@@ -91,7 +91,7 @@ def __init__(
9191 """
9292 Create an instance of the :class:`.AdversarialPatchPyTorch`.
9393
94- :param estimator: A trained estimator.
94+ :param estimator: A trained PyTorch estimator for classification or object detection .
9595 :param rotation_max: The maximum rotation applied to random patches. The value is expected to be in the
9696 range `[0, 180]`.
9797 :param scale_min: The minimum scaling applied to random patches. The value should be in the range `[0, 1]`,
@@ -183,7 +183,10 @@ def __init__(
183183 self ._optimizer = torch .optim .Adam ([self ._patch ], lr = self .learning_rate )
184184
185185 def _train_step (
186- self , images : "torch.Tensor" , target : "torch.Tensor" , mask : "torch.Tensor" | None = None
186+ self ,
187+ images : "torch.Tensor" ,
188+ target : "torch.Tensor" | list [dict [str , "torch.Tensor" ]],
189+ mask : "torch.Tensor" | None = None ,
187190 ) -> "torch.Tensor" :
188191 import torch
189192
@@ -227,7 +230,12 @@ def _predictions(
227230
228231 return predictions , target
229232
230- def _loss (self , images : "torch.Tensor" , target : "torch.Tensor" , mask : "torch.Tensor" | None ) -> "torch.Tensor" :
233+ def _loss (
234+ self ,
235+ images : "torch.Tensor" ,
236+ target : "torch.Tensor" | list [dict [str , "torch.Tensor" ]],
237+ mask : "torch.Tensor" | None ,
238+ ) -> "torch.Tensor" :
231239 import torch
232240
233241 if isinstance (target , torch .Tensor ):
@@ -475,13 +483,17 @@ def _random_overlay(
475483 return patched_images
476484
477485 def generate ( # type: ignore
478- self , x : np .ndarray , y : np .ndarray | None = None , ** kwargs
486+ self , x : np .ndarray , y : np .ndarray | list [ dict [ str , np . ndarray | "torch.Tensor" ]] | None = None , ** kwargs
479487 ) -> tuple [np .ndarray , np .ndarray ]:
480488 """
481489 Generate an adversarial patch and return the patch and its mask in arrays.
482490
483491 :param x: An array with the original input images of shape NCHW or input videos of shape NFCHW.
484- :param y: An array with the original true labels.
492+ :param y: True or target labels of format `list[dict[str, Union[np.ndarray, torch.Tensor]]]`, one for each
493+ input image. The fields of the dict are as follows:
494+
495+ - boxes [N, 4]: the boxes in [x1, y1, x2, y2] format, with 0 <= x1 < x2 <= W and 0 <= y1 < y2 <= H.
496+ - labels [N]: the labels for each image.
485497 :param mask: A boolean array of shape equal to the shape of a single samples (1, H, W) or the shape of `x`
486498 (N, H, W) without their channel dimensions. Any features for which the mask is True can be the
487499 center location of the patch during sampling.
@@ -499,12 +511,19 @@ def generate( # type: ignore
499511 if self .patch_location is not None and mask is not None :
500512 raise ValueError ("Masks can only be used if the `patch_location` is `None`." )
501513
502- if y is None : # pragma: no cover
503- logger .info ("Setting labels to estimator predictions and running untargeted attack because `y=None`." )
504- y = to_categorical (np .argmax (self .estimator .predict (x = x ), axis = 1 ), nb_classes = self .estimator .nb_classes )
505-
506514 if hasattr (self .estimator , "nb_classes" ):
507- y = check_and_transform_label_format (labels = y , nb_classes = self .estimator .nb_classes )
515+
516+ y_array : np .ndarray
517+
518+ if y is None : # pragma: no cover
519+ logger .info ("Setting labels to estimator classification predictions." )
520+ y_array = to_categorical (
521+ np .argmax (self .estimator .predict (x = x ), axis = 1 ), nb_classes = self .estimator .nb_classes
522+ )
523+ else :
524+ y_array = cast (np .ndarray , y )
525+
526+ y = check_and_transform_label_format (labels = y_array , nb_classes = self .estimator .nb_classes )
508527
509528 # check if logits or probabilities
510529 y_pred = self .estimator .predict (x = x [[0 ]])
@@ -513,6 +532,10 @@ def generate( # type: ignore
513532 self .use_logits = False
514533 else :
515534 self .use_logits = True
535+ else :
536+ if y is None : # pragma: no cover
537+ logger .info ("Setting labels to estimator object detection predictions." )
538+ y = self .estimator .predict (x = x )
516539
517540 if isinstance (y , np .ndarray ):
518541 x_tensor = torch .Tensor (x )
0 commit comments