Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

U-Net

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model

def conv_block(x, filters, **kwargs):
    block_name = kwargs.pop('name')
    x = layers.Conv2D(filters=filters, name=f'{block_name}-1', **kwargs)(x)
    x = layers.Conv2D(filters=filters, name=f'{block_name}-2', **kwargs)(x)
    return x

def upsample_block(x1, x2, **kwargs):
    x = layers.Conv2DTranspose(**kwargs)(x1)    
    # get crop size for x2 to match x
    x2_crop_size = (x2.shape[1] - x.shape[1], x2.shape[2] - x.shape[2])
    top_crop, bottom_crop = x2_crop_size[0] // 2, x2_crop_size[0] - x2_crop_size[0] // 2
    left_crop, right_crop = x2_crop_size[1] // 2, x2_crop_size[1] - x2_crop_size[1] // 2
    # crop x2
    x2 = layers.Cropping2D(cropping=((top_crop, bottom_crop), (left_crop, right_crop)))(x2)    
    x = layers.concatenate([x, x2], axis=3)
    return x

def create_unet(input_shape, num_classes, padding='valid', kernel_initializer='glorot_uniform', kernel_regularizer=None):
    inputs = layers.Input(shape=input_shape)
    conv1 = conv_block(inputs, 64, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block1')
    pool1 = layers.MaxPooling2D(pool_size=(2, 2), strides=2, name='pool1')(conv1)
    conv2 = conv_block(pool1, 128, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block2')
    pool2 = layers.MaxPooling2D(pool_size=(2, 2), strides=2, name='pool2')(conv2)
    conv3 = conv_block(pool2, 256, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block3')
    pool3 = layers.MaxPooling2D(pool_size=(2, 2), strides=2, name='pool3')(conv3)
    conv4 = conv_block(pool3, 512, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block4')
    pool4 = layers.MaxPooling2D(pool_size=(2, 2), strides=2, name='pool4')(conv4)
    conv5 = conv_block(pool4, 1024, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block5')
    # upsample    
    up6 = upsample_block(conv5, conv4, filters=512, kernel_size=(2, 2), strides=(2, 2), padding='same', name='up6')
    conv6 = conv_block(up6, 512, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block6')    
    up7 = upsample_block(conv6, conv3, filters=256, kernel_size=(2, 2), strides=(2, 2), padding='same', name='up7')    
    conv7 = conv_block(up7, 256, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block7')    
    up8 = upsample_block(conv7, conv2, filters=128, kernel_size=(2, 2), strides=(2, 2), padding='same', name='up8')
    conv8 = conv_block(up8, 128, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block8')
    up9 = upsample_block(conv8, conv1, filters=64, kernel_size=(2, 2), strides=(2, 2), padding='same', name='up9')
    conv9 = conv_block(up9, 64, kernel_size=3, activation='relu', padding=padding, kernel_initializer=kernel_initializer, kernel_regularizer=kernel_regularizer, name='conv_block9')
    # output
    outputs = layers.Conv2D(filters=num_classes, kernel_size=1, activation='softmax', name='output')(conv9)
    return Model(inputs, outputs, name='unet')
model = create_unet(input_shape=(572, 572, 3), num_classes=2)
model.summary()

1Image Segmentation

Oxford Pets Dataset

import os

input_dir = '../data/oxford-iiit-pet/images/'
target_dir = '../data/oxford-iiit-pet/annotations/trimaps/'

input_img_paths = sorted([os.path.join(input_dir, fname) for fname in os.listdir(input_dir) if fname.endswith('.jpg')])
target_img_paths = sorted([os.path.join(target_dir, fname) for fname in os.listdir(target_dir) if fname.endswith('.png') and not fname.startswith('.')])
assert len(input_img_paths) == len(target_img_paths)

입력과 목표 이미지(분할 마스크) 시각화

import matplotlib.pyplot as plt
from matplotlib.image import imread

normalize_target_img = lambda img: (img.astype('uint8') - 1) * 127

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(imread(input_img_paths[0]))
ax[1].imshow(imread(target_img_paths[0]))
from tensorflow import keras
from tensorflow.keras.preprocessing.image import load_img

class OxfordPets(keras.utils.Sequence):
    def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        self.target_img_paths = target_img_paths
    
    def __len__(self):
        return len(self.target_img_paths) // self.batch_size
    
    def __getitem__(self, idx):
        i = idx * self.batch_size
        batch_input_img_paths = self.input_img_paths[i : i + self.batch_size]
        batch_target_img_paths = self.target_img_paths[i : i + self.batch_size]
        x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype='float32')
        for j, path in enumerate(batch_input_img_paths):
            img = load_img(path, target_size=self.img_size)
            x[j] = img
        y = np.zeros((self.batch_size,) + self.img_size + (1,), dtype='uint8')
        for j, path in enumerate(batch_target_img_paths):
            img = load_img(path, target_size=self.img_size, color_mode='grayscale')
            y[j] = np.expand_dims(img, 2)
            # 1, 2, 3 -> 0, 1, 2
            y[j] -= 1
        return x, y
from sklearn.model_selection import train_test_split

val_samples = 1000
train_input_img_paths, val_input_img_paths, train_target_img_paths, val_target_img_paths = train_test_split(input_img_paths, target_img_paths, test_size=val_samples, random_state=42)

batch_size = 32
img_size = (160, 160)
train_generator = OxfordPets(batch_size, img_size, train_input_img_paths, train_target_img_paths)
val_generator = OxfordPets(batch_size, img_size, val_input_img_paths, val_target_img_paths)
# Free up RAM in case the model definition cells were run multiple times
keras.backend.clear_session()
model = create_unet(input_shape=img_size + (3,), num_classes=3, padding='same')
model.summary()
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
model.fit(train_generator, epochs=15, validation_data=val_generator, 
    callbacks=[keras.callbacks.ModelCheckpoint('../checkpoints/unet_oxford_pet_segmentation.h5', save_best_only=True)])
from tensorflow.keras.preprocessing.image import load_img, img_to_array

def display_mask(img_path):
    image = load_img(img_path, target_size=img_size)

    X = img_to_array(image).reshape(1, *img_size, 3)
    mask = model.predict(X)[0]
    mask = np.argmax(mask, axis=-1)
    mask *= 255 // 2
    fig, subplots = plt.subplots(1, 2, figsize=(10, 5))
    subplots[0].imshow(image)
    subplots[1].imshow(mask)
    plt.axis('off')
    plt.imshow(mask)

display_mask(val_input_img_paths[2])