画像分類(MNIST) – Flax & Optax

2022年12月3日

概要

こちらは Flax による画像分類のコードです。前回 Haiku & Optax の組み合わせで、コードを書いてみましたが、今回はそのコードを Flax 向けに改造してみたいと思います。

基本的には Neural Network のライブラリが Haiku から Flax に変わるだけなので、その部分だけが変わります。MNIST を使ったコードは本家のサイトでも紹介されていますので、英語が苦痛でない方はこちらを見て頂いてもよいかもしれません。

また、Flax の方は日本語でも幾つか素晴らしい解説記事がありますので、ぜひ検索してみて下さい。紹介されているコードの多くは、TrainState というクラスを使ったスマートな実装をしているものが多いかと思いますが、こちらの記事ではもうちょっと冗長な感じのコードを紹介してみたいと思います。

Install Packages

Google Colab 上で動かすことを前提に、話を進めます。まずは Flax と Optax をインストールします。JAX や tensorflow-datasets も利用するのですが、これらは Google Colab には既にインストールされているので、改めてインストールしなくても大丈夫です。

!pip install flax optax

Import Packages

次に、必要なパッケージをインポートします。

import jax
import jax.numpy as jnp

import optax
import flax.linen as nn

import tensorflow_datasets as tfds

Define Parameters

次に、コードの実行に必要なパラメータを定義します。

H1_DIM = 300
H2_DIM = 100
NUM_CLASSES = 10

LEARN_RATE = 1e-3
NUM_STEPS = 10000
CHECK_FREQ = 1000

Define Data Loader

MNIST のデータは TensorFlow のものをそのまま使います。

def load_dataset(split, batch_size, is_training):
    
  ds = tfds.load("mnist:3.*.*", split=split).cache().repeat()
  
  if is_training:

    ds = ds.shuffle(10 * batch_size, seed=0)
  
  ds = ds.batch(batch_size)
  
  return iter(tfds.as_numpy(ds))
ds_train = load_dataset('train', batch_size=1000, is_training=True)
ds_test = load_dataset('test', batch_size=10000, is_training=False)

Define Network – Standard

TensorFlow や PyTorch と同様に、Haiku でもさまざまなスタイルで、ネットワークを定義することができるのですが、まずはスタンダードなスタイルでネットワークを定義します。ベースのクラス nn.Module を継承して、次のようなクラスを定義します。

class MLP_standard(nn.Module):

    def setup(self):

        self.dense1 = nn.Dense(H1_DIM)
        self.dense2 = nn.Dense(H2_DIM)
        self.dense3 = nn.Dense(NUM_CLASSES)
    
    def __call__(self, batch):

        x = batch['image'].astype(jnp.float32) / 255.

        x = x.reshape((x.shape[0], -1)) # Flatten
        x = self.dense1(x)
        x = nn.relu(x)
        x = self.dense2(x)
        x = nn.relu(x)
        x = self.dense3(x)
        
        return x
net = MLP_standard()

Define Network – Compact

上で紹介した標準的な方法以外に、もう少しコンパクトな形でネットワークを定義することもできます。Flax のドキュメンテーションには、「どちらを使うべきか?」ということに関して指針を示してくれているページもありますので、迷ったときにはそちらを参照して頂くのがよいかもしれません。

class MLP_compact(nn.Module):
    
    @nn.compact
    def __call__(self, batch):

        x = batch['image'].astype(jnp.float32) / 255.
        x = x.reshape((x.shape[0], -1)) # Flatten
        
        x = nn.Dense(H1_DIM)(x)
        x = nn.relu(x)
        x = nn.Dense(H2_DIM)(x)
        x = nn.relu(x)
        x = nn.Dense(NUM_CLASSES)(x)
        
        return x
net = MLP_compact()

Define Function : Loss

Haiku のときと同様に、損失の関数を定義します。クロスエントロピーの関数は Optax にあるものを使ってしまっています。

実際にネットワークに予測をさせるには、apply メソッドを使います。このメソッドにネットワークのパラメータ(params)とバッチ(batch)を渡して、logits を計算させています。

@jax.jit
def loss_fn(params, batch):

    logits = net.apply(params, batch)
    labels = jax.nn.one_hot(batch['label'], NUM_CLASSES)
    
    loss = optax.softmax_cross_entropy(logits, labels).mean()
    
    return loss

Define Function : Evaluate Network

Haiku のときと同様に、精度評価のための関数を定義します。

@jax.jit
def evaluate(params, batch):

    logits = net.apply(params, batch)
    y_pred = jnp.argmax(logits, axis=-1)
  
    accuracy = jnp.mean(y_pred == batch['label'])

    return accuracy

Define Function : Update Parameters

Optax によるパラメータ更新のための関数を定義します。これも Haiku のときと同じ関数です。

@jax.jit
def update(params, opt_state, batch):

  grads = jax.grad(loss_fn)(params, batch)

  updates, opt_state = optimizer.update(grads, opt_state)
  new_params = optax.apply_updates(params, updates)
  
  return new_params, opt_state

Train Network

必要な関数を定義し終わったので、次に実際の学習を進めます。まず、最初のバッチ(dummy)を使ってネットワークの初期化を行い、ネットワークのパラメータ(params)を取得します。

dummy = next(ds_train)
params = net.init(jax.random.PRNGKey(42), dummy)

次に、上で取得したパラメータ(params)を使って optimizer の初期化を行い、そのステート(opt_state)を取得しておきます。

optimizer = optax.adam(LEARN_RATE)
opt_state = optimizer.init(params)

最後に学習のループを回します。基本的なコードの流れは、TensorFlow や PyTorch などと大きく変わるところはないかと思います。なお、jax.device_get 関数は、JAX の方の配列を通常の NumPy の配列へ変換する関数です。jax.device_put 関数というものあるので、組にして覚えておくと便利です。

%%time

for step in range(NUM_STEPS + 1):

  if step % CHECK_FREQ == 0:
    
    acc_A = evaluate(params, next(ds_train))
    acc_B = evaluate(params, next(ds_test))
    
    acc_A = jax.device_get(acc_A)
    acc_B = jax.device_get(acc_B)

    msg = 'Step {:d}, Train accuracy: {:.3f}, Test accuracy: {:.3f}'
    print(msg.format(step, acc_A, acc_B))

  params, opt_state = update(params, opt_state, next(ds_train))
Step 0, Train acc: 0.145, Test acc: 0.138
Step 1000, Train acc: 0.995, Test acc: 0.980
Step 2000, Train acc: 0.999, Test acc: 0.981
Step 3000, Train acc: 1.000, Test acc: 0.981
Step 4000, Train acc: 1.000, Test acc: 0.981
Step 5000, Train acc: 1.000, Test acc: 0.981
Step 6000, Train acc: 1.000, Test acc: 0.981
Step 7000, Train acc: 1.000, Test acc: 0.980
Step 8000, Train acc: 1.000, Test acc: 0.980
Step 9000, Train acc: 1.000, Test acc: 0.980
Step 10000, Train acc: 1.000, Test acc: 0.980
CPU times: user 3min 43s, sys: 4.95 s, total: 3min 48s
Wall time: 57.5 s

Summary

Flax と Optax の組み合わせで、画像分類を行うコードを紹介してみました。

お願い

記事につきましては、間違いないように十分に気をつけて書いたつもりなのですが、どこかで変なことをやっているかもしれません。お気付きの点がありましたら、ご指摘頂けますとありがたいです。

Flax

Posted by 管理者