History

2023.11.15. 초안 작성

torch.utils.data.Subset 을 사용하면 __len__이 호출되지 않습니다.

어떻게 알았냐면 거기에 버그가 있었는데 잘 돌아가다가 Subset 빼니까 터짐. 저도 알고 싶지 않았어요.

관련 실험 스크립트

"""
python scripts/test_subset_and_len.py \\
2>&1 | tee output/test_subset_and_len.log
"""
import torch

class ExampleDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.data = list(range(10))

    def __len__(self):
        print(f"[DEBUG] __len__() is called")
        return len(self.data)

    def __getitem__(self, index):
        print(f"[DEBUG] __getitem__() is called")
        return self.data[index]

if __name__ == '__main__':
    print(f"[DEBUG] no subset experiment")
    dataset = ExampleDataset()
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=False)
    for i, batch in enumerate(dataloader):
        print(f"[DEBUG] no subset i: {i}, batch: {batch}")
    print("=" * 40)
    print(f"[DEBUG] subset experiment")
    dataset_subset = torch.utils.data.Subset(dataset, indices=range(3, 7))
    dataloader_subset = torch.utils.data.DataLoader(dataset_subset, batch_size=2, shuffle=False)
    for i, batch in enumerate(dataloader_subset):
        print(f"[DEBUG] subset i: {i}, batch: {batch}")

결과

[DEBUG] no subset experiment
[DEBUG] __len__() is called
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 0, batch: tensor([0, 1])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 1, batch: tensor([2, 3])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 2, batch: tensor([4, 5])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 3, batch: tensor([6, 7])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] no subset i: 4, batch: tensor([8, 9])
========================================
[DEBUG] subset experiment
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] subset i: 0, batch: tensor([3, 4])
[DEBUG] __getitem__() is called
[DEBUG] __getitem__() is called
[DEBUG] subset i: 1, batch: tensor([5, 6])