俵氏:それでは「backboneとしてのtimm入門」というタイトルで発表したいと思います。

まず自己紹介ですが、前も死語だと言ったのですが、僕は「JTC Kaggler」というやつです。一応、研究開発職で、社会人からMachine Learningをちょっとやっています。一応、Kaggler Masterなのですが、最近日本人がどんどんGMになっていくので、置いていかれている感が半端なくて、ツラいです。近況としては、ちょっと前のatmaCupの第11回で入賞狙っていたのですが、残念ながら5位でした。

ここでこの話を出したのは、実装を公開しているのですが、これが前回の分析コンペLTで紹介したpfn-extrasを使った実装なので、興味がある人は見てみてください。

あと、宇宙人との交信(https://www.kaggle.com/c/seti-breakthrough-listen)が続いていて、明後日締め切りなのですが、そのことを忘れてこの日程を決めてしまいました。実はわりとヤバいです。僕の後ろで今もGPUマシンが動いています。

本題ですが、このtimm、公開NotebooksやSolutionでけっこう見かけます。知っている人はもうお馴染みだと思うのですが、知らない人もいると思うので今回発表することにしました。

timm。発音はおそらく「ティム」ですが、あまり馴染みがないので「ティー・アイ・エム・エム」と僕は言っています。これはPyTorch Image Modelsというライブラリで、通称がtimmです。

(スライドを示して)この右に写真が載っていますが、読み方はロス・ワイトマンさんですかね? この素敵な写真の男性が公開している、しかもこの人は企業やFacebookの人とかではないらしく、謎の人物です。

ここに書いているのですが、このバージョン0.4.12の時点で、612種類の実装と452種類のpretrained modelが使えるという、最強の画像認識ライブラリです。しかもメチャクチャ更新が頻繁です。最近だとVision Transformer系が発表されると、ちょっと後に追加されている。いろいろなモデルが使える、すごいライブラリです。たぶん最強と言っても、異論がある人は本当にいないと思います。

他にも、例えばmixupの実装なども含んでいるみたいです。ただ僕は、そこらへんは使ったことがなく、ちょっと詳しくないので、今回は話しません。

一般的な画像分類モデル構成

その前に「backboneとして」と言いましたが、backboneって何だろう? というところについて。

(スライドを示して)一般的な画像分類モデルは、だいたい下のような形になっていると思っています。まず、複数の畳み込み層、Convolutional Layerで特徴抽出を行います。ここでは入力をCと書いていますが、実際はおおよそ普通の画像がチャネル3で、縦横の長さがここであればどちらも224だとして、これをCNNに通すと、もう少しサイズが小さくなります。H、Wが小さくなり、このC´のチャネルの数がたくさん増えた状態になって出てきます。

これを「特徴の集約」と書いています。いわゆるGlobal Average Poolingなどをかけることで、HとWの方向、縦と横方向を潰して、最後にheadと呼ばれるもの、だいたいFully Connected Layerなのですが、これに通すことで、例えば10クラス分類だったらこの出力が10クラスとなります。

この時に、だいたい特徴抽出部とGlobal Average Pooling部分も含めてbackboneと呼んで、出力部をheadと呼ぶことが多いです。今回のtimmは、このbackboneの部分がメッチャたくさん実装されているという感じですね。

ちなみに物体検出などだと、このbackboneとheadの間にneckと呼ばれる部分があるので、もしかしたらそちら由来の呼び方なのかもしれません。ただ、きちんと調べてはいないです。

timmの基本的な使い方

ここからは、基本的な使い方と、ちょっとだけ凝った使い方と、どういうモデルを使うかという話をします。

まず、基本的な使い方ですが、これはメチャクチャ簡単で、import timmをした後にtimm.create_modelという関数を呼ぶだけでOKです。model_nameに使いたいモデルの名前を指定して「pretrained=True」とすると、なんとpretrained modelが勝手にダウンロードされて、しかも勝手に読み込んでくれます。

ここに表示していますが、ImageNetが基準になっているので、読み込んでそのまま適当に、ランダムな変数を用意してフォワードすると、出力で1,000次元が出力されます。

注意点ですが、pretrained modelがないのに「pretrained=True」としても、なにも言ってくれません。あと、ダウンロード済みの場合も、特になにも表示されません。この赤い表示はKaggle Notebook上で実行したやつですが、初回だと「今ダウンロードしていますよ」というのがここに出るのに、ダウンロード済みだと特になにも言ってくれないので、そこは注意が必要かもしれません。

このまま使うと1,000クラスなので、自分で適当なクラス数の分類をやりたい時、例えば先ほど10クラスと言いましたが、その時にはどうすればいいのか。

これもまた簡単で、num_classesという引数を指定すると、勝手にheadを置き換えてくれます。だいたいheadは基本的に1層だけの全結合層、つまりPyTorchだとnn.Linearが勝手に置き換わってくれます。

実はこのtimmの中を見渡すと、headの名前はまちまちだったりしますが、そこは勝手に処理して置き換えてくれます。この出力を見ると、ちゃんと[1, 10]で次元が10の出力になっているのがわかりますね。

もうこれでほぼ終わりじゃない? と思うわけです。あと学習させるだけじゃん、と。終わりですよね? ダメですかね?

headの部分やPoolingも無効化できる

もうちょっとだけ話を続けると、backboneとして使いたい場合があります。先ほどの場合だと、timmのモデルを読み込んで、num_classesを10と指定すると、Linear層がheadとして入りますが「もうちょっとheadの部分を複雑にしたい」というケースや、「backboneとしてひとまとめに扱いたい」というケースがあります。

どういうことかというと、例えば「backboneだけ学習率をちょっと落としたい」とか「backboneだけフリーズしたい」とか、そういう時はひとまとめになっているほうが便利なんですね。

ではどうやってやるの? というと、これもまたすごく簡単です。(スライドを示して)「num_classes=0」 ってすると、resnet18dをここでは指定していますが、こいつの取り出されるfeatureの次元は512のため、512が出ています。「num_classes=0」とするだけでOKなので、メチャクチャ楽です。

ちなみにPoolingを無効化することも可能です。これもglobal_poolの部分を空文字にします。ダブルクォーテーションで括っているだけでここにはなにも文字が入っていないのですが、実行すると今度はGlobal Average Poolingが適用されません。ここに縦と横の次元が残っているのがわかると思います。こうすることで、backboneとして使うことができます。

ちなみに、さきほどはcreate_modelを呼び出す瞬間に実行していました。しかし、実はreset_classifier関数を使うと、その場でいったん呼び出した後に「num_classes=0」としたり、Poolingを無効化したりできます。ここに今出力が3行出ていて、ちょっと小さくて見えないかもしれませんが、最初はモデルをそのまま呼び出したので、1,000次元の出力がされています。

次に「num_classes=0」とした場合、今度はCNNの最終の出力である512が出ているのがわかると思います。最後にGlobal Poolingを無効化した時には、縦と横がそのまま残っているので、1、512、7、7という出力になっていますね。

という感じで、あまりこの機能は使いませんが、何かあとから変更するのが実は可能です。

ちなみに、無効化するのは、そのheadなどを消しているわけではなく、全部torch内に存在するnn.Identityという、なにもしないクラスに置き換えることで行われています。

(次回へつづく)