NVIDIA DALIを使ってPyTorchのDataIOを高速化する

はじめに

学習にとても時間のかかるDeepLearningですが、
計算している部分よりも、データの前処理などに時間がかかっているということはよくあります。
少しでも学習を早くするために実装レベルでいろいろな工夫がありますが、
このエントリーではNVIDIA DALIを使ってPyTorchのDataIOを高速化した際のメモを紹介します。

最初に結論

PyTorchのDataLoaderをうまく組み合わせるべし

DALIとは?

NVIDIAが開発したライブラリで、データの前処理(augmentationなど)をGPU側に回すことが可能となります。
またメジャーなフレームワークとの連携用APIを提供しているため、
簡単に試すことができます。

以下などで例が紹介されています。
xvideos.hatenablog.com

DALIをPyTorchで使うには?

公式に丁寧なサンプルがあります。
https://github.com/NVIDIA/DALI/blob/master/docs/examples/pytorch/pytorch-external_input.ipynb

基本的にはこれに従うだけですが、
このままじゃ状況によってちょっと遅くなる時があります。
 
 

まずはDataLoaderそのままの場合

ms-cocoのデータセットを使った場合以下の様な実装になります。

class CocoDataset(Dataset):

    def __init__(self, dataType='val2017'):

        annFile='/content/annotations/instances_{}.json'.format(dataType)
        self.coco=COCO(annFile)
        self.ids = list(self.coco.imgToAnns.keys())
        self.imgs = self.coco.loadImgs(self.ids)
            
    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
      
        img_path = os.path.join('/content/val2017', self.imgs[idx]['file_name'])
        time.sleep(DATA_LOAD_TIME) #for simulation of very slow file IO
        img = cv2.imread(img_path)
        img = cv2.resize(img, (512,512))
        return img

def collate_fn(batch):
    imgs = [x for x in batch]
    return imgs

coco_dataset = CocoDataset()
coco_dataloader = DataLoader(coco_dataset, num_workers=8, batch_size=BATCH_SIZE, collate_fn=collate_fn)

途中

time.sleep(DATA_LOAD_TIME) #for simulation of very slow file IO

という怪しい部分がありますが、これは画像サイズがもっと大きいなど不幸なことが起きた時に
データ読み込みに10msくらいかかる場合をシミュレートしています。

これを使って速度をはかりました。

data_iter = iter(coco_dataloader)

start = time.time()
for _ in range(100):
  images = next(data_iter)
  time.sleep(LEARNING_TIME) #for simulation of DeepLearning
end = time.time()
print('total_time : %s' % ( str(end-start) ))

total_time : 11.102628707885742
(今回の設定ではDATA_LOAD_TIME=0.01 LEARNING_TIME=0.1)

10秒はtime.sleep(LEARNING_TIME)に取られているので、実質IOのところは1.1秒くらいです。
このままでも十分早いですね。
pytorchのDataLoaderはいい感じにパイプライン処理してくれるため、
学習の処理が行われている裏でこっそりデータを準備してくれます。
そのため、データIOの時間が少なく見えるようになっています。
 
 

DALI公式のサンプルに従う

DALIではデータIO処理を行うpipelineと、pipelineにデータを供給するiteratorを定義して、
DALIGenericIteratorに渡して制御してもらいます。

class ExternalInputIterator(object):
    def __init__(self, batch_size, dataType='val2017'):

        annFile='/content/annotations/instances_{}.json'.format(dataType)
        self.coco=COCO(annFile)
        self.ids = list(self.coco.imgToAnns.keys())
        self.imgs = self.coco.loadImgs(self.ids)

        self.batch_size = batch_size
        self.data_set_len = len(self.imgs) 
        self.n = len(self.imgs)

    def __iter__(self):
        self.i = 0
        return self

    def __next__(self):
        batch = []

        if self.i >= self.n:
            raise StopIteration

        for _ in range(self.batch_size):
            img_path = os.path.join('/content/val2017', self.imgs[self.i]['file_name'])
            f = open(img_path, 'rb')
            time.sleep(DATA_LOAD_TIME) #for simulation of very slow file IO
            batch.append(np.frombuffer(f.read(), dtype = np.uint8))
            self.i = (self.i + 1) % self.n
        return batch

    @property
    def size(self,):
        return self.data_set_len

    next = __next__

class ExternalSourcePipeline(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, external_data):
        super(ExternalSourcePipeline, self).__init__(batch_size,
                                      num_threads,
                                      device_id,
                                      seed=12)
        self.input = ops.ExternalSource()
        self.decode = ops.ImageDecoder(device = "mixed", output_type = types.RGB)
        self.res = ops.Resize(device="gpu", resize_x=512, resize_y=512)
        self.external_data = external_data
        self.iterator = iter(self.external_data)

    def define_graph(self):
        self.jpegs = self.input()
        images = self.decode(self.jpegs)
        images = self.res(images)
        return images

    def iter_setup(self):
        try:
            images = self.iterator.next()
            self.feed_input(self.jpegs, images)
        except StopIteration:
            self.iterator = iter(self.external_data)
            raise StopIteration

from nvidia.dali.plugin.pytorch import DALIGenericIterator

eii = ExternalInputIterator(batch_size=BATCH_SIZE)
pipe = ExternalSourcePipeline(batch_size=BATCH_SIZE, num_threads=2, device_id = 0,
                              external_data = eii)
pii = DALIGenericIterator(pipe, ['image'], size=100*BATCH_SIZE, last_batch_padded=True, fill_last_batch=False)

start = time.time()
for i, data in enumerate(pii):
    images = data[0]["image"]
    time.sleep(LEARNING_TIME) #for simulation of DeepLearning
end = time.time()

print('total_time : %s' % ( str(end-start) ))

公式の実装ほぼそのままです。
結果は
total_time : 14.16463017463684
遅くなった!?
 
 

何が問題なのか

DALIを使うとデコード処理やそのあとのResizeなど確かに高速に動きます。
しかし、pytorchのDataLoaderが持っていたパイプライン機能はありません。
そのため、最初のデータ読み込みに時間がかかってしまう場合、逆に遅くなってしまいます。
データはSSDなど処理の速いデバイスに保存し、TFRecordなどで固めて読み込めばこの問題は解決しますが、
そうも言ってられない場合もあるでしょう。
そこで、下記の様に実装を変更しました。
 
 

PyTorch DataLoader × DALI

まず最初にデータの読み込みだけ行うpytorchのDataLoaderを用意します。

class CocoDatasetForDALI(Dataset):

    def __init__(self, dataType='val2017'):

        annFile='/content/annotations/instances_{}.json'.format(dataType)
        self.coco=COCO(annFile)
        self.ids = list(self.coco.imgToAnns.keys())
        self.imgs = self.coco.loadImgs(self.ids)
            
    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
      
        img_path = os.path.join('/content/val2017', self.imgs[idx]['file_name'])
        f = open(img_path, 'rb')
        time.sleep(DATA_LOAD_TIME) #for simulation of very slow file IO
        img = np.frombuffer(f.read(), dtype = np.uint8)
        return img

次にpipline内でpytorchのdataloaderからデータを取得するようにします。

class ExternalSourcePipelineForPytorch(Pipeline):
    def __init__(self, batch_size, num_threads, device_id, external_data):
        super(ExternalSourcePipelineForPytorch, self).__init__(batch_size,
                                      num_threads,
                                      device_id,
                                      seed=12)
        self.input = ops.ExternalSource()
        self.decode = ops.ImageDecoder(device = "mixed", output_type = types.RGB)
        self.res = ops.Resize(device="gpu", resize_x=512, resize_y=512)
        self.external_data = external_data
        self.iterator = iter(self.external_data)

    def define_graph(self):
        self.jpegs = self.input()
        images = self.decode(self.jpegs)
        images = self.res(images)
        return images

    def iter_setup(self):
        try:
            images = next(self.iterator)
            self.feed_input(self.jpegs, images)
        except StopIteration:
            self.iterator = iter(self.external_data)
            raise StopIteration

これを使って速度をはかってみます。

coco_dataset_for_dali = CocoDatasetForDALI()
coco_dataloader_for_dali = DataLoader(coco_dataset_for_dali, num_workers=8, batch_size=BATCH_SIZE, collate_fn=collate_fn)

pipe = ExternalSourcePipelineForPytorch(batch_size=BATCH_SIZE, num_threads=2, device_id = 0,
                              external_data = coco_dataloader_for_dali)
pii = DALIGenericIterator(pipe, ['image'], size=100*BATCH_SIZE, last_batch_padded=True, fill_last_batch=False)

start = time.time()
for i, data in enumerate(pii):
    images = data[0]["image"]
    time.sleep(LEARNING_TIME) #for simulation of DeepLearning
end = time.time()

print('total_time : %s' % ( str(end-start) ))

total_time : 10.300386428833008

速くなった!
実質データIOにかかっているのは0.3秒ということになりますので、
元の3~4倍くらいには早くなります。
data augmentationをもっとリッチにした場合などを考えると、結構使えるんではないでしょうか?
 
 

おわりに

pytorchにDALIを組み合わせる場合、データ読み込みに時間がかかるような環境の場合は、
pytorchのDataLoaderをうまく使うと良さそうです。
折角DALIにしたのに遅くなったという方は、ぜひ試してみてください。

Google Colaboratoryで確認した際のソースを下記に公開しています。
良ければお試しください。
https://github.com/yoneken1/colab_pytorch_sample/blob/master/pytorch_dali.ipynb