MLPC Week 11: Anomaly Detection with Autoencoders

Reconstruction error as anomaly score on CahnHilliardDataset

Open In Colab

Learning Objectives

  • Use an autoencoder’s reconstruction error as an anomaly score
  • Train on one distribution and detect samples from a different distribution
  • Set detection thresholds and visualize highest-error samples

Setup

!pip install git+https://github.com/ECLIPSE-Lab/Ai4MatLectures.git "mdsdata>=0.1.5"
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from ai4mat.datasets import CahnHilliardDataset
import matplotlib.pyplot as plt
import numpy as np

1. Load the Data

We train on one simulation and treat a different simulation as “out-of-distribution” (anomalous).

train_dataset = CahnHilliardDataset(simulation_number=0)
test_dataset  = CahnHilliardDataset(simulation_number=1)

print(f"Train dataset (sim 0): {len(train_dataset)} samples")
print(f"Test  dataset (sim 1): {len(test_dataset)} samples")

x0, y0 = train_dataset[0]
print(f"\nSample x shape: {x0.shape}  (1, 64, 64)")
print(f"Sample y (energy): {y0:.4f}")
Train dataset (sim 0): 989 samples
Test  dataset (sim 1): 992 samples

Sample x shape: torch.Size([1, 64, 64])  (1, 64, 64)
Sample y (energy): 586.8312
# Visualize microstructure snapshots from both simulations
fig, axes = plt.subplots(2, 6, figsize=(16, 6))
for i in range(6):
    idx = i * (len(train_dataset) // 6)
    axes[0, i].imshow(train_dataset[idx][0].squeeze().numpy(), cmap='RdBu_r', vmin=0, vmax=1)
    axes[0, i].set_title(f"E={train_dataset[idx][1]:.3f}", fontsize=8)
    axes[0, i].axis('off')
    axes[1, i].imshow(test_dataset[idx][0].squeeze().numpy(), cmap='RdBu_r', vmin=0, vmax=1)
    axes[1, i].set_title(f"E={test_dataset[idx][1]:.3f}", fontsize=8)
    axes[1, i].axis('off')

axes[0, 0].set_ylabel("Sim 0 (train)", fontsize=9)
axes[1, 0].set_ylabel("Sim 1 (test)",  fontsize=9)
plt.suptitle("Cahn-Hilliard microstructures — red/blue = phase separation")
plt.tight_layout()
plt.show()

2. Build the Autoencoder

class ConvAutoencoder(nn.Module):
    def __init__(self, latent_dim=2):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),   # 32x32
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 16 * 16, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 32 * 16 * 16),
            nn.ReLU(),
            nn.Unflatten(1, (32, 16, 16)),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),  # 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),   # 64x64
            nn.Sigmoid()
        )

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z), z

model = ConvAutoencoder(latent_dim=16)
n_params = sum(p.numel() for p in model.parameters())
print(f"Autoencoder parameters: {n_params:,}")
Autoencoder parameters: 279,921

3. Train on Simulation 0

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
criterion    = nn.MSELoss()
optimizer    = torch.optim.Adam(model.parameters(), lr=1e-3)

train_losses = []
for epoch in range(30):
    model.train()
    ep_loss = 0.0
    for x_batch, _ in train_loader:  # labels (energy) not used during AE training
        optimizer.zero_grad()
        x_recon, _ = model(x_batch)
        loss = criterion(x_recon, x_batch)
        loss.backward()
        optimizer.step()
        ep_loss += loss.item() * len(x_batch)
    train_losses.append(ep_loss / len(train_dataset))
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:3d} | Train recon MSE: {train_losses[-1]:.5f}")

plt.plot(train_losses)
plt.xlabel("Epoch"); plt.ylabel("Reconstruction MSE")
plt.title("Autoencoder Training on Simulation 0")
plt.tight_layout(); plt.show()
Epoch  10 | Train recon MSE: 0.00056
Epoch  20 | Train recon MSE: 0.00021
Epoch  30 | Train recon MSE: 0.00013

4. Score: Reconstruction Error as Anomaly Signal

def compute_errors(dataset, model, batch_size=64):
    """Return per-sample reconstruction MSE."""
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    errors = []
    model.eval()
    with torch.no_grad():
        for x_batch, _ in loader:
            x_recon, _ = model(x_batch)
            # Per-sample MSE: mean over pixels
            per_sample = ((x_recon - x_batch) ** 2).mean(dim=(1, 2, 3))
            errors.extend(per_sample.numpy().tolist())
    return np.array(errors)

errors_train = compute_errors(train_dataset, model)
errors_test  = compute_errors(test_dataset,  model)

print(f"Train (sim 0) — mean error: {errors_train.mean():.5f} ± {errors_train.std():.5f}")
print(f"Test  (sim 1) — mean error: {errors_test.mean():.5f}  ± {errors_test.std():.5f}")
Train (sim 0) — mean error: 0.00012 ± 0.00011
Test  (sim 1) — mean error: 0.15296  ± 0.02372

5. Detect Anomalies

# Histogram of reconstruction errors
plt.figure(figsize=(8, 4))
bins = np.linspace(0, max(errors_train.max(), errors_test.max()) * 1.05, 50)
plt.hist(errors_train, bins=bins, alpha=0.6, color='tab:blue', label='Train (sim 0 — normal)')
plt.hist(errors_test,  bins=bins, alpha=0.6, color='tab:red',  label='Test  (sim 1 — OOD)')

# Simple threshold: mean + 2*std of training errors
threshold = errors_train.mean() + 2 * errors_train.std()
plt.axvline(threshold, color='black', linestyle='--', label=f'Threshold = {threshold:.5f}')
plt.xlabel("Reconstruction MSE"); plt.ylabel("Count")
plt.title("Anomaly Detection: Reconstruction Error Distribution")
plt.legend(); plt.tight_layout(); plt.show()

flagged_train = (errors_train > threshold).mean() * 100
flagged_test  = (errors_test  > threshold).mean() * 100
print(f"\nFlagged as anomalous:")
print(f"  Train (sim 0): {flagged_train:.1f}%  (false positive rate)")
print(f"  Test  (sim 1): {flagged_test:.1f}%   (detection rate)")


Flagged as anomalous:
  Train (sim 0): 5.0%  (false positive rate)
  Test  (sim 1): 100.0%   (detection rate)
# Show the 5 highest-error images alongside their reconstructions
top5_idx = np.argsort(errors_test)[-5:][::-1]

fig, axes = plt.subplots(2, 5, figsize=(14, 6))
for col, idx in enumerate(top5_idx):
    x_orig = test_dataset[idx][0].unsqueeze(0)
    with torch.no_grad():
        x_recon, _ = model(x_orig)

    axes[0, col].imshow(x_orig.squeeze().numpy(), cmap='RdBu_r', vmin=0, vmax=1)
    axes[0, col].set_title(f"err={errors_test[idx]:.4f}", fontsize=8)
    axes[0, col].axis('off')

    axes[1, col].imshow(x_recon.squeeze().numpy(), cmap='RdBu_r', vmin=0, vmax=1)
    axes[1, col].axis('off')

axes[0, 0].set_ylabel("Original (sim 1)", fontsize=9)
axes[1, 0].set_ylabel("Reconstructed",    fontsize=9)
plt.suptitle("5 Highest-Error Test Images (anomalies)")
plt.tight_layout()
plt.show()

Exercises

  1. What threshold would you set to catch 90% of simulation-1 samples while keeping the false positive rate below 10%? Adjust the multiplier in mean + k * std.
  2. Try training on simulation_number=[0, 1, 2] (three simulations combined). Does the detection rate for simulation-1 samples decrease? What does this tell you about the training distribution?
  3. What other types of microstructure anomalies could this approach detect — for example, images with different phase fractions, coarser morphology, or added noise?