Conversation
* use logging.getLogger to get a logger to write to instead of calling logging directly. Also set the matplotlib logger to only print ERROR level messages, so we don't get extra output when plotting. * create a generic function to plot (an) image(s) using a specified amount of columns. This allows the user to easily plot more samples
Passing in the class instead of a string makes construct_vae generic enough that it doesn't need changes when you want to play with different implementations.
the kl_divergence parameter in the GaussianVAE was never used, and it seems diagonal_gaussian_kl should be private to GaussianInferenceNetwork, so remove it.
philschulz
left a comment
There was a problem hiding this comment.
This a nice clean-up of the notebook. Thanks for doing this. I only have one remark on the changes that I left inline. Sorry taking so long to get back to you.
| " def __init__(self,\n", | ||
| " generator: Generator,\n", | ||
| " inference_net: GaussianInferenceNetwork,\n", | ||
| " kl_divergence: Callable) -> None:\n", |
There was a problem hiding this comment.
Why are you removing the KL divergence. It is needed to properly construct the ELBO.
I have thought about this a lot before and one alternative is to rewrite the ELBO as E[p(x,z)] + H(q(z)). Then we could provide the entropy as a method of the inference net. However, it's out of whack with the slides in that case. That's why for the time being I would just pass in the kl_divergence as an argument.
There was a problem hiding this comment.
I removed it because it wasn't actually being used (as far as I can see). The GaussianInferenceNetwork calls diagonal_gaussian_kl() directly. I've been implementing inference networks with different distributions with Wilker, which require different ways to calculate the KL divergence, it seemed to me that the KL divergence was conceptually linked more to the inference network than to the VAE as a whole, which is why I opted for removing the parameter rather than making the GaussianInferenceNetwork not hardcode the function.
Some bits I changed while trying to figure things out/implement a pytorch version. Mostly consist of a flexible plot_images function that lets you display arbitrary amounts of images.