学習が爆速と噂のTTFNetをGoogle Colaboratoryで動かしてみた

TTFNetとは

2019年9月頭ころにarxivに公開されたObject Detectionのモデルで、学習がとても速いのが特徴です。
論文タイトル:Training-Time-Friendly Network for Real-Time Object Detection

既に実装が公開されています。
github.com

所謂anchor freeのモデルでCenterNet (Object as point)をベースとしています。
そのため、推論時の実行速度が速い事も売りの一つです。
f:id:Yoneken1:20191025142858p:plain

TTFNetの実力

以下論文の表を抜粋。
f:id:Yoneken1:20191025134514p:plain
TT(h)のところが学習のトータル時間です。
この表によると、一番早いやつでms-cocoの学習がたったの1.8hで終わるようです。

Google Colaboratoryで試してみる

論文中では8枚のGTX1080tiを使って学習しているようですが、
無料で誰でも使えるGoogle Colaboratoryではどの位のスピードになるのでしょうか?
実際に試してみました。

TTFNetのインストール
%cd /content/
!git clone https://github.com/ZJULearning/ttfnet.git

%cd /content/ttfnet
!pip install -v -e .

特に問題なくすんなりとインストールできました。

ms-cocoの準備

学習にはms-cocoのデータが必要なので、これを取ってきて配置します。

%cd /content/
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
!unzip -q -n '/content/annotations_trainval2017.zip'

%cd /content/
!wget http://images.cocodataset.org/zips/val2017.zip
!unzip -q -n '/content/val2017.zip'

%cd /content/
!wget http://images.cocodataset.org/zips/train2017.zip
!unzip -q -n '/content/train2017.zip'

!mkdir -p data/coco
!mv /content/train2017 /content/ttfnet/data/coco/
!mv /content/val2017 /content/ttfnet/data/coco/
!mv /content/annotations /content/ttfnet/data/coco/annotations

この作業は結構時間がかかります。

Google Driveへの接続

モデルの保存先としてGoogle Driveを利用します。

from google.colab import drive
drive.mount('/content/gdrive')

上記を実行するとリンクと入力枠が表示されるので、リンクをクリックして表示された画面で認証を行い、発行されたチケットを入力する事でGoogle Driveがマウントされます。

学習の実行

今回は一番軽量なResNet18の学習スケジュール短い版を学習させてみます。
また使えるGPUは1枚しかないので、gpus=1を指定します。

%cd /content/ttfnet
!mkdir '/content/gdrive/My Drive/ttfnet_models/ttfnet_r18'
!python ./tools/train.py configs/ttfnet/ttfnet_r18_1x.py --work_dir '/content/gdrive/My Drive/ttfnet_models/ttfnet_r18' --validate --autoscale-lr --gpus=1

以下の様な実行結果が出力されます。

/content/ttfnet
2019-10-25 04:21:28,109 - INFO - Distributed training: False
2019-10-25 04:21:28,253 - INFO - load model from: modelzoo://resnet18
2019-10-25 04:21:28,388 - WARNING - The model and loaded state dict do not match exactly

unexpected key in source state_dict: fc.weight, fc.bias

loading annotations into memory...
Done (t=17.55s)
creating index...
index created!
2019-10-25 04:21:49,935 - INFO - Start running, host: root@436d36ee658c, work_dir: /content/gdrive/My Drive/ttfnet_models/ttfnet_r18
2019-10-25 04:21:49,936 - INFO - workflow: [('train', 1)], max: 12 epochs
2019-10-25 04:23:05,022 - INFO - Epoch [1][50/7330]	lr: 0.00056, eta: 1 day, 12:40:12, time: 1.502, data_time: 0.042, memory: 9804, losses/ttfnet_loss_heatmap: 4.6783, losses/ttfnet_loss_wh: 5.0018, loss: 9.6801
2019-10-25 04:24:08,402 - INFO - Epoch [1][100/7330]	lr: 0.00072, eta: 1 day, 9:47:34, time: 1.268, data_time: 0.025, memory: 9804, losses/ttfnet_loss_heatmap: 4.5887, losses/ttfnet_loss_wh: 4.9279, loss: 9.5166
2019-10-25 04:25:12,307 - INFO - Epoch [1][150/7330]	lr: 0.00088, eta: 1 day, 8:54:25, time: 1.278, data_time: 0.025, memory: 9804, losses/ttfnet_loss_heatmap: 4.2175, losses/ttfnet_loss_wh: 4.1373, loss: 8.3549
2019-10-25 04:26:15,777 - INFO - Epoch [1][200/7330]	lr: 0.00104, eta: 1 day, 8:24:09, time: 1.269, data_time: 0.025, memory: 9804, losses/ttfnet_loss_heatmap: 4.1055, losses/ttfnet_loss_wh: 3.9063, loss: 8.0118
2019-10-25 04:27:19,726 - INFO - Epoch [1][250/7330]	lr: 0.00120, eta: 1 day, 8:08:22, time: 1.279, data_time: 0.025, memory: 9804, losses/ttfnet_loss_heatmap: 3.9817, losses/ttfnet_loss_wh: 3.6191, loss: 7.6008
2019-10-25 04:28:23,286 - INFO - Epoch [1][300/7330]	lr: 0.00136, eta: 1 day, 7:55:35, time: 1.271, data_time: 0.024, memory: 9804, losses/ttfnet_loss_heatmap: 3.9224, losses/ttfnet_loss_wh: 3.3858, loss: 7.3082
2019-10-25 04:29:27,332 - INFO - Epoch [1][350/7330]	lr: 0.00152, eta: 1 day, 7:48:11, time: 1.281, data_time: 0.026, memory: 9804, losses/ttfnet_loss_heatmap: 3.9698, losses/ttfnet_loss_wh: 3.2906, loss: 7.2604
2019-10-25 04:30:31,479 - INFO - Epoch [1][400/7330]	lr: 0.00168, eta: 1 day, 7:42:44, time: 1.283, data_time: 0.025, memory: 9804, losses/ttfnet_loss_heatmap: 3.8613, losses/ttfnet_loss_wh: 3.1832, loss: 7.0445
2019-10-25 04:31:35,502 - INFO - Epoch [1][450/7330]	lr: 0.00184, eta: 1 day, 7:37:51, time: 1.280, data_time: 0.025, memory: 9804, losses/ttfnet_loss_heatmap: 3.8858, losses/ttfnet_loss_wh: 3.1491, loss: 7.0349
2019-10-25 04:32:39,829 - INFO - Epoch [1][500/7330]	lr: 0.00200, eta: 1 day, 7:34:37, time: 1.287, data_time: 0.026, memory: 9804, losses/ttfnet_loss_heatmap: 3.7939, losses/ttfnet_loss_wh: 2.9790, loss: 6.7729
2019-10-25 04:33:43,913 - INFO - Epoch [1][550/7330]	lr: 0.00200, eta: 1 day, 7:31:09, time: 1.282, data_time: 0.025, memory: 9804, losses/ttfnet_loss_heatmap: 3.7948, losses/ttfnet_loss_wh: 2.8774, loss: 6.6722
2019-10-25 04:34:47,898 - INFO - Epoch [1][600/7330]	lr: 0.00200, eta: 1 day, 7:27:49, time: 1.280, data_time: 0.025, memory: 9804, losses/ttfnet_loss_heatmap: 3.7696, losses/ttfnet_loss_wh: 2.7948, loss: 6.5644
2019-10-25 04:35:52,258 - INFO - Epoch [1][650/7330]	lr: 0.00200, eta: 1 day, 7:25:41, time: 1.287, data_time: 0.025, memory: 9804, losses/ttfnet_loss_heatmap: 3.7014, losses/ttfnet_loss_wh: 2.7241, loss: 6.4255
2019-10-25 04:36:56,718 - INFO - Epoch [1][700/7330]	lr: 0.00200, eta: 1 day, 7:23:55, time: 1.289, data_time: 0.025, memory: 9804, losses/ttfnet_loss_heatmap: 3.6163, losses/ttfnet_loss_wh: 2.6678, loss: 6.2842
2019-10-25 04:38:01,054 - INFO - Epoch [1][750/7330]	lr: 0.00200, eta: 1 day, 7:22:00, time: 1.287, data_time: 0.025, memory: 9804, losses/ttfnet_loss_heatmap: 3.6610, losses/ttfnet_loss_wh: 2.6458, loss: 6.3068
2019-10-25 04:39:05,227 - INFO - Epoch [1][800/7330]	lr: 0.00200, eta: 1 day, 7:19:53, time: 1.283, data_time: 0.025, memory: 9804, losses/ttfnet_loss_heatmap: 3.6560, losses/ttfnet_loss_wh: 2.6244, loss: 6.2804
・・・

おおよそ1日と8時間で学習が完了するようです。
Google Colaboratoryは最大で12時間までしか連続使用できないので、12時間以内に学習が終わってくれればベストだったのですが、
それでもこの環境で30時間強で学習が終わるのはなかなか良いです。

あとはPCの電源が落ちないようにして、ブラウザにオートリフレッシュを仕掛けて、
12時間に1回ランタイムを再起動すればOKです。

おわりに

学習が速い事にどれ位の価値があるのか、個人的には色々と疑問ではあるのですが、
非常に短時間で結果を出さないといけないとき(ちょっとした味見とか、締め切り直前のレポート対策とか)なんかに
使ってみるのはいかがでしょうか?

追記

デフォルトの設定では、4epochに1回しかmodelファイルが保存されないようです。
ttfnet_r18_1x.pyのcheckpoint_config = dict(interval=4)のところを編集し、
保存頻度を上げた方が、ランタイム再起動時の無駄が減ります。
またついでにimgs_per_gpuやlrを調整すれば、もう少し早く学習が完了すると思います。
modelファイルの保存時にsymlinkの作成でエラーで落ちてしまう場合がある様です。
これについては、mmcvのrunner.pyを編集して、mmcv.symlinkの所をコメントアウトすると解決します。