93 lines
3.1 KiB
Python

import json
import os
import random
import cv2
from PIL import Image
import torch
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
class ClsDataset(Dataset):
def __init__(self, images_path, images_labels, transform=None):
self.images_path = images_path
self.images_labels = images_labels
self.transform = transform
def __len__(self):
return len(self.images_path)
def __getitem__(self, idx):
img = cv2.imread(self.images_path[idx])
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img)
label = self.images_labels[idx]
if self.transform is not None:
img = self.transform(img)
return img, label
@staticmethod
def collate_fn(batch):
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
def read_split_data(root, valid_rate=0.2):
assert os.path.exists(root), 'dataset root: {} does not exist.'.format(root)
class_names = [cls for cls in os.listdir(root) if os.path.isdir(os.path.join(root, cls))]
class_names.sort()
class_indices = {name: i for i, name in enumerate(class_names)}
json_str = json.dumps({v: k for k, v in class_indices.items()}, indent=4)
with open('class_indices.json', 'w') as f:
f.write(json_str)
train_images_path = []
train_labels = []
val_images_path = []
val_labels = []
per_class_num = []
supported = ['.jpg', '.JPG', '.png', '.PNG']
for cls in class_names:
cls_path = os.path.join(root, cls)
images_path = [os.path.join(cls_path, i) for i in os.listdir(cls_path)
if os.path.splitext(i)[-1] in supported]
images_label = class_indices[cls]
per_class_num.append(len(images_path))
val_path = random.sample(images_path, int(len(images_path) * valid_rate))
for img_path in images_path:
if img_path in val_path:
val_images_path.append(img_path)
val_labels.append(images_label)
else:
train_images_path.append(img_path)
train_labels.append(images_label)
print("{} images were found in the dataset.".format(sum(per_class_num)))
print("{} images for training.".format(len(train_images_path)))
print("{} images for validation.".format(len(val_images_path)))
assert len(train_images_path) > 0, "number of training images must greater than zero"
assert len(val_images_path) > 0, "number of validation images must greater than zero"
plot_distribution = False
if plot_distribution:
plt.bar(range(len(class_names)), per_class_num, align='center')
plt.xticks(range(len(class_names)), class_names)
for i, v in enumerate(per_class_num):
plt.text(x=i, y=v + 5, s=str(v), ha='center')
plt.xlabel('classes')
plt.ylabel('numbers')
plt.title('the distribution of dataset')
plt.show()
return [train_images_path, train_labels], [val_images_path, val_labels], len(class_names)