Skip to article frontmatterSkip to article content
Site not loading correctly?

This may be due to an incorrect BASE_URL configuration. See the MyST Documentation for reference.

TensorFlow Dataset

import tensorflow as tf

1Oxford IIIT Pet 데이터

from pathlib import Path

dir_path = Path('data/oxford-iiit-pet')
img_files = list((dir_path / 'images').glob('*.jpg'))
mask_files = list((dir_path / 'annotations' / 'trimaps').glob('*.png'))
img_files.sort()
mask_files.sort()

print(f'파일수: {len(img_files)}')
assert len(img_files) == len(mask_files), '이미지 파일과 목표 파일의 개수가 다릅니다.'
for filepath1, filepath2 in zip(img_files, mask_files):
    assert filepath1.stem == filepath2.stem, f'파일이 짝이 맞지 않습니다: {filepath1}, {filepath2}'
from PIL import Image

이상파일목록 = []
for 파일경로 in img_files:
    with Image.open(파일경로) as 이미지:
        try:
            이미지.verify()
            assert 이미지.mode == 'RGB'
        except:
            이상파일목록.append(파일경로.stem)

# 파일형식이 RGB가 아닌 파일
print(f'{len(이상파일목록)}/{len(img_files)}')
print(이상파일목록[:5])

2tf.Dataset

import numpy as np
import tensorflow as tf

def transform(img_filepath, mask_filepath):
    img = tf.io.read_file(img_filepath)
    img = tf.image.decode_jpeg(img, channels=3)

    mask = tf.io.read_file(mask_filepath)
    mask = tf.image.decode_png(mask, channels=1)

    img = tf.image.resize(img, (200, 200))
    mask = tf.image.resize(mask, (200, 200))
    mask = tf.cast(mask, tf.uint8) - 1  # 1, 2, 3 값을 0, 1, 2로 변경

    return img, mask

dir_path = Path('data/oxford-iiit-pet')

img_dataset = tf.data.Dataset.list_files(str(dir_path / 'images' / '*.jpg'), shuffle=False)
mask_dataset = tf.data.Dataset.list_files(str(dir_path / 'annotations' / 'trimaps' / '*.png'), shuffle=False)

# 이상파일 필터
def get_stem(path: tf.Tensor) -> tf.Tensor:
    """
    path: b'/some/dir/Abyssinian_12.jpg'
    return: b'Abyssinian_12'
    """
    # 1) 디렉토리 부분 제거 (슬래시/역슬래시 모두 처리)
    filename = tf.strings.regex_replace(path, r'.*[\\/]', '')
    # 2) 확장자 제거 (마지막 점부터 끝까지)
    stem = tf.strings.regex_replace(filename, r'\.[^\.]+$', '')
    return stem

이상파일목록 = tf.constant(이상파일목록)
이상탐지 = lambda x: tf.logical_not(tf.reduce_any(tf.equal(get_stem(x), 이상파일목록)))

print(f'before: {sum(1 for _ in img_dataset)}')
img_dataset = img_dataset.filter(이상탐지)
mask_dataset = mask_dataset.filter(이상탐지)
print(f'after: {sum(1 for _ in img_dataset)}')

# 파일 이름 쌍 확인
print("파일 쌍 확인")
for (img_path, mask_path) in zip(img_dataset, mask_dataset):
    img_name = Path(img_path.numpy().decode()).stem
    mask_name = Path(mask_path.numpy().decode()).stem
    assert img_name == mask_name, f"파일 이름 불일치: {img_name} != {mask_name}"


tf_dataset = tf.data.Dataset.zip((img_dataset, mask_dataset))
tf_dataset = tf_dataset.map(lambda x, y: transform(x, y))

for sample, target in tf_dataset.take(1):
    assert sample.shape == (200, 200, 3)
    assert target.shape == (200, 200, 1)
    # 마스크 값: 0, 1, 2 범위 확인
    min_val = tf.reduce_min(target)
    max_val = tf.reduce_max(target)
    assert min_val >= 0 and max_val < 3, f"마스크 값 범위 오류: [{min_val}, {max_val}]"

3배치 생성

batch_size = 32
batch_dataset = tf_dataset.batch(batch_size).shuffle(buffer_size=100)

for batch_data, batch_target in batch_dataset:
    try:
        assert batch_data.shape == (batch_size, 200, 200, 3), batch_data.shape
        assert batch_target.shape == (batch_size, 200, 200, 1), batch_target.shape
    except Exception as e:
        # 배치 크기가 맞지 않는 경우 출력
        # 마지막 배치는 크기가 다를 수 있음
        print(e)

3.1배치별 소요 시간

import time

def transform(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, (200, 200))
    return img

dir_path = Path('data/oxford-iiit-pet')
files = list((dir_path / 'images').glob('*.jpg'))
print(f'파일수: {len(files)}')

tf_dataset = tf.data.Dataset.list_files([str(filepath) for filepath in files])
tf_dataset = tf_dataset.map(transform, num_parallel_calls=tf.data.AUTOTUNE)
tf_dataset = tf_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

results = {}
for batch_size in [1, 8, 32, 64, 128, 256, 512, 1024]:
    print(f'배치 크기: {batch_size} 에폭당 배치수: {len(tf_dataset) // batch_size}')
    batch_dataset = tf_dataset.batch(batch_size).shuffle(buffer_size=batch_size * 2)
    
    start_time = time.time()
    for batch in batch_dataset:
        pass

    duration = time.time() - start_time
    print(f'소요 시간: {duration:.2f} 초\n')
    results[batch_size] = duration
import pandas as pd

frame = pd.DataFrame.from_dict(results, orient='index', columns=['소요 시간 (초)'])
frame.T.round(1)