画像分類(MNIST) – Haiku & Optax

2022年12月3日

概要

こちらは Haiku による画像分類のコードです。基本的には、GitHab のレポジトリにある次のコードを少しだけシンプルにして、日本語の解説を加えたものです。

Install Packages

Google Colab 上で動かすことを前提に、話を進めます。まずは Haiku と Optax をインストールします。JAX や tensorflow-datasets も利用するのですが、こちらは Google Colab には既にインストールされているので、改めてインストールしなくても大丈夫です。

!pip install dm-haiku optax

Install Packages

次に、必要なパッケージをインポートします。

import jax
import jax.numpy as jnp

import optax
import haiku as hk

import tensorflow_datasets as tfds

Define Parameters

次に、コードの実行に必要なパラメータを定義します。

H1_DIM = 300
H2_DIM = 100
NUM_CLASSES = 10

LEARN_RATE = 1e-3
NUM_STEPS = 10000
CHECK_FREQ = 1000

Define Data Loader

MNIST のデータは TensorFlow のものをそのまま使います。

def load_dataset(split, batch_size, is_training):
    
  ds = tfds.load("mnist:3.*.*", split=split).cache().repeat()
  
  if is_training:

    ds = ds.shuffle(10 * batch_size, seed=0)
  
  ds = ds.batch(batch_size)
  
  return iter(tfds.as_numpy(ds))
ds_train = load_dataset('train', batch_size=1000, is_training=True)
ds_test = load_dataset('test', batch_size=10000, is_training=False)

Define Network – Standard

TensorFlow や PyTorch と同様に、Haiku でもさまざまなスタイルで、ネットワークを定義することができるのですが、まずはスタンダードなスタイルでネットワークを定義します。ベースのクラス hk.Module を継承して、次のようなクラスを定義します。

class MLP(hk.Module):

    def __init__(self):

        super().__init__()

        self.flatten = hk.Flatten()
        self.linear1 = hk.Linear(H1_DIM)
        self.linear2 = hk.Linear(H2_DIM)
        self.linear3 = hk.Linear(NUM_CLASSES)

    def __call__(self, x):

        x = self.flatten(x)
        x = self.linear1(x)
        x = jax.nn.relu(x)
        x = self.linear2(x)
        x = jax.nn.relu(x)
        x = self.linear3(x)

        return x

ネットワークを定義した後で、次のような関数を定義します。この関数の中で、入力データをどのように処理するかを定義します。

def mlp_fn(batch):

    x = batch['image'].astype(jnp.float32) / 255.
    mlp = MLP()

    return mlp(x)

次に、上の関数を JAX でうまく扱えるようにピュアな関数へ変換しておきます(hk.transform)。また、更に hk.without_apply_rng で変換しておくことで、ネットワークを使うときに擬似乱数のキー(PRNGKey)を渡さずに済むようになります。

net = hk.transform(mlp_fn)
net = hk.without_apply_rng(net)

Define Network – Compact

下は、上と同じネットワークを別のやり方で定義しているものです。こちらの方がずっとコンパクトに表現できていることがわかります。

def mlp_fn(batch):

    x = batch['image'].astype(jnp.float32) / 255.

    mlp = hk.Sequential([

        hk.Flatten(),
        hk.Linear(H1_DIM), jax.nn.relu,
        hk.Linear(H2_DIM), jax.nn.relu,
        hk.Linear(NUM_CLASSES)
    ])

    return mlp(x)
net = hk.transform(mlp_fn)
net = hk.without_apply_rng(net)

Define Function : Loss

次に、損失関数を定義していくことにします。こちらも基本的には GitHub にあるコードからは大きく変えていませんが、クロスエントロピーの関数は Optax にあるものを使ってしまっています。

実際にネットワークに予測をさせるには、apply メソッドを使います。このメソッドにネットワークのパラメータ(params)とバッチ(batch)を渡して、logits を計算させています。

def loss_fn(params, batch):

    logits = net.apply(params, batch)

    labels = jax.nn.one_hot(batch['label'], NUM_CLASSES)

    loss = optax.softmax_cross_entropy(logits, labels).mean()

    return loss

Define Function : Evaluate Network

精度評価のための関数を定義します。

@jax.jit
def evaluate(params, batch):

    logits = net.apply(params, batch)
    y_pred = jnp.argmax(logits, axis=-1)
  
    accuracy = jnp.mean(y_pred == batch['label'])

    return accuracy

Define Function : Update Params

Optax による最適化を行う関数を定義します。

@jax.jit
def update(params, opt_state, batch):

  grads = jax.grad(loss_fn)(params, batch)
    
  updates, opt_state = optimizer.update(grads, opt_state)
  new_params = optax.apply_updates(params, updates)
    
  return new_params, opt_state

Train Network

必要な関数を定義し終わったので、次に実際の学習を進めます。まず、最初のバッチ(dummy)を使ってネットワークの初期化を行い、ネットワークのパラメータ(params)を取得します。

dummy = next(ds_train)
params = net.init(jax.random.PRNGKey(42), dummy)

次に、上で取得したパラメータ(params)を使って optimizer の初期化を行い、そのステート(opt_state)を取得しておきます。

optimizer = optax.adam(LEARN_RATE)
opt_state = optimizer.init(params)

最後に学習のループを回します。基本的なコードの流れは、TensorFlow や PyTorch などと大きく変わるところはないかと思います。なお、jax.device_get 関数は、JAX の方の配列を通常の NumPy の配列へ変換する関数です。jax.device_put 関数というものあるので、組にして覚えておくと便利です。

%%time

for step in range(NUM_STEPS + 1):

  if step % CHECK_FREQ == 0:
    
    acc_A = evaluate(params, next(ds_train))
    acc_B = evaluate(params, next(ds_test))
    
    acc_A = jax.device_get(acc_A)
    acc_B = jax.device_get(acc_B)

    msg = 'Step {:d}, Train accuracy: {:.3f}, Test accuracy: {:.3f}'
    print(msg.format(step, acc_A, acc_B))
    
  params, opt_state = update(params, opt_state, next(ds_train))
Step 0, Train accuracy: 0.125, Test accuracy: 0.132
Step 1000, Train accuracy: 0.997, Test accuracy: 0.979
Step 2000, Train accuracy: 1.000, Test accuracy: 0.980
Step 3000, Train accuracy: 1.000, Test accuracy: 0.980
Step 4000, Train accuracy: 1.000, Test accuracy: 0.980
Step 5000, Train accuracy: 1.000, Test accuracy: 0.980
Step 6000, Train accuracy: 1.000, Test accuracy: 0.981
Step 7000, Train accuracy: 1.000, Test accuracy: 0.981
Step 8000, Train accuracy: 1.000, Test accuracy: 0.981
Step 9000, Train accuracy: 1.000, Test accuracy: 0.981
Step 10000, Train accuracy: 1.000, Test accuracy: 0.981
CPU times: user 3min 44s, sys: 4.76 s, total: 3min 48s
Wall time: 53.8 s

Summary

Haiku と Optax の組み合わせで、画像分類を行うコードを紹介してみました。次回は、このコードを Flax version に書き換えてみたいと思います。

お願い

記事につきましては、間違いないように十分に気をつけて書いたつもりなのですが、どこかで変なことをやっているかもしれません。お気付きの点がありましたら、ご指摘頂けますとありがたいです。

Haiku

Posted by 管理者