|
20 | 20 | from models.common import * |
21 | 21 | from models.experimental import * |
22 | 22 | from utils.autoanchor import check_anchor_order |
23 | | -from utils.general import check_yaml, make_divisible, print_args, set_logging |
| 23 | +from utils.general import check_yaml, make_divisible, print_args, set_logging, check_version |
24 | 24 | from utils.plots import feature_visualization |
25 | 25 | from utils.torch_utils import copy_attr, fuse_conv_and_bn, initialize_weights, model_info, scale_img, \ |
26 | 26 | select_device, time_sync |
@@ -74,7 +74,10 @@ def forward(self, x): |
74 | 74 |
|
75 | 75 | def _make_grid(self, nx=20, ny=20, i=0): |
76 | 76 | d = self.anchors[i].device |
77 | | - yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)]) |
| 77 | + if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility |
| 78 | + yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)], indexing='ij') |
| 79 | + else: |
| 80 | + yv, xv = torch.meshgrid([torch.arange(ny).to(d), torch.arange(nx).to(d)]) |
78 | 81 | grid = torch.stack((xv, yv), 2).expand((1, self.na, ny, nx, 2)).float() |
79 | 82 | anchor_grid = (self.anchors[i].clone() * self.stride[i]) \ |
80 | 83 | .view((1, self.na, 1, 1, 2)).expand((1, self.na, ny, nx, 2)).float() |
|
0 commit comments