Numpyroでベイズ統計モデリング~ロジスティック回帰モデル~
RとStanで始めるベイズ統計モデリングによるデータ分析入門のNumpyro実装第6回。
今回はロジスティック回帰モデルをやっていきます。
準備
import jax.numpy as jnp import numpy as np import jax.random as random import pandas as pd import matplotlib.pyplot as plt import seaborn as sns sns.set(style='darkgrid',palette='bright') import numpyro import numpyro.distributions as dist from numpyro.infer import MCMC, NUTS numpyro.set_host_device_count(4) import arviz as az germination_dat = pd.read_csv("3-9-1-germination.csv")
color_dict = {"shade":"b","sunshine":"r"} g = sns.relplot(data=germination_dat,x="nutrition",y="germination",hue="solar",palette=color_dict) g.fig.set_figheight(6) g.fig.set_figwidth(10)
上記のように、植木鉢の中の10個の種子から、何粒が発芽したかを示すデータとなっています。
変数として solar
が含まれており、sunshine
と shade
で発芽率に差があるようなデータとなっています。
モデル
上記データを以下のようなモデルにします。
これは、リンク関数がロジスティック関数、確率分布に二項分布を利用した一般化線形モデルで、通称ロジスティック回帰モデルと呼ばれます。
※ロジスティック回帰モデルというと、 1
か 0
の予測をするモデルであるという認識もありますが、
その場合、二項分布ではなくベルヌーイ分布(つまり1回の試行のうちの成功確率)を利用するモデルです。(二項分布の n = 1
と同義)
二項分布を利用することで、複数の試行回数における成功確率をモデル化できます。
#ダミー変数化 germination_dat_2 = pd.get_dummies(germination_dat).drop("solar_shade",axis=1) germination_dat_2["Intercept"] = 1 #モデル def model( N, C, X, size, germination, ): beta = numpyro.sample("beta",dist.Normal(0,100),sample_shape=(C,)) with numpyro.plate("N",N): prob = jnp.dot(X,beta) numpyro.sample("germination",dist.BinomialLogits(logits=prob,total_count=size),obs = germination) # 推論 X = germination_dat_2[["Intercept","solar_sunshine","nutrition"]].values data_dict = { "N":X.shape[0], "C":X.shape[1], "X":X, "size":10, "germination":germination_dat_2["germination"].values } kernel = NUTS(model) sample_kwargs = dict( sampler=kernel, num_warmup=2000, num_samples=2000, num_chains=4, chain_method="parallel" ) mcmc = MCMC(**sample_kwargs) mcmc.run(random.PRNGKey(0), **data_dict)
解釈
結果は無事、書籍と同一になりました。 トレースプロットも問題ないようです。
az.summary(mcmc) az.plot_trace(mcmc,compact=False,figsize=(14,8))
ロジスティック回帰モデルの回帰係数の解釈は「対数オッズ比」であるため注意が必要です。
事後予測分布
このモデルの事後予測分布がどのようになってるかみてみたいと思います。
mcmc_samples=mcmc.get_samples() predictive = numpyro.infer.Predictive(model, mcmc_samples) pred_dict = { "N":X.shape[0], "C":X.shape[1], "X":X, "size":10, "germination":None } ppc_samples = predictive(random.PRNGKey(0),**pred_dict) idata_ppc = az.from_numpyro(mcmc, posterior_predictive=ppc_samples) germination_pred = idata_ppc.posterior_predictive['germination'] #予測分布用データ shade_index = germination_dat_2[germination_dat_2["solar_sunshine"]==0].index sunshine_index = germination_dat_2[germination_dat_2["solar_sunshine"]==1].index shade_nutrition = germination_dat_2.iloc[shade_index,:]["nutrition"] sunshine_nutrition = germination_dat_2.iloc[sunshine_index,:]["nutrition"] shade_germination_pred = ppc_samples["germination"][:,list(shade_index)] sunshine_germination_pred = ppc_samples["germination"][:,list(sunshine_index)] #サンプリング ax = az.plot_hdi(shade_nutrition, shade_germination_pred, hdi_prob=0.99, plot_kwargs={"ls": "--"},smooth=False,color="b",figsize=(10,6)) az.plot_hdi(sunshine_nutrition, sunshine_germination_pred, hdi_prob=0.99, plot_kwargs={"ls": "--"},smooth=False,color="r") sns.scatterplot(data=germination_dat, x="nutrition",y="germination",hue="solar",palette=color_dict) plt.show()
すべてのデータが99%予測分布内に収まっており、
nutrition
が高くなるほど発芽率が高く、sunshine
のときに発芽率が高いことが理解できました。
おわりに
今回までシンプルな一般化線形モデルをつくってきました。
次回からはGLMM、ランダム効果を仮定したモデル、いわゆる階層ベイズモデルです。