HyperGAN
  • About
  • Getting started
  • CLI guide
  • Configurations
    • Configurable Parameters
  • Showcase
    • AI Explorer for Android
    • Youtube, Twitter, Discord +
  • Examples
    • 2D
    • Text
    • Classification
    • Colorizer
    • Next Frame (video)
  • Tutorials
    • Training a GAN
    • Pygame inference
    • Creating an image dataset
    • Searching for hyperparameters
  • Components
    • GAN
      • Aligned GAN
      • Aligned Interpolated GAN
      • Standard GAN
    • Generator
      • Configurable Generator
      • DCGAN Generator
      • Resizable Generator
    • Discriminator
      • DCGAN Discriminator
      • Configurable Discriminator
    • Layers
      • add
      • cat
      • channel_attention
      • ez_norm
      • layer
      • mul
      • multi_head_attention
      • operation
      • pixel_shuffle
      • residual
      • resizable_stack
      • segment_softmax
      • upsample
    • Loss
      • ALI Loss
      • F Divergence Loss
      • Least Squares Loss
      • Logistic Loss
      • QP Loss
      • RAGAN Loss
      • Realness Loss
      • Softmax Loss
      • Standard Loss
      • Wasserstein Loss
    • Latent
      • Uniform Distribution
    • Trainer
      • Alternating Trainer
      • Simultaneous Trainer
      • Balanced Trainer
      • Accumulate Gradient Trainer
    • Optimizer
    • Train Hook
      • Adversarial Norm
      • Weight Constraint
      • Stabilizing Training
      • JARE
      • Learning Rate Dropout
      • Gradient Penalty
      • Rolling Memory
    • Other GAN implementations
Powered by GitBook
On this page
  • examples
  • options

Was this helpful?

  1. Components
  2. Trainer

Balanced Trainer

PreviousSimultaneous TrainerNextAccumulate Gradient Trainer

Last updated 4 years ago

Was this helpful?

  • Source:

fake, real = self.gan.forward_discriminator()
if d_real < (fake+config.imbalance):
  self.train_d()
else:
  self.train_g()

examples

  • Configurations:

{
  "class": "class:hypergan.trainers.balanced_trainer.BalancedTrainer",
  "imbalance": 0.06,
  "pretrain_d": 1000,
  "d_optimizer": {
    "class": "class:torch.optim.Adam",
    "lr": 1e-4,
    "betas":[0.0,0.999]
  },
  "g_optimizer": {
    "class": "class:torch.optim.Adam",
    "lr": 1e-4,
    "betas":[0.0,0.999]
  },
  "hooks": [
    {
      "class": "function:hypergan.train_hooks.adversarial_norm_train_hook.AdversarialNormTrainHook",
      "gamma": 2e4,
      "loss": ["d"]
    },
    {
      "class": "function:hypergan.train_hooks.initialize_as_autoencoder.InitializeAsAutoencoder",
      "steps": 10000,
      "optimizer": {
        "class": "class:torch.optim.Adam",
        "lr": 1e-4,
        "betas":[0.9,0.999]
      },
      "encoder": {
        "class": "class:hypergan.discriminators.configurable_discriminator.ConfigurableDiscriminator",
        "layers":[
          "conv 32 stride=1", "adaptive_avg_pool", "relu",
          "conv 64 stride=1", "adaptive_avg_pool", "relu",
          "conv 128 stride=1", "adaptive_avg_pool", "relu",
          "conv 256 stride=1", "adaptive_avg_pool", "relu",
          "conv 512 stride=1", "adaptive_avg_pool", "relu",
          "conv 512 stride=1", "adaptive_avg_pool", "relu",
          "flatten",
          "linear 256 bias=false", "tanh"
        ]
      }
    }
  ]
}

options

attribute

description

type

g_optimizer

Optimizer configuration for G

Config (required)

d_optimizer

Optimizer configuration for D

Config (required)

hooks

Train Hooks

Array of configs (optional)

pretrain_d

First N steps only trains D

Integer (optional)

imbalance

Threshold distance for G training. Defaults to 0.1

Float (optional)

d_fake_balance

Changes conditional to d_fake(t) > d_fake(t-1)

Boolean (optional)

/trainers/balanced_trainer.py
/trainers/balanced_trainer/