2024.10.21
お互い疑心暗鬼になりがちな、経営企画と事業部の壁 組織に「分断」が生まれる要因と打開策
リンクをコピー
記事をブックマーク
石上氏(以下、石上):続いて実践編で、実際に「VGG16」を題材に、PyTorchからTensorFlowにモデル移植、Weight移植を行ってみます。
PyTorchに関しては、今回モデルとしてtorchvision.modelsのVGG16を用いて、TensorFlowに自前で実装して、Weightを移植することを行っていきたいと思います。
VGG16はかなり有名なモデルだとは思いますが、あらためて復習していきたいと思います。
VGG16は、Convolutionのレイヤーと、Fully Connectedのレイヤーに分かれていて、Convolutionのレイヤーでは、3掛ける3の畳み込みを行いつつ、チャネル数は増えながら画像のスケール自体は下がります。
最終的には、ここの特徴マップですね、7掛ける7掛ける512の特徴マップをflattenによって平坦化をして、そこからFC層で最終的なクラス数分の予測を行うというモデルとなっています。
なので、大きく分けてConvolution層と、Fully Connected layerの2つに分かれているのが特徴となっています。
このそれぞれを、PyTorchからTensorFlowに移植していくのですが、まずはtorchvision.modelsのConvolution層の移植を行ってみます。
右側がPyTorchのVGG16の実際のスクリプトコードですが、make_layersでは、Convolution層というのを定義しています。
上から見ていくと、まずlayersという入れ物をリストで用意して、あらかじめcgfでそれぞれのConvolutionのチャネル、もしくはMax Poolingがどこにあるかをリストで用意しておいて、それを順々にfor文で回していって、レイヤーをそれぞれ追加していくという書き方となっています。
例えばリストに「MaxPool」や「Conv2d」を入れて、レイヤーを構築してSequentialレイヤーに追加することで、最終的にConvolution層を作っています。
このような書き方にならってTensorFlowでも実装すると、左のようなかたちとなります。
cgfは同じで、まずtflayersという入れ物を用意して、このリストのチャネル数に従ってアウトプットのチャネルを変えつつ、Convolutionを入れたり、Mが入っている場合は、Max Poolingを入れたりしてリストにモデルを入れて、最終的にtf.keras.SequentialでConvolution層を構築します。
例えばここの数字、「64, 64, 'M', 128, 128, 'M'」というのは、先ほどの図でいうと、ここの「64、64、Max Pooling、128、128、Max Pooling」に該当します。
これでConvolution層の定義が終わったので、続いては右側のメインのVGGのfeatureの特徴抽出器およびclassifierレイヤーの定義を行います。
右の元実装を見てみると、self.featuresは、先ほどの特徴抽出器のConvolution層に該当するので、make_layersで作ったレイヤーをここに入れてます。
続いてavgpoolは実はあまり意味がなくて、実際はflatten(平坦化)するだけの役割となっています。
self.classifierでは、最後にFC層として3つのLinear層があって、その間にReLUとDropoutが挟まっているかたちになっています。
forwardレイヤーでは、特徴抽出、平坦化、そしてclassifierでクラス数分の予測を行います。このようなモデルをTensorFlowで実装すると、左のようなかたちとなります。
TensorFlowの場合は、tf.keras.Modelで定義をして、featuresは同じですね。Flattenレイヤーとして、PyTorchの場合もflattenで平坦化しているのですが、TensorFlowもlayers.Flattenを用意して、のちほど平坦化を行います。
classifierの定義は、LinearがDenceに該当するので、Denseの4096。TensorFlowの場合は、Denseの中でactivationを定義できるので、より短く書くことができます。そしてDropoutを挟んで、最終的に3つのLinearレイヤーからなっています。
ここのcallの部分ですが、若干の違いがあります。TensorFlowの場合は、self.featuresで得られた特徴マップは、[B,H,W,C]と、Cがラストチャネルの形式になっているのですが、PyTorchの場合は、ファーストチャネルとなっているので、ここで軸を入れ替える必要があります。
ここでは、tf.transposeという操作を行うことによって、PyTorchの場合と同様の軸の順番で推論が行えるようになっています。
そのあとにflattenで平坦化を行って、classifierで分類を行うようなモデルとなっています。ここまででモデル定義の比較というのを行ってきました。
実際に定義できたので、これらのWeightをちょっと見ていきたいと思います。
TFVGGのname、shapeの表示です。torchvision.models.vgg16のname、shapeの表示結果がこのようになります。
conv2dのkernelとbias。それから、Linear、Dense層のkernelとbiasが重みの中身となっています。
ここで注意が必要なのは、conv2dのkernelの場合は軸が4つあって、それぞれ正しく入れ替える必要があるというところと、denseのkernelも軸を0番目と1番目で入れ替える必要があるというところです。
というわけで最後に、PyTorchからTensorFlowへ重みの移植を行っていきます。
手順は、はじめに説明したとおりですが、torchのparamとkeyを取得して、上から順番に中身を見ていきます。今回はtconvのkernelは4つの軸があるので、このように軸の入れ替え操作を行います。
2つの場合は、Linear層のWeightに該当するので、その1と0を入れ替えるというif文を使っています。
biasの場合は、そのまま入れてあげる操作を行っています。この操作を行うことによって、モデルの移植が完了します。
最後に、Sanity checkを行います。もとのPyTorchからTFVGGに重みを移植したもので、それぞれ同じインプットを入力として出力を行って、そのアウトプットの絶対値の差の平均を取っています。
例えばこの場合だと、4.7掛けるeのマイナス7乗という非常に小さい値でdiffが取れているので、重み移植が成功であることがわかります。
というわけで、ここまでで実践的なTensorFlow、PyTorchのモデル移植の説明を行ってきました。
最後にまとめです。今回は4つの項目についてお話ししました。まずはTensorFlow/PyTorchのモデル実装の基礎ということで、TensorFlowとPyTorchが似た書き方で書けるというところを説明しました。
続いてTensorFlow/PyTorchの比較では、主な違いとしてチャネルの違いや、featureを書く必要性の違いを説明をしました。
そして3番目は、TensorFlow/PyTorchの重み変換のテクニックとして、どういうふうに重みを取得するか、そしてどういうふうに移植するかを説明しました。
そして最後に、実際にTensorFlow/PyTorchのモデル移植というところで、VGG16を題材に、実際にモデルを移植して、Weightを移植して、最後に重みが正しく移植できているかを確認しました。
本来は日本語「RoBERTa」をTensorFlowからPyTorchに移植するという内容も考えていたのですが、今回は時間の都合上で入れることができなかったので、またなにかの機会に紹介できればなと思っています。
というわけで、「TensorFlow/PyTorchのモデル移植のススメ」というお話をしました。ご清聴ありがとうございました。
司会者:すごい発表でした。
これはどんなモデルでも、基本的には移植できると思っていいんですか? それともこれは難しいみたいなのはあるんでしょうか?
石上:そうですね、GRUが入っている場合はちょっと注意が必要かなというところですね。
司会者:ResNet系だと、やはり難しそうですが、ああいうのもできるんですか?
石上:基本的にモデル構造が正しく書けていれば、重みというのは基本的にここに書かれている重みだけなので、重みの移植に関しては、このような手順を踏んでいけば問題ないかなとは思います。
司会者:なるほど、そうなんですね。これで重みの移植にチャレンジする人が増えそうな気がします。
石上:そうですね。今回、分析コンペにおいて何がうれしいかというと、TensorFlowのTPUが使えるというところで、もちろんPyTorchも使えるのですが、TensorFlowのほうがトラブルが起きにくいんです。
例えばTensorFlowでTPUで高速に学習をさせて、その重みをPyTorchに移植して、あとはファインチューニングを行うということも可能になるかと思います。
司会者:なるほど、そうですよね。今はTPUが「Colab」でも使えるし、Kaggleの「Notebook」でも使えますもんね。
石上:そうですね。TPUをまだ試したことない方がいたら、ぜひ試してほしいですね。感動します。
司会者:本当に早くて、驚く感じでした。これのおかげでできる人も増えたのかもというのでスパチャしたいです。
石上:スパチャの代わりになるのですが、PyTorchは、このあと俵さんがお話しになる「TIMM」という画像認識のライブラリがあるのですが、それのTensorFlow版をどなたか作ってもらえないかなという思いを込めて、今回発表しました。
司会者:そうですね、TIMMのTensorFlow版をみなさんで作ってもらえるとメッチャありがたいですね。
石上:はい、発表者冥利に尽きます。
司会者:ありがとうございます。これで終わりたいと思います。
2024.11.13
週3日働いて年収2,000万稼ぐ元印刷屋のおじさん 好きなことだけして楽に稼ぐ3つのパターン
2024.11.21
40代〜50代の管理職が「部下を承認する」のに苦戦するわけ 職場での「傷つき」をこじらせた世代に必要なこと
2024.11.20
成果が目立つ「攻めのタイプ」ばかり採用しがちな職場 「優秀な人材」を求める人がスルーしているもの
2024.11.20
「元エースの管理職」が若手営業を育てる時に陥りがちな罠 順調なチーム・苦戦するチームの違いから見る、育成のポイント
2024.11.11
自分の「本質的な才能」が見つかる一番簡単な質問 他者から「すごい」と思われても意外と気づかないのが才能
2023.03.21
民間宇宙開発で高まる「飛行機とロケットの衝突」の危機...どうやって回避する?
2024.11.18
20名の会社でGoogleの採用を真似するのはもったいない 人手不足の時代における「脱能力主義」のヒント
2024.11.19
がんばっているのに伸び悩む営業・成果を出す営業の違い 『無敗営業』著者が教える、つい陥りがちな「思い込み」の罠
2024.11.13
“退職者が出た時の会社の対応”を従業員は見ている 離職防止策の前に見つめ直したい、部下との向き合い方
2024.11.15
好きなことで起業、赤字を膨らませても引くに引けない理由 倒産リスクが一気に高まる、起業でありがちな失敗