transとreduce
ローカルテンソルには、テンソルを元に新しいテンソルを作ったり、既存のテンソルを変形する処理として、 テンソルのtransとテンソルのreduceというものがあります。
trnasとreduceは発展的な機能なので初期の頃は不要な機能ですが、 使いこなすと少しGPUに不向きなアルゴリズムの実現が可能となる、 奥の手のような機能です。
ここではtrnasとreduce、及びローカルテンソルの副作用について見ていきます。
テンソルのreduceは、関数のreduceとは異なる機能なので、正しくは「テンソルのreduce」と呼びますが、 本章はテンソルのreduceを扱う章なので、 特に曖昧でない所では単にreduceと呼ぶ事にします。
ローカルテンソルのご使用は計画的に!
テンソルのローカルテンソルのセクションで述べたように、 ローカルテンソルはなるべく避けるべきものですが、 局所的な範囲のヒストグラムなど、必要なアルゴリズムも幾つか存在しています。
MFGでは副作用による更新をローカルテンソルのみに集約する事で、 それ以外の部分を純粋な式に統一し、 副作用によるバグの発生をおさえています。
逆に言えば、副作用が必須な事をするにはローカルテンソルを使う必要がある、 という事でもあります。
副作用を行えるローカルテンソルですが、 副作用はなるべく使わない方がいい機能でもあるため、 副作用を使わなくても既存のテンソルから新しいテンソルを作るreduceという機能を充実させています。
それでどうしても実現出来ない機能だけを、transという既存のテンソルを変形するという、比較的安全な形で提供しています。
transでもreduceでもない副作用は現時点では += のみです。
このtransとreduceはMFG特有のもので他の言語に無く、重要な要素なのでここで詳しく見ていきます。
テンソルreduce
テンソルのreduceは、既存のテンソルから新しいテンソルを作る処理です。
reduceは、何らかの形で元となるテンソルの次元を削減した新しいテンソルを作ります。 例えばある範囲内のrgbごとにヒストグラムの累積和を求めて、それをhistCumSumというテンソルで持っていたとします。 histCumSumは (3, 256) のi32のテンソルです。
そこからメディアンを求めるには以下のようなコードになります。
def median by reduce<histCumSum>.accumulate(dim=0, init=-1) |i, rgb, val, accm| {
ifel(accm != -1, accm, ...)
elif(val < histCumSum(255, rgb)/2, -1, i)
}
medianはhistCumSumの256の方の次元が縮約されて、要素3の1次元のテンソルとなります。
reduceのシンタックス
reduceは
- テンソルの指定
- メソッドの指定
- メソッドに応じた名前付き引数
- ブロック引数
という要素があります。
シンタックスとしては、以下のようになります。
reduce<テンソル名>.メソッド名(...) |...| {...}
テンソルは角カッコ<>で指定するのはMFG共通のシンタックスです。 メソッド名はaccumulateとfind_first_indexで、次に見ていきます。
引数はメソッド名に寄りますが、reduceは共通で必ずdimという名前付き引数があり、 この次元の軸にそってブロックを実行していく事になります。
なかなか複雑ですが、 このテンソルの指定とメソッドの指定はtransでも同様になるので、 一度どちらかを理解してしまえば両方が分かるようになります。
reduceのメソッド: find_first_index
dimの軸にそってブロックを実行していき、0以外を返した最初のインデックスを値とします。 見つからなければ-1です。
find_first_indexは以下のような使い方です。
# ローカルテンソルのテンソルリテラルはNYIなので、通常の作り方
# [[1, 2, 1],
# [2, 3, 2],
# [1, 2, 1]]
@bounds(3, 3)
def weight |i, j| { 3-abs(i-1)-abs(j-1) }
def upper by reduce<weight>.find_first_index(dim=0) |i, j, val| {
val <= i+j
}
upperはdim=0、つまりx軸にそって実行していって、この軸を何か一つの値に置き換えます。 つまり、もともとが(3, 3)なのを、 (_, 3)、つまりy軸だけの3要素のテンソルにします。
reduceはいつもdimの次元を最終的には何か一つの値に置き換えがテンソルを生成する事になります。 find_first_indexの場合はx軸にそってブロックを実行していって、最初に0以外の値を返したindexがその値になります。 インデックスが結果となるので、結果のテンソルはいつもi32です。
上記の例だと、j=0の行では、1, 2, 1と順番に実行していって、
| i+j | val |
|---|---|
| 0 | 1 |
| 1 | 2 |
| 2 | 1 |
となるので、i+j <= val となる最初のインデックスは 2、同様に次は1、最後は0となります。 つまり [2, 1, 0] という1次元テンソルとなります。
引数
引数はdimとブロック引数のみです。
- dim: i32でどの次元にそってreduceを行うかを指定
ブロック
ブロックの引数はインデックスとそのインデックスでの元テンソルの値。
結果は整数値を返す。 この結果が最初に非0になるインデックスを探す。
reduceの考え方
reduceはdimについての何らかの集約を行うものです。
NxMのテンソルについて、dim=0ならMのテンソル、dim=1ならNのテンソルとなります。
dim=0で考えると、各MについてN方向に列を見ていって、何か一つの値にします。 find_first_indexはN方向にblockを実行して最初に0以外だったインデックス、という事になる訳です。
この軸にそって集約していく、という事を理解するとテンソルのreduceは理解出来ると思います。
reduceのメソッド: accumulate
accumulateはより高機能なreduceです。find_first_indexはシンタックスシュガーでaccumulateとして実行されています。
例としては以下のようなものになります。
def median by reduce<histCumSum>.accumulate(dim=0, init=-1) |i, rgb, val, accm| {
ifel(accm != -1, accm, ...)
elif(val < histCumSum(255, rgb)/2, -1, i)
}
引数は以下の2つとブロック引数ですです。
引数
- dim: reduceしていく軸を指定
- init: 最初のaccmの値を指定
ブロック
ブロック引数は
- 元となるテンソルのインデックス(2次元ならx, yの順番)
- 元となるテンソルの該当位置の値
- 前のブロック実行結果の値
となります。元となるテンソルが1次元ならブロックの引数は3個、2次元ならインデックスが2つになるので4個の引数となります。 上記の例で言えば、i, rgbがインデックス、valがhistCumSum(i, rgb)の値、accmが前のブロック実行の値です。
解説
initやブロックの実行のされ方は「関数のreduce」とだいたい同じですが、 一つの軸に対してだけ行う所が違います。
関数のreduceについてはifelとループを参照ください。
「前のブロック実行結果」の前というのは現在のindexに対して、指定した軸の要素を一つ前に戻した場所を意味します。 例えばインデックスが3, 3でdimが0なら一つ前は2, 3、dimが1なら一つ前は3, 2となります。
単一の変数になるreduce
reduceは元となるテンソルの次元を一つ減らす操作となります。 ですから、元となるテンソルが1次元の場合、結果は単一の値となります。
この場合、テンソルのreduceは式として使う事が出来、通常のletで変数に入れたり出来ます。
wcumsumという1次元の重みの累積和から、あるindexより大きな場所を求めたい時、 以下の2つの式は同じです。
# defによる定義。結果は0次元となるのでi3は単なる値となる
def i3 by reduce<wcumsum>.find_first_index(dim=0) |_, val| { index < val }
# 0次元になる時だけはletで変数として普通に代入出来る
let i3 = reduce<wcumsum>.find_first_index(dim=0) |_, val| { index < val }
これはメディアンフィルタで、ヒストグラムを求めるのでは無く、 色を重複して並べる事で重み付きメディアンを求める時に使う計算から持ってきた例です。
reduceは元となるテンソルも現時点ではローカルテンソルのみ
現時点では、ローカルテンソルのサイズは全てコード生成時に決定している、という制約を置いているのですが、 この制約を確実に実現するために、元となるテンソルもローカルテンソルのみとしています。
技術的にはサイズが確定するテンソルリテラルなどであればグローバルテンソルでも構わないはずなので、 将来的にはこの制約は緩和されるかもしれませんが、現時点ではローカルテンソルを元にローカルテンソルを作る事しか出来ません。
テンソルtrans
reduceは元となるテンソルから新しいテンソルを作る処理でしたが、 テンソルtransは元となるテンソルを変更する処理です。
既にある値を変更する、という副作用を行う、MFGでは数少ない機能となります。
例えばhistというテンソルの累積和を求めるなら以下のようになります。
mut! trans<hist>.cumsum!(dim=0)
これでhistというテンソルが、累積和の値に変わります。
transのメリット、デメリット
transは副作用で既にある値を書き換える事になります。 既にあるローカルテンソルを書き換えるのは、reduceで作り直す事に比べるとレジスタ数が少なくて済むというメリットがある場合があります。
一方で同じテンソル名で場所によって入るものが変わるので、コードは読みにくくなる傾向にあります。
transのシンタックス
テンソルのtransのシンタックスは以下となります。
mut! trans<テンソル名>.メソッド名!(...)
まずMFGでは、副作用のある文は必ず mut!から始まります。 そしてメソッド名の最後に ! がつきます。
これは副作用がある文を特にシンタックス上目立たせるという意図でそうなっています。
それ以外はreduceとほとんど変わらないシンタックスとなっています。
transのメソッドとしては以下があります。
- sort
- cumsum
transのメソッド: sort
sortは指定した軸にそって、小さい値から大きい値へとソートします。
mut! trans<wmat>.sort!(dim=0)
現時点ではsortは1次元テンソルでi32に対してしか実装してません。 技術的な理由では無く使い道がなかった為なので、将来使う用途が出てきたらf32用や2次元用を実装するつもりではいます。
引数はdimのみです。
transのメソッド: cumsum
指定した軸にそって累積和の値に置き換えます。
累積和とは、例えば[3, 2, 3, 1]というテンソルに対しては、 [3, 5, 8, 9] と左から順番に足していく結果に置き換える事です。 元と同じ次元になる事に注意してください。
以下のように使います。
mut! trans<wcumsum>.cumsum!(dim=0)
これも引数はdimのみです。
transとreduceのメソッドは全て名前付き引数で使える
dimやinitは全て名前付き引数として使えます。
ts.for_eachと+=
transでもreduceでも無いのですが、transと類似な機能として += というものがあります。 さらにテンソルのループ系メソッドとしてts.for_eachというものもあるのですが、 これは原理的に += と合わせて使うしか使い道が無いので、ここでまとめて説明をしたいと思います。
以下のように、各ピクセルの3x3の範囲で、重み付けヒストグラムを求める事が出来ます。
def weight by [[1, 2, 1],
[2, 3, 2],
[1, 2, 1]]
def result_u8 |x, y| {
@bounds(256, 4)
def hist |i, col| { 0 }
weight.for_each |ix, iy, wval| {
let [b, g, r, a] = input_u8(ix+x, iy+y)
mut! hist(b, 0) += wval
mut! hist(g, 1) += wval
mut! hist(r, 2) += wval
mut! hist(a, 3) += wval
}
...
}
この weight.for_each の所と mut! で始まる4つの文が今回の解説対象です。
ts.for_each
テンソルのfor_eachメソッドは、引数のブロックを各要素に対して実行する、という事をします。 結果は返さずブロックを実行するだけです。
ts.for_each | インデックス, そのインデックスのtsの値 | { ... }
インデックスはtsが1次元なら一つ、2次元なら2つとなります。
値を返さずブロックを実行するだけの為、中で副作用のある事をしないとこの文は意味がありません。
そしてこのコンテキストで実行出来る副作用としては、現時点では+=くらいしかありません。(用途は自分には思いつきませんが、原理的にはtrans系も使う事は出来ます)
+= による副作用
ローカルテンソルの各要素に対する副作用は、今の所 += のみです。 =も実装しても良いとは思っていますが、現時点では実装してません。
副作用に関する機能はなるべく減らしたいと思っているので、十分に必要という結論が出るまでは実装は保留しています。
+= は、以下のように使います。
mut! hist(b, 0) += wval
+=の文はmut!で始まります。 MFGでは、全ての副作用のある文はmut!で始まります。
+=は左辺にテンソルの参照を、右辺に値を置きます。 左辺のテンソルの参照した結果の値と右辺の値を足したものを、 左辺のテンソルの参照先に上書きします。
他の言語の += と同じふるまいと思います。