Skip to content

Commit a221f06

Browse files
Sasha Shengfacebook-github-bot
authored andcommitted
[fix,feat] update the state_key based on module (#675)
Summary: - Allow for module-wise state dict key update - Make use of the `_register_load_state_dict_pre_hook` to update the key of the state dict - opted for this approach because recursion is already being implemented in the load_state_dict function and therefore I think there is no need to re-implement recursion. Better to make use of the pytorch implementation. - Slightly cleaner fix compared to this fix: [664](#664) - Some documentation clean up Pull Request resolved: #675 Reviewed By: vedanuj Differential Revision: D24714619 Pulled By: ytsheng fbshipit-source-id: ccbf85c9aedae4bded3234d9b178e6b34241bbc3
1 parent 6aec3eb commit a221f06

File tree

6 files changed

+35
-25
lines changed

6 files changed

+35
-25
lines changed

mmf/models/base_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def forward(self, sample_list):
6060

6161

6262
class BaseModel(nn.Module):
63-
"""For integration with Pythia's trainer, datasets and other features,
63+
"""For integration with MMF's trainer, datasets and other features,
6464
models needs to inherit this class, call `super`, write a build function,
6565
write a forward function taking a ``SampleList`` as input and returning a
6666
dict as output and finally, register it using ``@registry.register_model``
@@ -124,8 +124,8 @@ def config_path(cls):
124124

125125
@classmethod
126126
def format_state_key(cls, key):
127-
"""Can be implemented if something special needs to be done
128-
key when pretrained model is being load. This will adapt and return
127+
"""Can be implemented if something special needs to be done to the
128+
key when pretrained model is being loaded. This will adapt and return
129129
keys according to that. Useful for backwards compatibility. See
130130
updated load_state_dict below. For an example, see VisualBERT model's
131131
code.

mmf/models/m4c.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,6 @@ def __init__(self, config):
3333
def config_path(cls):
3434
return "configs/models/m4c/defaults.yaml"
3535

36-
@classmethod
37-
def format_state_key(cls, key):
38-
key = key.replace("obj_faster_rcnn_fc7.module.lc", "obj_faster_rcnn_fc7.lc")
39-
key = key.replace("ocr_faster_rcnn_fc7.module.lc", "ocr_faster_rcnn_fc7.lc")
40-
return key
41-
4236
def build(self):
4337
# modules requiring custom learning rates (usually for finetuning)
4438
self.finetune_modules = []

mmf/models/movie_mcan.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,6 @@ def __init__(self, config):
3030
def config_path(cls):
3131
return "configs/models/movie_mcan/defaults.yaml"
3232

33-
@classmethod
34-
def format_state_key(cls, key):
35-
key = key.replace(
36-
"image_feature_encoders.0.module.lc", "image_feature_encoders.0.lc"
37-
)
38-
return key
39-
4033
def build(self):
4134
self.image_feature_dim = 2048
4235
self._build_word_embedding()

mmf/models/pythia.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,7 @@ def config_path(cls):
3030

3131
@classmethod
3232
def format_state_key(cls, key):
33-
key = key.replace("fa_history", "fa_context")
34-
key = key.replace(
35-
"image_feature_encoders.0.module.lc", "image_feature_encoders.0.lc"
36-
)
37-
return key
33+
return key.replace("fa_history", "fa_context")
3834

3935
def build(self):
4036
self._build_word_embedding()

mmf/modules/encoders.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,32 @@ def __init__(self, config: Config, *args, **kwargs):
125125
self.lc.bias.data.copy_(torch.from_numpy(bias))
126126
self.out_dim = out_dim
127127

128+
def _load_from_state_dict(
129+
self,
130+
state_dict,
131+
prefix,
132+
local_metadata,
133+
strict,
134+
missing_keys,
135+
unexpected_keys,
136+
error_msgs,
137+
):
138+
old_prefix = prefix + "module."
139+
for k in list(state_dict.keys()):
140+
if k.startswith(old_prefix):
141+
new_k = k.replace(old_prefix, prefix)
142+
state_dict[new_k] = state_dict.pop(k)
143+
144+
super()._load_from_state_dict(
145+
state_dict,
146+
prefix,
147+
local_metadata,
148+
strict,
149+
missing_keys,
150+
unexpected_keys,
151+
error_msgs,
152+
)
153+
128154
def forward(self, image):
129155
i2 = self.lc(image)
130156
i3 = nn.functional.relu(i2)

mmf/utils/checkpoint.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -321,11 +321,12 @@ def _load_pretrained(self, ckpt):
321321
key += "."
322322
value += "."
323323
for attr in ckpt:
324+
if hasattr(model, "format_state_key"):
325+
formatted_attr = model.format_state_key(attr)
326+
else:
327+
formatted_attr = attr
328+
324329
for own_attr in own_state:
325-
if hasattr(model, "format_state_key"):
326-
formatted_attr = model.format_state_key(attr)
327-
else:
328-
formatted_attr = attr
329330
if (
330331
key in own_attr
331332
and value in formatted_attr

0 commit comments

Comments
 (0)