Model Card
A fine-tuned keypoint detection model for detecting 14 keypoints on trousers.
The definition of keypoints is based on annotation of DeepFashion2 dataset.
Model Description
- Model type: Computer vision - keypoint detection
- Fine-tuned from: keypointrcnn_resnet50_fpn
Get Started
Install PyTorch and Torchvision.
Instantiate the model and replace the prediction heads.
from torchvision.models.detection import keypointrcnn_resnet50_fpn, KeypointRCNN_ResNet50_FPN_Weights from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.models.detection.keypoint_rcnn import KeypointRCNNPredictor # Load a pre-trained Keypoint RCNN model weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT model = keypointrcnn_resnet50_fpn(weights=weights) # Replace model's head num_classes = 2 in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) num_keypoints = 14 in_features = model.roi_heads.keypoint_predictor.kps_score_lowres.in_channels model.roi_heads.keypoint_predictor = KeypointRCNNPredictor(in_features, num_keypoints)Download the model weight and load the state dict to the model.
from safetensors.torch import load_model load_model(model, "model.safetensors", device="cuda")Refer to the Keypoint R-CNN doc for the model's usage.
Otherwise, refer to this script to export the model to ONNX and OpenVINO IR.
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support

