Files
deep-learning-with-python-n…/chapter12_part04_variational-autoencoders.ipynb
2021-06-05 20:28:07 -07:00

339 lines
9.4 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"This is a companion notebook for the book [Deep Learning with Python, Second Edition](https://www.manning.com/books/deep-learning-with-python-second-edition?a_aid=keras&a_bid=76564dff). For readability, it only contains runnable code blocks and section titles, and omits everything else in the book: text paragraphs, figures, and pseudocode.\n\n**If you want to be able to follow what's going on, I recommend reading the notebook side by side with your copy of the book.**\n\nThis notebook was generated for TensorFlow 2.6."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"## Generating images with variational autoencoders"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"### Sampling from latent spaces of images"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"### Concept vectors for image editing"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"### Variational autoencoders"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"### Implementing a VAE with Keras"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"**VAE encoder network**"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"from tensorflow import keras\n",
"from tensorflow.keras import layers\n",
"\n",
"latent_dim = 2\n",
"\n",
"encoder_inputs = keras.Input(shape=(28, 28, 1))\n",
"x = layers.Conv2D(32, 3, activation=\"relu\", strides=2, padding=\"same\")(encoder_inputs)\n",
"x = layers.Conv2D(64, 3, activation=\"relu\", strides=2, padding=\"same\")(x)\n",
"x = layers.Flatten()(x)\n",
"x = layers.Dense(16, activation=\"relu\")(x)\n",
"z_mean = layers.Dense(latent_dim, name=\"z_mean\")(x)\n",
"z_log_var = layers.Dense(latent_dim, name=\"z_log_var\")(x)\n",
"encoder = keras.Model(encoder_inputs, [z_mean, z_log_var], name=\"encoder\")"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"encoder.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"**Latent-space-sampling layer**"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"\n",
"class Sampler(layers.Layer):\n",
" def call(self, z_mean, z_log_var):\n",
" batch_size = tf.shape(z_mean)[0]\n",
" z_size = tf.shape(z_mean)[1]\n",
" epsilon = tf.random.normal(shape=(batch_size, z_size))\n",
" return z_mean + tf.exp(0.5 * z_log_var) * epsilon"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"**VAE decoder network, mapping latent space points to images**"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"latent_inputs = keras.Input(shape=(latent_dim,))\n",
"x = layers.Dense(7 * 7 * 64, activation=\"relu\")(latent_inputs)\n",
"x = layers.Reshape((7, 7, 64))(x)\n",
"x = layers.Conv2DTranspose(64, 3, activation=\"relu\", strides=2, padding=\"same\")(x)\n",
"x = layers.Conv2DTranspose(32, 3, activation=\"relu\", strides=2, padding=\"same\")(x)\n",
"decoder_outputs = layers.Conv2D(1, 3, activation=\"sigmoid\", padding=\"same\")(x)\n",
"decoder = keras.Model(latent_inputs, decoder_outputs, name=\"decoder\")"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"decoder.summary()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"**VAE model with custom `train_step()`**"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"class VAE(keras.Model):\n",
" def __init__(self, encoder, decoder, **kwargs):\n",
" super().__init__(**kwargs)\n",
" self.encoder = encoder\n",
" self.decoder = decoder\n",
" self.sampler = Sampler()\n",
" self.total_loss_tracker = keras.metrics.Mean(name=\"total_loss\")\n",
" self.reconstruction_loss_tracker = keras.metrics.Mean(\n",
" name=\"reconstruction_loss\")\n",
" self.kl_loss_tracker = keras.metrics.Mean(name=\"kl_loss\")\n",
"\n",
" @property\n",
" def metrics(self):\n",
" return [self.total_loss_tracker,\n",
" self.reconstruction_loss_tracker,\n",
" self.kl_loss_tracker]\n",
"\n",
" def train_step(self, data):\n",
" with tf.GradientTape() as tape:\n",
" z_mean, z_log_var = self.encoder(data)\n",
" z = self.sampler(z_mean, z_log_var)\n",
" reconstruction = decoder(z)\n",
" reconstruction_loss = tf.reduce_mean(\n",
" tf.reduce_sum(\n",
" keras.losses.binary_crossentropy(data, reconstruction),\n",
" axis=(1, 2)\n",
" )\n",
" )\n",
" kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))\n",
" total_loss = reconstruction_loss + tf.reduce_mean(kl_loss)\n",
" grads = tape.gradient(total_loss, self.trainable_weights)\n",
" self.optimizer.apply_gradients(zip(grads, self.trainable_weights))\n",
" self.total_loss_tracker.update_state(total_loss)\n",
" self.reconstruction_loss_tracker.update_state(reconstruction_loss)\n",
" self.kl_loss_tracker.update_state(kl_loss)\n",
" return {\n",
" \"total_loss\": self.total_loss_tracker.result(),\n",
" \"reconstruction_loss\": self.reconstruction_loss_tracker.result(),\n",
" \"kl_loss\": self.kl_loss_tracker.result(),\n",
" }"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"**Training the VAE**"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\n",
"mnist_digits = np.concatenate([x_train, x_test], axis=0)\n",
"mnist_digits = np.expand_dims(mnist_digits, -1).astype(\"float32\") / 255\n",
"\n",
"vae = VAE(encoder, decoder)\n",
"vae.compile(optimizer=keras.optimizers.Adam(), run_eagerly=True)\n",
"vae.fit(mnist_digits, epochs=30, batch_size=128)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"**Sampling a grid of points from the 2D latent space and decoding them to images**"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"n = 30\n",
"digit_size = 28\n",
"figure = np.zeros((digit_size * n, digit_size * n))\n",
"\n",
"grid_x = np.linspace(-1, 1, n)\n",
"grid_y = np.linspace(-1, 1, n)[::-1]\n",
"\n",
"for i, yi in enumerate(grid_y):\n",
" for j, xi in enumerate(grid_x):\n",
" z_sample = np.array([[xi, yi]])\n",
" x_decoded = vae.decoder.predict(z_sample)\n",
" digit = x_decoded[0].reshape(digit_size, digit_size)\n",
" figure[\n",
" i * digit_size : (i + 1) * digit_size,\n",
" j * digit_size : (j + 1) * digit_size,\n",
" ] = digit\n",
"\n",
"plt.figure(figsize=(15, 15))\n",
"start_range = digit_size // 2\n",
"end_range = n * digit_size + start_range\n",
"pixel_range = np.arange(start_range, end_range, digit_size)\n",
"sample_range_x = np.round(grid_x, 1)\n",
"sample_range_y = np.round(grid_y, 1)\n",
"plt.xticks(pixel_range, sample_range_x)\n",
"plt.yticks(pixel_range, sample_range_y)\n",
"plt.xlabel(\"z[0]\")\n",
"plt.ylabel(\"z[1]\")\n",
"plt.axis(\"off\")\n",
"plt.imshow(figure, cmap=\"Greys_r\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text"
},
"source": [
"### Wrapping up"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "chapter12_part04_variational-autoencoders.i",
"private_outputs": false,
"provenance": [],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 0
}