Skip to content

第2章:PyTorch中的数据处理

在深度学习项目中,数据处理是一个关键步骤。PyTorch提供了强大的工具来加载和预处理数据,使得数据准备变得更加高效和便捷。

2.1 数据加载和预处理

2.1.1 数据加载器(DataLoader)

PyTorch的DataLoader是用于加载数据集的类,它支持批量加载、打乱数据和多线程加载。

python
from torch.utils.data import DataLoader, Dataset

class CustomDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.targets[index]

dataset = CustomDataset(data, targets)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

2.1.2 数据预处理

使用torchvision.transforms进行数据预处理,如归一化、裁剪和数据增强。

python
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 应用转换
transformed_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

2.2 使用torchvisiontorchtext进行数据增强

2.2.1 图像数据增强

使用torchvision进行图像数据增强,如随机裁剪、旋转和翻转。

python
from torchvision import transforms

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

2.2.2 文本数据预处理

使用torchtext进行文本数据预处理,如分词、构建词汇表和编码。

python
from torchtext import data
from torchtext.datasets import IMDB

TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)

train_data, test_data = IMDB.splits(TEXT, LABEL)

TEXT.build_vocab(train_data, max_size=25000)
LABEL.build_vocab(train_data)

2.3 自定义数据集和数据加载器

2.3.1 自定义数据集

创建自定义数据集类,继承自torch.utils.data.Dataset

python
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.targets = targets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.targets[index]

2.3.2 数据加载器

使用DataLoader加载自定义数据集,设置批量大小、是否打乱和多线程加载。

python
from torch.utils.data import DataLoader

custom_dataset = CustomDataset(data, targets)
data_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True, num_workers=4)

2.4 本章小结

本章介绍了PyTorch中的数据加载和预处理,包括如何使用DataLoader、数据预处理和增强,以及如何自定义数据集。这些技能对于准备和处理深度学习数据至关重要。