Image Classifier with PyTorch

Updated: 2026-02-17

Deep learning has become the cornerstone of modern computer vision. When the goal is to distinguish between objects in images—a task once deemed exclusive to human perception—developers now rely on convolutional neural networks (CNNs) powered by libraries like PyTorch. This guide takes you through the full development cycle: from data acquisition and preprocessing to model design, training, evaluation, and deployment. By the end, you’ll have produced a functioning image classifier that you can tailor to your own needs.

Why PyTorch?

  • Dynamic graph execution: Easier debugging and experimentation with torch.autograd.
  • Pythonic API: Seamless integration with the wider scientific stack (NumPy, SciPy, matplotlib).
  • Rich ecosystem: Tools such as torchvision, torchdata, and torchtext.
  • Community support: Extensive tutorials, forums, and a thriving user base.

1. Setting the Stage: Define the Problem

Before coding, clarify:

Aspect Detail
Goal Classify images into pre‑defined categories (e.g., Cats vs. Dogs).
Labels Binary or multi‑class?
Evaluation Metric Accuracy, F1‑score, confusion matrix.
Deployment Web API, mobile, embedded?

Defining constraints early prevents wasted effort on unsuitable solutions.

1.1 Project Skeleton

image_classifier/
├── data/
│   ├── train/
│   ├── val/
│   └── test/
├── models/
│   └── resnet18.py
├── utils/
│   ├── dataset.py
│   └── metrics.py
├── train.py
├── evaluate.py
├── inference.py
└── requirements.txt

2. Data Pipeline

A robust pipeline is the engine of model performance. PyTorch’s torch.utils.data.DataLoader and torchvision.transforms make this straightforward.

2.1 Acquire the Dataset

Source Description Size Licensing
CIFAR‑10 60k 32x32 color images across 10 classes 60k BSD
Custom Collected via web scraping or lab imaging Varied Depends on source

For illustration, we’ll focus on the CIFAR‑10 dataset, available through torchvision.datasets.cifar10.

2.2 Data Augmentation & Normalization

import torchvision.transforms as T

train_transforms = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std= [0.229, 0.224, 0.225]),
])

val_transforms = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std= [0.229, 0.224, 0.225]),
])
  • RandomHorizontalFlip: Simulates mirror images, preventing overfitting to orientation.
  • RandomCrop: Adds slight translations, making the model robust to small shifts.
  • Normalize: Centers feature distribution, accelerating convergence.

2.3 Custom Dataset Class

from torch.utils.data import Dataset
from PIL import Image
import torch

class CustomImageDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        image = Image.open(self.paths[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long)

2.4 DataLoader Instantiation

from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

Using multiple workers and pinning memory speeds up training, especially on GPUs.

3. Model Design

3.1 Choosing a Back‑bone

Model Parameters Top‑1 Accuracy on ImageNet Comments
ResNet‑18 ~11M 69.8% Lightweight, good starting point.
VGG‑16 ~138M 71.5% Deeper but slower.
EfficientNet‑B0 ~5M 77.5% State‑of‑the‑art compression.

Why ResNet‑18? It balances performance and computational cost, making it ideal for proof‑of‑concept learning.

3.2 Fine‑Tuning Strategy

  1. Load pre‑trained weights.
  2. Freeze early layers to preserve low‑level features.
  3. Replace the classifier (fully connected layer) to match the target number of classes.
import torchvision.models as models
import torch.nn as nn

def create_model(num_classes):
    model = models.resnet18(pretrained=True)
    # Freeze features
    for param in model.parameters():
        param.requires_grad = False

    # Replace the final fully‑connected layer
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

3.3 Parameter Counting

Layer #Params
conv1 ~64k
fc (new) 128k
Total (trainable) 128k

Only the final layer is trainable, drastically reducing required FLOPs.

4. Training Loop

Effective training hinges on correct hyperparameters and monitoring.

4.1 Hyperparameter Blueprint

Hyperparameter Value Rationale
Learning Rate 0.01 Standard for SGD with momentum.
Momentum 0.9 Stabilizes convergence.
Weight Decay 5e-4 Regularizes weights.
Optimizer SGD Strong baseline for classification.
Scheduler StepLR Decays LR every 10 epochs.
Epochs 30 Sufficient for CIFAR‑10 convergence.

4.2 Loss Function

Cross‑entropy is the default choice for multi‑class tasks:

criterion = nn.CrossEntropyLoss()

4.3 Training Code Snippet

import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01,
                      momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

for epoch in range(30):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    scheduler.step()
    epoch_loss = train_loss / total
    epoch_acc  = 100 * correct / total
    print(f'Epoch {epoch+1:02d} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.2f}%')

Tips:

  • Log gradients and activations via TensorBoard for early anomaly detection.
  • Save checkpoints every 5 epochs to rescue progress.

5. Evaluation

A rigorous evaluation prevents complacency and reveals failure modes.

5.1 Validation Routine

model.eval()
val_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        val_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

val_acc = 100 * correct / total
print(f'Validation Acc: {val_acc:.2f}%')

5.2 Advanced Metrics

For balanced accuracy and handling class imbalance, compute the F1 score:

from sklearn.metrics import f1_score

all_preds = []
all_labels = []
for inputs, labels in test_loader:
    outputs = model(inputs.to(device))
    _, preds = torch.max(outputs, 1)
    all_preds.extend(preds.cpu().numpy())
    all_labels.extend(labels.numpy())

f1 = f1_score(all_labels, all_preds, average='macro')
print(f'F1‑Score: {f1:.4f}')

5.3 Confusion Matrix

Visual feedback via a confusion matrix spotlights confusability:

import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(all_labels, all_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.show()

In CIFAR‑10, classes like “Automobile” and “Truck” often mis‑classify due to similar silhouettes; analyzing the matrix guides targeted data augmentation.

6. Model Inference

6.1 Single Image Prediction

def predict(image_path, model, transform):
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        logits = model(image)
    prob = torch.softmax(logits, dim=1)
    pred_class = torch.argmax(prob, dim=1).item()
    return pred_class, prob.squeeze().cpu().numpy()

6.2 Batch Inference

Wrap the inference logic into a lightweight HTTP API using Flask:

from flask import Flask, request, jsonify

app = Flask(__name__)
model = create_model(num_classes=10).to(device)
model.load_state_dict(torch.load('best_model.pt'))

@app.route('/predict', methods=['POST'])
def predict_endpoint():
    img_bytes = request.files['image'].read()
    img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
    img_t = val_transforms(img).unsqueeze(0).to(device)
    with torch.no_grad():
        logits = model(img_t)
        prob = torch.softmax(logits, dim=1)
    return jsonify({
        'label': int(torch.argmax(prob).item()),
        'probabilities': prob.squeeze().tolist()
    })

This endpoint can be containerized and exposed via Kubernetes or a serverless platform.

7. Deployment Considerations

Format Strengths Weaknesses
TorchScript (JIT) Runtime independent of Python; faster inference. Requires conversion steps; limited dynamic ops.
ONNX Interoperable across frameworks. Extra conversion step; potential precision loss.
Mobile‑compatible Core ML / NNAPI Needs quantization.

7.1 TorchScript Example

model.eval()
scripted_model = torch.jit.trace(model, torch.rand(1, 3, 32, 32).to(device))
scripted_model.save('resnet18_cifar10.pt')

7.2 Quantization for Edge Devices

PyTorch supports post‑training static quantization. Replace model.fc with an nn.Linear that supports quantization, then apply:

torch.quantization.quantize_dynamic(
    model, {nn.Linear}, dtype=torch.qint8
)

Result: inference speedup of ~3× on CPUs without losing more than 1% accuracy.

8. Best Practices Checklist

  1. Validate data integrity—missing labels or corrupted JPEGs ruin training.
  2. Monitor loss curves—flat loss may indicate too low a learning rate.
  3. Save model checkpoints—use callbacks to preserve the best‑performing epoch.
  4. Test on unseen data—helps detect overfitting early.
  5. Documentation & Version Control—Keep your code reproducible.

9. Case Study: Real‑World Cat–Dog Classifier

Stage Result
Dataset 10k user‑collected photos (1:1 cat/dog).
Augmentation 60% horizontal flips, random brightness.
Model ResNet‑50 (full‑fine tuning).
Accuracy 94.2% on 2,000 test images.
Deployment Docker container served via FastAPI, 50 ms latency on NVIDIA Jetson Nano.

This example showcases how the same pipeline can be scaled up with larger models and custom data.

10. Troubleshooting Common Pitfalls

Symptom Likely Cause Fix
No improvement after 10 epochs LR too high/ low Adjust scheduler or switch to AdamW.
GPU OOM Data loader too large Reduce batch size or enable mixed precision.
Test accuracy < 50% Wrong label mapping Verify label_to_idx mapping.
Slow inference Unfrozen all layers Freeze non‑critical layers or export to TorchScript.

11. Next Steps

  • Curriculum learning: Order training samples by difficulty.
  • Active learning: Query uncertain samples for labeling.
  • Adversarial robustness: Adversarial training to withstand malicious noise.

12. Final Thoughts

Building an image classifier in PyTorch is no longer an exercise—it’s a production-ready solution. By structuring your project, crafting a solid data pipeline, leveraging transferable knowledge from pre‑trained networks, and rigorously tuning hyperparameters, you can achieve state‑of‑the‑art performance. Keep your pipeline modular, monitor your metrics closely, and iterate until you hit the sweet spot between speed and accuracy.

In a world where billions of images are captured every moment, your deep‑learning skills are the brush that paints insight onto the canvas—stroke by careful stroke.

Model after training; deploy with confidence, let the network learn; let the world see what you’ve made.

Related Articles