@@ -209,15 +209,7 @@ def update(self, **kwargs) -> None:
209209
210210 :param bool learning: Whether to allow connection updates.
211211 """
212- mask = kwargs .get ("mask" , None )
213- if mask is None :
214- return
215-
216- for f in self .pipeline :
217- if type (f ).__name__ != 'Weight' :
218- continue
219-
220- f .value .masked_fill_ (mask , 0 )
212+ pass
221213
222214 @abstractmethod
223215 def reset_state_variables (self ) -> None :
@@ -403,6 +395,7 @@ def __init__(
403395 pipeline : list = [],
404396 manual_update : bool = False ,
405397 traces : bool = False ,
398+ mask : torch .Tensor = None ,
406399 ** kwargs ,
407400 ) -> None :
408401 # language=rst
@@ -416,11 +409,13 @@ def __init__(
416409 :param manual_update: Set to :code:`True` to disable automatic updates (applying learning rules) to connection features.
417410 False by default, updates called after each time step
418411 :param traces: Set to :code:`True` to record history of connection activity (for monitors)
412+ :param mask: A mask to zero out weights
419413 """
420414
421415 super ().__init__ (source , target , device , pipeline , ** kwargs )
422416 self .traces = traces
423417 self .manual_update = manual_update
418+ self .mask = mask
424419 if self .traces :
425420 self .activity = None
426421
@@ -446,6 +441,8 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
446441
447442 # Run through pipeline
448443 for f in self .pipeline :
444+ if type (f ).__name__ == 'Weight' and self .mask :
445+ f .value .masked_fill_ (self .mask , 0 )
449446 conn_spikes = f .compute (conn_spikes )
450447
451448 # Sum signals for each of the output/terminal neurons
@@ -492,8 +489,6 @@ def update(self, **kwargs) -> None:
492489 """
493490 learning = kwargs .get ("learning" , False )
494491 if learning and not self .manual_update :
495- super ().update (** kwargs )
496-
497492 # Pipeline learning
498493 for f in self .pipeline :
499494 f .update (** kwargs )
0 commit comments