Determined AI Logo Determined AI Logo
  • Project
  • Blog
  • Resources

Train CycleGan on Multiple GPUS with Determined

Image of  Image of  Image of

By Sean Rowan, Shiyuan Zhu, Angela Jiang

October 07, 2020

Tags: GAN, PyTorch

Generative Adversarial Networks (GANs) are a powerful class of machine learning models that have the ability to generate image, video, text, and voice. Equipped with this capability, ML practitioners are realizing novel new applications like creating people that have never existed or composing never-before-heard music. Unfortunately, GANs can be very expensive to train: even using GPUs, it is not uncommon to take days or weeks to train a single GAN model, let alone to perform hyperparameter tuning or architecture search.

With Determined’s open source deep learning training platform, we can dramatically reduce the time required to train a GAN. Determined makes distributed training incredibly easy: once a model has been adapted to use Determined’s API, switching from single-GPU training to distributed training is a simple configuration change.

CycleGAN with PyTorch and Determined

In this this post, we examine how to train a CycleGan implementation using Determined. CycleGAN aims to generate Monet-style paintings from still images and the inverse, still images from Monet-style images.

Let’s walk through the steps we took to train CycleGAN in Determined, and then scale our model training from 1 GPU to 64 GPUs. We’ll show that in the same training time of 45 hours, training distributed by Determined AI is able to reach a much higher quality of model. Not only are we able to achieve faster training time, but other features such as hyperparameter tuning, experiment tracking, and Tensorboard integration work out-of-the-box.

Monet

First, we need to implement the Determined AI Trial Class. By organizing our code into the Trial API show below, we can leverage features such as experiment tracking, metric visualization, hyperparameter tuning, and distributed training.

In the process, we’ll have organized our PyTorch code and removed unnecessary boilerplate and abstracted about engineering code.

class CycleGANTrial(PyTorchTrial):
    def __init__(self, context: PyTorchTrialContext) -> None:
        pass

    def build_training_data_loader(self) -> DataLoader:
        pass

    def build_validation_data_loader(self) -> DataLoader:
        pass

    def train_batch(
        self, batch: TorchData, epoch_idx: int, batch_idx: int
    ) -> Dict[str, torch.Tensor]:
        pass

    def evaluate_batch(self, batch: TorchData) -> Dict[str, Any]:
        pass

By using the Determined Trial API, we organize our code and remove unnecessary boilerplate code. For instance, the Trial Class manages the training loop, saving training artifacts, and saving checkpoints for you. Also, key experiment configuration like length or training and hyperparameters are abstracted away into experiment configuration files.

To start, we’ll let’s create the PyTorchTrial constructor. This will look similar to the original code before the training loop. We save our instantiated models, optimizers, and learning rate schedulers to the Trial Class Context to access them in our training and evaluate functions.

class CycleGANTrial(PyTorchTrial):
	def __init__(self, context: PyTorchTrialContext) -> None:
		self.context = context
		self.logger = TorchWriter()
		self.dataset_path = f"""{self.context.get_data_config()["downloaded_path"]}/{self.context.get_data_config()["dataset_name"]}"""

		# Initialize the models.
		input_shape = (
			context.get_data_config()["channels"],
			context.get_data_config()["img_height"],
			context.get_data_config()["img_width"]
		)

		self.G_AB = self.context.wrap_model(GeneratorResNet(input_shape, context.get_hparam("n_residual_blocks")))
		self.G_BA = self.context.wrap_model(GeneratorResNet(input_shape, context.get_hparam("n_residual_blocks")))
		self.D_A = self.context.wrap_model(Discriminator(input_shape))
		self.D_B = self.context.wrap_model(Discriminator(input_shape))

		# Losses
		self.criterion_GAN = self.context.wrap_model(torch.nn.MSELoss())
		self.criterion_cycle = self.context.wrap_model(torch.nn.L1Loss())
		self.criterion_identity = self.context.wrap_model(torch.nn.L1Loss())

		# Initialize weights
		self.G_AB.apply(weights_init_normal)
		self.G_BA.apply(weights_init_normal)
		self.D_A.apply(weights_init_normal)
		self.D_B.apply(weights_init_normal)

		# Initialize the optimizers and learning rate scheduler.
		lr = context.get_hparam("lr")
		b1 = context.get_hparam("b1")
		b2 = context.get_hparam("b2")
		n_epochs = context.get_experiment_config()["searcher"]["max_length"]["epochs"]
		decay_epoch = context.get_hparam("decay_epoch")

		self.optimizer_G = self.context.wrap_optimizer(torch.optim.Adam(
			itertools.chain(self.G_AB.parameters(), self.G_BA.parameters()), lr=lr, betas=(b1, b2)
		))
		self.optimizer_D_A = self.context.wrap_optimizer(torch.optim.Adam(self.D_A.parameters(), lr=lr, betas=(b1, b2)))
		self.optimizer_D_B = self.context.wrap_optimizer(torch.optim.Adam(self.D_B.parameters(), lr=lr, betas=(b1, b2)))

		self.lr_scheduler_G = self.context.wrap_lr_scheduler(torch.optim.lr_scheduler.LambdaLR(
			self.optimizer_G, lr_lambda=LambdaLR(n_epochs, 0, decay_epoch).step
		), step_mode=LRScheduler.StepMode.STEP_EVERY_EPOCH)
		self.lr_scheduler_D_A = self.context.wrap_lr_scheduler(torch.optim.lr_scheduler.LambdaLR(
			self.optimizer_D_A, lr_lambda=LambdaLR(n_epochs, 0, decay_epoch).step
		), step_mode=LRScheduler.StepMode.STEP_EVERY_EPOCH)
		self.lr_scheduler_D_B = self.context.wrap_lr_scheduler(torch.optim.lr_scheduler.LambdaLR(
			self.optimizer_D_B, lr_lambda=LambdaLR(n_epochs, 0, decay_epoch).step
		), step_mode=LRScheduler.StepMode.STEP_EVERY_EPOCH)

		# Buffers of previously generated samples
		self.fake_A_buffer = ReplayBuffer()
		self.fake_B_buffer = ReplayBuffer()

		# Image transformations
		img_height = self.context.get_data_config()["img_height"]
		img_width = self.context.get_data_config()["img_width"]
		self.transforms_ = [
			transforms.Resize(int(img_height * 1.12), Image.BICUBIC),
			transforms.RandomCrop((img_height, img_width)),
			transforms.RandomHorizontalFlip(),
			transforms.ToTensor(),
			transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
		]

		# Test images that are used for displaying in Tensorboard.
		self.test_dataloader = torch.utils.data.DataLoader(
			ImageDataset(self.dataset_path, transforms_=self.transforms_, unaligned=True, mode="test"),
			batch_size=5,
			shuffle=True,
			num_workers=1,
		)

Now that we have our models saved to the context, we need to implement the train_batch function. The code inside train_batch should be almost identical to the code inside the training loop, except we make sure to use the Trial context models. We can step or zero out the gradients in the optimizers here. Determined automatically captures and plots the returned metrics.

    def train_batch(
        self, batch: TorchData, epoch_idx: int, batch_idx: int
    ) -> Dict[str, torch.Tensor]:
        imgs, _ = batch

        # Set model input
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *self.D_A.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((real_A.size(0), *self.D_A.output_shape))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        self.G_AB.requires_grad_(True)
        self.G_BA.requires_grad_(True)
        self.D_A.requires_grad_(False)
        self.D_B.requires_grad_(False)

        self.G_AB.train()
        self.G_BA.train()

        # Identity loss
        loss_id_A = self.criterion_identity(self.G_BA(real_A), real_A)
        loss_id_B = self.criterion_identity(self.G_AB(real_B), real_B)

        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss
        fake_B = self.G_AB(real_A)
        loss_GAN_AB = self.criterion_GAN(self.D_B(fake_B), valid)
        fake_A = self.G_BA(real_B)
        loss_GAN_BA = self.criterion_GAN(self.D_A(fake_A), valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Cycle loss
        recov_A = self.G_BA(fake_B)
        loss_cycle_A = self.criterion_cycle(recov_A, real_A)
        recov_B = self.G_AB(fake_A)
        loss_cycle_B = self.criterion_cycle(recov_B, real_B)

        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        lambda_cyc = self.context.get_hparam("lambda_cyc")
        lambda_id = self.context.get_hparam("lambda_id")
        loss_G = loss_GAN + lambda_cyc * loss_cycle + lambda_id * loss_identity

        self.context.backward(loss_G)
        self.optimizer_G.step()
        self.optimizer_G.zero_grad()

        # -----------------------
        #  Train Discriminator A
        # -----------------------

        # Set `requires_grad_` to only update parameters on the discriminator A.
        self.G_AB.requires_grad_(False)
        self.G_BA.requires_grad_(False)
        self.D_A.requires_grad_(True)
        self.D_B.requires_grad_(False)

        # Real loss
        loss_real = self.criterion_GAN(self.D_A(real_A), valid)
        # Fake loss (on batch of previously generated samples)
        fake_A_ = self.fake_A_buffer.push_and_pop(fake_A)
        loss_fake = self.criterion_GAN(self.D_A(fake_A_.detach()), fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        self.context.backward(loss_D_A)
        self.optimizer_D_A.step()
        self.optimizer_D_A.zero_grad()

        # -----------------------
        #  Train Discriminator B
        # -----------------------

        # Set `requires_grad_` to only update parameters on the discriminator A.
        self.G_AB.requires_grad_(False)
        self.G_BA.requires_grad_(False)
        self.D_A.requires_grad_(False)
        self.D_B.requires_grad_(True)

        # Real loss
        loss_real = self.criterion_GAN(self.D_B(real_B), valid)
        # Fake loss (on batch of previously generated samples)
        fake_B_ = self.fake_B_buffer.push_and_pop(fake_B)
        loss_fake = self.criterion_GAN(self.D_B(fake_B_.detach()), fake)
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2

        self.context.backward(loss_D_B)
        self.optimizer_D_B.step()
        self.optimizer_D_B.zero_grad()

        loss_D = (loss_D_A + loss_D_B) / 2

        # --------------
        #  Log Progress
        # --------------

        # If at sample interval save image
        global_records = self.context.get_global_batch_size() * batch_idx
        sample_interval = self.context.get_data_config()["sample_interval"]
        if self.context.distributed.get_rank() == 0 and global_records % sample_interval == 0:
            self.sample_images(self.context.get_data_config()["dataset_name"], batch_idx)

        return {
            "loss_D": loss_D,
            "loss_G": loss_G,
            "loss_GAN": loss_GAN,
            "loss_cycle": loss_cycle,
            "loss_identity": loss_identity,
        }

We compute the validation metrics in the evaluate_batch function. Determined automatically captures these metrics as well.

    def evaluate_batch(self, batch: TorchData) -> Dict[str, Any]:
        # Set model input
        real_A = Variable(batch["A"].type(Tensor))
        real_B = Variable(batch["B"].type(Tensor))
        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((real_A.size(0), *self.D_A.output_shape))), requires_grad=False)
        # Real loss
        loss_real_D_A = self.criterion_GAN(self.D_A(real_A), valid)
        loss_real_D_B = self.criterion_GAN(self.D_B(real_B), valid)
        # Total loss
        loss_real_D = (loss_real_D_A + loss_real_D_B) / 2

        return {
            "loss_real_D": loss_real_D,
            "loss_real_D_A": loss_real_D_A,
            "loss_real_D_B": loss_real_D_B,
        }

We also need to move our data loaders into the Trial Class.

    def build_training_data_loader(self) -> DataLoader:
        return DataLoader(
            ImageDataset(self.dataset_path, transforms_=self.transforms_, unaligned=True),
            batch_size=self.context.get_per_slot_batch_size(),
            shuffle=True,
            num_workers=self.context.get_data_config()["n_cpu"],
        )

    def build_validation_data_loader(self) -> DataLoader:
        return DataLoader(
            ImageDataset(self.dataset_path, transforms_=self.transforms_, unaligned=True, mode="test"),
            batch_size=5,
            shuffle=True,
            num_workers=1,
        )

Check out the completed Trial Class.

Single-GPU Training

Each Experiment requires an Experiment Configuration to run on a Determined Cluster. The Experiment Configuration is where we specify our dataset information such as records_per_epoch and training information like the number of training epochs and the target validation metric. The experiment configuration also provides a section for hyperparameter and arbitrary key-value pairs.

1-gpu.yaml:

description: Cycle GAN PyTorch 1 GPU
data:
  downloaded_path: /tmp
  dataset_name: monet2photo
  n_cpu: 8
  img_height: 256
  img_width: 256
  channels: 3
  sample_interval: 3000
hyperparameters:
  global_batch_size: 1
  lr: 0.0002
  b1: 0.5
  b2: 0.999
  decay_epoch: 100  # epoch from which to start lr decay
  n_residual_blocks: 9  # number of residual blocks in generator
  lambda_cyc: 10.0
  lambda_id: 5.0
records_per_epoch: 6287
searcher:
  name: single
  metric: loss_real_D
  max_length:
    epochs: 2000
  smaller_is_better: True
entrypoint: determined_model_def:CycleGANTrial
min_checkpoint_period:
  epochs: 1

Now that we implemented the Determined PyTorch Trial Class, we can launch our cluster training job using the Determined CLI. For this experiment, we will train on a single GPU.

det experiment create 1-gpu.yaml .

At nearly 45 hours into our single GPU experiment, we clicked the “View in Tensorboard” button to inspect our generated images. The top row is the input Monet image, and the second row is the corresponding generated still image. The third row is the input still image, and the last is the inferred Monet interpretation.

Single GPU Image

Multi-GPU Distributed Training

This model produces decent images, but it is taking too long to train. Determined AI’s built-in distributed training helps speed up training time without extra code or infrastructure set up. We can specify the number of GPUs and machines this training job consumes by changing the slots_per_trial and global_batch_size in the Experiment Configuration.

Let’s create a new configuration file with both slots_per_trial and global_batch_size set to 64. This configuration specifies that we want to distribute training to 64 GPUs in our cluster, and each GPU will receive a single training batch at a time. Note that we choose 64 as the global batch size because we use 64 slots and we need the per-slot batch size to be at least

  1. This choice might not be the optimized scaling-efficient setting for distributed training. However, we can easily tune the scaling efficiency by simply changing a few arguments in the experiment configuration. See Optimizing Distributed Training for details.

64-gpus.yaml:

description: Cycle GAN Pytorch 64 GPUs
data:
  downloaded_path: /tmp
  dataset_name: monet2photo
  n_cpu: 8
  img_height: 256
  img_width: 256
  channels: 3
  sample_interval: 3000
hyperparameters:
  global_batch_size: 64
  lr: 0.0002
  b1: 0.5
  b2: 0.999
  decay_epoch: 100  # epoch from which to start lr decay
  n_residual_blocks: 9  # number of residual blocks in generator
  lambda_cyc: 10.0
  lambda_id: 5.0
records_per_epoch: 6287
searcher:
  name: single
  metric: loss_real_D
  max_length:
    epochs: 2000
  smaller_is_better: True
entrypoint: determined_model_def:CycleGANTrial
resources:
  slots_per_trial: 64
min_checkpoint_period:
  epochs: 1

We can launch this experiment using the CLI and the new configuration file. The Determined Master launches 8 GCP VM instances with 8 GPUs on each node to satisfy the experiment requirements.

det experiment create 64-gpus.yaml .

Opening the Tensorboard after the same 45 hours, we can see that the images are significantly better than the single GPU training.

64 GPU Image

Next Steps

You can check out the source code for this blog here. We encourage you to give Determined a spin by trying this example or any others available in the Determined repository. If you have any questions along the way, hop on our community Slack or reach out our GitHub – we’d love to help!

Determined AI Logo
twitter linkedin
2025 Determined AI. All rights reserved.
Project
Careers
Contact
Docs
Blog
Privacy
twitter linkedin
2025 Determined AI. All rights reserved.