Source code for mira.detectors.fasterrcnn

# pylint: disable=too-many-instance-attributes
import typing

import torch
import torchvision
import typing_extensions as tx
import pkg_resources

from .. import datasets as mds
from .. import core as mc
from . import detector
from . import common as mdc


BACKBONE_TO_PARAMS = {
    "resnet50": {
        "weights_url": "https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth",
        "fpn_func": torchvision.models.detection.backbone_utils.resnet_fpn_backbone,
        "fpn_extra_blocks_kwargs": {},
        "default_fpn_kwargs": {
            "trainable_layers": 3,
            "backbone_name": "resnet50",
            "extra_blocks": "lastlevelmaxpool",
        },
        "default_anchor_kwargs": {
            "sizes": ((32,), (64,), (128,), (256,), (512,)),
            "aspect_ratios": ((0.5, 1.0, 2.0),) * 5,
        },
        "default_detector_kwargs": {},
    },
    "timm": {
        "fpn_func": mdc.BackboneWithTIMM,
        "default_fpn_kwargs": {
            "model_name": "mnasnet_small",
            "out_indices": (2, 3, 4),
        },
        "default_anchor_kwargs": {
            "sizes": tuple(
                (x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3)))
                for x in [32, 64, 128]
            ),
            "aspect_ratios": ((1.0,),) * 3,
        },
        "default_detector_kwargs": {},
    },
    "mobilenet_large": {
        "weights_url": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth",
        "fpn_extra_blocks_kwargs": {},
        "fpn_func": torchvision.models.detection.backbone_utils.mobilenet_backbone,
        "default_fpn_kwargs": {
            "trainable_layers": 3,
            "backbone_name": "mobilenet_v3_large",
            "fpn": True,
            "extra_blocks": "lastlevelmaxpool",
        },
        "default_anchor_kwargs": {
            "sizes": (
                (
                    32,
                    64,
                    128,
                    256,
                    512,
                ),
            )
            * 3,
            "aspect_ratios": ((0.5, 1.0, 2.0),) * 3,
        },
        "default_detector_kwargs": {
            "rpn_score_thresh": 0.05,
        },
    },
    "mobilenet_large_320": {
        "weights_url": "https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth",
        "fpn_extra_blocks_kwargs": {},
        "fpn_func": torchvision.models.detection.backbone_utils.mobilenet_backbone,
        "default_fpn_kwargs": {
            "trainable_layers": 3,
            "backbone_name": "mobilenet_v3_large",
            "fpn": True,
            "extra_blocks": "lastlevelmaxpool",
        },
        "default_anchor_kwargs": {
            "sizes": (
                (
                    32,
                    64,
                    128,
                    256,
                    512,
                ),
            )
            * 3,
            "aspect_ratios": ((0.5, 1.0, 2.0),) * 3,
        },
        "default_detector_kwargs": {
            "min_size": 320,
            "max_size": 640,
            "rpn_pre_nms_top_n_test": 150,
            "rpn_post_nms_top_n_test": 150,
            "rpn_score_thresh": 0.05,
        },
    },
}


class ModifiedFasterRCNN(torchvision.models.detection.faster_rcnn.FasterRCNN):
    """Modified version of Faster RCNN that always computes inferences."""

    def forward(self, images, targets=None):
        if self.training:
            if targets is None:
                torch._assert(False, "targets should not be none when in training mode")
            else:
                for target in targets:
                    boxes = target["boxes"]
                    if isinstance(boxes, torch.Tensor):
                        torch._assert(
                            len(boxes.shape) == 2 and boxes.shape[-1] == 4,
                            f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.",
                        )
                    else:
                        torch._assert(
                            False,
                            f"Expected target boxes to be of type Tensor, got {type(boxes)}.",
                        )

        original_image_sizes: typing.List[typing.Tuple[int, int]] = []
        for img in images:
            val = img.shape[-2:]
            torch._assert(
                len(val) == 2,
                f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}",
            )
            original_image_sizes.append((val[0], val[1]))

        images, targets = self.transform(images, targets)

        # Check for degenerate boxes
        if targets is not None:
            for target_idx, target in enumerate(targets):
                boxes = target["boxes"]
                degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
                if degenerate_boxes.any():
                    # print the first degenerate box
                    bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
                    degen_bb: typing.List[float] = boxes[bb_idx].tolist()
                    torch._assert(
                        False,
                        "All bounding boxes should have positive height and width."
                        f" Found invalid box {degen_bb} for target at index {target_idx}.",
                    )

        features = self.backbone(images.tensors)
        if isinstance(features, torch.Tensor):
            features = typing.OrderedDict([("0", features)])
        else:
            # Forces features to comply with self.roi_heads.box_roi_pool.featmap_names
            features = typing.OrderedDict(
                [(str(idx), f) for idx, f in enumerate(features.values())]
            )
        proposals, proposal_losses = self.rpn(images, features, targets)
        if self.training:
            self.roi_heads.training = False
            detections, _ = self.roi_heads(features, proposals, images.image_sizes)
            self.roi_heads.training = True
            _, detector_losses = self.roi_heads(
                features, proposals, images.image_sizes, targets
            )
        else:
            detections, detector_losses = self.roi_heads(
                features, proposals, images.image_sizes, targets
            )

        detections = self.transform.postprocess(detections, images.image_sizes, original_image_sizes)  # type: ignore[operator]

        losses = {}
        losses.update(detector_losses)
        losses.update(proposal_losses)
        return {
            "loss": sum(loss for loss in losses.values()) if losses else None,
            "output": detections,
        }


[docs]class FasterRCNN(detector.Detector): """A wrapper around the FasterRCNN models in torchvision.""" def __init__( self, categories=mds.COCOCategories90, pretrained_backbone: bool = True, pretrained_top: bool = False, backbone: tx.Literal[ "resnet50", "mobilenet_large", "mobilenet_large_320" ] = "resnet50", device="cpu", fpn_kwargs=None, detector_kwargs=None, anchor_kwargs=None, resize_config: mc.resizing.ResizeConfig = None, ): self.categories = mc.Categories.from_categories(categories) self.backbone_name = backbone if pretrained_top: pretrained_backbone = False ( self.resize_config, self.anchor_kwargs, self.fpn_kwargs, self.detector_kwargs, self.backbone, fpn, anchor_generator, ) = mdc.initialize_basic( BACKBONE_TO_PARAMS, backbone=backbone, fpn_kwargs=fpn_kwargs, anchor_kwargs=anchor_kwargs, detector_kwargs=detector_kwargs, resize_config=resize_config, pretrained_backbone=pretrained_backbone, ) self.model = ModifiedFasterRCNN( fpn, len(categories) + 1, rpn_anchor_generator=anchor_generator, **self.detector_kwargs, ) self.model.transform = mdc.convert_rcnn_transform(self.model.transform) if pretrained_top: self.model.load_state_dict( torch.hub.load_state_dict_from_url( BACKBONE_TO_PARAMS[backbone]["weights_url"], progress=True, ) ) torchvision.models.detection.faster_rcnn.overwrite_eps(self.model, 0.0) self.set_device(device) def serve_module_string(self): return ( pkg_resources.resource_string( "mira", "detectors/assets/serve/fasterrcnn.py" ) .decode("utf-8") .replace("NUM_CLASSES", str(len(self.categories) + 1)) .replace("BACKBONE_NAME", f"'{self.backbone_name}'") .replace("RESIZE_CONFIG", str(self.resize_config)) .replace("DETECTOR_KWARGS", str(self.detector_kwargs)) .replace("ANCHOR_KWARGS", str(self.anchor_kwargs)) .replace("FPN_KWARGS", str({**self.fpn_kwargs, "pretrained": False})) ) def invert_targets(self, y, threshold=0.5): return [ mc.torchtools.InvertedTarget( annotations=[ mc.Annotation( category=self.categories[int(labelIdx) - 1], x1=x1, y1=y1, x2=x2, y2=y2, score=score, ) for (x1, y1, x2, y2), labelIdx, score in zip( labels["boxes"].detach().cpu().numpy(), labels["labels"].detach().cpu().numpy(), labels["scores"].detach().cpu().numpy(), ) if score > threshold ], labels=[], ) for labels in y["output"] ] def compute_targets(self, targets, width, height): return [ { "boxes": torch.tensor( b[:, :4], dtype=torch.float32, device=self.device ), "labels": torch.tensor( b[:, -1] + 1, dtype=torch.int64, device=self.device ), } for b in [self.categories.bboxes_from_group(t.annotations) for t in targets] ] def compute_anchor_boxes(self, width, height): return mdc.get_torchvision_anchor_boxes( model=self.model, anchor_generator=typing.cast( torch.nn.Module, self.model.rpn ).anchor_generator, device=self.device, height=height, width=width, )