百科问答小站 logo
百科问答小站 font logo



使用pytorch时,训练集数据太多达到上千万张,Dataloader加载很慢怎么办? 第1页

  

user avatar   fang-niu-wa-28-17 网友的相关建议: 
      

下面是我见到过的写得最优雅的,预加载的dataloader迭代方式可以参考下:

使用方法就和普通dataloder一样 for xxx in trainloader .

主要思想就两点 , 第一重载 _iter 和 next_ ,第二点多线程异步Queue加载

       import numbers import os import queue as Queue import threading  import mxnet as mx import numpy as np import torch from torch.utils.data import DataLoader, Dataset from torchvision import transforms   class BackgroundGenerator(threading.Thread):     def __init__(self, generator, local_rank, max_prefetch=6):         super(BackgroundGenerator, self).__init__()         self.queue = Queue.Queue(max_prefetch)         self.generator = generator         self.local_rank = local_rank         self.daemon = True         self.start()      def run(self):         torch.cuda.set_device(self.local_rank)         for item in self.generator:             self.queue.put(item)         self.queue.put(None)      def next(self):         next_item = self.queue.get()         if next_item is None:             raise StopIteration         return next_item      def __next__(self):         return self.next()      def __iter__(self):         return self   class DataLoaderX(DataLoader):     def __init__(self, local_rank, **kwargs):         super(DataLoaderX, self).__init__(**kwargs)         self.stream = torch.cuda.Stream(local_rank)         self.local_rank = local_rank      def __iter__(self):         self.iter = super(DataLoaderX, self).__iter__()         self.iter = BackgroundGenerator(self.iter, self.local_rank)         self.preload()         return self      def preload(self):         self.batch = next(self.iter, None)         if self.batch is None:             return None         with torch.cuda.stream(self.stream):             for k in range(len(self.batch)):                 self.batch[k] = self.batch[k].to(device=self.local_rank,                                                  non_blocking=True)      def __next__(self):         torch.cuda.current_stream().wait_stream(self.stream)         batch = self.batch         if batch is None:             raise StopIteration         self.preload()         return batch   class MXFaceDataset(Dataset):     def __init__(self, root_dir, local_rank):         super(MXFaceDataset, self).__init__()         self.transform = transforms.Compose(             [transforms.ToPILImage(),              transforms.RandomHorizontalFlip(),              transforms.ToTensor(),              transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),              ])         self.root_dir = root_dir         self.local_rank = local_rank         path_imgrec = os.path.join(root_dir, 'train.rec')         path_imgidx = os.path.join(root_dir, 'train.idx')         self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')         s = self.imgrec.read_idx(0)         header, _ = mx.recordio.unpack(s)         if header.flag > 0:             self.header0 = (int(header.label[0]), int(header.label[1]))             self.imgidx = np.array(range(1, int(header.label[0])))         else:             self.imgidx = np.array(list(self.imgrec.keys))      def __getitem__(self, index):         idx = self.imgidx[index]         s = self.imgrec.read_idx(idx)         header, img = mx.recordio.unpack(s)         label = header.label         if not isinstance(label, numbers.Number):             label = label[0]         label = torch.tensor(label, dtype=torch.long)         sample = mx.image.imdecode(img).asnumpy()         if self.transform is not None:             sample = self.transform(sample)         return sample, label      def __len__(self):         return len(self.imgidx)     




  

相关话题

  如何比较Keras, TensorLayer, TFLearn ? 
  如何评价余凯在朋友圈发表呼吁大家用 caffe、mxnet 等框架,避免使用 TensorFlow? 
  无人车为什么一定要用激光雷达做,双目视觉难道不行吗? 
  为什么我用相同的模型,数据,超参,随机种子,在两台服务器会得到不同的结果? 
  神经网络中,bias有什么用,为什么要设置bias,当加权和大于某值时,激活才有意义? 
  graph convolutional network有什么比较好的应用task? 
  如何看待Geoffrey Hinton的言论,深度学习要另起炉灶,彻底抛弃反向传播? 
  scikit-learn, tensorflow, pytorch真的只需要查下API,不需要学吗? 
  有没有什么可以节省大量时间的 Deep Learning 效率神器? 
  如何评价中国人民大学高瓴人工智能学院教授的薪酬标准? 

前一个讨论
如何看待上海市科委、中科院上海有机所和观视频联合制作的科普微电影《无处不在的手性之有机师姐》?
下一个讨论
表哥说机械比计算机经管都好,如何看待他的言论?





© 2025-06-26 - tinynew.org. All Rights Reserved.
© 2025-06-26 - tinynew.org. 保留所有权利