はじめに
条件付き平均処置効果を算出する因果推論コンペ等で良い成績を収めている手法としてBayesian Additive Regression Trees(BART)が有名です。
BARTについて名前以上のことを理解していないので、調べて簡単にまとめました。
Bayesian Additive Regression Trees
BARTは因果推論のための機械学習モデルではなく、予測モデルの一つです。
未知の関数 を近似するための決定木ベースのアンサンブルモデルです。他のアンサンブル手法と同様に、各木は弱学習器として機能し、特徴量と結果の関係の一部だけを表現します。 決定木は非常に解釈しやすく柔軟なモデルですが、過学習しやすいという欠点があります。BARTでは過学習を避けるために正則化事前分布を使用し、各木が共変量と予測変数の間の関係の一部しか説明できないようにしています。
はじめに、BARTでは以下のような関係も学習することを考えていきます。
BARTでは、上記の関係を次のような形で推定します。
式からわかるように決定木をm個アンサンブルすることによって予測を実行するモデルである。 GBDTは残差を最小にするような決定木を段階的に学習していくが、BARTは段階的な学習は行わない。(段階的ではないが順に学習はしていく)
そして、残差に正規分布を仮定しており、ツリーの合計モデルとそのモデルのパラメータの事前正規化の 2 つの部分で構成されます。
決定木のアンサンブル
一つの二分木を以下のように表現する。
ここでは決定技における各枝ごとの目的変数の値である。 なおは決定木の構造であり、特定の特徴量で2分割するときの閾値を保持している。
事前分布の設定
事前分布を設定することで個々の決定木が過剰に影響しないようにする正則化の作用がある。 これにより加法表現による表現力の獲得と過学習を回避することができる。
事前の独立性と対称性
正規化事前分布の指定を簡素化するために、は互いに独立しており、またとも独立している。 このことは次のように表現できる。
また次のように説明できる。
この関係が成り立つ場合、を事前分布として指定する必要がある。
の事前分布
決定木の構造の決定には三つの要素が影響しているとして事前分布を考える。
1. ノードの深さ
決定木のノードの深さ。以下の分布により設定。
正則化により各ツリーの複雑性を低く保つために、α=0.95 および β=2 を使用しています。 この場合は1, 2, 3, 4, 5つ以上の終端ノードを持つツリーに対する事前確率はそれぞれ0.05、0.55、0.28、0.09、0.03となる。
2. ノードの分割対象とする変数の分布
どの変数で分岐を行うかを選択する。可能性のある変数から一様(等確率)で選択。
3. 分割する値の分布
分割を行う場合の閾値。可能性のある分割値の離散集合に対する一様事前分布を用いる。
(Extremely Randomized Treesに近い状態と言える?)
の事前分布
決定木の各枝の値の事前分布は以下のように提案されている。
ここでと設定する。
この事前分布により各決定木の出力が過剰な大きさにならず、正則化の役目を果たすとしている。
また、Chipmanはは1~3の間にすると良い結果が得られるとしている。
の事前分布
基本的に、この目的のために自由度 とスケール の事前分布をデータに基づいた粗い過大評価 使用して調整します。 の自然な選択肢は次の2つです:
その後、適切な形状を得るために の値を3から10の間で選び、事前分布の に関する第 分位点が に位置するように の値を選びます。
Chipmanはデフォルト設定としてを推奨している。 また は の過剰適合につながるため推奨されていない。
モデルの推定
MCMCのサンプリングによる分布推定を用いて、観測値から以下の事後分布を推定する。
ここで表記を簡単にするために、 を を除いたすべてのツリーの集合とし、同様に を定義します。したがって、 は m−1 個のツリーの集合となり、 はそれに対応する終端ノードのパラメータです。ここでのサンプリングではを条件づけてで連続して m 回サンプリングすることを意味します。
またに関しては、観測値とツリー情報に従ってサンプリングすることになる。
のサンプリングは逆ガンマ分布からのサンプルであり、一般的な方法で容易に得ることができます。 一方でのサンプリングは、通常の分布関数のサンプリングと異なり難しいため、以下の関係を考える。
これは番目のツリーを除外したフィットに基づく部分残差のn次元ベクトルです。
この部分残差を用いるとサンプリングは、次にように修正することができます。
実際にはの各サンプリングを、次の2つの連続したステップで実行することになります。
またのサンプリングは、やや複雑ですがCGM98のMetropolis-Hastings (MH) アルゴリズムを使用して実行します。 このアルゴリズムでは現在のツリーに基づいて次の4つの操作のいずれかを使用して新しいツリーを提案します。操作とそれに対応する提案確率は以下の通りです
- 終端ノードの成長(0.25)
- 終端ノードのペアの剪定(0.25)
- 非終端規則の変更(0.40)
- 親子間での規則の交換(0.10)
因果効果の推定
条件付き処置効果の推定は、処置を行う場合と行わない場合の結果変数の期待値の差をとることにより行う。
またBARTの学習に用いるの共変量に傾向スコアを追加して因果効果を推定する手法も提案されている。
この方法で推定される条件付き平均処置効果の精度が改善するとされている。
実行
RでBARTを実行するパッケージとして、BARTパッケージが代表的である模様。
BARTパッケージを用いたコードは以下のページを参考にした。
はじめにサンプルデータの生成。
library(BART) library(tidyverse) library(tidybayes) library(tidytreatment) sim <- simulate_su_hill_data(n = 200, treatment_linear = FALSE, omega = 0, add_categorical = TRUE, coef_categorical_treatment = c(0,0,1), coef_categorical_nontreatment = c(-1,0,-1) ) dat <- sim$data
生成プロセスを確認する。
> sim$formulas $treatment_assignment expression(0.4*x5 + 0.2*x6 + 0.4*x7 + 0.2*x8 + 0.4*x9 + 0.2*x10 + 0.8*I(x5^2) + 0.8*I(x6^2) + 0.5*I(x5 * x6) + 0.3*I(x5 * x6 * x7) + 0.8*I(x7^2) + 0.2*I(x7^3) + 0.4*I(x8^2) + 0.3*I(x7 * x8) + 0.8*I(x9^2) + 0.5*I(x9 * x10)) $response_treatment expression(0.5*x5 + 2*x7 + 0.5*x9 + 2*x10 + 0.4*I(x5^2) + 0.8*I(x6^2) + 0.5*I(x7^2) + 0.5*I(x8^2) + 0.5*I(x9^2) + 0.7*I(x9 * x10) + 1*I(c1=='3')) $response_nontreatment expression(0.5*x5 + 2*x7 + 0.5*x9 + 2*x10 + 0.4*I(x5^2) + 0.8*I(x6^2) + 0.5*I(x7^2) + 0.5*I(x8^2) + 0.5*I(x9^2) + 0.7*I(x9 * x10) + -1*I(c1=='1') + -1*I(c1=='3')) $generic ~x1 + x2 + I(x1^2) + I(x2^2) + I(x2 * x6) + x5 + x6 + x7 + x8 + x9 + x10 + I(x5^2) + I(x6^2) + I(x5 * x6) + I(x5 * x6 * x7) + I(x7^2) + I(x7^3) + I(x8^2) + I(x7 * x8) + I(x9^2) + I(x9 * x10) <environment: 0x7fad347e6870>
モデル推定
4段階で因果効果を推定する。
Hahn、Murray、Carvalho (2020)の手順に従った以下の方法で因果効果の推定を行って行きます。
- 変数選択モデルの作成:共変量(処置変数を除く)に対する結果を回帰する
- 変数選択モデルから結果変数に影響がある共変量のサブセットを選択
- 傾向スコアモデルの作成:ステップ2で選択された共変量のみを使用して傾向スコアを推定するプロビット/ロジットモデルを作成
- 処置効果モデルを適合する:ステップ3の元の共変量と傾向スコアを使用する
# STEP 1 VS Model: 共変量を用いた回帰モデル var_select_bart <- wbart(x.train = select(dat,-y,-z), y.train = pull(dat, y), sparse = TRUE, # 変数の選択 nskip = 1000, # burn inのステップ数 ndpost = 3000 # サンプリング回数 )
# STEP 2: 変数選択 var_select <- covar_ranking %>% filter(avg_inclusion >= quantile(avg_inclusion, 0.5)) %>% pull(variable) var_select <- unique(gsub("c1[1-3]$","c1", var_select))
> var_select [1] "c1" "x5" "x6" "x7" "x8" "x9" "x10"
# STEP 3 PS Model: 傾向スコアモデルの作成 prop_bart <- pbart( x.train = select(dat, all_of(var_select)), y.train = pull(dat, z), nskip = 1000, ndpost = 3000 ) dat$ps <- prop_bart$prob.train.mean
# Step 4 TE Model: BART + Psocre te_model <- wbart( x.train = select(dat,-y), y.train = pull(dat, y), nskip = 1000, ndpost = 2000, keepevery = 100 )
推定結果の処理
処置効果の推定結果を抽出して、ヒストグラムを表示する。 この時表示するのは、全ての観測サンプルに対して実施したサンプリングの値の全てが混ざって状態である。
posterior_treat_eff <- treatment_effects(te_model, treatment = "z", newdata = dat) posterior_treat_eff %>% ggplot() + geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + ggtitle("処置効果のヒストグラム") + xlab("CATE")+ theme_bw(base_family = "HiraKakuPro-W3")
次に各個体から取得した処置効果の2000件のサンプリングの中央値のヒストグラムを表示する。
posterior_treat_eff %>% summarise(cte_hat = median(cte)) %>% ggplot() + geom_histogram(aes(x = cte_hat), binwidth = 0.1, colour = "white") + ggtitle("処置効果のヒストグラム (各対象における中央値)") + xlab("CATE")+ theme_bw(base_family = "HiraKakuPro-W3")
今回の場合は二峰性となっていた。
特定の共変量を条件付けた場合の因果効果の分布も確認することができる。
posterior_fitted <- fitted_draws(te_model, value = "fit", include_newdata = FALSE) treatment_var_and_c1 <- dat %>% select(z,c1) %>% mutate(.row = 1:n(), z = as.factor(z)) posterior_fitted %>% left_join(treatment_var_and_c1, by = ".row") %>% ggplot() + stat_halfeye(aes(x = z, y = fit)) + facet_wrap(~c1, labeller = as_labeller( function(x) paste("c1 =",x) ) ) + xlab("処置") + ylab("CATE") + theme_bw() + ggtitle("C1と処置変数を条件付けた場合の結果変数の事後分布")
次に各個体の処置効果の信頼区間を表示する。
posterior_treat_eff %>% select(-z) %>% point_interval() %>% arrange(cte) %>% mutate(.orow = 1:n()) %>% ggplot() + geom_interval(aes(x = .orow, y= cte, ymin = .lower, ymax = .upper)) + geom_point(aes(x = .orow, y = cte), shape = "circle open", alpha = 0.5) + ggtitle("CATE分布(95%信頼区間)") + ylab("CATE")+ theme_bw() + coord_flip() + scale_colour_brewer() + theme(axis.title.y = element_blank(), axis.text.y = element_blank(), axis.ticks.y = element_blank(), legend.position = "none")
処理効果推定における変数重要度の表示
特定の変数が処置変数と合わせてツリーに(平均的に)何回含まれたかを数えることで重要度としています。
treatment_interactions <- covariate_with_treatment_importance(te_model, treatment = "z") treatment_interactions %>% ggplot() + geom_bar(aes(x = reorder(variable, avg_inclusion), y = avg_inclusion), stat = "identity") + ggtitle("処置Zとの相互作用が強い変数") + ylab("Inclusion counts") + xlab("") + theme(axis.text.x = element_text(angle = 45, hjust=1)) + theme_bw(base_family = "HiraKakuPro-W3")
モデルの診断
推定したモデルにおける結果変量の残差を確認
res <- residual_draws(te_model, response = pull(dat, y), include_newdata = FALSE) res %>% point_interval(.residual, y, .width = c(0.95) ) %>% select(-y.lower, -y.upper) %>% ggplot() + geom_pointinterval(aes(x = y, y = .residual, ymin = .residual.lower, ymax = .residual.upper), alpha = 0.2) + scale_fill_brewer() + ggtitle("観測値と残差の傾向") + theme_bw(base_family = "HiraKakuPro-W3")
観測値とモデル出力の相関を確認。
res %>% summarise(.fitted = mean(.fitted), y = first(y)) %>% ggplot(aes(x = y, y = .fitted)) + geom_point() + geom_smooth(method = "lm") + ggtitle("観測値とモデルの出力") + theme_bw(base_family = "HiraKakuPro-W3")
Q-Qプロットの確認
res %>% summarise(.residual = mean(.residual)) %>% ggplot(aes(sample = .residual)) + geom_qq() + geom_qq_line() + ggtitle("Q-Qプロット") + theme_bw(base_family = "HiraKakuPro-W3")