俺のOneNote

俺のOneNote

データ分析が仕事な人のOneNote愛とか、分析小話とか。

Numpyroでベイズ統計モデリング~ロジスティック回帰モデル~

RとStanで始めるベイズ統計モデリングによるデータ分析入門のNumpyro実装第6回。

今回はロジスティック回帰モデルをやっていきます。

準備

import jax.numpy as jnp
import numpy as np
import jax.random as random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='darkgrid',palette='bright')

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
numpyro.set_host_device_count(4)

import arviz as az

germination_dat = pd.read_csv("3-9-1-germination.csv")

color_dict = {"shade":"b","sunshine":"r"}

g = sns.relplot(data=germination_dat,x="nutrition",y="germination",hue="solar",palette=color_dict)
g.fig.set_figheight(6)
g.fig.set_figwidth(10)

上記のように、植木鉢の中の10個の種子から、何粒が発芽したかを示すデータとなっています。
変数として solar が含まれており、sunshineshade で発芽率に差があるようなデータとなっています。

モデル

上記データを以下のようなモデルにします。

これは、リンク関数がロジスティック関数、確率分布に二項分布を利用した一般化線形モデルで、通称ロジスティック回帰モデルと呼ばれます。

※ロジスティック回帰モデルというと、 10 の予測をするモデルであるという認識もありますが、
その場合、二項分布ではなくベルヌーイ分布(つまり1回の試行のうちの成功確率)を利用するモデルです。(二項分布の n = 1 と同義)
二項分布を利用することで、複数の試行回数における成功確率をモデル化できます。

#ダミー変数化
germination_dat_2 = pd.get_dummies(germination_dat).drop("solar_shade",axis=1)
germination_dat_2["Intercept"] = 1


#モデル

def model(
    N,
    C,
    X,
    size,
    germination,
  
):
  beta = numpyro.sample("beta",dist.Normal(0,100),sample_shape=(C,))

  with numpyro.plate("N",N):
    prob = jnp.dot(X,beta)
    numpyro.sample("germination",dist.BinomialLogits(logits=prob,total_count=size),obs = germination)


# 推論
X = germination_dat_2[["Intercept","solar_sunshine","nutrition"]].values
data_dict = {
    "N":X.shape[0],
    "C":X.shape[1],
    "X":X,
    "size":10,
    "germination":germination_dat_2["germination"].values
}

kernel = NUTS(model)
sample_kwargs = dict(
    sampler=kernel, 
    num_warmup=2000, 
    num_samples=2000, 
    num_chains=4, 
    chain_method="parallel"
)
mcmc = MCMC(**sample_kwargs)
mcmc.run(random.PRNGKey(0), **data_dict)

解釈

結果は無事、書籍と同一になりました。 トレースプロットも問題ないようです。

az.summary(mcmc)

az.plot_trace(mcmc,compact=False,figsize=(14,8))

ロジスティック回帰モデルの回帰係数の解釈は「対数オッズ比」であるため注意が必要です。

オッズ比 - Wikipedia

事後予測分布

このモデルの事後予測分布がどのようになってるかみてみたいと思います。

mcmc_samples=mcmc.get_samples()
predictive = numpyro.infer.Predictive(model, mcmc_samples)

pred_dict = {
    "N":X.shape[0],
    "C":X.shape[1],
    "X":X,
    "size":10,
    "germination":None
}

ppc_samples = predictive(random.PRNGKey(0),**pred_dict)
idata_ppc = az.from_numpyro(mcmc, posterior_predictive=ppc_samples)
germination_pred =  idata_ppc.posterior_predictive['germination']


#予測分布用データ
shade_index = germination_dat_2[germination_dat_2["solar_sunshine"]==0].index
sunshine_index = germination_dat_2[germination_dat_2["solar_sunshine"]==1].index

shade_nutrition = germination_dat_2.iloc[shade_index,:]["nutrition"]
sunshine_nutrition = germination_dat_2.iloc[sunshine_index,:]["nutrition"]

shade_germination_pred = ppc_samples["germination"][:,list(shade_index)]
sunshine_germination_pred = ppc_samples["germination"][:,list(sunshine_index)]

#サンプリング
ax = az.plot_hdi(shade_nutrition, shade_germination_pred, hdi_prob=0.99, plot_kwargs={"ls": "--"},smooth=False,color="b",figsize=(10,6))
az.plot_hdi(sunshine_nutrition, sunshine_germination_pred, hdi_prob=0.99, plot_kwargs={"ls": "--"},smooth=False,color="r")
sns.scatterplot(data=germination_dat, x="nutrition",y="germination",hue="solar",palette=color_dict)
plt.show()

すべてのデータが99%予測分布内に収まっており、
nutrition が高くなるほど発芽率が高く、sunshine のときに発芽率が高いことが理解できました。

おわりに

今回までシンプルな一般化線形モデルをつくってきました。
次回からはGLMM、ランダム効果を仮定したモデル、いわゆる階層ベイズモデルです。