PyTorchのObjectDetectionフレームワーク「MMDetection」を使って独自モデルを学習する

公式リリースからしばらく経過したPyTorchですが、最近は便利な周辺ライブラリが揃い始めました。
ObjectDetection用のライブラリもちらほら出てきています。
PyTorch用のObjectDetectionライブラリといえばDetectron2が有名ですね。
GitHub - facebookresearch/detectron2: Detectron2 is FAIR's next-generation research platform for object detection and segmentation.

ところが最近arxivに登場するObjectDetectionモデルはMMDetectionというフレームワークの上に実装されているものが多いです。
arxiv.org

上記文献内の表(下記)によると、学習スピードも速く、inferenceもそこそこ速く、
性能もオリジナルをほぼ再現しているといえます。
f:id:Yoneken1:20191208230650p:plain

何よりカスタマイズ性が高いというところがポイントです。
DeepLearningの研究が盛んな中国で開発されているということもあり、
これから流行っていく可能性を秘めています。

早速使ってみる

どのくらいカスタマイズしやすいか試してみました。
今回は、FCOSにBiFPNをくっつけてみたいと思います。

FCOSはRetinaNetの改良版の様な手法で、Detection HeadにCenternessと呼ばれる部分を追加して性能を向上したものです。
詳しくは論文をご参照ください。
arxiv.org

BiFPNはもはや説明不要かと思いますが、EfficientDetで提案されたFPNの上位版です。
arxiv.org

FCOSはMMDetectionに既に入っていますので、今回はFCOSをベースに、BiFPNを組み込んでいきます。
組込み方法はとても簡単で、MMDetectionのGETTING_STARTED.mdに従うだけです。
github.com

手順1: BiFPNの実装

BiFPNは以下に既にそれっぽいものがあったのでお借りしました。
github.com
2019/12/08時点では、なぜかextraモジュールがなく、階層が足りなかったため、適当に追加しました。

MMDetectionでは、FPNなどのモジュールをNECKというカテゴリとしています。
そこで、次のようなおまじないを書いて、NECKに登録します。

from ..registry import NECKS
@NECKS.register_module
class BIFPN(nn.Module):
 ・・・

基本的な変更点はこれだけです。
とても簡単です。

後は初期化とかを適当に追加しました。
最終的なソースは以下の様になりました。
https://github.com/yoneken1/pytorch_bifpn_for_mmdetection/blob/master/bifpn.py
注:この実装はBiFPNのオリジナルを再現してません。ご使用の際は自己責任でお願いします。

手順2:modulesに追加

上記で作成したpythonファイルを、mmdet/models/necksに配置します。
また以下の様に__init__.pyを修正して、BiFPNをインポートします。

from .bifpn import BIFPN
__all__ = ['FPN', 'BFP', 'HRFPN', 'BIFPN']
手順3:configファイルを修正する

MMDetectionはconfigファイルに使いたいモジュールの設定などを書いておき、
実行時にconfigファイルを読み込むことで、
自動的にモデルを構築して学習や評価を実行してくれます。

今回はfcosのconfigファイルを少し書き換えてBiFPNを読み込むようにします。
設定はとても簡単で、以下の様にneckにBIFPNを指定し、その他パラメータを調整するだけです。

    neck=dict(
        type='BIFPN',
        in_channels=[256, 512, 1024, 2048],
        out_channels=256,
        start_level=1,
        stack=2,
        add_extra_convs=False,
        extra_convs_on_inputs=False,
        num_outs=5,
        relu_before_extra_convs=False),

BiFPNは入力解像度が128の倍数である必要があるため、
今回は以下の様な感じで、Resizeの設定を変えました。

 dict(type='Resize', img_scale=(768, 768), keep_ratio=False),

あとはGPUの数やメモリに合わせてバッチサイズと学習率を調整します。

学習の実行

以下の様なコマンドで学習がスタートします。

python ./tools/train.py configs/fcos/fcos_r50_caffe_bifpn_gn_1x_1gpu.py --work_dir 'models/fcos' --gpus 1

評価なども同様です。
とてもシンプルでわかりやすいですね。

おわりに

とてもお手軽に独自モデルを構築することができました。
ObjectDetectionは実装が複雑になりがちで、
自分で実装するととても複雑なコードが出来上がってしまう事が多々あるかと思います。
私自身も以前FasterRCNNをスクラッチ実装しようとして、全く論文値が再現せず、
1か月くらい苦しんだのち、結局著者実装を使うことにしたという苦い経験があります。
(その後何とか再現までこぎつけましたが・・・)
これからObjectDetectionの研究をしてみたい方で、baselineを手早く作りたい人、
さくさくアイディアを試していきたい人にはお勧めのフレームワークかと思います。
ぜひお試しください。

おまけ

上記モデルは鋭意学習中ですので、結果が出たら追記したいと思います。

追記

上記モデルですが、12epoch学習させたところ、mAPが0.31と非常に低い値になりました。
実は元の実装は結構バグがあり、それを直したつもりなのですが、まだ再現に至っていないようです。
この結果は結構悔しいので、いつかリベンジしようと思っています。