티스토리 뷰

 

Pytorch의 Dataset과 Dataloader

PyTorch는 데이터를 다루기 위해 두 가지 클래스인 torch.utils.data.DataLoadertorch.utils.data.Dataset 를 제공한다. 두 가지 데이터 기본 요소를 제공하여 미리 준비해둔(pre-loaded) 데이터셋 뿐만 아니라 가지고 있는 데이터를 사용할 수 있다. Dataset은 데이터셋을 정의하는 클래스로 샘플과 정답(label)을 저장하고, DataLoader는  Dataset을 감싸서 미니배치 학습, suffling 등을 쉽게 처리할 수 있도록 도와준다. 즉, Dataset은 데이터의 구조를 정의하고, DataLoader는 학습에 적합하게 데이터를 꺼내는 역할이다.

 

Dataset 다루기

PyTorch는 torchvision.datasets 모듈에서 여러 데이터셋을 미리 제공합니다. 예를 들어 Fashion-MNIST는 28×28 크기의 흑백 이미지(의류 이미지)와 그에 해당하는 10개의 클래스 레이블로 구성되어 있습니다.

from torchvision import datasets
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(
    root="data",       # 데이터가 저장될 위치
    train=True,        # 학습용 데이터셋
    download=True,     # 없으면 다운로드
    transform=ToTensor()  # 텐서 형태로 변환(특징(feature)과 정답(label) 변형(transform)을 지정)
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,       # 테스트용 데이터셋
    download=True,
    transform=ToTensor()
)

 

Dataset 커스터마이징

직접 데이터셋을 정의하려면 torch.utils.data.Dataset 을 상속받아 3가지 메서드를 구현해야 한다.

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self):
        # 데이터 경로, 파일명, transform 정의
        pass  

    def __len__(self):
        # 전체 데이터 개수 반환
        pass  

    def __getitem__(self, idx):
        # idx번째 샘플과 라벨 반환
        # (전처리, 데이터 증강 등을 수행)
        pass

 

  • __init__ : 데이터 위치 지정, transform 정의
  • __len__ : 전체 데이터 개수 반환 (len(dataset) 호출 시 사용됨)
  • __getitem__ : dataset[i] 호출 시 i번째 샘플과 레이블 반환

모든 데이터를 한꺼번에 메모리에 올리지 않고, 필요한 샘플만 불러오기 때문에 효율적이다.

 

DataLoader 다루기

모델 학습 시에는 데이터를 미니배치(minibatch) 단위로 꺼내오고,
매 에폭마다 데이터를 섞어주며(shuffle),
빠른 데이터 로딩을 위해 멀티프로세싱(num_workers) 을 사용할 수 있습니다.

이 역할을 간단하게 처리해주는 것이 DataLoader이다.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

 

DataLoader의 기본 구성 요소

DataLoader는 모델 학습을 위해서 데이터를 미니 배치 단위로 제공해주는 역할을 한다.

DataLoader(
    dataset, batch_size=1, shuffle=False,
    sampler=None, batch_sampler=None,
    num_workers=0, collate_fn=None,
    pin_memory=False, drop_last=False
)

 

  • dataset : 불러올 Dataset 객체
  • batch_size : 한 배치에 담을 데이터 개수 (기본값=1)
  • shuffle : 매 에폭마다 데이터를 섞을지 여부
  • sampler : 데이터 샘플링 방식 직접 정의 (shuffle과 동시에 사용 불가)
  • num_workers : 데이터를 불러올 때 사용할 서브 프로세스 개수
  • collate_fn : 배치 단위로 데이터를 합치는 방식 정의 (예: 패딩 필요 시 활용)
  • drop_last : 마지막 배치가 batch_size보다 작을 경우 버릴지 여부

 

공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2026/04   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30
글 보관함