2024.12.10
“放置系”なのにサイバー攻撃を監視・検知、「統合ログ管理ツール」とは 最先端のログ管理体制を実現する方法
リンクをコピー
記事をブックマーク
石上氏(以下、石上):続いて実践編で、実際に「VGG16」を題材に、PyTorchからTensorFlowにモデル移植、Weight移植を行ってみます。
PyTorchに関しては、今回モデルとしてtorchvision.modelsのVGG16を用いて、TensorFlowに自前で実装して、Weightを移植することを行っていきたいと思います。
VGG16はかなり有名なモデルだとは思いますが、あらためて復習していきたいと思います。
VGG16は、Convolutionのレイヤーと、Fully Connectedのレイヤーに分かれていて、Convolutionのレイヤーでは、3掛ける3の畳み込みを行いつつ、チャネル数は増えながら画像のスケール自体は下がります。
最終的には、ここの特徴マップですね、7掛ける7掛ける512の特徴マップをflattenによって平坦化をして、そこからFC層で最終的なクラス数分の予測を行うというモデルとなっています。
なので、大きく分けてConvolution層と、Fully Connected layerの2つに分かれているのが特徴となっています。
このそれぞれを、PyTorchからTensorFlowに移植していくのですが、まずはtorchvision.modelsのConvolution層の移植を行ってみます。
右側がPyTorchのVGG16の実際のスクリプトコードですが、make_layersでは、Convolution層というのを定義しています。
上から見ていくと、まずlayersという入れ物をリストで用意して、あらかじめcgfでそれぞれのConvolutionのチャネル、もしくはMax Poolingがどこにあるかをリストで用意しておいて、それを順々にfor文で回していって、レイヤーをそれぞれ追加していくという書き方となっています。
例えばリストに「MaxPool」や「Conv2d」を入れて、レイヤーを構築してSequentialレイヤーに追加することで、最終的にConvolution層を作っています。
このような書き方にならってTensorFlowでも実装すると、左のようなかたちとなります。
cgfは同じで、まずtflayersという入れ物を用意して、このリストのチャネル数に従ってアウトプットのチャネルを変えつつ、Convolutionを入れたり、Mが入っている場合は、Max Poolingを入れたりしてリストにモデルを入れて、最終的にtf.keras.SequentialでConvolution層を構築します。
例えばここの数字、「64, 64, 'M', 128, 128, 'M'」というのは、先ほどの図でいうと、ここの「64、64、Max Pooling、128、128、Max Pooling」に該当します。
これでConvolution層の定義が終わったので、続いては右側のメインのVGGのfeatureの特徴抽出器およびclassifierレイヤーの定義を行います。
右の元実装を見てみると、self.featuresは、先ほどの特徴抽出器のConvolution層に該当するので、make_layersで作ったレイヤーをここに入れてます。
続いてavgpoolは実はあまり意味がなくて、実際はflatten(平坦化)するだけの役割となっています。
self.classifierでは、最後にFC層として3つのLinear層があって、その間にReLUとDropoutが挟まっているかたちになっています。
forwardレイヤーでは、特徴抽出、平坦化、そしてclassifierでクラス数分の予測を行います。このようなモデルをTensorFlowで実装すると、左のようなかたちとなります。
TensorFlowの場合は、tf.keras.Modelで定義をして、featuresは同じですね。Flattenレイヤーとして、PyTorchの場合もflattenで平坦化しているのですが、TensorFlowもlayers.Flattenを用意して、のちほど平坦化を行います。
classifierの定義は、LinearがDenceに該当するので、Denseの4096。TensorFlowの場合は、Denseの中でactivationを定義できるので、より短く書くことができます。そしてDropoutを挟んで、最終的に3つのLinearレイヤーからなっています。
ここのcallの部分ですが、若干の違いがあります。TensorFlowの場合は、self.featuresで得られた特徴マップは、[B,H,W,C]と、Cがラストチャネルの形式になっているのですが、PyTorchの場合は、ファーストチャネルとなっているので、ここで軸を入れ替える必要があります。
ここでは、tf.transposeという操作を行うことによって、PyTorchの場合と同様の軸の順番で推論が行えるようになっています。
そのあとにflattenで平坦化を行って、classifierで分類を行うようなモデルとなっています。ここまででモデル定義の比較というのを行ってきました。
実際に定義できたので、これらのWeightをちょっと見ていきたいと思います。
TFVGGのname、shapeの表示です。torchvision.models.vgg16のname、shapeの表示結果がこのようになります。
conv2dのkernelとbias。それから、Linear、Dense層のkernelとbiasが重みの中身となっています。
ここで注意が必要なのは、conv2dのkernelの場合は軸が4つあって、それぞれ正しく入れ替える必要があるというところと、denseのkernelも軸を0番目と1番目で入れ替える必要があるというところです。
というわけで最後に、PyTorchからTensorFlowへ重みの移植を行っていきます。
手順は、はじめに説明したとおりですが、torchのparamとkeyを取得して、上から順番に中身を見ていきます。今回はtconvのkernelは4つの軸があるので、このように軸の入れ替え操作を行います。
2つの場合は、Linear層のWeightに該当するので、その1と0を入れ替えるというif文を使っています。
biasの場合は、そのまま入れてあげる操作を行っています。この操作を行うことによって、モデルの移植が完了します。
最後に、Sanity checkを行います。もとのPyTorchからTFVGGに重みを移植したもので、それぞれ同じインプットを入力として出力を行って、そのアウトプットの絶対値の差の平均を取っています。
例えばこの場合だと、4.7掛けるeのマイナス7乗という非常に小さい値でdiffが取れているので、重み移植が成功であることがわかります。
というわけで、ここまでで実践的なTensorFlow、PyTorchのモデル移植の説明を行ってきました。
最後にまとめです。今回は4つの項目についてお話ししました。まずはTensorFlow/PyTorchのモデル実装の基礎ということで、TensorFlowとPyTorchが似た書き方で書けるというところを説明しました。
続いてTensorFlow/PyTorchの比較では、主な違いとしてチャネルの違いや、featureを書く必要性の違いを説明をしました。
そして3番目は、TensorFlow/PyTorchの重み変換のテクニックとして、どういうふうに重みを取得するか、そしてどういうふうに移植するかを説明しました。
そして最後に、実際にTensorFlow/PyTorchのモデル移植というところで、VGG16を題材に、実際にモデルを移植して、Weightを移植して、最後に重みが正しく移植できているかを確認しました。
本来は日本語「RoBERTa」をTensorFlowからPyTorchに移植するという内容も考えていたのですが、今回は時間の都合上で入れることができなかったので、またなにかの機会に紹介できればなと思っています。
というわけで、「TensorFlow/PyTorchのモデル移植のススメ」というお話をしました。ご清聴ありがとうございました。
司会者:すごい発表でした。
これはどんなモデルでも、基本的には移植できると思っていいんですか? それともこれは難しいみたいなのはあるんでしょうか?
石上:そうですね、GRUが入っている場合はちょっと注意が必要かなというところですね。
司会者:ResNet系だと、やはり難しそうですが、ああいうのもできるんですか?
石上:基本的にモデル構造が正しく書けていれば、重みというのは基本的にここに書かれている重みだけなので、重みの移植に関しては、このような手順を踏んでいけば問題ないかなとは思います。
司会者:なるほど、そうなんですね。これで重みの移植にチャレンジする人が増えそうな気がします。
石上:そうですね。今回、分析コンペにおいて何がうれしいかというと、TensorFlowのTPUが使えるというところで、もちろんPyTorchも使えるのですが、TensorFlowのほうがトラブルが起きにくいんです。
例えばTensorFlowでTPUで高速に学習をさせて、その重みをPyTorchに移植して、あとはファインチューニングを行うということも可能になるかと思います。
司会者:なるほど、そうですよね。今はTPUが「Colab」でも使えるし、Kaggleの「Notebook」でも使えますもんね。
石上:そうですね。TPUをまだ試したことない方がいたら、ぜひ試してほしいですね。感動します。
司会者:本当に早くて、驚く感じでした。これのおかげでできる人も増えたのかもというのでスパチャしたいです。
石上:スパチャの代わりになるのですが、PyTorchは、このあと俵さんがお話しになる「TIMM」という画像認識のライブラリがあるのですが、それのTensorFlow版をどなたか作ってもらえないかなという思いを込めて、今回発表しました。
司会者:そうですね、TIMMのTensorFlow版をみなさんで作ってもらえるとメッチャありがたいですね。
石上:はい、発表者冥利に尽きます。
司会者:ありがとうございます。これで終わりたいと思います。
2024.12.10
メールのラリー回数でわかる「評価されない人」の特徴 職場での評価を下げる行動5選
2024.12.09
10点満点中7点の部下に言うべきこと 部下を育成できない上司の特徴トップ5
2024.12.09
国内の有名ホテルでは、マグロ丼がなんと1杯「24,000円」 「良いものをより安く」を追いすぎた日本にとって値上げが重要な理由
2024.12.12
会議で発言しやすくなる「心理的安全性」を高めるには ファシリテーションがうまい人の3つの条件
2023.03.21
民間宇宙開発で高まる「飛行機とロケットの衝突」の危機...どうやって回避する?
2024.12.10
職場であえて「不機嫌」を出したほうがいいタイプ NOと言えない人のための人間関係をラクにするヒント
2024.12.12
今までとこれからで、エンジニアに求められる「スキル」の違い AI時代のエンジニアの未来と生存戦略のカギとは
PR | 2024.11.26
なぜ電話営業はなくならない?その要因は「属人化」 通話内容をデータ化するZoomのクラウドサービス活用術
PR | 2024.11.22
「闇雲なAI導入」から脱却せよ Zoom・パーソル・THE GUILD幹部が語る、従業員と顧客体験を高めるAI戦略の要諦
2024.12.11
大企業への転職前に感じた、「なんか違うかも」の違和感の正体 「親が喜ぶ」「モテそう」ではない、自分の判断基準を持つカギ