先日CmdStanPyをインストールしてみました。(詳細は下記参照)
publicjournal.hatenablog.com
RのStanやPyStanと結構違うので、どういうものかと一通り使ってみました。ひとまず“Hello, World”を参照にしながら動作確認をしました。
同じような例題が書籍「Pythonによるベイズ統計モデリング」のP48~57にあるので、同書と同じことをCmdStanPyを使ってやってみたいと思います。
例題
こんな例題を考えます。
コイン投げ問題を考える。コインを10回投げて裏と表を記録する。表を1、裏を0として記録したところ、[1,0,1,0,0,0,0,0,1,0]となった。このデータをもとに、コインが偏っているかを判断する。コインの偏りを調べるために、次のモデルを用いる。
プログラムの全貌
プログラムは次の通り。
Stanのファイル
Stanのファイルは次の通り
解説
統計モデルの作成
上記のStanファイルにそって統計モデルを作成します。
モデルオブジェクトの生成
先ほどのStanファイルを読み込みます。
PyStanならばpystan.stan(model_code=XXX, data=XXXX, iter=1000, chains=4)
などと書いていましたが、CmdStanPyならばStanファイルを指定して次の通りになります。
### Stanファイルを読み込んでオブジェクトを生成する stan_file_path = 'sample_model.stan' model = CmdStanModel(stan_file=stan_file_path)
model_code
でモデルをベタ書きできるのかは調査中です。
サンプリングの実行
sample()
でハミルトニアンモンテカルロ法のノーUターンサンプラが走るそうです。
### MCMCを実行して事後分布をサンプリングする iter_sampling = 1000 # サンプリングの数.デフォルト1000 chains = 4 # 並列サンプリングの数.デフォルト4 # データの作成 front_and_back_list = [1,0,1,0,0,0,0,0,1,0] data = { "N" : len(front_and_back_list), "y" : front_and_back_list } # 実行 fit_sm = model.sample(data=data, iter_sampling=iter_sampling, iter_warmup = iter_warmup, chains = chains, seed=1234)
sample()
の引数として、よく使うのは次のものかと思います。
data
:データの辞書を渡す。これは他のStanと同じ。iter_sampling
:サンプル数を指定する。他のStanのiter
と同じっぽい。iter_warmup
:バーンインの数を指定する。他のStanのwarmup
と同じっぽい。chains
:並列して走るチェーンの数。これは他のStanと同じ。seed
:seed値を設定。これは他のStanと同じ。
下記を参考にしました。
またノーUターンサンプラは下記に詳細が載っています。
診断と結果の解釈
ここら辺はもっと便利なものがあるのかどうか分かりませんが、ひとまずPyMC3の収束診断っぽいものを作ってみました。
### 事後分布の診断 la = fit_sm.stan_variables() plt.figure(figsize=(10,3)) bins=15 for i, k in enumerate(la.keys()): boxplot_list =[] for j in range(chains): x = la[k][j*iter_sampling:(j+1)*iter_sampling] hist, bin_edges = np.histogram(x, bins=bins) # 度数分布に変換 #print(len(hist),len(bin_edges)) plt.subplot(len(la.keys()), 3, 3*i+1) plt.plot(bin_edges[:bins], hist, alpha=0.6, lw=0.6) plt.subplot(len(la.keys()), 3, 3*i+2) plt.plot(x, alpha=0.5, lw=0.8) boxplot_list.append(x) plt.subplot(len(la.keys()), 3, 3*i+3) plt.boxplot(boxplot_list) plt.show() ### 結果を解釈する samples = az.from_cmdstanpy(posterior = fit_sm) print(az.summary(samples))
また結果の要約についてもarviz
というライブラリを使うと比較的簡単に出来るようです。
Data variables: theta float64 1.002 mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk ess_tail r_hat theta 0.335 0.133 0.095 0.577 0.003 0.002 1462.0 1934.0 1.0