マルチステージプログラミングのいいところ

鈴木健一氏(以下、鈴木):鈴木健一と申します。よろしくお願いします。本日はマルチステージプログラミングの話をします。

Dotty(※Scala3)でメタプログラミングの機能が強化されてマルチステージプログラミングができるようになったので、せっかくなのでそのお話をしまして、後半はTagless-finalのご紹介をします。最後にマルチステージプログラミングとTagless-final両方を組み合わせてインタプリタを作る話をいたします。

マルチステージプログラミングはジェネリックなソフトウェアを作るためのパラダイムになっていて、コード生成がサポートされています。

そもそもコード生成ってどういうメリットがあるんだっけ? というところですけれども、メンテナンス性が向上するところとか、開発の生産性、あとはパフォーマンスですね。そういったところが向上できる見込みがあります。

ただその一方でメタプログラミングを直接手でゴリゴリ書いてしまうと、どこかで間違ったときにその原因を特定するのが非常につらいという問題があると思います。

MSP、マルチステージプログラミングだとどういったいいことがあるかと言うと、生成されたコードが構文的に正しいかたちになっていて、さらに正しく型付けされた状態になっているものが生成されることが保証されているところです。

抽象のレベルとパフォーマンスって一般にはトレードオフの関係にあるかなと思うんですけれども、マルチステージプログラミングの場合はその抽象レベルを保持しながらパフォーマンスのいいコードを生成できるといった特徴もあります。

古典的な例でマルチステージプログラミングをみる

古典的な例を使ってマルチステージプログラミングを見ていきます。(スライドを示し)ここに出しているのはいわゆるpower関数、べき乗を計算する関数ですね。こちらを載せています。

このpower関数に2と10をパラメータとして与えると、2の10乗を計算して1,024の値を返してくれます。例えば2の10乗を計算するんだったら、10乗の部分を全部 a × a × … × aで掛け算をそのまま10個並べてしまえば同じ結果が得られるわけですね。

右のSpecial-purpose programと書いてあるところは、左側のpower関数に現れるような再帰の呼び出しがないというところで、左のプログラムよりもパフォーマンスがいいわけですね。これはいま10乗しているだけですが、これが例えば1,000個掛け算しなきゃいけないといった場合、パフォーマンスの効果がもっと顕著になっていきます。

とはいえ、この掛け算をひたすら並べていく退屈なコードを書きたいかと言われると、まあ書きたくないわけです(笑)。それももちろんバリエーションがいっぱいあって、じゃあ1万の場合どうするのかとか、50の場合どうするのかみたいに、いちいち書いていられないわけですね。

では、左の汎用的なプログラムから右の特定の計算に特化したプログラムを生成するにはどうしたらいいのか? というところでマルチステージプログラミングが役に立ってきます。

今回使うDottyのバージョンは、最新の0.27.0を使っています。ステージングプログラミングを始めるにあたっては、sbtの lampepfl/dotty-staging.g8 というテンプレートがありますので、そちらでプロジェクトを作るのが一番手っ取り早いと思います。

プロジェクトを作ったあとに、実際にそのステージングのコードを書くためには2つ、scala.quoted._とscala.quoted.staging._というパッケージをインポートする必要があります。さらにQuotationを扱うためにToolboxを与える必要があります。

3つ押さえておかなきゃいけない構成子

マルチステージプログラミングをするにあたって、3つ押さえておかなきゃいけない構成子があります。これを順にご紹介します。

まず1つ目がブラケットと呼ばれるものですね。これはDottyの場合だとQuotationとか呼ばれるものになります。何をやるかと言いますと、いま1+2という式を与えていて、こちらをブラケットで囲むとこのようなかたち(’{ 1+2 })になります。

これは何を意味しているかと言うと、1+2という計算を遅らせる、評価を遅らせることができるものです。これはマルチステージプログラミングというのが段階的に計算するものになっていて、特になにもない状態、1+2と書いてあるだけの状態だと、これはCurrentステージ、現在のステージの計算として扱われます。

そこにブラケットで囲ってあげると、その中の計算が次の、未来のステージに送られて計算を遅らせることができるという仕掛けになっています。

ちょっと注意事項として、今のDottyのバージョンは、Quotationを使うときにQuoteContextというのを暗黙的に与えてあげないといけません。これをスライドのコードのあちこちに書くのがつらかったので、スライドの中では省略しています。実際に書くときにはQuoteContextを与える必要があるというのをご承知おきください。詳しくは公式のマニュアルを読んでいただければと思います。

次、2つ目の構成子はエスケープです。こちらはブラケットの反対の操作を行うものです。上のほうのdef xで定義しているxが1+2という計算を囲って、それをコード化しています。

その下の def xxで上のxを参照して、$で書いてあるところがエスケープの操作になるんですが、ここでコードを展開しています。中身を見ると、1+2であるxを展開して、(1+2) + (1+2)ということをやって、さらにその式全体をブラケットで囲って1つのコード ’{ (1+2) + (1+2) } にしています。

つまり小さいコード片をつなぎ合わせて大きいコード片を作ることができるのが、このエスケープになります。モナドの箱から値を取り出して、また箱に入れるみたいな。そんなイメージで思っていただければいいかなと思います。

あともう1つは、今ご紹介したブラケットとエスケープは、双対の関係になっています。いわゆるdualですね。上の例で言いますと、まずeをブラケットで囲って、それをエスケープしています。それが打ち消しあって、結果的にただのeに戻る。

下の例も、コードのeをほどいてあげて、それをブラケットで囲ってあげるので、それが打ち消しあってただのeに戻ります。

最後は、runという構成子があって、これが組み立てたコードを実行するものになっています。コードを作っておいて、それをどこかで持ちまわして、任意のタイミングで走らせるといった使い方ができます。

注意点として、ステージングのレベルですね。例に出している計算の仕方が、今ハイライトしているのが1なんですけれども、注目したい項に対してその周囲にブラケットとエスケープの数がどれだけあるのかで計算します(周囲のブラケット数 − エスケープ数 = その項のステージレベル)。

‘{ ${ ‘{ 1 + 2 }} * … } の1の場合、ブラケット ‘{ } の数が2個あって、エスケープ$の数が1個になっていて、2-1なので、ステージレベルは1であるということが計算できます。このように見ていくというかたちになります。

MSPの構成子を使って汎用プログラムにステージングの注釈をつける

話を戻しまして、これらのMSPの構成子を使って左の汎用的なプログラムから右のプログラムをどうやって生成するのか。先ほどのブラケット、エスケープ、これがステージングアノテーションと呼ばれるもので、通常のpower関数に対して、そのプログラムの構造を壊すわけじゃなくて、そこに対してアノテーションを付けるような感覚でコードを生成していくかたちになります。

こちらのpower関数ですと、結果を返す値のところでブラケットで囲ってあげて、これをコード化してあげます。そうすると戻り値がExpr、コードの値になります。それを全部受けるときにコードの値を展開して、それを掛け算でつないで、さらにその全体をまたブラケットで囲ってあげる。

そうしますと、これは再帰で回していくので、掛け算というのがどんどんどんどん積み重なりつつコードが増えていくという、そういうイメージのコードが生成されます。

ここから例えば先ほどの2の1,000乗の1,000乗に特化したプログラムを生成したい場合どうするかと言うと、べき乗の値、1,000の値を与えます。それを先ほどの左のコードを生成するpower関数のパラメータに与えた状態でこのコードをrunすると、1,000に特化したコードが生成される。

その1,000の部分がこのパラメータとして与えられて、それが-1されて999になって、それがまた再帰で繰り返されて998になってというのを繰り返してコードが作られていくというイメージですね。

使い方としては上のstagedPowerという関数に特化したいパラメータ1,000を与えてあげて、これを関数として使う、と。そこに例えば2を与えると、2の1,000乗が計算されるという仕組みです。

生成されるコードを見ると、stagedPowerに1,000を与えるとまずパラメータとしてaを受けています。そのaをひたすら掛け算するコードが生成されているという感じですね。

この関数に2を与えると、このaの部分が2に置き換わって、それをひたすら掛け算する形になります。なので非常にパフォーマンスがいいコードが生成されます。ここに2を与えるだけでもパッと計算結果が返ってくるので、パフォーマンスがいいものができあがるという感じですね。

ステージ化したインタプリタ

もう1つ、ステージングのおもしろい例として、ステージ化したインタプリタですね。これがトランスレーターとして使えるというところをお話しします。まず例として、非常に単純なIntのリテラルと、足し算と、掛け算、この3つがあるだけの非常にシンプルな言語を考えます。

サンプルの項を作る時に、こんな感じで代数的データ型なのでそのまま、 And(IntLit(1), Mul(IntLit(2), IntLit(3)) みたいなかたちでデータを作って、これをインタプリタのパラメータに渡してあげます。

これを解釈する側は、このデータ型になっているので中身をパターンマッチして、さらに例えばAddの場合、そのAddの中のe1とe2をさらに評価して、evalの再帰関数に渡してグルグルほじくって計算していかなきゃいけないと。これが、パフォーマンスが悪くなるわけですね。

実際にやりたいことは、ここでは1 + (2 * 3)をただ計算したいだけなんですね。このコードができあがってくればそれを計算するだけで、パターンマッチとかなにもいらないわけです。

これをどうやるかというと、これも同じように先ほどのインタプリタのところにステージングのアノテーションを付け加えるだけです。ここではIntLitの場合はその値に対してブラケットで囲ってあげるだけで、Addの場合はe1とe2の値を、まず中身がどうなっているのかわからないのでevalで再帰的に呼び出します。

その結果はコード、Exprの型になっているので、ほどいて足し算でつなぎ合わせて、全体をまた先ほどみたいにコードで囲います。そうすると中身がほじくられてどんどんコードができあがってくる。最終的に下の1 + (2 * 3)というかたちになります。

もともと我々は左側の特別な言語、我々が作った言語でプログラムを書きたい。でもそいつを、いろいろがんばって遅い処理をして解釈するんじゃなくて、もっと効率のいいScalaネイティブなコードに置き換えてあげる、そのトランスレーションができるという使い方もできます。

(後半へつづく)