class ConvVAE(th.nn.Module):
def __init__(self):
super(ConvVAE, self).__init__()
# Encoder
self.encoder = th.nn.Sequential(
th.nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1),
th.nn.ReLU(),
th.nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
th.nn.ReLU(),
th.nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
th.nn.ReLU(),
th.nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
th.nn.ReLU(),
th.nn.Flatten()
)
# Latent space
self.fc_mu = th.nn.Linear(256 * 8 * 8, 3)
self.fc_var = th.nn.Linear(256 * 8 * 8, 3)
# Decoder
self.decoder_input = th.nn.Linear(3, 256 * 8 * 8)
self.decoder = th.nn.Sequential(
th.nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
th.nn.ReLU(),
th.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
th.nn.ReLU(),
th.nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
th.nn.ReLU(),
th.nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),
th.nn.Sigmoid()
)
def encode(self, x):
x = self.encoder(x)
mu = self.fc_mu(x)
log_var = self.fc_var(x)
return mu, log_var
def reparameterize(self, mu, log_var):
std = th.exp(0.5 * log_var)
eps = th.randn_like(std)
return mu + eps * std
def decode(self, z):
x = self.decoder_input(z)
x = x.view(-1, 256, 8, 8)
x = self.decoder(x)
return x
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
# Initialize model
model = ConvVAE().cuda()
# Define loss function
def loss_function(recon_x, x, mu, log_var):
BCE = th.nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * th.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE + KLD
# Initialize optimizer
optimizer = th.optim.Adam(model.parameters(), lr=1e-3)