MG Week 11: Latent Space as Materials Genome Coordinate

Energy-ordered representations with CahnHilliardDataset

Open In Colab

Learning Objectives

  • Use a continuous physical quantity (energy) to interpret latent space structure
  • Discover whether autoencoders preserve thermodynamic ordering
  • Apply nearest-neighbor retrieval in latent space

Setup

!pip install git+https://github.com/ECLIPSE-Lab/Ai4MatLectures.git "mdsdata>=0.1.5"
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from ai4mat.datasets import CahnHilliardDataset
import matplotlib.pyplot as plt
import numpy as np

1. Load the Data

# Use three simulations for training diversity
dataset = CahnHilliardDataset(simulation_number=[0, 1, 2])
print(f"Dataset size: {len(dataset)}")
x0, y0 = dataset[0]
print(f"Sample x shape: {x0.shape}  (1, 64, 64)")
print(f"Sample y (energy): {y0:.4f}  (free energy of microstructure)")

energies_all = torch.tensor([dataset[i][1].item() for i in range(len(dataset))]).numpy()
print(f"\nEnergy range: [{energies_all.min():.4f}, {energies_all.max():.4f}]")
print(f"Energy mean:  {energies_all.mean():.4f} ± {energies_all.std():.4f}")
  0%|          | 0/3 [00:00<?, ?it/s] 33%|███▎      | 1/3 [00:00<00:00,  2.40it/s] 67%|██████▋   | 2/3 [00:01<00:00,  1.82it/s]100%|██████████| 3/3 [00:01<00:00,  1.65it/s]                                             
Dataset size: 2975
Sample x shape: torch.Size([1, 64, 64])  (1, 64, 64)
Sample y (energy): 586.8312  (free energy of microstructure)

Energy range: [457.6855, 1099.6606]
Energy mean:  629.7600 ± 136.7265
# Visualize microstructures sorted by energy
idx_sorted = np.argsort(energies_all)
fig, axes = plt.subplots(2, 6, figsize=(16, 6))
# Low-energy samples (more phase-separated)
for i, ax in enumerate(axes[0]):
    idx = idx_sorted[i * (len(idx_sorted) // 6)]
    ax.imshow(dataset[idx][0].squeeze().numpy(), cmap='RdBu_r', vmin=0, vmax=1)
    ax.set_title(f"E={dataset[idx][1]:.3f}", fontsize=8)
    ax.axis('off')
axes[0, 0].set_ylabel("Low energy\n(phase-separated)", fontsize=8)

# High-energy samples (more mixed)
for i, ax in enumerate(axes[1]):
    idx = idx_sorted[-(i * (len(idx_sorted) // 6)) - 1]
    ax.imshow(dataset[idx][0].squeeze().numpy(), cmap='RdBu_r', vmin=0, vmax=1)
    ax.set_title(f"E={dataset[idx][1]:.3f}", fontsize=8)
    ax.axis('off')
axes[1, 0].set_ylabel("High energy\n(mixed/disordered)", fontsize=8)

plt.suptitle("Cahn-Hilliard microstructures ordered by free energy")
plt.tight_layout()
plt.show()

2. Train/Val Split

n_train = int(0.8 * len(dataset))
n_val = len(dataset) - n_train
train_ds, val_ds = random_split(dataset, [n_train, n_val])

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=32, shuffle=False)
print(f"Train: {n_train} | Val: {n_val}")
Train: 2380 | Val: 595

3. Define the Autoencoder

class ConvAutoencoder(nn.Module):
    def __init__(self, latent_dim=2):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),   # 32x32
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1),  # 16x16
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 16 * 16, latent_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 32 * 16 * 16),
            nn.ReLU(),
            nn.Unflatten(1, (32, 16, 16)),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1),  # 32x32
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1),   # 64x64
            nn.Sigmoid()
        )

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z), z

model = ConvAutoencoder(latent_dim=2)
n_params = sum(p.numel() for p in model.parameters())
print(f"Autoencoder parameters: {n_params:,}")
Autoencoder parameters: 50,531

4. Training Loop

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_losses, val_losses = [], []

for epoch in range(30):
    model.train()
    ep_loss = 0.0
    for x_batch, _ in train_loader:
        optimizer.zero_grad()
        x_recon, _ = model(x_batch)
        loss = criterion(x_recon, x_batch)
        loss.backward()
        optimizer.step()
        ep_loss += loss.item() * len(x_batch)
    train_losses.append(ep_loss / n_train)

    model.eval()
    v_loss = 0.0
    with torch.no_grad():
        for x_batch, _ in val_loader:
            x_recon, _ = model(x_batch)
            v_loss += criterion(x_recon, x_batch).item() * len(x_batch)
    val_losses.append(v_loss / n_val)
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1:3d} | Train MSE: {train_losses[-1]:.5f} | Val MSE: {val_losses[-1]:.5f}")

plt.plot(train_losses, label='Train'); plt.plot(val_losses, label='Val')
plt.xlabel("Epoch"); plt.ylabel("Reconstruction MSE")
plt.title("Autoencoder Training"); plt.legend()
plt.tight_layout(); plt.show()
Epoch  10 | Train MSE: 0.00311 | Val MSE: 0.00287
Epoch  20 | Train MSE: 0.00101 | Val MSE: 0.00095
Epoch  30 | Train MSE: 0.00060 | Val MSE: 0.00061

5. Latent Space Colored by Energy

all_z, all_energies = [], []
model.eval()
full_loader = DataLoader(dataset, batch_size=64, shuffle=False)
with torch.no_grad():
    for x_batch, y_batch in full_loader:
        _, z = model(x_batch)
        all_z.append(z.numpy())
        all_energies.append(y_batch.numpy())

all_z = np.concatenate(all_z, axis=0)
all_energies = np.concatenate(all_energies, axis=0)

fig, ax = plt.subplots(figsize=(8, 6))
sc = ax.scatter(all_z[:, 0], all_z[:, 1], c=all_energies,
                cmap='viridis', alpha=0.5, s=12)
plt.colorbar(sc, ax=ax, label="Free energy")
ax.set_xlabel("Latent dim 1")
ax.set_ylabel("Latent dim 2")
ax.set_title("Latent Space — colored by free energy")
plt.tight_layout()
plt.show()

# Check if energy ordering is preserved in latent space
# Compute Spearman correlation between latent coords and energy
from scipy.stats import spearmanr

r1, p1 = spearmanr(all_z[:, 0], all_energies)
r2, p2 = spearmanr(all_z[:, 1], all_energies)
print(f"Spearman correlation: latent dim 1 vs energy: r={r1:.3f} (p={p1:.2e})")
print(f"Spearman correlation: latent dim 2 vs energy: r={r2:.3f} (p={p2:.2e})")
print()
if max(abs(r1), abs(r2)) > 0.5:
    print("Strong correlation: the latent space has organized microstructures")
    print("by their thermodynamic energy — a physically meaningful coordinate!")
else:
    print("Weak correlation: the latent dims may encode other structural features.")
    print("Try rerunning or increasing training epochs.")
Spearman correlation: latent dim 1 vs energy: r=0.665 (p=0.00e+00)
Spearman correlation: latent dim 2 vs energy: r=0.392 (p=1.35e-109)

Strong correlation: the latent space has organized microstructures
by their thermodynamic energy — a physically meaningful coordinate!

6. Reconstruction Quality

fig, axes = plt.subplots(2, 8, figsize=(16, 5))
x_show = torch.stack([val_ds[i][0] for i in range(8)])
with torch.no_grad():
    x_recon, _ = model(x_show)
for i in range(8):
    axes[0, i].imshow(x_show[i].squeeze().numpy(), cmap='RdBu_r', vmin=0, vmax=1)
    axes[1, i].imshow(x_recon[i].squeeze().numpy(), cmap='RdBu_r', vmin=0, vmax=1)
    axes[0, i].axis('off'); axes[1, i].axis('off')
axes[0, 0].set_ylabel("Original",      fontsize=9)
axes[1, 0].set_ylabel("Reconstructed", fontsize=9)
plt.suptitle("Original vs. Reconstructed Cahn-Hilliard Microstructures")
plt.tight_layout(); plt.show()

7. Nearest-Neighbor Retrieval in Latent Space

# Given a query microstructure, find the 5 most similar ones by latent distance
query_idx = 0
query_z = all_z[query_idx:query_idx+1]  # (1, 2)

distances = np.linalg.norm(all_z - query_z, axis=1)
top5_idx = np.argsort(distances)[1:6]  # skip self (idx 0)

fig, axes = plt.subplots(1, 6, figsize=(16, 3))
axes[0].imshow(dataset[query_idx][0].squeeze().numpy(), cmap='RdBu_r', vmin=0, vmax=1)
axes[0].set_title(f"Query\nE={dataset[query_idx][1]:.3f}", fontsize=8)
axes[0].axis('off')

for col, idx in enumerate(top5_idx, 1):
    axes[col].imshow(dataset[idx][0].squeeze().numpy(), cmap='RdBu_r', vmin=0, vmax=1)
    axes[col].set_title(f"Rank {col}\nE={dataset[idx][1]:.3f}\nd={distances[idx]:.3f}", fontsize=8)
    axes[col].axis('off')

plt.suptitle("Nearest-Neighbor Retrieval in Latent Space")
plt.tight_layout()
plt.show()

Exercises

  1. Does the latent space preserve energy ordering? The Spearman correlation above quantifies this. Try training longer or with a larger latent_dim — does the correlation improve?
  2. Try simulation_number=-1 (all 18 simulations). How does the latent space look when trained on all data? Does the energy gradient in the scatter plot become clearer or more diffuse?
  3. For the nearest-neighbor retrieval: do the 5 nearest neighbors have similar energies to the query? Compute the mean absolute energy difference between query and its 5 neighbors — how does this compare to the global energy standard deviation?