U-Net
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
def conv_block(x, filters, **kwargs):
block_name = kwargs.pop('name')
x = layers.Conv2D(filters=filters, name=f'{block_name}-1', **kwargs)(x)
x = layers.Conv2D(filters=filters, name=f'{block_name}-2', **kwargs)(x)
return x
def upsample_block(x1, x2, **kwargs):
x = layers.Conv2DTranspose(**kwargs)(x1)
# get crop size for x2 to match x
x2_crop_size = (x2.shape[1] - x.shape[1], x2.shape[2] - x.shape[2])
top_crop, bottom_crop = x2_crop_size[0] // 2, x2_crop_size[0] - x2_crop_size[0] // 2
left_crop, right_crop = x2_crop_size[1] // 2, x2_crop_size[1] - x2_crop_size[1] // 2
# crop x2
x2 = layers.Cropping2D(cropping=((top_crop, bottom_crop), (left_crop, right_crop)))(x2)
x = layers.concatenate([x, x2], axis=3)
return x
def create_unet(input_shape, num_classes, padding='valid', kernel_initializer='glorot_uniform', kernel_regularizer=None):
inputs = layers.Input(shape=input_shape)
conv1 = conv_block(inputs, 64, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block1')
pool1 = layers.MaxPooling2D(pool_size=(2, 2), strides=2, name='pool1')(conv1)
conv2 = conv_block(pool1, 128, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block2')
pool2 = layers.MaxPooling2D(pool_size=(2, 2), strides=2, name='pool2')(conv2)
conv3 = conv_block(pool2, 256, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block3')
pool3 = layers.MaxPooling2D(pool_size=(2, 2), strides=2, name='pool3')(conv3)
conv4 = conv_block(pool3, 512, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block4')
pool4 = layers.MaxPooling2D(pool_size=(2, 2), strides=2, name='pool4')(conv4)
conv5 = conv_block(pool4, 1024, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block5')
# upsample
up6 = upsample_block(conv5, conv4, filters=512, kernel_size=(2, 2), strides=(2, 2), padding='same', name='up6')
conv6 = conv_block(up6, 512, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block6')
up7 = upsample_block(conv6, conv3, filters=256, kernel_size=(2, 2), strides=(2, 2), padding='same', name='up7')
conv7 = conv_block(up7, 256, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block7')
up8 = upsample_block(conv7, conv2, filters=128, kernel_size=(2, 2), strides=(2, 2), padding='same', name='up8')
conv8 = conv_block(up8, 128, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block8')
up9 = upsample_block(conv8, conv1, filters=64, kernel_size=(2, 2), strides=(2, 2), padding='same', name='up9')
conv9 = conv_block(up9, 64, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block9')
# output
outputs = layers.Conv2D(filters=num_classes, kernel_size=1, activation='softmax', name='output')(conv9)
return Model(inputs, outputs, name='unet')model = create_unet(input_shape=(572, 572, 3), num_classes=2)
model.summary()1Image Segmentation¶
Oxford Pets Dataset
import os
input_dir = '../data/oxford-iiit-pet/images/'
target_dir = '../data/oxford-iiit-pet/annotations/trimaps/'
input_img_paths = sorted([os.path.join(input_dir, fname) for fname in os.listdir(input_dir) if fname.endswith('.jpg')])
target_img_paths = sorted([os.path.join(target_dir, fname) for fname in os.listdir(target_dir) if fname.endswith('.png') and not fname.startswith('.')])
assert len(input_img_paths) == len(target_img_paths)입력과 목표 이미지(분할 마스크) 시각화
import matplotlib.pyplot as plt
from matplotlib.image import imread
normalize_target_img = lambda img: (img.astype('uint8') - 1) * 127
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(imread(input_img_paths[0]))
ax[1].imshow(imread(target_img_paths[0]))
from tensorflow import keras
from tensorflow.keras.preprocessing.image import load_img
class OxfordPets(keras.utils.Sequence):
def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):
self.batch_size = batch_size
self.img_size = img_size
self.input_img_paths = input_img_paths
self.target_img_paths = target_img_paths
def __len__(self):
return len(self.target_img_paths) // self.batch_size
def __getitem__(self, idx):
i = idx * self.batch_size
batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]
x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype='float32')
for j, path in enumerate(batch_input_img_paths):
img = load_img(path, target_size=self.img_size)
x[j] = img
y = np.zeros((self.batch_size,) + self.img_size + (1,), dtype='uint8')
for j, path in enumerate(batch_target_img_paths):
img = load_img(path, target_size=self.img_size, color_mode='grayscale')
y[j] = np.expand_dims(img, 2)
# 1, 2, 3 -> 0, 1, 2
y[j] -= 1
return x, yfrom sklearn.model_selection import train_test_split
val_samples = 1000
train_input_img_paths, val_input_img_paths, train_target_img_paths, val_target_img_paths = train_test_split(input_img_paths, target_img_paths, test_size=val_samples, random_state=42)
batch_size = 32
img_size = (160, 160)
train_generator = OxfordPets(batch_size, img_size, train_input_img_paths, train_target_img_paths)
val_generator = OxfordPets(batch_size, img_size, val_input_img_paths, val_target_img_paths)# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()
model = create_unet(input_shape=img_size + (3,), num_classes=3, padding='same')
model.summary()model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
model.fit(train_generator, epochs=15, validation_data=val_generator,
callbacks=[keras.callbacks.ModelCheckpoint('../checkpoints/unet_oxford_pet_segmentation.h5', save_best_only=True)])from tensorflow.keras.preprocessing.image import load_img, img_to_array
def display_mask(img_path):
image = load_img(img_path, target_size=img_size)
X = img_to_array(image).reshape(1, *img_size, 3)
mask = model.predict(X)[0]
mask = np.argmax(mask, axis=-1)
mask *= 255 // 2
fig, subplots = plt.subplots(1, 2, figsize=(10, 5))
subplots[0].imshow(image)
subplots[1].imshow(mask)
plt.axis('off')
plt.imshow(mask)
display_mask(val_input_img_paths[2])