CG×ML #2 勾配ベースの最適化(Mitsuba3) [Devlog #011]

Table of Contents

勾配ベースの最適化(Mitsuba3)

Mitsuba3の動作確認


Mitsuba3はPyPIからpip経由でインストールすることが推奨されている
公式ドキュメントはこちら

Mitsuba3の使い方(Python)

インポート

import mitsuba as mi

バリアントの選択

バリアントについての詳細

mi.variants()
['scalar_rgb', 'scalar_spectral', 'cuda_ad_rgb', 'llvm_ad_rgb']

この4つ以外のバリアントを必要とする場合は、pipではなく自力でコンパイルする必要がある(ソースからのコンパイルに関するドキュメントはこちら

mi.set_variant("scalar_rgb")
計算バックエンド(Computational backend)の指定

scalar:旧Mitsubaと同様にCPU上で浮動小数点演算を実行する(一度に個々のレイを処理するモードであるため、コンパイルエラーの修正やレンダラーのデバッグに適している)

cuda:Dr.JITが計算をCUDAカーネルに変換してGPUにオフロードする(GPUレイトレーシングにNVIDIAのOptiXライブラリを使用する)

llvm:Dr.JITが計算を並列CPUカーネルにコンパイルする(LLVMコンパイラフレームワークを使用する)(※ NVIDIA GPUを持っていない場合の代替手段)

自動微分(Automatic differentiation)の指定

_ad:自動微分を有効にする(cudallvmモードで指定可能)

主に微分可能レンダリングを行う際に利用するモードで、光の動きを含めた複雑な逆問題を解く

色の表現(Color representation)の指定

mono:単色ベースの色表現

rgb:RGBベースの色表現

spectral:可視域をカバーする完全なスペクトルカラー表現

偏光(Polarization)の指定

偏光の追跡が必要な場合に指定する追加オプション

偏光は逆問題を解決するための強力なツールである

精度(Precision)の指定

_double:通常の演算で設定している単精度(32bit)から倍精度(64bit)に設定を変更する(EmbreeとOptiXが倍精度をサポートしていないことを考慮した設定が必要な場合がある)

シーンのロード

scene = mi.load_file("../scenes/cbox.xml")

Mitsuba3ではシーンの記述にXMLベースの形式を利用する(シーンXMLファイルフォーマットに関するドキュメントはこちら

シーンのレンダリング

image = mi.render(scene, spp=256)

render()関数を使用してシーンをレンダリング

render()線形RGB色空間テンソル(NumPy配列に似たmi.TensorXf)を返す

(APIリファレンスはこちら ※サイトが重いので注意)

mi.Bitmap(image)

レンダリング画像の表示

matplotlibを使っても画像を表示できる
import matplotlib.pyplot as plt

plt.axis("off")
plt.imshow(image ** (1.0 / 2.2));
レンダリング画像の保存
mi.util.write_bitmap("my_first_render.png", image)
mi.util.write_bitmap("my_first_render.exr", image)

勾配ベースの最適化(Gradient-based optimization)


勾配ベースの最適化とは

シーン入力$\textbf x$からレンダリング結果$\textbf y$を出力する関数$f(\textbf x)$をレンダリングアルゴリズムと解釈すると、関数$f$を微分して$\frac{d\textbf y}{d\textbf x}$を求めることで、シーン入力$\textbf x$の変化によるレンダリング出力$\textbf y$の変化を見ることができ、

シーンパラメータの適合度(suitability)を定量化する微分可能な目的関数を$g(\textbf y)$とすると、

確率的勾配降下法(SGD)やAdamなどの勾配ベースの最適化アルゴリズムを用いることで、目的関数$g(\textbf y)$を改善するシーンパラメータを求めることができる

Mitsuba3公式ドキュメントでの説明はこちら

ここではMitsuba3を用いて光の動きを含めた逆問題を解く

Mitsuba3での実装

セットアップ

import drjit as dr
import mitsuba as mi
mi.set_variant('cuda_ad_rgb')

シーンのロード

scene = mi.load_file('scenes/cbox.xml', res=128, integrator='prb')

リファレンス画像

image_ref = mi.render(scene, spp=512)
mi.util.convert_to_bitmap(image_ref)

最適化するための設定

traverseメカニズムを使い、シーン内のパラメータを取得する

scene_params = mi.traverse(scene)
type(scene_params)
mitsuba.python.util.SceneParameters
scene_params
SceneParameters[
  ----------------------------------------------------------------------------------------
  Name                                 Flags    Type           Parent
  ----------------------------------------------------------------------------------------
  sensor.near_clip                              float          PerspectiveCamera
  
  ~省略~

  gray.reflectance.value               ∂        Color3f        SRGBReflectanceSpectrum
  white.reflectance.value              ∂        Color3f        SRGBReflectanceSpectrum
  green.reflectance.value              ∂        Color3f        SRGBReflectanceSpectrum
  red.reflectance.value                ∂        Color3f        SRGBReflectanceSpectrum
  
  ~省略~

  redwall.vertex_texcoords             ∂        Float          OBJMesh
  mirrorsphere.to_world                ∂, D     Transform4f    Sphere
  glasssphere.to_world                 ∂, D     Transform4f    Sphere
]

最適化するパラメータを選び、値を変更する(元の値をparam_refに保存しておく)

param_key = 'red.reflectance.value'
param_ref = mi.Color3f(scene_params[param_key])
scene_params[param_key] = mi.Color3f(0.01, 0.0, 0.9)
scene_params.update()
image = mi.render(scene, spp=128)
mi.util.convert_to_bitmap(image)

最適化

勾配降下を使ってパラメータの値を変更する

Momentum有り/無しの確率的勾配降下(SGD)を含む標準的なオプティマイザとAdamのうち、後者をインスタンス化し、学習率0.05でパラメータを最適化する

オプティマイザはパラメータ情報を取得し、指定されたパラメータを常に更新する

これらの変更をシーンに反映させる必要がある(updateメソッドで呼び出す)

オプティマイザクラスに関する公式ドキュメントはこちら

opt = mi.ad.Adam(lr=0.05)
opt[param_key] = scene_params[param_key]
scene_params.update(opt)

勾配降下する毎に、目的関数に関するパラメータの導関数を計算するため、画像間の平均二乗誤差(L2 error)を利用する
def mse(image):
    return dr.mean(dr.sqr(image - image_ref))

最適化ループ回数を制御するハイパーパラメータを定義する
iteration_count = 50

勾配降下ループによって、50回の微分可能レンダリングを行う

ロスに基づいて勾配を計算しているのがdr.backward(loss)であり、
計算された勾配に基づいてパラメータを更新するのがopt.step()である

ここではパラメータの値を0.0~1.0にクランプし、updateメソッドでoptの内容を反映してパラメータを更新する

それ以降はデバッグ用のコードになる

images = []
for it in range(iteration_count):
    image = mi.render(scene, scene_params, spp=8)
    loss = mse(image)

    dr.backward(loss)
    opt.step()

    opt[param_key] = dr.clamp(opt[param_key], 0.0, 1.0)

    scene_params.update(opt)

    err_ref = dr.sum(dr.sqr(param_ref - scene_params[param_key]))
    print(f"Iteration {it:02d}: paramerer error = {err_ref[0]:6f}", end='\r')
    images.append(image)
print('\nOptimization complete.')

結果

最適化結果の確認

最適化によって、壁の色が元の赤色に復元されることがわかる

image_final = mi.render(scene, spp=128)
mi.util.convert_to_bitmap(image_final)

最適化についての詳細(Mistuba3 + Dr.Jit)

参考公式ドキュメント

最適化の仕組み

基本 - OptimizerのAPI

自動微分に互換性のあるバリアントを選択

import mitsuba as mi
import drjit as dr

mi.set_variant('cuda_ad_rgb')

Adamと同様に、運動量あり/なしの確率的勾配降下(SGD)を含む標準的なオプティマイザを同梱している

学習率0.25の単純なSGDオプティマイザを作成

opt = mi.ad.SGD(lr=0.25, params={'x': mi.Float(1.0), 'y': mi.Float(2.0)})
opt
SGD[
  variables = ['x', 'y'],
  lr = {'default': 0.25},
  momentum = 0
]

勾配がオプティマイザの変数に適切に伝搬されることを確認する
print(f"Optimizer: {opt['x']}, grad_enabled={dr.grad_enabled(opt['x'])}.")
print(f"Optimizer: {opt['y']}, grad_enabled={dr.grad_enabled(opt['y'])}.")
Original:  [2.0], grad_enabled=True.
Optimizer: [2.0], grad_enabled=True.

dr.backward(z)では、レンダリングプロセスを誤差逆伝播している

z = opt['x'] + 2.0 * opt['y']
dr.backward(z)

print(f"x grad={dr.grad(opt['x'])}")
print(f"y grad={dr.grad(opt['y'])}")
x grad=[1.0]
y grad=[2.0]

opt.step()では、変数(パラメータ)の更新を適用している

print(f"Before the gradient step: x={opt['x']}, y={opt['y']}")
opt.step()
print(f"After the gradient step:  x={opt['x']}, y={opt['y']}")
Before the gradient step: x=[1.0], y=[2.0]
After the gradient step:  x=[0.75], y=[1.5]

シーンパラメータの最適化

自動微分に互換性のあるバリアントを選択

import mitsuba as mi
import drjit as dr

mi.set_variant('cuda_ad_rgb')

mi.traverse()SceneParametersを返す

scene = mi.load_file('../scenes/cbox.xml')
params = mi.traverse(scene)
params
SceneParameters[
  ----------------------------------------------------------------------------------------
  Name                                 Flags    Type            Parent
  ----------------------------------------------------------------------------------------
  sensor.near_clip                              float           PerspectiveCamera
  sensor.far_clip                               float           PerspectiveCamera
  sensor.shutter_open                           float           PerspectiveCamera
  sensor.shutter_open_time                      float           PerspectiveCamera
  sensor.x_fov                                  float           PerspectiveCamera
  sensor.to_world                               Transform4f     PerspectiveCamera
  gray.reflectance.value               ∂        Color3f         SRGBReflectanceSpectrum
  white.reflectance.value              ∂        Color3f         SRGBReflectanceSpectrum
  green.reflectance.value              ∂        Color3f         SRGBReflectanceSpectrum
  red.reflectance.value                ∂        Color3f         SRGBReflectanceSpectrum
  glass.eta                                     float           SmoothDielectric
  mirror.eta.value                     ∂, D     Float           UniformSpectrum
  mirror.k.value                       ∂, D     Float           UniformSpectrum
  mirror.specular_reflectance.value    ∂        Float           UniformSpectrum
  light.emitter.radiance.value         ∂        Color3f         SRGBEmitterSpectrum
  light.vertex_count                            int             OBJMesh
  light.face_count                              int             OBJMesh
  light.faces                                   UInt            OBJMesh
  light.vertex_positions               ∂, D     Float           OBJMesh
  light.vertex_normals                 ∂, D     Float           OBJMesh
  light.vertex_texcoords               ∂        Float           OBJMesh
  floor.vertex_count                            int             OBJMesh
  floor.face_count                              int             OBJMesh
  floor.faces                                   UInt            OBJMesh
  floor.vertex_positions               ∂, D     Float           OBJMesh
  floor.vertex_normals                 ∂, D     Float           OBJMesh
  floor.vertex_texcoords               ∂        Float           OBJMesh
  ceiling.vertex_count                          int             OBJMesh
  ceiling.face_count                            int             OBJMesh
  ceiling.faces                                 UInt            OBJMesh
  ceiling.vertex_positions             ∂, D     Float           OBJMesh
  ceiling.vertex_normals               ∂, D     Float           OBJMesh
  ceiling.vertex_texcoords             ∂        Float           OBJMesh
  back.vertex_count                             int             OBJMesh
  back.face_count                               int             OBJMesh
  back.faces                                    UInt            OBJMesh
  back.vertex_positions                ∂, D     Float           OBJMesh
  back.vertex_normals                  ∂, D     Float           OBJMesh
  back.vertex_texcoords                ∂        Float           OBJMesh
  greenwall.vertex_count                        int             OBJMesh
  greenwall.face_count                          int             OBJMesh
  greenwall.faces                               UInt            OBJMesh
  greenwall.vertex_positions           ∂, D     Float           OBJMesh
  greenwall.vertex_normals             ∂, D     Float           OBJMesh
  greenwall.vertex_texcoords           ∂        Float           OBJMesh
  redwall.vertex_count                          int             OBJMesh
  redwall.face_count                            int             OBJMesh
  redwall.faces                                 UInt            OBJMesh
  redwall.vertex_positions             ∂, D     Float           OBJMesh
  redwall.vertex_normals               ∂, D     Float           OBJMesh
  redwall.vertex_texcoords             ∂        Float           OBJMesh
  mirrorsphere.to_world                         Transform4f     Sphere
  glasssphere.to_world                          Transform4f     Sphere
]

keep()paramsオブジェクトの項目をフィルタリングしている

params.keep(r'.*\.reflectance\.value')
params
SceneParameters[
  ------------------------------------------------------------------------------
  Name                       Flags    Type            Parent
  ------------------------------------------------------------------------------
  gray.reflectance.value     ∂        Color3f         SRGBReflectanceSpectrum
  white.reflectance.value    ∂        Color3f         SRGBReflectanceSpectrum
  green.reflectance.value    ∂        Color3f         SRGBReflectanceSpectrum
  red.reflectance.value      ∂        Color3f         SRGBReflectanceSpectrum
]

よって、反射率の値を最適化するオプティマイザをコンストラクタする

opt = mi.ad.SGD(lr=0.25, params=params)
opt
SGD[
  variables = ['gray.reflectance.value', 'white.reflectance.value', 'green.reflectance.value', 'red.reflectance.value'],
  lr = {'default': 0.25},
  momentum = 0
]

ロードされたパラメータは内部的にコピーされるため、オプティマイザーでパラメータ値を変更しようとしても、直接paramsに反映されることはない

opt['red.reflectance.value'] *= 0.5

print(f"params:   {params['red.reflectance.value']}")
print(f"optimize: {opt['red.reflectance.value']}")
params:   [[0.5700680017471313, 0.043013498187065125, 0.04437059909105301]]
optimize: [[0.2850340008735657, 0.021506749093532562, 0.022185299545526505]]

これらの変更をparamsに(そしてScene自体にも)伝えるため、SceneParametersupdate()メソッドを使用する

オプティマイザの変数キーがparamsのものと一致するのを探し、対応する値を上書きする

params.update(opt);

print(f"params:   {params['red.reflectance.value']}")
print(f"optimize: {opt['red.reflectance.value']}")
params:   [[0.2850340008735657, 0.021506749093532562, 0.022185299545526505]]
optimize: [[0.2850340008735657, 0.021506749093532562, 0.022185299545526505]]