From 25a7502d034ef0c3432b19a6ea4afaa8ffee0862 Mon Sep 17 00:00:00 2001 From: abhi-glitchhg <72816663+abhi-glitchhg@users.noreply.github.com> Date: Mon, 30 Aug 2021 19:04:46 +0530 Subject: [PATCH 1/2] Update gan.py added zero_grad() method on discriminator; and now step method returns Generator loss --- code_soup/ch5/models/gan.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/code_soup/ch5/models/gan.py b/code_soup/ch5/models/gan.py index 71e62c0..f3eea30 100644 --- a/code_soup/ch5/models/gan.py +++ b/code_soup/ch5/models/gan.py @@ -169,6 +169,8 @@ def step(self, data: torch.Tensor) -> Tuple: Discriminator loss D_G_z2: Average discriminator outputs for the all fake batch after updating discriminator + errG: + Generator loss """ real_image, _ = data real_image = real_image.to(self.device) @@ -176,6 +178,7 @@ def step(self, data: torch.Tensor) -> Tuple: label = torch.full( (batch_size,), self.real_label, dtype=torch.float, device=self.device ) + self.discriminator.zero_grad() # Forward pass real batch through D output = self.discriminator(real_image).view(-1) # Calculate loss on all-real batch @@ -211,4 +214,4 @@ def step(self, data: torch.Tensor) -> Tuple: D_G_z2 = output.mean().item() # Update G self.generator.optimizer.step() - return D_x, D_G_z1, errD, D_G_z2 + return D_x, D_G_z1, errD, D_G_z2,errG From d2312e8b95d2b71478ea4df77d8d867a67c85b0e Mon Sep 17 00:00:00 2001 From: abhi-glitchhg <72816663+abhi-glitchhg@users.noreply.github.com> Date: Mon, 30 Aug 2021 19:17:18 +0530 Subject: [PATCH 2/2] Update gan.py --- code_soup/ch5/models/gan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/code_soup/ch5/models/gan.py b/code_soup/ch5/models/gan.py index f3eea30..d8236fb 100644 --- a/code_soup/ch5/models/gan.py +++ b/code_soup/ch5/models/gan.py @@ -214,4 +214,4 @@ def step(self, data: torch.Tensor) -> Tuple: D_G_z2 = output.mean().item() # Update G self.generator.optimizer.step() - return D_x, D_G_z1, errD, D_G_z2,errG + return D_x, D_G_z1, errD, D_G_z2, errG