Skip to content

Commit 80c325b

Browse files
authored
Merge pull request #78 from shinning0821/main
The update of MobileSAM and other code optimizations.
2 parents 882249a + d5279ad commit 80c325b

File tree

300 files changed

+37432
-145
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

300 files changed

+37432
-145
lines changed

__pycache__/cfg.cpython-37.pyc

49 Bytes
Binary file not shown.
106 Bytes
Binary file not shown.

__pycache__/utils.cpython-37.pyc

264 Bytes
Binary file not shown.

cfg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ def parse_args():
55
parser = argparse.ArgumentParser()
66
parser.add_argument('-net', type=str, default='sam', help='net type')
77
parser.add_argument('-baseline', type=str, default='unet', help='baseline net type')
8+
parser.add_argument('-encoder', type=str, default='default', help='encoder type')
89
parser.add_argument('-seg_net', type=str, default='transunet', help='net type')
910
parser.add_argument('-mod', type=str, default='sam_adpt', help='mod type:seg,cls,val_ad')
1011
parser.add_argument('-exp_name', default='msa_test_isic', type=str, help='net type')

function.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,16 +124,20 @@ def train_sam(args, net: nn.Module, optimizer, train_loader,
124124
# imgs = imgs.to(dtype = mask_type,device = GPUdevice)
125125

126126
'''Train'''
127-
if args.net == 'sam' or args.net == 'efficient_sam':
127+
if args.mod == 'sam_adpt':
128128
for n, value in net.image_encoder.named_parameters():
129129
if "Adapter" not in n:
130130
value.requires_grad = False
131131
else:
132132
value.requires_grad = True
133+
else:
134+
for n, value in net.image_encoder.named_parameters():
135+
value.requires_grad = True
136+
133137
imge= net.image_encoder(imgs)
134138

135139
with torch.no_grad():
136-
if args.net == 'sam':
140+
if args.net == 'sam' or args.net == 'mobile_sam':
137141
se, de = net.prompt_encoder(
138142
points=pt,
139143
boxes=None,
@@ -146,7 +150,7 @@ def train_sam(args, net: nn.Module, optimizer, train_loader,
146150
labels=labels_torch,
147151
)
148152

149-
if args.net == 'sam':
153+
if args.net == 'sam' or args.net == 'mobile_sam':
150154
pred, _ = net.mask_decoder(
151155
image_embeddings=imge,
152156
image_pe=net.prompt_encoder.get_dense_pe(),
@@ -276,8 +280,7 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
276280
'''test'''
277281
with torch.no_grad():
278282
imge= net.image_encoder(imgs)
279-
280-
if args.net == 'sam':
283+
if args.net == 'sam' or args.net == 'mobile_sam':
281284
se, de = net.prompt_encoder(
282285
points=pt,
283286
boxes=None,
@@ -290,7 +293,7 @@ def validation_sam(args, val_loader, epoch, net: nn.Module, clean_dir=True):
290293
labels=labels_torch,
291294
)
292295

293-
if args.net == 'sam':
296+
if args.net == 'sam' or args.net == 'mobile_sam':
294297
pred, _ = net.mask_decoder(
295298
image_embeddings=imge,
296299
image_pe=net.prompt_encoder.get_dense_pe(),

models/ImageEncoder/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
from .adapter_block import AdapterBlock
1+
from .tinyvit.tiny_vit import TinyViT
2+
from .vit import AdapterBlock, Block
54 Bytes
Binary file not shown.
Binary file not shown.
5.09 KB
Binary file not shown.
13.5 KB
Binary file not shown.

0 commit comments

Comments
 (0)