1. Transformer Object Detection
1. DeTR
Transformer-based object detection has gained popularity with models like DETR (DEtection TRansformer). Here’s an example of how to use the pre-trained DETR model from PyTorch’s torchvision
library for object detection.
Sample Code
import torchimport torchvision.transforms as Tfrom torchvision.models.detection import detr_resnet50from PIL import Imageimport matplotlib.pyplot as pltimport numpy as np
# Load a pre-trained DETR modelmodel = detr_resnet50(pretrained=True)model.eval() # Set the model to evaluation mode
# Define a function to transform the input imagedef transform_image(image_path): image = Image.open(image_path).convert("RGB") transform = T.Compose([ T.ToTensor(), # Convert image to tensor T.Resize(800), # Resize image to a consistent size ]) return transform(image)
# Perform object detectiondef detect_objects(image_path): image_tensor = transform_image(image_path)
# Add batch dimension image_tensor = image_tensor.unsqueeze(0)
with torch.no_grad(): predictions = model(image_tensor)
return predictions[0]
# Draw bounding boxes on the imagedef draw_boxes(image_path, predictions): image = Image.open(image_path).convert("RGB") image_tensor = transform_image(image_path)
# Extract bounding boxes, labels, and scores from predictions boxes = predictions['boxes'] labels = predictions['labels'] scores = predictions['scores']
# Filter out boxes with scores below a threshold (e.g., 0.5) threshold = 0.5 keep = scores >= threshold boxes = boxes[keep].cpu().numpy() labels = labels[keep].cpu().numpy()
# Convert image tensor to numpy array for drawing image_np = image_tensor.permute(1, 2, 0).numpy() image_np = (image_np * 255).astype(np.uint8)
# Draw bounding boxes fig, ax = plt.subplots(1, figsize=(12, 9)) ax.imshow(image_np) for box, label, score in zip(boxes, labels, scores): x1, y1, x2, y2 = box rect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none') ax.add_patch(rect) ax.text(x1, y1, f'{int(label)}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), fontsize=12, color='black')
plt.axis('off') plt.show()
# Example usageimage_path = 'path_to_your_image.jpg'predictions = detect_objects(image_path)draw_boxes(image_path, predictions)
Explanation:
-
Loading the Model:
model = detr_resnet50(pretrained=True)model.eval() # Set the model to evaluation mode- Loads the DETR model with a ResNet-50 backbone, pre-trained on the COCO dataset.
-
Image Transformation:
def transform_image(image_path):image = Image.open(image_path).convert("RGB")transform = T.Compose([T.ToTensor(),T.Resize(800), # Resize image to a consistent size])return transform(image)- Converts the image to a tensor and resizes it to a consistent size.
-
Object Detection:
def detect_objects(image_path):image_tensor = transform_image(image_path)image_tensor = image_tensor.unsqueeze(0) # Add batch dimensionwith torch.no_grad():predictions = model(image_tensor)return predictions[0]- Runs the image through the DETR model to get predictions.
-
Drawing Bounding Boxes:
def draw_boxes(image_path, predictions):image = Image.open(image_path).convert("RGB")image_tensor = transform_image(image_path)boxes = predictions['boxes']labels = predictions['labels']scores = predictions['scores']threshold = 0.5keep = scores >= thresholdboxes = boxes[keep].cpu().numpy()labels = labels[keep].cpu().numpy()image_np = image_tensor.permute(1, 2, 0).numpy()image_np = (image_np * 255).astype(np.uint8)fig, ax = plt.subplots(1, figsize=(12, 9))ax.imshow(image_np)for box, label, score in zip(boxes, labels, scores):x1, y1, x2, y2 = boxrect = plt.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')ax.add_patch(rect)ax.text(x1, y1, f'{int(label)}: {score:.2f}', bbox=dict(facecolor='yellow', alpha=0.5), fontsize=12, color='black')plt.axis('off')plt.show()- Draws bounding boxes on the image and displays the results using matplotlib.
Notes:
- Threshold: Adjust the threshold for filtering low-confidence detections.
- Labels: DETR returns class indices. To map these to class names, you need the COCO dataset class names or other relevant labels.
This example demonstrates how to use the DETR model for object detection with transformers. The DETR model combines transformers with CNNs to detect objects, making it a powerful tool for modern object detection tasks.