VAE
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as pltlatent_dim = 2
encoder_inputs = keras.Input(shape=(28, 28, 1))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var], name="encoder")
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(7 * 7 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((7, 7, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2D(1, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
class Sampler(layers.Layer):
def call(self, z_mean, z_log_var):
batch_size = tf.shape(z_mean)[0]
z_size = tf.shape(z_mean)[1]
epsilon = tf.random.normal(shape=(batch_size, z_size))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super().__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.sampler = Sampler()
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss")
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
@property
def metrics(self):
return [self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker]
def train_step(self, data):
with tf.GradientTape() as tape:
z_mean, z_log_var = self.encoder(data)
z = self.sampler(z_mean, z_log_var)
reconstruction = decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.reduce_sum(
keras.losses.binary_crossentropy(data, reconstruction),
axis=(1, 2)
)
)
kl_loss = -0.5 * (1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
total_loss = reconstruction_loss + tf.reduce_mean(kl_loss)
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"total_loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}Epoch 1/30
547/547 [==============================] - 13s 18ms/step - total_loss: 213.6910 - reconstruction_loss: 210.9417 - kl_loss: 2.7493
Epoch 2/30
547/547 [==============================] - 10s 18ms/step - total_loss: 186.1581 - reconstruction_loss: 183.5664 - kl_loss: 2.5915
Epoch 3/30
547/547 [==============================] - 10s 18ms/step - total_loss: 167.9219 - reconstruction_loss: 164.0332 - kl_loss: 3.8887
Epoch 4/30
547/547 [==============================] - 10s 18ms/step - total_loss: 161.4123 - reconstruction_loss: 157.4814 - kl_loss: 3.9309
Epoch 5/30
547/547 [==============================] - 10s 18ms/step - total_loss: 158.0701 - reconstruction_loss: 154.0942 - kl_loss: 3.9760
Epoch 6/30
547/547 [==============================] - 10s 18ms/step - total_loss: 156.2735 - reconstruction_loss: 152.3186 - kl_loss: 3.9549
Epoch 7/30
547/547 [==============================] - 10s 18ms/step - total_loss: 154.9481 - reconstruction_loss: 151.0216 - kl_loss: 3.9264
Epoch 8/30
547/547 [==============================] - 10s 18ms/step - total_loss: 153.7915 - reconstruction_loss: 149.8839 - kl_loss: 3.9075
Epoch 9/30
547/547 [==============================] - 10s 18ms/step - total_loss: 153.1475 - reconstruction_loss: 149.2586 - kl_loss: 3.8889
Epoch 10/30
547/547 [==============================] - 10s 18ms/step - total_loss: 152.2777 - reconstruction_loss: 148.4026 - kl_loss: 3.8751
Epoch 11/30
547/547 [==============================] - 10s 18ms/step - total_loss: 151.8590 - reconstruction_loss: 147.9890 - kl_loss: 3.8701
Epoch 12/30
547/547 [==============================] - 10s 18ms/step - total_loss: 151.3233 - reconstruction_loss: 147.4735 - kl_loss: 3.8498
Epoch 13/30
547/547 [==============================] - 10s 18ms/step - total_loss: 150.9009 - reconstruction_loss: 147.0544 - kl_loss: 3.8466
Epoch 14/30
547/547 [==============================] - 10s 18ms/step - total_loss: 150.4010 - reconstruction_loss: 146.5613 - kl_loss: 3.8396
Epoch 15/30
547/547 [==============================] - 10s 18ms/step - total_loss: 150.2408 - reconstruction_loss: 146.4061 - kl_loss: 3.8348
Epoch 16/30
547/547 [==============================] - 10s 18ms/step - total_loss: 149.8306 - reconstruction_loss: 145.9978 - kl_loss: 3.8328
Epoch 17/30
547/547 [==============================] - 10s 18ms/step - total_loss: 149.5211 - reconstruction_loss: 145.6898 - kl_loss: 3.8313
Epoch 18/30
547/547 [==============================] - 10s 18ms/step - total_loss: 149.2525 - reconstruction_loss: 145.4385 - kl_loss: 3.8140
Epoch 19/30
547/547 [==============================] - 10s 18ms/step - total_loss: 149.0624 - reconstruction_loss: 145.2484 - kl_loss: 3.8140
Epoch 20/30
547/547 [==============================] - 10s 18ms/step - total_loss: 148.8579 - reconstruction_loss: 145.0437 - kl_loss: 3.8142
Epoch 21/30
547/547 [==============================] - 10s 18ms/step - total_loss: 148.6375 - reconstruction_loss: 144.8200 - kl_loss: 3.8175
Epoch 22/30
547/547 [==============================] - 10s 18ms/step - total_loss: 148.4102 - reconstruction_loss: 144.5960 - kl_loss: 3.8142
Epoch 23/30
547/547 [==============================] - 10s 18ms/step - total_loss: 148.1947 - reconstruction_loss: 144.3832 - kl_loss: 3.8115
Epoch 24/30
547/547 [==============================] - 10s 18ms/step - total_loss: 148.0633 - reconstruction_loss: 144.2524 - kl_loss: 3.8109
Epoch 25/30
547/547 [==============================] - 10s 18ms/step - total_loss: 147.8548 - reconstruction_loss: 144.0357 - kl_loss: 3.8193
Epoch 26/30
547/547 [==============================] - 10s 19ms/step - total_loss: 147.7023 - reconstruction_loss: 143.8790 - kl_loss: 3.8234
Epoch 27/30
547/547 [==============================] - 10s 19ms/step - total_loss: 147.6233 - reconstruction_loss: 143.8067 - kl_loss: 3.8168
Epoch 28/30
547/547 [==============================] - 11s 19ms/step - total_loss: 147.3167 - reconstruction_loss: 143.5024 - kl_loss: 3.8143
Epoch 29/30
547/547 [==============================] - 10s 18ms/step - total_loss: 147.2497 - reconstruction_loss: 143.4479 - kl_loss: 3.8020
Epoch 30/30
547/547 [==============================] - 10s 18ms/step - total_loss: 147.1389 - reconstruction_loss: 143.3195 - kl_loss: 3.8194
<keras.callbacks.History at 0x7f910c627ee0># 훈련
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
mnist_digits = np.concatenate([x_train, x_test], axis=0)
mnist_digits = np.expand_dims(mnist_digits, -1).astype("float32") / 255
vae = VAE(encoder, decoder)
vae.compile(optimizer=keras.optimizers.Adam(), run_eagerly=True)
vae.fit(mnist_digits, epochs=30, batch_size=128)n = 30
digit_size = 28
figure = np.zeros((digit_size * n, digit_size * n))
grid_x = np.linspace(-1, 1, n)
grid_y = np.linspace(-1, 1, n)[::-1]
for i, yi in enumerate(grid_y):
for j, xi in enumerate(grid_x):
z_sample = np.array([[xi, yi]])
x_decoded = vae.decoder.predict(z_sample)
digit = x_decoded[0].reshape(digit_size, digit_size)
figure[
i * digit_size : (i + 1) * digit_size,
j * digit_size : (j + 1) * digit_size,
] = digit
plt.figure(figsize=(15, 15))
start_range = digit_size // 2
end_range = n * digit_size + start_range
pixel_range = np.arange(start_range, end_range, digit_size)
sample_range_x = np.round(grid_x, 1)
sample_range_y = np.round(grid_y, 1)
plt.xticks(pixel_range, sample_range_x)
plt.yticks(pixel_range, sample_range_y)
plt.xlabel("z[0]")
plt.ylabel("z[1]")
plt.axis("off")
plt.imshow(figure, cmap="Greys_r")Fetching long content....
