Skip to content

how do you calculate the mean and std? It seems that it is different when using dataset to calculate #36

@kekekeke8

Description

@kekekeke8

def get_dataset(dataset, data_path):
if dataset == 'MNIST':
channel = 1
im_size = (28, 28)
num_classes = 10
mean = [0.1307]
std = [0.3081]
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # no augmentation
dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
class_names = [str(c) for c in range(num_classes)]

elif dataset == 'FashionMNIST':
    channel = 1
    im_size = (28, 28)
    num_classes = 10
    mean = [0.2861]
    std = [0.3530]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    dst_train = datasets.FashionMNIST(data_path, train=True, download=True, transform=transform) # no augmentation
    dst_test = datasets.FashionMNIST(data_path, train=False, download=True, transform=transform)
    class_names = dst_train.classes

elif dataset == 'SVHN':
    channel = 3
    im_size = (32, 32)
    num_classes = 10
    mean = [0.4377, 0.4438, 0.4728]
    std = [0.1980, 0.2010, 0.1970]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    dst_train = datasets.SVHN(data_path, split='train', download=True, transform=transform)  # no augmentation
    dst_test = datasets.SVHN(data_path, split='test', download=True, transform=transform)
    class_names = [str(c) for c in range(num_classes)]

elif dataset == 'CIFAR10':
    channel = 3
    im_size = (32, 32)
    num_classes = 10
    mean = [0.4914, 0.4822, 0.4465]
    std = [0.2023, 0.1994, 0.2010]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    dst_train = datasets.CIFAR10(data_path, train=True, download=True, transform=transform) # no augmentation
    dst_test = datasets.CIFAR10(data_path, train=False, download=True, transform=transform)
    class_names = dst_train.classes

elif dataset == 'CIFAR100':
    channel = 3
    im_size = (32, 32)
    num_classes = 100
    mean = [0.5071, 0.4866, 0.4409]
    std = [0.2673, 0.2564, 0.2762]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    dst_train = datasets.CIFAR100(data_path, train=True, download=True, transform=transform) # no augmentation
    dst_test = datasets.CIFAR100(data_path, train=False, download=True, transform=transform)
    class_names = dst_train.classes

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions