October 07, 2020
Generative Adversarial Networks (GANs) are a powerful class of machine learning models that have the ability to generate image, video, text, and voice. Equipped with this capability, ML practitioners are realizing novel new applications like creating people that have never existed or composing never-before-heard music. Unfortunately, GANs can be very expensive to train: even using GPUs, it is not uncommon to take days or weeks to train a single GAN model, let alone to perform hyperparameter tuning or architecture search.
With Determined’s open source deep learning training platform, we can dramatically reduce the time required to train a GAN. Determined makes distributed training incredibly easy: once a model has been adapted to use Determined’s API, switching from single-GPU training to distributed training is a simple configuration change.
In this this post, we examine how to train a CycleGan implementation using Determined. CycleGAN aims to generate Monet-style paintings from still images and the inverse, still images from Monet-style images.
Let’s walk through the steps we took to train CycleGAN in Determined, and then scale our model training from 1 GPU to 64 GPUs. We’ll show that in the same training time of 45 hours, training distributed by Determined AI is able to reach a much higher quality of model. Not only are we able to achieve faster training time, but other features such as hyperparameter tuning, experiment tracking, and Tensorboard integration work out-of-the-box.
First, we need to implement the Determined AI Trial Class. By organizing our code into the Trial API show below, we can leverage features such as experiment tracking, metric visualization, hyperparameter tuning, and distributed training.
In the process, we’ll have organized our PyTorch code and removed unnecessary boilerplate and abstracted about engineering code.
class CycleGANTrial(PyTorchTrial):
def __init__(self, context: PyTorchTrialContext) -> None:
pass
def build_training_data_loader(self) -> DataLoader:
pass
def build_validation_data_loader(self) -> DataLoader:
pass
def train_batch(
self, batch: TorchData, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
pass
def evaluate_batch(self, batch: TorchData) -> Dict[str, Any]:
pass
By using the Determined Trial API, we organize our code and remove unnecessary boilerplate code. For instance, the Trial Class manages the training loop, saving training artifacts, and saving checkpoints for you. Also, key experiment configuration like length or training and hyperparameters are abstracted away into experiment configuration files.
To start, we’ll let’s create the PyTorchTrial
constructor. This will look
similar to the original code before the training loop. We save our instantiated
models, optimizers, and learning rate schedulers to the Trial Class Context to
access them in our training and evaluate functions.
class CycleGANTrial(PyTorchTrial):
def __init__(self, context: PyTorchTrialContext) -> None:
self.context = context
self.logger = TorchWriter()
self.dataset_path = f"""{self.context.get_data_config()["downloaded_path"]}/{self.context.get_data_config()["dataset_name"]}"""
# Initialize the models.
input_shape = (
context.get_data_config()["channels"],
context.get_data_config()["img_height"],
context.get_data_config()["img_width"]
)
self.G_AB = self.context.wrap_model(GeneratorResNet(input_shape, context.get_hparam("n_residual_blocks")))
self.G_BA = self.context.wrap_model(GeneratorResNet(input_shape, context.get_hparam("n_residual_blocks")))
self.D_A = self.context.wrap_model(Discriminator(input_shape))
self.D_B = self.context.wrap_model(Discriminator(input_shape))
# Losses
self.criterion_GAN = self.context.wrap_model(torch.nn.MSELoss())
self.criterion_cycle = self.context.wrap_model(torch.nn.L1Loss())
self.criterion_identity = self.context.wrap_model(torch.nn.L1Loss())
# Initialize weights
self.G_AB.apply(weights_init_normal)
self.G_BA.apply(weights_init_normal)
self.D_A.apply(weights_init_normal)
self.D_B.apply(weights_init_normal)
# Initialize the optimizers and learning rate scheduler.
lr = context.get_hparam("lr")
b1 = context.get_hparam("b1")
b2 = context.get_hparam("b2")
n_epochs = context.get_experiment_config()["searcher"]["max_length"]["epochs"]
decay_epoch = context.get_hparam("decay_epoch")
self.optimizer_G = self.context.wrap_optimizer(torch.optim.Adam(
itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()), lr=lr, betas=(b1, b2)
))
self.optimizer_D_A = self.context.wrap_optimizer(torch.optim.Adam(self.D_A.parameters(), lr=lr, betas=(b1, b2)))
self.optimizer_D_B = self.context.wrap_optimizer(torch.optim.Adam(self.D_B.parameters(), lr=lr, betas=(b1, b2)))
self.lr_scheduler_G = self.context.wrap_lr_scheduler(torch.optim.lr_scheduler.LambdaLR(
self.optimizer_G, lr_lambda=LambdaLR(n_epochs, 0, decay_epoch).step
), step_mode=LRScheduler.StepMode.STEP_EVERY_EPOCH)
self.lr_scheduler_D_A = self.context.wrap_lr_scheduler(torch.optim.lr_scheduler.LambdaLR(
self.optimizer_D_A, lr_lambda=LambdaLR(n_epochs, 0, decay_epoch).step
), step_mode=LRScheduler.StepMode.STEP_EVERY_EPOCH)
self.lr_scheduler_D_B = self.context.wrap_lr_scheduler(torch.optim.lr_scheduler.LambdaLR(
self.optimizer_D_B, lr_lambda=LambdaLR(n_epochs, 0, decay_epoch).step
), step_mode=LRScheduler.StepMode.STEP_EVERY_EPOCH)
# Buffers of previously generated samples
self.fake_A_buffer = ReplayBuffer()
self.fake_B_buffer = ReplayBuffer()
# Image transformations
img_height = self.context.get_data_config()["img_height"]
img_width = self.context.get_data_config()["img_width"]
self.transforms_ = [
transforms.Resize(int(img_height * 1.12), Image.BICUBIC),
transforms.RandomCrop((img_height, img_width)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
# Test images that are used for displaying in Tensorboard.
self.test_dataloader = torch.utils.data.DataLoader(
ImageDataset(self.dataset_path, transforms_=self.transforms_, unaligned=True, mode="test"),
batch_size=5,
shuffle=True,
num_workers=1,
)
Now that we have our models saved to the context, we need to implement the
train_batch
function. The code inside train_batch
should be almost
identical to the code inside the training loop, except we make sure to use the
Trial context models. We can step or zero out the gradients in the optimizers
here. Determined automatically captures and plots the returned metrics.
def train_batch(
self, batch: TorchData, epoch_idx: int, batch_idx: int
) -> Dict[str, torch.Tensor]:
imgs, _ = batch
# Set model input
real_A = Variable(batch["A"].type(Tensor))
real_B = Variable(batch["B"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((real_A.size(0), *self.D_A.output_shape))), requires_grad=False)
fake = Variable(Tensor(np.zeros((real_A.size(0), *self.D_A.output_shape))), requires_grad=False)
# ------------------
# Train Generators
# ------------------
self.G_AB.requires_grad_(True)
self.G_BA.requires_grad_(True)
self.D_A.requires_grad_(False)
self.D_B.requires_grad_(False)
self.G_AB.train()
self.G_BA.train()
# Identity loss
loss_id_A = self.criterion_identity(self.G_BA(real_A), real_A)
loss_id_B = self.criterion_identity(self.G_AB(real_B), real_B)
loss_identity = (loss_id_A + loss_id_B) / 2
# GAN loss
fake_B = self.G_AB(real_A)
loss_GAN_AB = self.criterion_GAN(self.D_B(fake_B), valid)
fake_A = self.G_BA(real_B)
loss_GAN_BA = self.criterion_GAN(self.D_A(fake_A), valid)
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
# Cycle loss
recov_A = self.G_BA(fake_B)
loss_cycle_A = self.criterion_cycle(recov_A, real_A)
recov_B = self.G_AB(fake_A)
loss_cycle_B = self.criterion_cycle(recov_B, real_B)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# Total loss
lambda_cyc = self.context.get_hparam("lambda_cyc")
lambda_id = self.context.get_hparam("lambda_id")
loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity
self.context.backward(loss_G)
self.optimizer_G.step()
self.optimizer_G.zero_grad()
# -----------------------
# Train Discriminator A
# -----------------------
# Set `requires_grad_` to only update parameters on the discriminator A.
self.G_AB.requires_grad_(False)
self.G_BA.requires_grad_(False)
self.D_A.requires_grad_(True)
self.D_B.requires_grad_(False)
# Real loss
loss_real = self.criterion_GAN(self.D_A(real_A), valid)
# Fake loss (on batch of previously generated samples)
fake_A_ = self.fake_A_buffer.push_and_pop(fake_A)
loss_fake = self.criterion_GAN(self.D_A(fake_A_.detach()), fake)
# Total loss
loss_D_A = (loss_real + loss_fake) / 2
self.context.backward(loss_D_A)
self.optimizer_D_A.step()
self.optimizer_D_A.zero_grad()
# -----------------------
# Train Discriminator B
# -----------------------
# Set `requires_grad_` to only update parameters on the discriminator A.
self.G_AB.requires_grad_(False)
self.G_BA.requires_grad_(False)
self.D_A.requires_grad_(False)
self.D_B.requires_grad_(True)
# Real loss
loss_real = self.criterion_GAN(self.D_B(real_B), valid)
# Fake loss (on batch of previously generated samples)
fake_B_ = self.fake_B_buffer.push_and_pop(fake_B)
loss_fake = self.criterion_GAN(self.D_B(fake_B_.detach()), fake)
# Total loss
loss_D_B = (loss_real + loss_fake) / 2
self.context.backward(loss_D_B)
self.optimizer_D_B.step()
self.optimizer_D_B.zero_grad()
loss_D = (loss_D_A + loss_D_B) / 2
# --------------
# Log Progress
# --------------
# If at sample interval save image
global_records = self.context.get_global_batch_size() * batch_idx
sample_interval = self.context.get_data_config()["sample_interval"]
if self.context.distributed.get_rank() == 0 and global_records % sample_interval == 0:
self.sample_images(self.context.get_data_config()["dataset_name"], batch_idx)
return {
"loss_D": loss_D,
"loss_G": loss_G,
"loss_GAN": loss_GAN,
"loss_cycle": loss_cycle,
"loss_identity": loss_identity,
}
We compute the validation metrics in the evaluate_batch
function. Determined automatically captures these
metrics as well.
def evaluate_batch(self, batch: TorchData) -> Dict[str, Any]:
# Set model input
real_A = Variable(batch["A"].type(Tensor))
real_B = Variable(batch["B"].type(Tensor))
# Adversarial ground truths
valid = Variable(Tensor(np.ones((real_A.size(0), *self.D_A.output_shape))), requires_grad=False)
# Real loss
loss_real_D_A = self.criterion_GAN(self.D_A(real_A), valid)
loss_real_D_B = self.criterion_GAN(self.D_B(real_B), valid)
# Total loss
loss_real_D = (loss_real_D_A + loss_real_D_B) / 2
return {
"loss_real_D": loss_real_D,
"loss_real_D_A": loss_real_D_A,
"loss_real_D_B": loss_real_D_B,
}
We also need to move our data loaders into the Trial Class.
def build_training_data_loader(self) -> DataLoader:
return DataLoader(
ImageDataset(self.dataset_path, transforms_=self.transforms_, unaligned=True),
batch_size=self.context.get_per_slot_batch_size(),
shuffle=True,
num_workers=self.context.get_data_config()["n_cpu"],
)
def build_validation_data_loader(self) -> DataLoader:
return DataLoader(
ImageDataset(self.dataset_path, transforms_=self.transforms_, unaligned=True, mode="test"),
batch_size=5,
shuffle=True,
num_workers=1,
)
Check out the completed Trial Class.
Each Experiment requires an Experiment Configuration to run on a Determined
Cluster. The Experiment Configuration is where we specify our dataset
information such as records_per_epoch
and training information like the
number of training epochs and the target validation metric. The experiment
configuration also provides a section for hyperparameter and arbitrary
key-value pairs.
1-gpu.yaml
:
description: Cycle GAN PyTorch 1 GPU
data:
downloaded_path: /tmp
dataset_name: monet2photo
n_cpu: 8
img_height: 256
img_width: 256
channels: 3
sample_interval: 3000
hyperparameters:
global_batch_size: 1
lr: 0.0002
b1: 0.5
b2: 0.999
decay_epoch: 100 # epoch from which to start lr decay
n_residual_blocks: 9 # number of residual blocks in generator
lambda_cyc: 10.0
lambda_id: 5.0
records_per_epoch: 6287
searcher:
name: single
metric: loss_real_D
max_length:
epochs: 2000
smaller_is_better: True
entrypoint: determined_model_def:CycleGANTrial
min_checkpoint_period:
epochs: 1
Now that we implemented the Determined PyTorch Trial Class, we can launch our cluster training job using the Determined CLI. For this experiment, we will train on a single GPU.
det experiment create 1-gpu.yaml .
At nearly 45 hours into our single GPU experiment, we clicked the “View in Tensorboard” button to inspect our generated images. The top row is the input Monet image, and the second row is the corresponding generated still image. The third row is the input still image, and the last is the inferred Monet interpretation.
This model produces decent images, but it is taking too long to train.
Determined AI’s built-in distributed training helps speed up training time
without extra code or infrastructure set up. We can specify the number of GPUs
and machines this training job consumes by changing the slots_per_trial
and
global_batch_size
in the Experiment Configuration.
Let’s create a new configuration file with both slots_per_trial
and
global_batch_size
set to 64. This configuration specifies that we want to
distribute training to 64 GPUs in our cluster, and each GPU will receive a
single training batch at a time. Note that we choose 64 as the global batch
size because we use 64 slots and we need the per-slot batch size to be at least
64-gpus.yaml
:
description: Cycle GAN Pytorch 64 GPUs
data:
downloaded_path: /tmp
dataset_name: monet2photo
n_cpu: 8
img_height: 256
img_width: 256
channels: 3
sample_interval: 3000
hyperparameters:
global_batch_size: 64
lr: 0.0002
b1: 0.5
b2: 0.999
decay_epoch: 100 # epoch from which to start lr decay
n_residual_blocks: 9 # number of residual blocks in generator
lambda_cyc: 10.0
lambda_id: 5.0
records_per_epoch: 6287
searcher:
name: single
metric: loss_real_D
max_length:
epochs: 2000
smaller_is_better: True
entrypoint: determined_model_def:CycleGANTrial
resources:
slots_per_trial: 64
min_checkpoint_period:
epochs: 1
We can launch this experiment using the CLI and the new configuration file. The Determined Master launches 8 GCP VM instances with 8 GPUs on each node to satisfy the experiment requirements.
det experiment create 64-gpus.yaml .
Opening the Tensorboard after the same 45 hours, we can see that the images are significantly better than the single GPU training.
You can check out the source code for this blog here. We encourage you to give Determined a spin by trying this example or any others available in the Determined repository. If you have any questions along the way, hop on our community Slack or reach out our GitHub – we’d love to help!