GIL

Python 自带 Gobal Interpreter Lock (GIL),任何时候,Python只能运行一个线程

DotaLoader的构建

DataLoader(dataset, batch_size=200, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None)

官方模板

PyTorch官方为我们提供了自定义数据读取的标准化代码代码模块。

from torch.utils.data import Dataset
class CustomDataset(Dataset):
    def __init__(self, ...):
        pass

    def __getitem__(self, index):
        return (img, label)

    def __len__(self):
        # return examples size
        return count
  1. __init__()函数用于初始化数据读取逻辑,比如读取包含标签和图片地址的csv文件、定义transform组合等。
  2. __getitem__()函数用来返回数据和标签。目的上是为了能够被后续的dataloader所调用。
  3. __len__()函数则用于返回样本数量。

其中,__getitem__()__len__()用于构建Map-style datasets;__iter__()用于构建Iterable-style datasets(一般不太用)

训练集和验证集的划分

如果需要对数据划分训练集和验证集,torch的Dataset对象也提供了random_split函数作为数据划分工具,且划分结果可直接供后续的DataLoader使用。

from torch.utils.data import random_split

trainset, valset = random_split(dataset, [len_dataset*0.7, len_dataset*0.3])

Pytorch 并行化

  1. Data Parallel, DP: 数据并行化
  2. Distributed Data Parallel, DDP: 分布式数据并行化

Mnist Dataset 的实现

参考

[1] 官方文档 torch.utils.data

[2] 夕小瑶的卖萌屋-PyTorch数据Pipeline标准化代码模板