티스토리 뷰

Pytorch의 Dataset과 Dataloader
PyTorch는 데이터를 다루기 위해 두 가지 클래스인 torch.utils.data.DataLoader 와 torch.utils.data.Dataset 를 제공한다. 두 가지 데이터 기본 요소를 제공하여 미리 준비해둔(pre-loaded) 데이터셋 뿐만 아니라 가지고 있는 데이터를 사용할 수 있다. Dataset은 데이터셋을 정의하는 클래스로 샘플과 정답(label)을 저장하고, DataLoader는 Dataset을 감싸서 미니배치 학습, suffling 등을 쉽게 처리할 수 있도록 도와준다. 즉, Dataset은 데이터의 구조를 정의하고, DataLoader는 학습에 적합하게 데이터를 꺼내는 역할이다.
Dataset 다루기
PyTorch는 torchvision.datasets 모듈에서 여러 데이터셋을 미리 제공합니다. 예를 들어 Fashion-MNIST는 28×28 크기의 흑백 이미지(의류 이미지)와 그에 해당하는 10개의 클래스 레이블로 구성되어 있습니다.
from torchvision import datasets
from torchvision.transforms import ToTensor
training_data = datasets.FashionMNIST(
root="data", # 데이터가 저장될 위치
train=True, # 학습용 데이터셋
download=True, # 없으면 다운로드
transform=ToTensor() # 텐서 형태로 변환(특징(feature)과 정답(label) 변형(transform)을 지정)
)
test_data = datasets.FashionMNIST(
root="data",
train=False, # 테스트용 데이터셋
download=True,
transform=ToTensor()
)
Dataset 커스터마이징
직접 데이터셋을 정의하려면 torch.utils.data.Dataset 을 상속받아 3가지 메서드를 구현해야 한다.
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self):
# 데이터 경로, 파일명, transform 정의
pass
def __len__(self):
# 전체 데이터 개수 반환
pass
def __getitem__(self, idx):
# idx번째 샘플과 라벨 반환
# (전처리, 데이터 증강 등을 수행)
pass
- __init__ : 데이터 위치 지정, transform 정의
- __len__ : 전체 데이터 개수 반환 (len(dataset) 호출 시 사용됨)
- __getitem__ : dataset[i] 호출 시 i번째 샘플과 레이블 반환
모든 데이터를 한꺼번에 메모리에 올리지 않고, 필요한 샘플만 불러오기 때문에 효율적이다.
DataLoader 다루기
모델 학습 시에는 데이터를 미니배치(minibatch) 단위로 꺼내오고,
매 에폭마다 데이터를 섞어주며(shuffle),
빠른 데이터 로딩을 위해 멀티프로세싱(num_workers) 을 사용할 수 있습니다.
이 역할을 간단하게 처리해주는 것이 DataLoader이다.
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
DataLoader의 기본 구성 요소
DataLoader는 모델 학습을 위해서 데이터를 미니 배치 단위로 제공해주는 역할을 한다.
DataLoader(
dataset, batch_size=1, shuffle=False,
sampler=None, batch_sampler=None,
num_workers=0, collate_fn=None,
pin_memory=False, drop_last=False
)
- dataset : 불러올 Dataset 객체
- batch_size : 한 배치에 담을 데이터 개수 (기본값=1)
- shuffle : 매 에폭마다 데이터를 섞을지 여부
- sampler : 데이터 샘플링 방식 직접 정의 (shuffle과 동시에 사용 불가)
- num_workers : 데이터를 불러올 때 사용할 서브 프로세스 개수
- collate_fn : 배치 단위로 데이터를 합치는 방식 정의 (예: 패딩 필요 시 활용)
- drop_last : 마지막 배치가 batch_size보다 작을 경우 버릴지 여부
'파이썬 > Pytorch' 카테고리의 다른 글
| [Pytorch] Tensor 인덱싱(Indexing), 슬라이싱(Slicing) (1) | 2025.08.13 |
|---|---|
| [Pytorch] Tensor 생성, 속성 그리고 연산 방법 (3) | 2025.08.08 |
- Total
- Today
- Yesterday
- 손실함수
- 비용함수
- ndarray
- ML 종류
- python
- baekjoon
- 강의노트 정리
- **
- 딥러닝
- 숏코딩
- *args
- **kwargs
- NumPy
- ML Process
- Action spaces
- cnn
- 클래스 총 정리
- ML
- 경사하강법
- 파이썬
- *
- Sort
- 강화학습
- Andrew Ng
- ML 프로세스
- 머신러닝
- 앤드류응
- 로지스틱 회귀
- sorted
- 백준
| 일 | 월 | 화 | 수 | 목 | 금 | 토 |
|---|---|---|---|---|---|---|
| 1 | 2 | 3 | 4 | |||
| 5 | 6 | 7 | 8 | 9 | 10 | 11 |
| 12 | 13 | 14 | 15 | 16 | 17 | 18 |
| 19 | 20 | 21 | 22 | 23 | 24 | 25 |
| 26 | 27 | 28 | 29 | 30 |
