Skip to content

Commit 70ee260

Browse files
committed
fix(preprocessor): add multimodal arc ref
1 parent 64351f4 commit 70ee260

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

test/module/perceiver_io/test_preprocessor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_text_preprocessor(batch, in_shape, num_freq_bands):
5050
'vector': {
5151
'type': 'FourierPreprocessor',
5252
'num_freq_bands': 16,
53-
'max_reso': [32],
53+
'max_reso': [31],
5454
'cat_pos': True,
5555
},
5656
}

torcharc/arc_ref.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,56 @@
231231
}
232232
}
233233
},
234+
'perceiver_multimodal2classifier': {
235+
'type': 'Perceiver',
236+
'in_shape': {'image': [224, 224, 3], 'vector': [31, 2]},
237+
'arc': {
238+
'preprocessor': {
239+
'type': 'MultimodalPreprocessor',
240+
'arc': {
241+
'image': {
242+
'type': 'FourierPreprocessor',
243+
'num_freq_bands': 64,
244+
'max_reso': [224, 224],
245+
'cat_pos': True,
246+
},
247+
'vector': {
248+
'type': 'FourierPreprocessor',
249+
'num_freq_bands': 16,
250+
'max_reso': [31],
251+
'cat_pos': True,
252+
},
253+
},
254+
'pad_channels': 2,
255+
},
256+
'encoder': {
257+
'type': 'PerceiverEncoder',
258+
'latent_shape': [2048, 1024],
259+
'head_dim': 1024, # usually preserves latent_shape[-1]
260+
'v_head_dim': None, # defaults to head_dim
261+
'cross_attn_num_heads': 1,
262+
'cross_attn_widening_factor': 1,
263+
'num_self_attn_blocks': 8,
264+
'num_self_attn_per_block': 6,
265+
'self_attn_num_heads': 8,
266+
'self_attn_widening_factor': 1,
267+
'dropout_p': 0.0,
268+
},
269+
'decoder': {
270+
'type': 'PerceiverDecoder',
271+
'out_shape': [1, 1024],
272+
'head_dim': 1024, # usually preserves out_shape[-1]
273+
'v_head_dim': None, # defaults to head_dim
274+
'cross_attn_num_heads': 1,
275+
'cross_attn_widening_factor': 1,
276+
'dropout_p': 0.0,
277+
},
278+
'postprocessor': {
279+
'type': 'ClassificationPostprocessor',
280+
'out_dim': 10,
281+
}
282+
}
283+
},
234284
# DAGs
235285
'forward': {
236286
'dag_in_shape': [8],

torcharc/module/perceiver_io/preprocessor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def build_pos_encoding(self, pos: torch.Tensor, max_reso: list = None) -> torch.
8585
@return position encodings tensor of shape (x, y,... d*(2*num_freq_bands+1))
8686
'''
8787
max_reso = max_reso or pos.shape[:-1]
88-
assert len(max_reso) == len(pos.shape[:-1]), f'max_reso len(shape) must match pos len(shape), but got {len(max_reso)} != {len(pos.shape[:-1])}'
88+
assert len(max_reso) == len(pos.shape[:-1]), f'max_reso len(shape) must match pos len(shape), but got {len(max_reso)} instead of {len(pos.shape[:-1])}'
8989
freq_bands = torch.stack([torch.linspace(1.0, max_r / 2.0, steps=self.num_freq_bands) for max_r in max_reso])
9090
pos_freqs = rearrange(torch.einsum('...d,df->d...f', pos, freq_bands), 'd ... f -> ... (d f)')
9191

0 commit comments

Comments
 (0)