1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
| class DecodeBox(nn.Module): def __init__(self, anchors, num_classes, input_shape, anchors_mask=None,index=0): super(DecodeBox, self).__init__() if anchors_mask is None: anchors_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]] self.anchors = anchors self.num_classes = num_classes self.bbox_attrs = 5 + num_classes self.input_shape = input_shape
if index > 2 or index < 0: index = 0 self.index = index self.batch_size = 1
self.input_height = int(input_shape[0]/(8*(2**(2-index)))) self.input_width = int(input_shape[1]/(8*(2**(2-index)))) stride_h = self.input_shape[0] / self.input_height stride_w = self.input_shape[1] / self.input_width self._scale = torch.Tensor([self.input_width, self.input_height , self.input_width, self.input_height]) self.anchors_mask = anchors_mask self.scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in self.anchors[anchors_mask[self.index]]]
def forward(self, x):
prediction = x.view(self.batch_size, len(self.anchors_mask[self.index]), self.bbox_attrs, self.input_height, self.input_width).permute(0, 1, 3, 4, 2).contiguous()
box_x = torch.sigmoid(prediction[..., 0]) box_y = torch.sigmoid(prediction[..., 1]) w = torch.sigmoid(prediction[..., 2]) h = torch.sigmoid(prediction[..., 3]) conf = torch.sigmoid(prediction[..., 4]) pred_cls = torch.sigmoid(prediction[..., 5:])
FloatTensor = torch.cuda.FloatTensor if box_x.is_cuda else torch.FloatTensor LongTensor = torch.cuda.LongTensor if box_x.is_cuda else torch.LongTensor
grid_x = torch.linspace(0, self.input_width - 1, self.input_width).repeat(self.input_height, 1).repeat( self.batch_size * len(self.anchors_mask[self.index]), 1, 1).view(box_x.shape).type(FloatTensor) grid_y = torch.linspace(0, self.input_height - 1, self.input_height).repeat(self.input_width, 1).t().repeat( self.batch_size * len(self.anchors_mask[self.index]), 1, 1).view(box_y.shape).type(FloatTensor)
anchor_w = FloatTensor(self.scaled_anchors).index_select(1, LongTensor([0])) anchor_h = FloatTensor(self.scaled_anchors).index_select(1, LongTensor([1])) anchor_w = anchor_w.repeat(self.batch_size, 1).repeat(1, 1, self.input_height * self.input_width).view(w.shape) anchor_h = anchor_h.repeat(self.batch_size, 1).repeat(1, 1, self.input_height * self.input_width).view(h.shape)
pred_boxes = FloatTensor(prediction[..., :4].shape) pred_boxes[..., 0] = box_x * 2. - 0.5 + grid_x pred_boxes[..., 1] = box_y * 2. - 0.5 + grid_y pred_boxes[..., 2] = (w * 2) ** 2 * anchor_w pred_boxes[..., 3] = (h * 2) ** 2 * anchor_h
output = torch.cat((pred_boxes.view(self.batch_size, -1, 4) / self._scale.type(FloatTensor), conf.view(self.batch_size, -1, 1), pred_cls.view(self.batch_size, -1, self.num_classes)), -1) return output
|