-
Notifications
You must be signed in to change notification settings - Fork 929
Open
Description
edit vgg16_bn.py. forward() *because torch jit does not supply "named tuple"
def forward(self, X):
h = self.slice1(X)
h_relu2_2 = h
h = self.slice2(h)
h_relu3_2 = h
h = self.slice3(h)
h_relu4_3 = h
h = self.slice4(h)
h_relu5_3 = h
h = self.slice5(h)
h_fc7 = h
# vgg_outputs = namedtuple(
# "VggOutputs", ["fc7", "relu5_3", "relu4_3", "relu3_2", "relu2_2"]
# )
# out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
return h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2convert craft to jit
craft = CRAFT()
ckpt_path = "~~"
state_dict = _get_state_dict(
ckpt_path=ckpt_path, include="module.", delete="module.", cuda=cuda
)
craft.load_state_dict(state_dict=state_dict, strict=True)
craft.eval()
craft = torch.jit.script(craft)In my gpu, speed +5%(1.34 sec -> 1.27sec), memory used are same(don't know why..)

wkpark
Metadata
Metadata
Assignees
Labels
No labels
