Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

전이 학습

전이 학습은 하나의 문제를 풀기 위해 학습한 모형을 다른 관련 문제에 옮겨 적용하는 기법입니다. 사전 훈련된 모형을 미세 조정하는 것도 전이 학습의 한 사례에 해당합니다.

전이 학습은 크게 두 단계로 이루어집니다. 먼저 사전 훈련(pre-training) 단계에서는 대규모의 일반적인 데이터셋에서 모형을 학습시킵니다. 이미지 인식이라면 ImageNet과 같은 데이터셋이 여기에 쓰이며, 이 과정에서 모형은 폭넓게 적용 가능한 일반적인 특성을 학습합니다. 이어지는 미세 조정(fine-tuning) 단계에서는 사전 훈련된 모형을 더 작고 특화된 데이터셋에 적용하고, 가중치를 미세하게 조정하여 특정 작업에 최적화합니다.

전이 학습의 핵심은 이미 학습된 특성과 패턴을 다른 작업에 재사용한다는 데 있습니다. 덕분에 적은 양의 데이터로도 효과적인 성능을 얻을 수 있고 학습 시간도 단축됩니다. 특히 데이터가 제한적이거나 학습에 필요한 계산 자원이 부족한 상황에서 유용합니다.

0.1전이 학습 전략

전이 학습은 미세 조정 외에도 다양한 형태로 구현됩니다.

  • 특성 추출(feature extraction): 사전 훈련된 모형에서 마지막 몇 개 층을 제외한 나머지 층의 가중치를 고정합니다. 고정된 층은 입력으로부터 특성을 추출하는 역할을 하고, 마지막 층만 새로운 작업에 맞게 새로 학습합니다.

  • 멀티태스크 학습(multi-task learning): 여러 관련 작업을 동시에 학습하여 작업 사이에 유용한 지식을 공유합니다. 예를 들어 얼굴 인식과 감정 분석을 함께 학습시킬 수 있습니다.

  • 도메인 적응(domain adaptation): 소스 도메인에서 학습한 지식을 다른 대상 도메인에 적용합니다. 사진 이미지에서 학습한 모형을 그림 스타일의 이미지 분류에 전이하는 경우가 그 예입니다.

  • 제로샷·퓨샷 학습(zero/few-shot learning): 제로샷 학습은 학습 중에 보지 못한 범주의 객체를 인식하도록 하고, 퓨샷 학습은 한두 개 정도의 매우 적은 예시만으로 새로운 범주를 학습합니다.

  • 프로그레시브 네트워크(progressive networks): 새로운 작업을 학습할 때마다 네트워크 컬럼을 추가하고, 기존 컬럼에 축적된 지식을 활용하여 새 작업을 더 빠르게 학습합니다.

  • 메타-러닝(meta-learning): 모델 아그노스틱 메타-러닝(MAML)처럼, 모형이 여러 작업에 걸쳐 빠르게 적응하도록 학습시켜 새로운 작업에 신속하게 최적화될 수 있게 합니다.

이처럼 전이 학습은 다양한 방식으로 구현되며, 그 방법과 적용 범위는 계속 확장되고 있습니다.

import sys, platform
import tensorflow as tf
import keras

print(f'Python version: {sys.version}')
print(f'platform: {platform.platform()}')
print(f'TensorFlow version: {tf.__version__}')
print(f'keras version: {keras.__version__}')
# GPU check
try:
    assert tf.config.list_physical_devices('GPU')
except:
    print('No GPU detected')
Python version: 3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:45:18) [GCC 12.3.0]
platform: Linux-5.15.146.1-microsoft-standard-WSL2-x86_64-with-glibc2.36
TensorFlow version: 2.16.1
keras version: 3.3.3
def print_scores(metrics):
    print(f'Loss: {metrics["loss"]:.3f}', end='\t')
    metrics.pop('loss')
    for k, v in metrics.items():
        print(f'{k}: {v:.3f}', end='\t')

1데이터

import os.path

data_dir = '../data/cats_dogs_small'
join_path = lambda *args: os.path.join(data_dir, *args)

batch_size = 32
image_size = (180, 180)
train_dataset = keras.utils.image_dataset_from_directory(
    join_path('train'), batch_size=batch_size, image_size=image_size)
validation_dataset = keras.utils.image_dataset_from_directory(
    join_path('validation'), batch_size=batch_size, image_size=image_size)
test_dataset = keras.utils.image_dataset_from_directory(
    join_path('test'), batch_size=batch_size, image_size=image_size)

train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
validation_dataset = validation_dataset.prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.prefetch(tf.data.AUTOTUNE)
Found 2000 files belonging to 2 classes.
Found 1000 files belonging to 2 classes.
Found 2000 files belonging to 2 classes.

2응용 모형

from keras import layers
from keras.applications.vgg16 import VGG16

base_model = VGG16(include_top=False)
base_model.trainable = False
# base_model.summary()

def create_model(input_shape=None):
    inputs = layers.Input(shape=input_shape)
    # preprocessing
    x = keras.applications.vgg16.preprocess_input(inputs)
    # feature extraction
    x = base_model(inputs, training=False)
    # prediction
    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(1)(x)
    return keras.Model(inputs, outputs)
keras.backend.clear_session()
model = create_model(input_shape=image_size + (3,))
# model.summary()
model.compile(
    optimizer='adam', 
    loss=keras.losses.BinaryCrossentropy(from_logits=True), 
    metrics=['acc'])

model_name = 'cats_dogs_vgg16'
checkpoint_filepath = f'checkpoints/{model_name}.keras'
log_dir = f'logs/{model_name}'
history = model.fit(
    train_dataset, epochs=30, validation_data=validation_dataset, 
    callbacks=[
        keras.callbacks.ModelCheckpoint(checkpoint_filepath, save_best_only=True),
        keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True),
        keras.callbacks.TensorBoard(log_dir)
    ]
)
Epoch 1/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 4s 45ms/step - acc: 0.7688 - loss: 1.2706 - val_acc: 0.8990 - val_loss: 0.4341
Epoch 2/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - acc: 0.9001 - loss: 0.4032 - val_acc: 0.9160 - val_loss: 0.3278
Epoch 3/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - acc: 0.9263 - loss: 0.2974 - val_acc: 0.9270 - val_loss: 0.2763
Epoch 4/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - acc: 0.9460 - loss: 0.2100 - val_acc: 0.9380 - val_loss: 0.2612
Epoch 5/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - acc: 0.9483 - loss: 0.1799 - val_acc: 0.9410 - val_loss: 0.2457
Epoch 6/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 32ms/step - acc: 0.9556 - loss: 0.1559 - val_acc: 0.9460 - val_loss: 0.2482
Epoch 7/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - acc: 0.9628 - loss: 0.1085 - val_acc: 0.9470 - val_loss: 0.2408
Epoch 8/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - acc: 0.9734 - loss: 0.0932 - val_acc: 0.9480 - val_loss: 0.2323
Epoch 9/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - acc: 0.9748 - loss: 0.0916 - val_acc: 0.9460 - val_loss: 0.2353
Epoch 10/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - acc: 0.9723 - loss: 0.0721 - val_acc: 0.9450 - val_loss: 0.2311
Epoch 11/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 34ms/step - acc: 0.9794 - loss: 0.0638 - val_acc: 0.9460 - val_loss: 0.2311
Epoch 12/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 32ms/step - acc: 0.9838 - loss: 0.0501 - val_acc: 0.9490 - val_loss: 0.2331
Epoch 13/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 32ms/step - acc: 0.9829 - loss: 0.0432 - val_acc: 0.9450 - val_loss: 0.2382
Epoch 14/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 33ms/step - acc: 0.9838 - loss: 0.0451 - val_acc: 0.9460 - val_loss: 0.2321
best_model = keras.models.load_model('checkpoints/cats_dogs_vgg16.keras')
metrics = model.evaluate(test_dataset, return_dict=True, verbose=0)
print('[Test]')
print_scores(metrics)
[Test]
Loss: 0.211	acc: 0.949	

3미세조정 (Fine-Tuning)

미세 조정은 사전 훈련된 모형을 특정 응용 작업에 적합하도록 추가로 학습시키는 과정입니다. 이때 모형의 가중치는 새로운 데이터셋에 맞춰 조정되며, 자연어 처리에서라면 단어와 문장을 벡터 공간에 표현하는 임베딩(embedding) 층 또한 함께 조정됩니다.

임베딩 층은 사전 훈련 과정에서 학습한 지식, 즉 단어 사이의 관계나 문맥 정보를 담고 있습니다. 미세 조정을 거치면 이 지식이 특정 도메인의 언어 사용 특성을 더 잘 반영하도록 다듬어집니다. 결과적으로 사전 훈련된 모형이 지닌 일반적 표현이 새로운 작업에 더 적합하게 조정되면서 전체 모형의 성능 향상에 기여합니다.

from keras.applications.vgg16 import VGG16, preprocess_input

base_model = VGG16(include_top=False)
base_model.trainable = True
# 마지막 세 개를 제외한 나머지 층들은 가중치 고정
for layer in base_model.layers[:-3]:
    layer.trainable = False

keras.backend.clear_session()
model = create_model(input_shape=image_size + (3,))
# model.summary()
model.compile(
    optimizer=keras.optimizers.Adam(1e-5), # 학습률 낮춤 (1/100) 
    loss=keras.losses.BinaryCrossentropy(from_logits=True), 
    metrics=['acc'])

model_name = 'cats_dogs_vgg16-fine-tuned'
checkpoint_filepath = f'checkpoints/{model_name}.keras'
log_dir = f'logs/{model_name}'
history = model.fit(
    train_dataset, epochs=30, validation_data=validation_dataset, 
    callbacks=[
        keras.callbacks.ModelCheckpoint(checkpoint_filepath, save_best_only=True),
        keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True),
        keras.callbacks.TensorBoard(log_dir)
    ]
)
Epoch 1/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 5s 53ms/step - acc: 0.7896 - loss: 1.2397 - val_acc: 0.9150 - val_loss: 0.3467
Epoch 2/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 0.9472 - loss: 0.1817 - val_acc: 0.9240 - val_loss: 0.2912
Epoch 3/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 39ms/step - acc: 0.9811 - loss: 0.0586 - val_acc: 0.9280 - val_loss: 0.2708
Epoch 4/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 0.9938 - loss: 0.0219 - val_acc: 0.9310 - val_loss: 0.2689
Epoch 5/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 0.9994 - loss: 0.0110 - val_acc: 0.9250 - val_loss: 0.2686
Epoch 6/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 0.0072 - val_acc: 0.9260 - val_loss: 0.2660
Epoch 7/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 39ms/step - acc: 1.0000 - loss: 0.0057 - val_acc: 0.9260 - val_loss: 0.2631
Epoch 8/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 0.0042 - val_acc: 0.9280 - val_loss: 0.2609
Epoch 9/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 37ms/step - acc: 1.0000 - loss: 0.0035 - val_acc: 0.9290 - val_loss: 0.2595
Epoch 10/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 0.0030 - val_acc: 0.9310 - val_loss: 0.2577
Epoch 11/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 0.0025 - val_acc: 0.9320 - val_loss: 0.2566
Epoch 12/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 0.0022 - val_acc: 0.9330 - val_loss: 0.2556
Epoch 13/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 0.0018 - val_acc: 0.9350 - val_loss: 0.2543
Epoch 14/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 0.0016 - val_acc: 0.9370 - val_loss: 0.2536
Epoch 15/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 39ms/step - acc: 1.0000 - loss: 0.0015 - val_acc: 0.9370 - val_loss: 0.2527
Epoch 16/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 0.0013 - val_acc: 0.9370 - val_loss: 0.2525
Epoch 17/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 0.0012 - val_acc: 0.9380 - val_loss: 0.2522
Epoch 18/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 0.0011 - val_acc: 0.9380 - val_loss: 0.2515
Epoch 19/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - acc: 1.0000 - loss: 9.6643e-04 - val_acc: 0.9370 - val_loss: 0.2516
Epoch 20/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 9.0749e-04 - val_acc: 0.9380 - val_loss: 0.2512
Epoch 21/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - acc: 1.0000 - loss: 8.4548e-04 - val_acc: 0.9380 - val_loss: 0.2513
Epoch 22/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 7.1048e-04 - val_acc: 0.9380 - val_loss: 0.2511
Epoch 23/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 6.5007e-04 - val_acc: 0.9400 - val_loss: 0.2506
Epoch 24/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 39ms/step - acc: 1.0000 - loss: 5.9501e-04 - val_acc: 0.9400 - val_loss: 0.2503
Epoch 25/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 5.4211e-04 - val_acc: 0.9400 - val_loss: 0.2503
Epoch 26/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 5.4306e-04 - val_acc: 0.9400 - val_loss: 0.2503
Epoch 27/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - acc: 1.0000 - loss: 4.8749e-04 - val_acc: 0.9410 - val_loss: 0.2503
Epoch 28/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 39ms/step - acc: 1.0000 - loss: 4.6722e-04 - val_acc: 0.9400 - val_loss: 0.2500
Epoch 29/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 38ms/step - acc: 1.0000 - loss: 4.3405e-04 - val_acc: 0.9400 - val_loss: 0.2499
Epoch 30/30
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 35ms/step - acc: 1.0000 - loss: 4.0452e-04 - val_acc: 0.9400 - val_loss: 0.2500
[Test]
Loss: 0.211	acc: 0.949	
best_model = keras.models.load_model('checkpoints/cats_dogs_vgg16-fine-tuned.keras')
metrics = model.evaluate(test_dataset, return_dict=True)
print('[Test]')
print_scores(metrics)
63/63 ━━━━━━━━━━━━━━━━━━━━ 2s 21ms/step - acc: 0.9540 - loss: 0.1864
[Test]
Loss: 0.211	acc: 0.949