Numpyroでベイズ統計モデリング~事後予測分布~
今回はRとStanで始めるベイズ統計モデリングによるデータ分析入門の実装勉強第2回。
第3部第3章の「モデルを用いた予測」+αの実装です。
データ、モデル推定
ここは前回と同様のため主要コードのみ掲載。
#ライブラリ import jax.numpy as jnp import numpy as np import jax.random as random import pandas as pd import matplotlib.pyplot as plt import seaborn as sns sns.set(style='darkgrid',palette='bright') !pip install numpyro import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS numpyro.set_host_device_count(4) #データインポート file_bear_sales_2 = pd.read_csv("3-2-1-beer-sales-2.csv") #モデル定義 def model( N, sales, temperature ): Intercept = numpyro.sample("Intercept",dist.Normal(0,100)) beta = numpyro.sample("beta",dist.Normal(0,100)) sigma = numpyro.sample("sigma",dist.HalfNormal(100)) with numpyro.plate("N",N): numpyro.sample("sales",dist.Normal(Intercept + beta * temperature, sigma),obs = sales) #MCMCによる事後分布サンプリング data_dict = { "N":len(file_bear_sales_2), "temperature":file_bear_sales_2["temperature"].values, "sales":file_bear_sales_2["sales"].values } kernel = NUTS(model) sample_kwargs = dict( sampler=kernel, num_warmup=2000, num_samples=2000, num_chains=4, chain_method="parallel" ) mcmc = MCMC(**sample_kwargs) mcmc.run(random.PRNGKey(0), **data_dict)
事後予測分布の生成
事後分布のサンプリングからパラメータのMCMCサンプルを得るところから。
numpyro.infer.Predictive
インスタンスの利用と、予測したい説明変数を生成し、作成したモデルで事後予測分布を取得する準備をします。
今回は本紙に則り、気温が11度~30度までの区間の事後予測分布を得ることを目的とします。
mcmc_samples=mcmc.get_samples() predictive = numpyro.infer.Predictive(model, mcmc_samples) temperature_pred = jnp.arange(11,31)
あとはpredictive
インスタンスの引数に乱数と予測したいモデルの説明変数を与えるだけです。
目的変数となる観測データ(observations
)はNone
を指定し、モデルから出力することを明示します。
ppc_samples = predictive(random.PRNGKey(0),N = len(temperature_pred), temperature = temperature_pred, sales=None)
これで事後予測分布が取得できました。
arviz
による可視化のため、
InferenceData
オブジェクトに変換します。
idata_ppc = az.from_numpyro(mcmc, posterior_predictive=ppc_samples)
可視化による事後予測分布のチェック
あとは arviz
のプロットに必要な InferenceData
オブジェクトを渡すだけです。
軸ラベルなどが適当ですが、arviz
のラベルガイドに基づけば柔軟に対応できそうです。
※読み込み大変なので、ここではデフォルトでご容赦ください。
az.plot_forest(idata_ppc.posterior_predictive["sales"], var_names=["sales"], hdi_prob=0.95, combined=True, colors='b');
軸ラベルが分かりづらいですが、本紙と同じく、気温11度~30度までの各予測分布を可視化できました。
特定の事後予測分布を並列で可視化することも可能です。
ここでは、本紙と同じく気温11度と気温30度の事後予測分布を可視化します。
az.plot_forest(idata_ppc.posterior_predictive["sales"][:,:,[0,19]], kind="ridgeplot", var_names=["sales"], hdi_prob=0.95, ridgeplot_overlap=0.9, ridgeplot_truncate = False, ridgeplot_quantiles=[.5], combined=True, ridgeplot_alpha=0.5, figsize=(8,5), colors='b');
本紙のR
における bayesplot
とは微妙に見た目違いますが、同じものが出力できました。
arviz
はベイズモデリングの結果解釈でかなり有用になりそうです。
おまけ
本紙第3部第3章にはありませんが、回帰直線による予測区間の可視化も可能です。
sales_pred = idata_ppc.posterior_predictive['sales'] az.plot_hdi(temperature_pred, sales_pred,fill_kwargs={'alpha': 0.3}) plt.plot(temperature_pred, sales_pred.mean(axis=0).mean(axis=0),color="orange") sns.scatterplot(x=file_bear_sales_2["temperature"], y=file_bear_sales_2["sales"],s=50,color="gray")
おわりに
stan
とは若干流儀が違いますが、かなり簡易にモデルを利用した予測分布を得ることができました。
予測分布取得用のモデルを追記する必要がない分、慣れればこっちのほうが使いやすそうな気もします。
あとは arviz
が優秀。もうちょい使いこなせるようにしたいです。
次は第3部第6章、ダミー変数と分散分析モデルがテーマです。