Appearance
第2章:PyTorch中的数据处理
在深度学习项目中,数据处理是一个关键步骤。PyTorch提供了强大的工具来加载和预处理数据,使得数据准备变得更加高效和便捷。
2.1 数据加载和预处理
2.1.1 数据加载器(DataLoader)
PyTorch的DataLoader是用于加载数据集的类,它支持批量加载、打乱数据和多线程加载。
python
from torch.utils.data import DataLoader, Dataset
class CustomDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index], self.targets[index]
dataset = CustomDataset(data, targets)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)2.1.2 数据预处理
使用torchvision.transforms进行数据预处理,如归一化、裁剪和数据增强。
python
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 应用转换
transformed_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)2.2 使用torchvision和torchtext进行数据增强
2.2.1 图像数据增强
使用torchvision进行图像数据增强,如随机裁剪、旋转和翻转。
python
from torchvision import transforms
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}2.2.2 文本数据预处理
使用torchtext进行文本数据预处理,如分词、构建词汇表和编码。
python
from torchtext import data
from torchtext.datasets import IMDB
TEXT = data.Field(tokenize='spacy', tokenizer_language='en_core_web_sm')
LABEL = data.LabelField(dtype=torch.float)
train_data, test_data = IMDB.splits(TEXT, LABEL)
TEXT.build_vocab(train_data, max_size=25000)
LABEL.build_vocab(train_data)2.3 自定义数据集和数据加载器
2.3.1 自定义数据集
创建自定义数据集类,继承自torch.utils.data.Dataset。
python
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data, targets):
self.data = data
self.targets = targets
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index], self.targets[index]2.3.2 数据加载器
使用DataLoader加载自定义数据集,设置批量大小、是否打乱和多线程加载。
python
from torch.utils.data import DataLoader
custom_dataset = CustomDataset(data, targets)
data_loader = DataLoader(custom_dataset, batch_size=32, shuffle=True, num_workers=4)2.4 本章小结
本章介绍了PyTorch中的数据加载和预处理,包括如何使用DataLoader、数据预处理和增强,以及如何自定义数据集。这些技能对于准备和处理深度学习数据至关重要。
