Skip to content

1. Transformer Object Detection

1. DeTR

Transformer-based object detection has gained popularity with models like DETR (DEtection TRansformer). Here’s an example of how to use the pre-trained DETR model from PyTorch’s torchvision library for object detection.

Sample Code

import torch
import torchvision.transforms as T
from torchvision.models.detection import detr_resnet50
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
# Load a pre-trained DETR model
model = detr_resnet50(pretrained=True)
model.eval() # Set the model to evaluation mode
# Define a function to transform the input image
def transform_image(image_path):
image = Image.open(image_path).convert("RGB")
transform = T.Compose([
T.ToTensor(), # Convert image to tensor
T.Resize(800), # Resize image to a consistent size
])
return transform(image)
# Perform object detection
def detect_objects(image_path):
image_tensor = transform_image(image_path)
# Add batch dimension
image_tensor = image_tensor.unsqueeze(0)
with torch.no_grad():
predictions = model(image_tensor)
return predictions[0]
# Draw bounding boxes on the image
def draw_boxes(image_path, predictions):
image = Image.open(image_path).convert("RGB")
image_tensor = transform_image(image_path)
# Extract bounding boxes, labels, and scores from predictions
boxes = predictions['boxes']
labels = predictions['labels']
scores = predictions['scores']
# Filter out boxes with scores below a threshold (e.g., 0.5)
threshold = 0.5
keep = scores >= threshold
boxes = boxes[keep].cpu().numpy()
labels = labels[keep].cpu().numpy()
# Convert image tensor to numpy array for drawing
image_np = image_tensor.permute(1, 2, 0).numpy()
image_np = (image_np * 255).astype(np.uint8)
# Draw bounding boxes
fig, ax = plt.subplots(1, figsize=(12, 9))
ax.imshow(image_np)
for box, label, score in zip(boxes, labels, scores):
x1, y1, x2, y2 = box
rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')
ax.add_patch(rect)
ax.text(x1, y1, f'{int(label)}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), fontsize=12, color='black')
plt.axis('off')
plt.show()
# Example usage
image_path = 'path_to_your_image.jpg'
predictions = detect_objects(image_path)
draw_boxes(image_path, predictions)

Explanation:

  1. Loading the Model:

    model = detr_resnet50(pretrained=True)
    model.eval() # Set the model to evaluation mode
    • Loads the DETR model with a ResNet-50 backbone, pre-trained on the COCO dataset.
  2. Image Transformation:

    def transform_image(image_path):
    image = Image.open(image_path).convert("RGB")
    transform = T.Compose([
    T.ToTensor(),
    T.Resize(800), # Resize image to a consistent size
    ])
    return transform(image)
    • Converts the image to a tensor and resizes it to a consistent size.
  3. Object Detection:

    def detect_objects(image_path):
    image_tensor = transform_image(image_path)
    image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
    with torch.no_grad():
    predictions = model(image_tensor)
    return predictions[0]
    • Runs the image through the DETR model to get predictions.
  4. Drawing Bounding Boxes:

    def draw_boxes(image_path, predictions):
    image = Image.open(image_path).convert("RGB")
    image_tensor = transform_image(image_path)
    boxes = predictions['boxes']
    labels = predictions['labels']
    scores = predictions['scores']
    threshold = 0.5
    keep = scores >= threshold
    boxes = boxes[keep].cpu().numpy()
    labels = labels[keep].cpu().numpy()
    image_np = image_tensor.permute(1, 2, 0).numpy()
    image_np = (image_np * 255).astype(np.uint8)
    fig, ax = plt.subplots(1, figsize=(12, 9))
    ax.imshow(image_np)
    for box, label, score in zip(boxes, labels, scores):
    x1, y1, x2, y2 = box
    rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')
    ax.add_patch(rect)
    ax.text(x1, y1, f'{int(label)}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), fontsize=12, color='black')
    plt.axis('off')
    plt.show()
    • Draws bounding boxes on the image and displays the results using matplotlib.

Notes:

  • Threshold: Adjust the threshold for filtering low-confidence detections.
  • Labels: DETR returns class indices. To map these to class names, you need the COCO dataset class names or other relevant labels.

This example demonstrates how to use the DETR model for object detection with transformers. The DETR model combines transformers with CNNs to detect objects, making it a powerful tool for modern object detection tasks.