import gdown# Google Drive file ID from the shareable linkfile_id1 ="1wBeqnl_Hjk3l5zbDHlww4V5WBUSATSuA"# Replace with actual file IDoutput_path1 ="./04_image_to_image.zip"# Create directory if it doesn't existos.makedirs(os.path.dirname(output_path1), exist_ok=True)# Download file from Google Driveurl =f'https://drive.google.com/uc?id={file_id1}'gdown.download(url, output_path1, quiet=False)
Downloading...
From (original): https://drive.google.com/uc?id=1wBeqnl_Hjk3l5zbDHlww4V5WBUSATSuA
From (redirected): https://drive.google.com/uc?id=1wBeqnl_Hjk3l5zbDHlww4V5WBUSATSuA&confirm=t&uuid=ecfeb98d-438d-4fc6-8f7a-3c73e74d7870
To: c:\Users\braun\OneDrive\Documents\GitHub\DataScienceForElectronMicroscopy\notebooks\04_image_to_image.zip
100%|██████████| 1.29G/1.29G [00:13<00:00, 95.4MB/s]
'./04_image_to_image.zip'
In [6]:
import zipfile# Unzip the filewith zipfile.ZipFile("04_image_to_image.zip", 'r') as zip_ref: zip_ref.extractall("04_image_to_image")
import torchvisioninputs = []for i inrange(9): x, _ = dataset[i] # Get just the input image inputs.append(x)# Use torchvision's make_grid to create a grid of imagesgrid = torchvision.utils.make_grid(inputs, nrow=3, padding=2)# Convert to numpy and scale to 0-255 range for displaygrid_np = (grid.squeeze().numpy() *255).astype(np.uint8)# Plot the grid using matplotlibplt.figure(figsize=(7,7))plt.imshow(grid_np.transpose(1,2,0)) # Transpose from (C,H,W) to (H,W,C) for matplotlibplt.axis('off')plt.show()
Create a 3x3 grid of target images
In [14]:
import torchvisioninputs = []for i inrange(9): _, x = dataset[i] # Get just the input image inputs.append(x)# Use torchvision's make_grid to create a grid of imagesgrid = torchvision.utils.make_grid(inputs, nrow=3, padding=2)# Convert to numpy and scale to 0-255 range for displaygrid_np = (grid.squeeze().numpy() *255).astype(np.uint8)# Plot the grid using matplotlibplt.figure(figsize=(7,7))plt.imshow(grid_np.transpose(1,2,0)) # Transpose from (C,H,W) to (H,W,C) for matplotlibplt.axis('off')plt.show()
Create training and validation splits
In [18]:
import torch dataset_size =len(dataset)train_size =int(0.8* dataset_size)val_size = dataset_size - train_sizetrain_dataset, val_dataset = torch.utils.data.random_split( dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))print(f"Training set size: {len(train_dataset)}")print(f"Validation set size: {len(val_dataset)}")# Create data loadersbatch_size =32train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)# Get a random batch from training setdataiter =iter(train_loader)batch =next(dataiter)input_images, target_images = batchprint(f"Batch shape: {input_images.shape}")print(f"Target shape: {target_images.shape}")# Display a random batchidx = torch.randint(0, batch_size, (9,))sample_inputs = [input_images[i] for i in idx]sample_targets = [target_images[i] for i in idx]# Create input gridinput_grid = torchvision.utils.make_grid(sample_inputs, nrow=3, padding=2)input_grid_np = (input_grid.squeeze().numpy() *255).astype(np.uint8)# Create target grid target_grid = torchvision.utils.make_grid(sample_targets, nrow=3, padding=2)target_grid_np = (target_grid.squeeze().numpy() *255).astype(np.uint8)# Plot side by sideplt.figure(figsize=(14,7))plt.subplot(1,2,1)plt.imshow(input_grid_np.transpose(1,2,0))plt.title('Input Images (Training Set)')plt.axis('off')plt.subplot(1,2,2)plt.imshow(target_grid_np.transpose(1,2,0))plt.title('Target Images (Training Set)') plt.axis('off')plt.show()
Training set size: 1475
Validation set size: 369
Batch shape: torch.Size([32, 1, 1024, 1024])
Target shape: torch.Size([32, 1, 1024, 1024])
Now we can start defining our model and training it