画像分類(MNIST) – Flax & Optax
概要
こちらは Flax による画像分類のコードです。前回 Haiku & Optax の組み合わせで、コードを書いてみましたが、今回はそのコードを Flax 向けに改造してみたいと思います。
基本的には Neural Network のライブラリが Haiku から Flax に変わるだけなので、その部分だけが変わります。MNIST を使ったコードは本家のサイトでも紹介されていますので、英語が苦痛でない方はこちらを見て頂いてもよいかもしれません。
また、Flax の方は日本語でも幾つか素晴らしい解説記事がありますので、ぜひ検索してみて下さい。紹介されているコードの多くは、TrainState というクラスを使ったスマートな実装をしているものが多いかと思いますが、こちらの記事ではもうちょっと冗長な感じのコードを紹介してみたいと思います。
Install Packages
Google Colab 上で動かすことを前提に、話を進めます。まずは Flax と Optax をインストールします。JAX や tensorflow-datasets も利用するのですが、これらは Google Colab には既にインストールされているので、改めてインストールしなくても大丈夫です。
!pip install flax optax
Import Packages
次に、必要なパッケージをインポートします。
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn
import tensorflow_datasets as tfds
Define Parameters
次に、コードの実行に必要なパラメータを定義します。
H1_DIM = 300
H2_DIM = 100
NUM_CLASSES = 10
LEARN_RATE = 1e-3
NUM_STEPS = 10000
CHECK_FREQ = 1000
Define Data Loader
MNIST のデータは TensorFlow のものをそのまま使います。
def load_dataset(split, batch_size, is_training):
ds = tfds.load("mnist:3.*.*", split=split).cache().repeat()
if is_training:
ds = ds.shuffle(10 * batch_size, seed=0)
ds = ds.batch(batch_size)
return iter(tfds.as_numpy(ds))
ds_train = load_dataset('train', batch_size=1000, is_training=True)
ds_test = load_dataset('test', batch_size=10000, is_training=False)
Define Network – Standard
TensorFlow や PyTorch と同様に、Haiku でもさまざまなスタイルで、ネットワークを定義することができるのですが、まずはスタンダードなスタイルでネットワークを定義します。ベースのクラス nn.Module を継承して、次のようなクラスを定義します。
class MLP_standard(nn.Module):
def setup(self):
self.dense1 = nn.Dense(H1_DIM)
self.dense2 = nn.Dense(H2_DIM)
self.dense3 = nn.Dense(NUM_CLASSES)
def __call__(self, batch):
x = batch['image'].astype(jnp.float32) / 255.
x = x.reshape((x.shape[0], -1)) # Flatten
x = self.dense1(x)
x = nn.relu(x)
x = self.dense2(x)
x = nn.relu(x)
x = self.dense3(x)
return x
net = MLP_standard()
Define Network – Compact
上で紹介した標準的な方法以外に、もう少しコンパクトな形でネットワークを定義することもできます。Flax のドキュメンテーションには、「どちらを使うべきか?」ということに関して指針を示してくれているページもありますので、迷ったときにはそちらを参照して頂くのがよいかもしれません。
class MLP_compact(nn.Module):
@nn.compact
def __call__(self, batch):
x = batch['image'].astype(jnp.float32) / 255.
x = x.reshape((x.shape[0], -1)) # Flatten
x = nn.Dense(H1_DIM)(x)
x = nn.relu(x)
x = nn.Dense(H2_DIM)(x)
x = nn.relu(x)
x = nn.Dense(NUM_CLASSES)(x)
return x
net = MLP_compact()
Define Function : Loss
Haiku のときと同様に、損失の関数を定義します。クロスエントロピーの関数は Optax にあるものを使ってしまっています。
実際にネットワークに予測をさせるには、apply メソッドを使います。このメソッドにネットワークのパラメータ(params)とバッチ(batch)を渡して、logits を計算させています。
@jax.jit
def loss_fn(params, batch):
logits = net.apply(params, batch)
labels = jax.nn.one_hot(batch['label'], NUM_CLASSES)
loss = optax.softmax_cross_entropy(logits, labels).mean()
return loss
Define Function : Evaluate Network
Haiku のときと同様に、精度評価のための関数を定義します。
@jax.jit
def evaluate(params, batch):
logits = net.apply(params, batch)
y_pred = jnp.argmax(logits, axis=-1)
accuracy = jnp.mean(y_pred == batch['label'])
return accuracy
Define Function : Update Parameters
Optax によるパラメータ更新のための関数を定義します。これも Haiku のときと同じ関数です。
@jax.jit
def update(params, opt_state, batch):
grads = jax.grad(loss_fn)(params, batch)
updates, opt_state = optimizer.update(grads, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, opt_state
Train Network
必要な関数を定義し終わったので、次に実際の学習を進めます。まず、最初のバッチ(dummy)を使ってネットワークの初期化を行い、ネットワークのパラメータ(params)を取得します。
dummy = next(ds_train)
params = net.init(jax.random.PRNGKey(42), dummy)
次に、上で取得したパラメータ(params)を使って optimizer の初期化を行い、そのステート(opt_state)を取得しておきます。
optimizer = optax.adam(LEARN_RATE)
opt_state = optimizer.init(params)
最後に学習のループを回します。基本的なコードの流れは、TensorFlow や PyTorch などと大きく変わるところはないかと思います。なお、jax.device_get 関数は、JAX の方の配列を通常の NumPy の配列へ変換する関数です。jax.device_put 関数というものあるので、組にして覚えておくと便利です。
%%time
for step in range(NUM_STEPS + 1):
if step % CHECK_FREQ == 0:
acc_A = evaluate(params, next(ds_train))
acc_B = evaluate(params, next(ds_test))
acc_A = jax.device_get(acc_A)
acc_B = jax.device_get(acc_B)
msg = 'Step {:d}, Train accuracy: {:.3f}, Test accuracy: {:.3f}'
print(msg.format(step, acc_A, acc_B))
params, opt_state = update(params, opt_state, next(ds_train))
Step 0, Train acc: 0.145, Test acc: 0.138
Step 1000, Train acc: 0.995, Test acc: 0.980
Step 2000, Train acc: 0.999, Test acc: 0.981
Step 3000, Train acc: 1.000, Test acc: 0.981
Step 4000, Train acc: 1.000, Test acc: 0.981
Step 5000, Train acc: 1.000, Test acc: 0.981
Step 6000, Train acc: 1.000, Test acc: 0.981
Step 7000, Train acc: 1.000, Test acc: 0.980
Step 8000, Train acc: 1.000, Test acc: 0.980
Step 9000, Train acc: 1.000, Test acc: 0.980
Step 10000, Train acc: 1.000, Test acc: 0.980
CPU times: user 3min 43s, sys: 4.95 s, total: 3min 48s
Wall time: 57.5 s
Summary
Flax と Optax の組み合わせで、画像分類を行うコードを紹介してみました。
お願い
記事につきましては、間違いないように十分に気をつけて書いたつもりなのですが、どこかで変なことをやっているかもしれません。お気付きの点がありましたら、ご指摘頂けますとありがたいです。