Skip to content

Commit 9f8494d

Browse files
committed
minor update use_permute_fix
1 parent 8ce22df commit 9f8494d

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

models/dualprompt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def get_parser(parser) -> ArgumentParser:
2727
parser.set_defaults(optimizer='adam', batch_size=128)
2828

2929
parser.add_argument('--pretrained', default=1, type=binary_to_boolean_type, help='Load pretrained model or not')
30-
parser.add_argument('--use_fix_permute', type=binary_to_boolean_type, default=0, help='Apply fix to reshape issue from original implementation (ref: issue #56)')
30+
parser.add_argument('--use_permute_fix', type=binary_to_boolean_type, default=0, help='Apply fix to reshape issue from original implementation (ref: issue #56)')
3131

3232
# Optimizer parameters
3333
parser.add_argument('--clip_grad', type=float, default=1.0, metavar='NORM', help='Clip gradient norm (default: None, no clipping)')

models/dualprompt_utils/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def __init__(self, args: Namespace, n_classes: int):
1414
num_classes=n_classes,
1515
drop_rate=0,
1616
drop_path_rate=0,
17+
args=args
1718
)
1819
self.original_model.eval()
1920

@@ -41,6 +42,7 @@ def __init__(self, args: Namespace, n_classes: int):
4142
e_prompt_layer_idx=args.e_prompt_layer_idx,
4243
use_prefix_tune_for_e_prompt=args.use_prefix_tune_for_e_prompt,
4344
same_key_value=args.same_key_value,
45+
args=args
4446
)
4547

4648
if args.freeze:

models/dualprompt_utils/prompt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,10 +120,10 @@ def forward(self, x_embed, prompt_mask=None, cls_features=None):
120120
num_layers, dual, batch_size, top_k, length, num_heads, heads_embed_dim = batched_prompt_raw.shape
121121
if self.use_permute_fix: # this fixes issue #56
122122
batched_prompt_raw = batched_prompt_raw.permute(0, 2, 1, 3, 4, 5, 6)
123-
else: # this follows the original implementation from https://github.dev/google-research/l2p
124-
batched_prompt = batched_prompt_raw.reshape(
125-
num_layers, batch_size, dual, top_k * length, num_heads, heads_embed_dim
126-
)
123+
# else this follows the original implementation from https://github.dev/google-research/l2p
124+
batched_prompt = batched_prompt_raw.reshape(
125+
num_layers, batch_size, dual, top_k * length, num_heads, heads_embed_dim
126+
)
127127
else:
128128
batched_prompt_raw = self.prompt[:, idx]
129129
num_layers, batch_size, top_k, length, embed_dim = batched_prompt_raw.shape

0 commit comments

Comments
 (0)