diff --git a/code_soup/ch5/models/gan.py b/code_soup/ch5/models/gan.py index 71e62c0..d8236fb 100644 --- a/code_soup/ch5/models/gan.py +++ b/code_soup/ch5/models/gan.py @@ -169,6 +169,8 @@ def step(self, data: torch.Tensor) -> Tuple: Discriminator loss D_G_z2: Average discriminator outputs for the all fake batch after updating discriminator + errG: + Generator loss """ real_image, _ = data real_image = real_image.to(self.device) @@ -176,6 +178,7 @@ def step(self, data: torch.Tensor) -> Tuple: label = torch.full( (batch_size,), self.real_label, dtype=torch.float, device=self.device ) + self.discriminator.zero_grad() # Forward pass real batch through D output = self.discriminator(real_image).view(-1) # Calculate loss on all-real batch @@ -211,4 +214,4 @@ def step(self, data: torch.Tensor) -> Tuple: D_G_z2 = output.mean().item() # Update G self.generator.optimizer.step() - return D_x, D_G_z1, errD, D_G_z2 + return D_x, D_G_z1, errD, D_G_z2, errG