학습을 진행하며 특정 조건들을 수집 후 해당 조건을 가지고 기존 데이터 셋에서 데이터를 삭제할 일이 생김.
데이터셋을 불러오는 코드부터 전체를 짜긴 귀찮으니.. torchvision에서 제공하고 있는 torchvision.datasets.imagefolder를 이용하자!
먼저 해당 페이지에서 코드를 가져오자 github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
ImageFolder를 사용하는 방법은 간단하다.
- 데이터가 있는 경로만 넣어주면 바로 dataset 형태로 만들어지며, DataLoader에 넘겨주기만 하면 된다.
- 보통의 경우 x(input), y(target) 데이터를 받아서 학습을 시켜주게 된다.
- 여기서 우리는 x,y 값 말고도 다른 정보를 얻고 싶고, 정보를 얻어오는 방법과 데이터를 핸들링하는 방법을 알아본다.
- 데이터셋의 경로 구조
----/test
---class1
---class2
import imagefolder
test_dataset = imagefolder.ImageFolder('./test', transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=1,
shuffle=False,
num_workers=4)
for i, (input, target) in enumerate(train_loader):
print(input, tartget)
수정 한 부부은 # edit 으로 표시했고 추가한 기능은 다음과 같다.
- loader에서 데이터의 index와 파일명을 추가로 받아옴
- 특정 조건을 만족할 경우 데이터 셋 삭제
수정한 코드내용
- self.samples, self.targets : 기존에 tuple 형식이던 데이터를 각각 list 형태로 받아옴
- self.file_names : file_name을 가지고 데이터를 삭제할 것이기 때문에 추가해주고 numpy array 형식으로 변환
- remove_data : 해당 함수를 말들어 filename이 같을 경우에 데이터를 삭제해줌 -> x, y 둘다 삭제
- __getitem__ : path, target을 각각 받도록 수정, return 값들 추가 (index, filename)
- 해당 함수의 경우 index를 인자로 받아서 해당 index에 해당하는 파일들만 return 해줌
from torchvision.datasets.vision import VisionDataset
from PIL import Image
import numpy as np
import os
import os.path
class DatasetFolder(VisionDataset):
def __init__(self, root, loader, extensions=None, transform=None,
target_transform=None, is_valid_file=None):
super(DatasetFolder, self).__init__(root, transform=transform,
target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root)
samples, file_names = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
if extensions is not None:
msg += "Supported extensions are: {}".format(",".join(extensions))
raise RuntimeError(msg)
self.loader = loader
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
# edit
self.file_names = np.array(file_names)
self.samples = [s[0] for s in samples]
self.targets = [s[1] for s in samples]
# edit
def remove_data(self, filename):
idx = np.where(self.file_names == filename)[0]
self.targets = np.delete(self.targets, idx)
self.samples = np.delete(self.samples, idx)
def __getitem__(self, index):
# edit
path = self.samples[index]
target = self.targets[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
# edit
return sample, target, index, self.file_names[index]
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def __len__(self):
return len(self.samples)
datasets에 접근
- DatasetFolder의 init 초기화 함수에서 설정했던 self 변수들에 접근이 가능하고
- remove를 했을 경우 데이터셋이 삭제된다.
import imagefolder
test_dataset = imagefolder.ImageFolder('./test', transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_dataset,
batch_size=1,
shuffle=False,
num_workers=4)
print(len(test_dataset.samples), len(test_dataset.targets))
test_dataset.remove_data('20200401_161246.jpg')
print(len(test_dataset.samples), len(test_dataset.targets))
for i, (input, target, idx, file_name) in enumerate(train_loader):
print(input, tartget, idx, file_name)