학습을 진행하며 특정 조건들을 수집 후 해당 조건을 가지고 기존 데이터 셋에서 데이터를 삭제할 일이 생김.

데이터셋을 불러오는 코드부터 전체를 짜긴 귀찮으니.. torchvision에서 제공하고 있는 torchvision.datasets.imagefolder를 이용하자!

 

먼저 해당 페이지에서 코드를 가져오자  github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py

 

ImageFolder를 사용하는 방법은 간단하다.

- 데이터가 있는 경로만 넣어주면 바로 dataset 형태로 만들어지며, DataLoader에 넘겨주기만 하면 된다.

- 보통의 경우 x(input), y(target) 데이터를 받아서 학습을 시켜주게 된다. 

- 여기서 우리는 x,y 값 말고도 다른 정보를 얻고 싶고, 정보를 얻어오는 방법과 데이터를 핸들링하는 방법을 알아본다.

- 데이터셋의 경로 구조

----/test

           ---class1

           ---class2

import imagefolder

test_dataset = imagefolder.ImageFolder('./test', transform=test_transform)


test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=1,
                                          shuffle=False,
                                          num_workers=4)
                                          
for i, (input, target) in enumerate(train_loader):
	print(input, tartget)

수정 한 부부은 # edit 으로 표시했고 추가한 기능은 다음과 같다.

- loader에서 데이터의 index와 파일명을 추가로 받아옴

- 특정 조건을 만족할 경우 데이터 셋 삭제

 

수정한 코드내용

- self.samples, self.targets : 기존에 tuple 형식이던 데이터를 각각 list 형태로 받아옴

- self.file_names : file_name을 가지고 데이터를 삭제할 것이기 때문에 추가해주고 numpy array 형식으로 변환

- remove_data : 해당 함수를 말들어 filename이 같을 경우에 데이터를 삭제해줌 -> x, y 둘다 삭제

- __getitem__ : path, target을 각각 받도록 수정, return 값들 추가 (index, filename)

       - 해당 함수의 경우 index를 인자로 받아서 해당 index에 해당하는 파일들만 return 해줌

from torchvision.datasets.vision import VisionDataset
from PIL import Image

import numpy as np
import os
import os.path

class DatasetFolder(VisionDataset):
    def __init__(self, root, loader, extensions=None, transform=None,
                 target_transform=None, is_valid_file=None):
        super(DatasetFolder, self).__init__(root, transform=transform,
                                            target_transform=target_transform)
        classes, class_to_idx = self._find_classes(self.root)
        samples, file_names = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
        if len(samples) == 0:
            msg = "Found 0 files in subfolders of: {}\n".format(self.root)
            if extensions is not None:
                msg += "Supported extensions are: {}".format(",".join(extensions))
            raise RuntimeError(msg)
        self.loader = loader
        self.extensions = extensions
        self.classes = classes
        self.class_to_idx = class_to_idx

        # edit
        self.file_names = np.array(file_names)
        self.samples = [s[0] for s in samples]
        self.targets = [s[1] for s in samples]

	# edit
    def remove_data(self, filename):
        idx = np.where(self.file_names == filename)[0]
        self.targets = np.delete(self.targets, idx)
        self.samples = np.delete(self.samples, idx)
	
    def __getitem__(self, index):
      # edit
      path = self.samples[index]
      target = self.targets[index]

      sample = self.loader(path)

      if self.transform is not None:
      sample = self.transform(sample)
      if self.target_transform is not None:
      target = self.target_transform(target)

      # edit
      return sample, target, index, self.file_names[index]

    def _find_classes(self, dir):
        """
        Finds the class folders in a dataset.

        Args:
            dir (string): Root directory path.

        Returns:
            tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.

        Ensures:
            No class is a subdirectory of another.
        """
        classes = [d.name for d in os.scandir(dir) if d.is_dir()]
        classes.sort()
        class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
        return classes, class_to_idx

    def __len__(self):
        return len(self.samples)

datasets에 접근

- DatasetFolder의 init 초기화 함수에서 설정했던 self 변수들에 접근이 가능하고

- remove를 했을 경우 데이터셋이 삭제된다.

import imagefolder

test_dataset = imagefolder.ImageFolder('./test', transform=test_transform)

test_loader = torch.utils.data.DataLoader(test_dataset,
                                          batch_size=1,
                                          shuffle=False,
                                          num_workers=4)

print(len(test_dataset.samples), len(test_dataset.targets))
test_dataset.remove_data('20200401_161246.jpg')
print(len(test_dataset.samples), len(test_dataset.targets))
                                          
for i, (input, target, idx, file_name) in enumerate(train_loader):
	print(input, tartget, idx, file_name)

 

+ Recent posts