Model Card

A fine-tuned keypoint detection model for detecting 14 keypoints on trousers.

The definition of keypoints is based on annotation of DeepFashion2 dataset.

Trousers Landmarks

Sample Prediction

Model Description

Get Started

  1. Install PyTorch and Torchvision.

  2. Instantiate the model and replace the prediction heads.

    from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights
    from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
    from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor
    
    # Load a pre-trained Keypoint RCNN model
    weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
    model = keypointrcnn_resnet50_fpn(weights=weights)
    
    # Replace model's head
    num_classes = 2
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    
    num_keypoints = 14
    in_features = model.roi_heads.keypoint_predictor.kps_score_lowres.in_channels
    model.roi_heads.keypoint_predictor = KeypointRCNNPredictor(in_features, num_keypoints)
    
  3. Download the model weight and load the state dict to the model.

    from safetensors.torch import load_model
    load_model(model, "model.safetensors", device="cuda")
    
  4. Refer to the Keypoint R-CNN doc for the model's usage.

  5. Otherwise, refer to this script to export the model to ONNX and OpenVINO IR.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Space using kengboon/keypointrcnn-trousers 1