PyStan の実行例 – Eight Schools

2022年12月3日

概要

有名な例題である Eight Schools を Google Colab 上で実行してみました。コードはこちらのサイトを参考に書いています。

Install Package

まずは、PyStan をインストールします。

!pip install pystan==2.19.1.1

Import Packages

次に、必要なパッケージをインポートします。

import pystan
import arviz as az

Define Data

データは、下のデータを利用しました。

data = {'J': 8,
        'y': [28,  8, -3,  7, -1,  1, 18, 12],
        'sigma': [15, 10, 16, 11,  9, 11, 10, 18]}

Define Model & Inference

モデルを定義して、推論を行います。

model_code = """
data {
  int<lower=0> J;         // number of schools
  real y[J];              // estimated treatment effects
  real<lower=0> sigma[J]; // standard error of effect estimates
}
parameters {
  real mu;                // population treatment effect
  real<lower=0> tau;      // standard deviation in treatment effects
  vector[J] eta;          // unscaled deviation from mu by school
}
transformed parameters {
  vector[J] theta = mu + tau * eta;        // school treatment effects
}
model {
  target += normal_lpdf(eta | 0, 1);       // prior log-density
  target += normal_lpdf(y | theta, sigma); // log-likelihood
}
"""
%%time

sm = pystan.StanModel(model_code=model_code)
INFO:pystan:COMPILING THE C++ CODE FOR MODEL anon_model_67cb7d0f2cb7720776cbeb52007d2dbb NOW.

CPU times: user 1.86 s, sys: 132 ms, total: 1.99 s
Wall time: 1min 10s

Google Colab だと、コードのコンパイルに1分くらいかかりました。手元のマシンで計算すると、もっと速くなる可能性はあるかもしれません。次に、Chain を 1本だけ計算させた場合と Chain を 4本計算させた場合の両方で時間を測ってみます(Chain を並列に計算させる方法がわからなかったため ^^;)

%%time

mcmc = sm.sampling(data=data, iter=100000, warmup=500, chains=1, control=dict(adapt_delta=0.99, max_treedepth=10))
WARNING:pystan:2 of 99500 iterations ended with a divergence (0.00201 %).
WARNING:pystan:Try running with adapt_delta larger than 0.99 to remove the divergences.

CPU times: user 7.44 s, sys: 130 ms, total: 7.57 s
Wall time: 7.54 s
%%time

mcmc = sm.sampling(data=data, iter=100000, warmup=500, chains=4, control=dict(adapt_delta=0.99, max_treedepth=10))
WARNING:pystan:10 of 398000 iterations ended with a divergence (0.00251 %).
WARNING:pystan:Try running with adapt_delta larger than 0.99 to remove the divergences.

CPU times: user 7.45 s, sys: 434 ms, total: 7.88 s
Wall time: 38.3 s

Chain 1本あたりで 8~10秒くらいの時間がかかっているようです。それにしても、Stan は WARNING が細かく、親切設計ですね ^^

idata = az.from_pystan(mcmc)
az.plot_trace(idata);

az.summary(idata)
meansdhdi_3%hdi_97%mcse_meanmcse_sdess_bulkess_tailr_hat
mu7.9305.217-1.81817.5340.0150.022212292.0176914.01.0
tau6.5815.7280.00016.2630.0190.026144668.0174137.01.0
eta[0]0.3900.938-1.3692.1720.0020.001361783.0281559.01.0
eta[1]0.0010.874-1.6741.6430.0010.001389953.0288434.01.0
eta[2]-0.1940.927-1.9221.5730.0010.002396820.0288937.01.0
eta[3]-0.0310.883-1.6761.6720.0010.001357902.0281880.01.0
eta[4]-0.3510.878-1.9951.3420.0020.001332011.0273770.01.0
eta[5]-0.2100.893-1.9101.4740.0010.001415465.0286695.01.0
eta[6]0.3420.886-1.3582.0090.0010.001354287.0272854.01.0
eta[7]0.0560.933-1.7111.8100.0010.002410448.0290257.01.0
theta[0]11.3848.345-2.93528.4630.0160.011305154.0302281.01.0
theta[1]7.8946.282-4.21019.8900.0090.007509134.0350682.01.0
theta[2]6.1347.729-9.33920.4930.0130.010389201.0319471.01.0
theta[3]7.6456.536-4.83920.3170.0090.007486618.0348020.01.0
theta[4]5.1206.355-7.23916.8100.0100.008409586.0336711.01.0
theta[5]6.1426.727-7.24918.5350.0100.008466074.0348823.01.0
theta[6]10.6526.785-1.62824.0270.0110.008368395.0334347.01.0
theta[7]8.4437.879-6.83423.8300.0130.011387466.0300924.01.0

Summary

Google Colab 上でやっているためかもしれませんが、Stan のコードをコンパイルするのに1分くらいの時間がかかりました。サンプルはかなり多く発生させましたが、Chain 1本あたり 8~10秒くらいで走りました。

PyStan

Posted by 管理者