93 lines
3.1 KiB
Python
93 lines
3.1 KiB
Python
import json
|
|
import os
|
|
import random
|
|
|
|
import cv2
|
|
from PIL import Image
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
class ClsDataset(Dataset):
|
|
def __init__(self, images_path, images_labels, transform=None):
|
|
self.images_path = images_path
|
|
self.images_labels = images_labels
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return len(self.images_path)
|
|
|
|
def __getitem__(self, idx):
|
|
img = cv2.imread(self.images_path[idx])
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
img = Image.fromarray(img)
|
|
label = self.images_labels[idx]
|
|
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
return img, label
|
|
|
|
@staticmethod
|
|
def collate_fn(batch):
|
|
images, labels = tuple(zip(*batch))
|
|
images = torch.stack(images, dim=0)
|
|
labels = torch.as_tensor(labels)
|
|
return images, labels
|
|
|
|
|
|
def read_split_data(root, valid_rate=0.2):
|
|
assert os.path.exists(root), 'dataset root: {} does not exist.'.format(root)
|
|
|
|
class_names = [cls for cls in os.listdir(root) if os.path.isdir(os.path.join(root, cls))]
|
|
class_names.sort()
|
|
|
|
class_indices = {name: i for i, name in enumerate(class_names)}
|
|
json_str = json.dumps({v: k for k, v in class_indices.items()}, indent=4)
|
|
with open('class_indices.json', 'w') as f:
|
|
f.write(json_str)
|
|
|
|
train_images_path = []
|
|
train_labels = []
|
|
val_images_path = []
|
|
val_labels = []
|
|
per_class_num = []
|
|
|
|
supported = ['.jpg', '.JPG', '.png', '.PNG']
|
|
for cls in class_names:
|
|
cls_path = os.path.join(root, cls)
|
|
images_path = [os.path.join(cls_path, i) for i in os.listdir(cls_path)
|
|
if os.path.splitext(i)[-1] in supported]
|
|
images_label = class_indices[cls]
|
|
per_class_num.append(len(images_path))
|
|
|
|
val_path = random.sample(images_path, int(len(images_path) * valid_rate))
|
|
for img_path in images_path:
|
|
if img_path in val_path:
|
|
val_images_path.append(img_path)
|
|
val_labels.append(images_label)
|
|
else:
|
|
train_images_path.append(img_path)
|
|
train_labels.append(images_label)
|
|
|
|
print("{} images were found in the dataset.".format(sum(per_class_num)))
|
|
print("{} images for training.".format(len(train_images_path)))
|
|
print("{} images for validation.".format(len(val_images_path)))
|
|
|
|
assert len(train_images_path) > 0, "number of training images must greater than zero"
|
|
assert len(val_images_path) > 0, "number of validation images must greater than zero"
|
|
|
|
plot_distribution = False
|
|
if plot_distribution:
|
|
plt.bar(range(len(class_names)), per_class_num, align='center')
|
|
plt.xticks(range(len(class_names)), class_names)
|
|
|
|
for i, v in enumerate(per_class_num):
|
|
plt.text(x=i, y=v + 5, s=str(v), ha='center')
|
|
|
|
plt.xlabel('classes')
|
|
plt.ylabel('numbers')
|
|
plt.title('the distribution of dataset')
|
|
plt.show()
|
|
return [train_images_path, train_labels], [val_images_path, val_labels], len(class_names)
|