Google Colaboratory + pytorch で SIGNATEコンペ参加(開発編2)

引き続きGoogle Colaboratory + pytorch で SIGNATEの
AIエッジコンテスト(オブジェクト検出)に参加するお話です。

開発項目おさらい

前回のエントリーで、開発項目を幾つかあげました。
それぞれ紹介するつもりでしたが、あまりGoogle Colaboratoryと関係ない一般的なものについては
省略することにしました。
代わりにGoogle Colaboratoryの機能を使ってGitHubにnotebookをアップロードしましたので、
そのリンクを紹介いたします。

  • Box変換系関数群 、学習データ選択関数

https://github.com/yoneken1/colab_pytorch_detection/blob/master/roi_util.ipynb

  • Anchor

https://github.com/yoneken1/colab_pytorch_detection/blob/master/Anchor.ipynb

  • Dataset

後程紹介します。

  • Loss

https://github.com/yoneken1/colab_pytorch_detection/blob/master/Loss.ipynb

  • Model

前回紹介済み。

次のエントリーで紹介します。

Dataset

pytorchではDatasetクラスを継承してデータ読み込み・変換部を作り、
作成したDatasetオブジェクトをDatalodarに渡すことで、
効率よく読み込んでくれるようになります。
Data Loading and Processing Tutorial — PyTorch Tutorials 1.0.0.dev20190104 documentation
今回はチュートリアルに従って、Signateコンペ用のデータセットにアクセスするクラスを作成します。

Google Driveのマウント

おさらいになりますが、Google ColaboratoryからGoogle Driveにアクセスするために、
ランタイムに接続後、以下を実行する必要があります。

from google.colab import drive
drive.mount('/content/gdrive')

データの準備

必要なデータはGoogle Driveにすべてあげておきます。
今回は下記の様なデータ構造にしてデータを置いています。

/content/gdrive/My Drive/colab_pytorch_detection/data
 /annotations
  /dtc_train_annotations
   train_00000.json
   ・・・
  train.txt
  val.txt
 /dtc_train_images_res
   train_00000.jpg
   ・・・
 /dtc_test_images_res
   test_00000.jpg
   ・・・

dtc_train_annotationsはコンペサイトから落としてきたものを解凍したものです。
dtc_train_images_resとdtc_test_images_resは元データを縦横1/2に縮小したものです。

本コンペでは、validationデータが無かったので、trainデータを9:1の割合でtrainとvalに分けました。
分けたファイルのリストがtrain.txtおよびval.txtです。
それぞれ、以下の様なソースで簡単に作れます。
折角なので、pytorchで作ってみました。

def create_trainval_list():
  
  ROOT_DIR='/content/gdrive/My Drive/colab_pytorch_detection/data'
  ANNO_DIR = os.path.join(ROOT_DIR,'annotations','dtc_train_annotations')
  TRAIN_FILE = os.path.join(ROOT_DIR,'annotations','train.txt')
  VAL_FILE = os.path.join(ROOT_DIR,'annotations','val.txt')
  TRAIN_RATIO = 0.9

  anno = os.listdir(ANNO_DIR)
  anno = sorted(anno)
  
  ids = torch.randperm(len(anno))
  print(ids)
  train_num = int(len(anno) * TRAIN_RATIO)

  train_ids = ids[:train_num]
  val_ids = ids[train_num:]
  
  print(train_ids.size())
  print(val_ids.size())

  with open(TRAIN_FILE, 'w') as f:
    for i in train_ids:
      f.write(anno[i]+'\n')
      
  with open(VAL_FILE, 'w') as f:
    for i in val_ids:
      f.write(anno[i]+'\n')
  
create_trainval_list()

Datasetクラス

冗長な部分があって少々格好悪いですが、以下の様なソースになりました。
targetによってtrain,val,trainval,testを切り替えます。

class SignateDataset(Dataset):

    def __init__(self, root_dir='/content/gdrive/My Drive/colab_pytorch_detection/data', target='trainval', transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.target = target
        
        self.annotation_dir = os.path.join(root_dir,'annotations','dtc_train_annotations')
        if target=='train':
          self.img_dir = os.path.join(root_dir,'dtc_train_images_res')
          read_file = os.path.join(root_dir,'annotations','train.txt')
          self.anno_list = []
          with open(read_file,'r') as f:
            self.anno_list = [fname.strip() for fname in f]
            self.anno_list = sorted(self.anno_list)
          self.img_list = [os.path.splitext(fname)[0]+".jpg" for fname in self.anno_list]
          

        elif target=='val':
          self.img_dir = os.path.join(root_dir,'dtc_train_images_res')
          read_file = os.path.join(root_dir,'annotations','val.txt')
          self.anno_list = []
          with open(read_file,'r') as f:
            self.anno_list = [fname.strip() for fname in f]
            self.anno_list = sorted(self.anno_list)
          self.img_list = [os.path.splitext(fname)[0]+".jpg" for fname in self.anno_list]

        elif target=='test':
          self.img_dir = os.path.join(root_dir,'dtc_test_images_res')
          self.anno_list = None        
          self.img_list =  sorted(os.listdir(self.img_dir))
          
        else:
          self.img_dir = os.path.join(root_dir,'dtc_train_images_res')
          self.anno_list = sorted(os.listdir(self.annotation_dir))
          self.img_list =  sorted(os.listdir(self.img_dir))
          
        
        self._classes = ('Car','Truck','Pedestrian',
                   'Bicycle','Signal','Signs')
#        self._classes = ('Car','Bus','Truck','SVehicle','Pedestrian','Motorbike',
#                   'Bicycle','Train','Signal','Signs')
        
        self.boxes_list = [None] * len(self.img_list)
        self.labels_list = [None] * len(self.img_list)
        self.scale = 0.5
            
    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
      
      if((self.boxes_list[idx] is None) and (self.anno_list is not None)):
        anno_file_path = os.path.join(self.annotation_dir,self.anno_list[idx])
        with open(anno_file_path) as f:
          data = json.load(f)
          boxes = []
          lbls = []
          if 'labels' in data:
            labels = data['labels']
            for label in labels:
              if 'box2d' in label:
                if label['category'] in self._classes:
                  x1 = float(label['box2d']['x1']) * self.scale
                  x2 = float(label['box2d']['x2']) * self.scale
                  y1 = float(label['box2d']['y1']) * self.scale
                  y2 = float(label['box2d']['y2']) * self.scale
                  category=self._classes.index(label['category'])
                  boxes.append(torch.tensor([x1,y1,x2,y2]).float())
                  lbls.append(torch.tensor([category]).long())

          self.boxes_list[idx] = torch.stack(boxes)
          self.labels_list[idx] = torch.stack(lbls)

      
      img_path = os.path.join(self.img_dir, self.img_list[idx])
      image = io.imread(img_path)
         
      sample = {'image': image, 'boxes': self.boxes_list[idx], 'labels':self.labels_list[idx]}

      if self.transform:
          sample = self.transform(sample)
      
      return sample
    
    def collate_fn(self, batch):
      imgs = [x['image'] for x in batch]
      boxes = [x['boxes'] for x in batch]
      labels = [x['labels'] for x in batch]
      
      return {'image': torch.stack(imgs), 'boxes': boxes, 'labels':labels}

ポイントは、__getitem__の中でファイルを読み込んでいる部分です。
普段Datasetクラスを作るときは、__init__内でデータを展開してしまう事が多いのですが、
Google Driveへのアクセス速度が非常にネックになり、
学習を開始するまでに延々待たされてしまいます。
そこで今回は__getitem__内で1ファイルずつ読み込むことにしました。
こうすることで、pytorchが勝手にパイプライン化してくれるため、
学習処理をしている間に並列してファイルを読み込みを行うことができ、
時間短縮になります。

datasetの使用例は以下の様になります。

dataset = SignateDataset(transform=train_transform,target='train')
datalodar = DataLoader(dataset, batch_size=BATCH_SIZE,
                        shuffle=True, num_workers=8,collate_fn=dataset.collate_fn)

for i_batch, sample_batched in enumerate(datalodar ):
      
      image = sample_batched['image']
      boxes = sample_batched['boxes']
      labels = sample_batched['labels']
      # 学習処理とか

Datasetについてはここまで。
後は評価スクリプトと学習スクリプトを書けば、一旦完成です。