俺のOneNote

俺のOneNote

データ分析が仕事な人のOneNote愛とか、分析小話とか。

Numpyroでベイズ統計モデリング~単回帰モデル~

かれこれ数年付き合ってきたstan ですが、
推定の遅さが仕事上かなりネックになっており、 メインで使うPPLを高速にMCMCを回せるNumpyroに変えようと模索中。

以下、RとStanで始めるベイズ統計モデリングによるデータ分析入門のデータ&分析内容を題材に、Numpyroで実装しながら練習する記録です。

※可視化用ライブラリの arviz もちょっと使ってない間にだいぶ高機能になっており、その辺も試していく所存。

準備

環境はgoogle colabです。

import jax.numpy as jnp
import numpy as np
import jax.random as random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style='darkgrid',palette='bright')

#!pip install numpyro
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
numpyro.set_host_device_count(4)

import arviz as az

データの準備・確認

データは以下サポートサイトより

logics-of-blue.com

file_bear_sales_2 = pd.read_csv("3-2-1-beer-sales-2.csv")
file_bear_sales_2.head()

元データはsalestemperatureの1変量による単回帰モデル用。

plt.figure(figsize=(10,5))
sns.scatterplot(x="temperature",y="sales",data=file_bear_sales_2)
plt.show()

データ分布は以下のような形。

モデル

単回帰モデルなので非常に単純。

※はてぶのTexが全く表示されなかったり、よくわからないエラー起こるので作法が謎。

def model(
    sales,
    temperature
):
  Intercept = numpyro.sample("Intercept",dist.Normal(0,100))
  beta = numpyro.sample("beta",dist.Normal(0,100))
  sigma = numpyro.sample("sigma",dist.HalfNormal(100))

  numpyro.sample("sales",dist.Normal(Intercept + beta * temperature, sigma),obs = sales)

Numpyro のいいところはモデルの記述が直感的でわかりやすいところ。 後に取り組む階層モデルもわかりやすい。

一方で、時系列には jax.lax.scan を使うのが肝要なんですが、一気に複雑になる。
ここもいずれ実装して理解を深めていきたい。

data_dict = {
    "temperature":file_bear_sales_2["temperature"].values,
    "sales":file_bear_sales_2["sales"].values
}
kernel = NUTS(model)
sample_kwargs = dict(
    sampler=kernel, 
    num_warmup=2000, 
    num_samples=2000, 
    num_chains=4, 
    chain_method="parallel"
)
mcmc = MCMC(**sample_kwargs)
mcmc.run(random.PRNGKey(0), **data_dict)

MCMC回すのもそうなんですが、コンパイルいらないので、体感ものすごく早い。
4000サンプリングで1chain12秒。
並列処理、GPU利用も簡単なので、アクセラレータを最大限活かせる。

結果

mcmc.print_summary()

結果はほぼ本紙と同一になりました。

arviz でトレースプロットを確認してみる。

idata = az.from_numpyro(mcmc)
az.plot_trace(idata, figsize=(16,9))
plt.show()

r_hat でも確認できるとおり、収束も問題なしです。

所感

Numpyrostan と比べるとかなり高速にMCMCを回せるし、 比較的(TFPとかと比べれば)理解しやすい記述であることもポイントの一つ。
事前分布を必ず明示しなきゃいけないので、stanのデフォルト無情報事前分布に慣れてるとちょっと面倒です(恥

時系列はいずれ取り組みたいが、ちょっとモデル記述がややこしくなるので、そこはやっぱり stan のほうが分かりやすい(単に慣れてるだけ説が濃厚ですが。)

次は事後予測分布の可視化が焦点です。