Mental Model of Backpropagation

NN の授業で bprob の挙動を理解しようとするたびいつも混乱してきたが、今日ふと Deep Learning Book の section 6.5 を読んだところだいぶ混乱が収まった。

この節では TensorFlow のような NN 処理系が bprob をモデル化する方法を大まかに説明している。基本的には欲しい gradient から output の方向に向かって再帰的に gradient を求めては chain rule を適用していく。少なくともプログラマにとって、これはすごくわかりやすい気がする。そして自分が混乱した理由もわかった。

授業で習う bprob がわかりにくい理由は大きく分けて二つあると思う。一つは data flow graph の概念がないこと。上の section 6.5 を見るとわかるように, data flow graph の DAG は NN のネットワークと一見似ているけれど少し違う。Data flow graph は AST みたいなものである。そして言語処理系が AST をたどって式を評価するように NN 処理系は DAG を辿って bprob を評価する。数式だけ眺めてこのフローを理解するのは、けっこう math 力がないと難しい気がする。自分はわかっていなかった。

理由その二は、"Backpropagation" という名前に output から input の「下方向」にむけてデータを押し出す (propagate する) イメージがあること。これは間違っていないし実際そういう由来なのだろうけれど、gradient を求めるという観点では 下流にある weight のノードから上流の output のノードに向け「上方向に」必要な値(gradient)を取りに行くと考えた方が馴染みがある。言語処理系が AST をたどるのと同じ順番だから。

汎用フレームワークのモデルである「取りに行く」スタイルが backpropagation の名前が示唆する「押し下げる」スタイルよりわかりやすい別の理由は、前者だと計算したい値が比較的はっきりしているからでもある気がする。「押し下げる」スタイルで考えていた時は、自分がどの値を propagate すべきなのかわからず混乱した。DAG の概念の不在が混乱に拍車をかけた。DAG の下流に propagate すると考えれば、まあ無駄はあれど間違ってはいないからね。

Data flow の DAG 上を再帰的に pull するスタイルで素朴に考えると gradient の部分式を何度も計算する羽目になり、実装に無駄が多い。だから効率的な push/propagate のモデルで説明したくなる気持ちはわからなくもない。でもプログラマにしてみれば部分式をテーブルにキャッシュするなり DP に書き直すなりしてこの手の問題を高速化できるのは当たり前なわけで、それなら「欲しい値を再帰的に求める」というシンプルなモデルとして理解し、ただし遅いから DP しましょうねと言われるほうがだいぶわかりやすいよなあ。まあでも腑に落ちてよかった。達成感大。

「プログラマのための XX 入門」みたいのは胡散臭いので読まないようにしてるけれど、場合によっては CS の基礎知識が役に立つ事もある例。まあ NN も CS なので当たり前か・・・