import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import h5py as h5
import torch as th
In [2]:
Download the data from GDrive into your local folder
or download the data from the notebook
In [4]:
import gdown
# https://drive.google.com/file/d/1sVATyJhuX0UEdtd07gDnVTmg5xu1mp-g/view?usp=drive_link
# Google Drive file ID from the shareable link
= "1sVATyJhuX0UEdtd07gDnVTmg5xu1mp" # Replace with actual file ID
file_id = "./01_segmentation.tar.gz"
output_path
# Create directory if it doesn't exist
=True)
os.makedirs(os.path.dirname(output_path), exist_ok
# Download file from Google Drive
= f'https://drive.google.com/uc?id={file_id}'
url =False) gdown.download(url, output_path, quiet
Let’s load the data and display some examples
In [6]:
= '/mnt/data/insync/braunphil/Public/datascience_miniproject/optimal_Au/images/'
file_path
print("Folders in file_path:")
= [f for f in os.listdir(file_path) if os.path.isdir(os.path.join(file_path, f))]
folders for folder in folders[:5]:
print(f"- {folder}")
Folders in file_path:
- image_batch_52494658464_20240403
- image_batch_5254823359_20240403
- image_batch_52136360101_20240403
- image_batch_52393632444_20240403
- image_batch_52048343522_20240403
Get all .h5 files from all subfolders
In [7]:
= []
h5_files for folder in folders:
= os.path.join(file_path, folder)
folder_path = [f for f in os.listdir(folder_path) if f.endswith('.h5')]
files for f in files])
h5_files.extend([os.path.join(folder_path, f)
print(f"\nFound {len(h5_files)} .h5 files:")
for f in h5_files[:5]: # Print first 5 files as example
print(f"- {f}")
if len(h5_files) > 5:
print("...")
Found 128 .h5 files:
- /mnt/data/insync/braunphil/Public/datascience_miniproject/optimal_Au/images/image_batch_52494658464_20240403/image_batch.h5
- /mnt/data/insync/braunphil/Public/datascience_miniproject/optimal_Au/images/image_batch_5254823359_20240403/image_batch.h5
- /mnt/data/insync/braunphil/Public/datascience_miniproject/optimal_Au/images/image_batch_52136360101_20240403/image_batch.h5
- /mnt/data/insync/braunphil/Public/datascience_miniproject/optimal_Au/images/image_batch_52393632444_20240403/image_batch.h5
- /mnt/data/insync/braunphil/Public/datascience_miniproject/optimal_Au/images/image_batch_52048343522_20240403/image_batch.h5
...
In [8]:
= h5_files[0]
file_names
with h5.File(file_names, 'r') as f:
= f['train_batch'][...]
d = f['mask_batch'][...]
m
= plt.subplots(1, 3, figsize=(15/2, 5/2))
fig, axes for i in range(3):
axes[i].imshow(d[i])'off')
axes[i].axis(f'Image {i+1}')
axes[i].set_title(
plt.tight_layout()
plt.show()# Plot first 3 masks
= plt.subplots(1, 3, figsize=(15/2, 5/2))
fig, axes for i in range(3):
axes[i].imshow(m[i])'off')
axes[i].axis(f'Mask {i+1}')
axes[i].set_title(
plt.tight_layout()
plt.show()
# %%
In [10]:
d.shape, m.shape
((64, 512, 512), (64, 512, 512))
In [13]:
class SegmentationDataset:
def __init__(self, file_paths):
self.file_paths = file_paths
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
= self.file_paths[idx]
file_path with h5.File(file_path, 'r') as f:
# Load image and mask from h5 file
= f['train_batch'][...]
image = f['mask_batch'][...]
mask return image, mask
# Create dataset instance
= SegmentationDataset(h5_files)
dataset
# Test loading first item
= dataset[0]
images, masks print(f"Loaded images shape: {images.shape}")
print(f"Loaded masks shape: {masks.shape}")
Loaded images shape: (64, 512, 512)
Loaded masks shape: (64, 512, 512)
In [14]:
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
=3, padding=1),
nn.Conv2d(in_channels, out_channels, kernel_size
nn.BatchNorm2d(out_channels),=True),
nn.ReLU(inplace=3, padding=1),
nn.Conv2d(out_channels, out_channels, kernel_size
nn.BatchNorm2d(out_channels),=True)
nn.ReLU(inplace
)
def forward(self, x):
return self.double_conv(x)
class UNet(nn.Module):
def __init__(self, in_channels=1, out_channels=1):
super().__init__()
# Encoder
self.conv1 = DoubleConv(in_channels, 64)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(2)
self.conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(2)
self.conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(2)
self.conv5 = DoubleConv(512, 1024)
# Decoder
self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.upconv4 = DoubleConv(1024, 512)
self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.upconv3 = DoubleConv(512, 256)
self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.upconv2 = DoubleConv(256, 128)
self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.upconv1 = DoubleConv(128, 64)
# Final output
self.outconv = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
# Encoder
= self.conv1(x)
conv1 = self.pool1(conv1)
pool1 = self.conv2(pool1)
conv2 = self.pool2(conv2)
pool2 = self.conv3(pool2)
conv3 = self.pool3(conv3)
pool3 = self.conv4(pool3)
conv4 = self.pool4(conv4)
pool4 = self.conv5(pool4)
conv5
# Decoder with skip connections
= self.up4(conv5)
up4 = torch.cat([up4, conv4], dim=1)
up4 = self.upconv4(up4)
up4
= self.up3(up4)
up3 = torch.cat([up3, conv3], dim=1)
up3 = self.upconv3(up3)
up3
= self.up2(up3)
up2 = torch.cat([up2, conv2], dim=1)
up2 = self.upconv2(up2)
up2
= self.up1(up2)
up1 = torch.cat([up1, conv1], dim=1)
up1 = self.upconv1(up1)
up1
= self.outconv(up1)
out return out
# Create model instance
= UNet(in_channels=1, out_channels=1)
model
In [18]:
import torch.optim as optim
from torch.utils.data import DataLoader
# Define loss function and optimizer
= nn.BCEWithLogitsLoss()
criterion = optim.Adam(model.parameters(), lr=0.001)
optimizer
# Create data loaders
= 4
batch_size = DataLoader(dataset, batch_size=batch_size, shuffle=True)
train_loader # val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
# Training parameters
= 50
num_epochs = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = model.to(device)
model
# Training loop
for epoch in range(num_epochs):
model.train()= 0.0
running_loss
for i, (images, masks) in enumerate(train_loader):
# Move data to device
= images.to(device)
images = masks.to(device)
masks
# Zero the gradients
optimizer.zero_grad()
# Forward pass
= model(images)
outputs = criterion(outputs, masks)
loss
# Backward pass and optimize
loss.backward()
optimizer.step()
+= loss.item()
running_loss
# Print statistics every 10 batches
if (i + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')
# Print epoch statistics
= running_loss / len(train_loader)
epoch_loss print(f'Epoch [{epoch+1}/{num_epochs}] complete. Average Loss: {epoch_loss:.4f}')
print('Training finished!')
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) Cell In[18], line 33 30 optimizer.zero_grad() 32 # Forward pass ---> 33 outputs = model(images) 34 loss = criterion(outputs, masks) 36 # Backward pass and optimize File ~/mambaforge/envs/main11/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs) 1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(*args, **kwargs) File ~/mambaforge/envs/main11/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs) 1536 # If we don't have any hooks, we want to skip the rest of the logic in 1537 # this function, and just call forward. 1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(*args, **kwargs) 1543 try: 1544 result = None Cell In[14], line 49, in UNet.forward(self, x) 47 def forward(self, x): 48 # Encoder ---> 49 conv1 = self.conv1(x) 50 pool1 = self.pool1(conv1) 51 conv2 = self.conv2(pool1) File ~/mambaforge/envs/main11/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs) 1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(*args, **kwargs) File ~/mambaforge/envs/main11/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs) 1536 # If we don't have any hooks, we want to skip the rest of the logic in 1537 # this function, and just call forward. 1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(*args, **kwargs) 1543 try: 1544 result = None Cell In[14], line 17, in DoubleConv.forward(self, x) 16 def forward(self, x): ---> 17 return self.double_conv(x) File ~/mambaforge/envs/main11/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs) 1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(*args, **kwargs) File ~/mambaforge/envs/main11/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs) 1536 # If we don't have any hooks, we want to skip the rest of the logic in 1537 # this function, and just call forward. 1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(*args, **kwargs) 1543 try: 1544 result = None File ~/mambaforge/envs/main11/lib/python3.11/site-packages/torch/nn/modules/container.py:217, in Sequential.forward(self, input) 215 def forward(self, input): 216 for module in self: --> 217 input = module(input) 218 return input File ~/mambaforge/envs/main11/lib/python3.11/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs) 1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] 1531 else: -> 1532 return self._call_impl(*args, **kwargs) File ~/mambaforge/envs/main11/lib/python3.11/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs) 1536 # If we don't have any hooks, we want to skip the rest of the logic in 1537 # this function, and just call forward. 1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1539 or _global_backward_pre_hooks or _global_backward_hooks 1540 or _global_forward_hooks or _global_forward_pre_hooks): -> 1541 return forward_call(*args, **kwargs) 1543 try: 1544 result = None File ~/mambaforge/envs/main11/lib/python3.11/site-packages/torch/nn/modules/conv.py:460, in Conv2d.forward(self, input) 459 def forward(self, input: Tensor) -> Tensor: --> 460 return self._conv_forward(input, self.weight, self.bias) File ~/mambaforge/envs/main11/lib/python3.11/site-packages/torch/nn/modules/conv.py:456, in Conv2d._conv_forward(self, input, weight, bias) 452 if self.padding_mode != 'zeros': 453 return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), 454 weight, bias, self.stride, 455 _pair(0), self.dilation, self.groups) --> 456 return F.conv2d(input, weight, bias, self.stride, 457 self.padding, self.dilation, self.groups) RuntimeError: Input type (unsigned char) and bias type (float) should be the same