Skip to content

Commit e13974b

Browse files
committed
Mask for connection
1 parent 2f709fb commit e13974b

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

bindsnet/network/topology.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,7 @@ def update(self, **kwargs) -> None:
209209
210210
:param bool learning: Whether to allow connection updates.
211211
"""
212-
mask = kwargs.get("mask", None)
213-
if mask is None:
214-
return
215-
216-
for f in self.pipeline:
217-
if type(f).__name__ != 'Weight':
218-
continue
219-
220-
f.value.masked_fill_(mask, 0)
212+
pass
221213

222214
@abstractmethod
223215
def reset_state_variables(self) -> None:
@@ -403,6 +395,7 @@ def __init__(
403395
pipeline: list = [],
404396
manual_update: bool = False,
405397
traces: bool = False,
398+
mask: torch.Tensor = None,
406399
**kwargs,
407400
) -> None:
408401
# language=rst
@@ -416,11 +409,13 @@ def __init__(
416409
:param manual_update: Set to :code:`True` to disable automatic updates (applying learning rules) to connection features.
417410
False by default, updates called after each time step
418411
:param traces: Set to :code:`True` to record history of connection activity (for monitors)
412+
:param mask: A mask to zero out weights
419413
"""
420414

421415
super().__init__(source, target, device, pipeline, **kwargs)
422416
self.traces = traces
423417
self.manual_update = manual_update
418+
self.mask = mask
424419
if self.traces:
425420
self.activity = None
426421

@@ -446,6 +441,8 @@ def compute(self, s: torch.Tensor) -> torch.Tensor:
446441

447442
# Run through pipeline
448443
for f in self.pipeline:
444+
if type(f).__name__ == 'Weight' and self.mask:
445+
f.value.masked_fill_(self.mask, 0)
449446
conn_spikes = f.compute(conn_spikes)
450447

451448
# Sum signals for each of the output/terminal neurons
@@ -492,8 +489,6 @@ def update(self, **kwargs) -> None:
492489
"""
493490
learning = kwargs.get("learning", False)
494491
if learning and not self.manual_update:
495-
super().update(**kwargs)
496-
497492
# Pipeline learning
498493
for f in self.pipeline:
499494
f.update(**kwargs)

0 commit comments

Comments
 (0)