Simulation task: recovering causal effects from post-treatment selection induced by missing outcome data

selection bias
mediation
regression
IPW
doubly-robust
TMLE
A simulation exercise on missing data, selection bias, causal inference and TMLE
Published

August 15, 2023


Objective

The final goal is to estimate the effects of ADHD treatment on school performance, where:

  • the estimands: average treatment effect (ATE) and conditional average treatment effect (CATE)
  • the exposure: stimulant medication for ADHD
  • the outcome: raw score at the mathematics national test at grade 8
  • the target population: ADHD-diagnosed Norwegian schoolchildren between grades 6 and 8

For this task, we employ:

  • The structural causal models (SCM) framework (Pearl 2009)
  • A synthetic observational dataset, generated by an SCM learnt from real-world data (functional mechanisms are treated as unknown in the analysis)
  • A back-door admissible set of pre-treatment variables; i.e., the assumption of no latent confounding
  • A missing-outcome mechanism that allows recoverability via IPW and regression adjustment (functional specification is unknown in the analysis)
  • The targeted minimum-loss estimation (TMLE) framework (van der Laan and Rose 2011)

Synthetic data

Let \(\mathcal{G}'\) be a directed acyclic graph (DAG) built from domain knowledge and temporal-order constraints, involving the exposure \(A\), the outcome \(Y\), a set of confounders \(H\), and mediators \(M\). We previously fit flexible models (random forests) to learn the causal mechanisms \(\hat{f}_V:\text{supp}\, \text{pa}(V;\mathcal{G}')\times \text{supp}\, U_V\rightarrow\text{supp}\, V\) from real-world data on the subject. This allows us to replicate such mechanisms to generate fake data from seeds (noises and exogenous variables) in a controlled environment, along with counterfactual outcomes. In this exercise, exogenous variables mimic the marginal distribution of their real-world counterpart, but not the joint distribution.

By design, conditional ignorability is satisfied in the synthetic system using the full confounder set. It might not be satisfied in the system where the real-world data come from.

Generated variables can be grouped in three categories:

1. Pre-treatment / exogenous variables and the exposure:

  • mom.vuln = \(H_1\in\mathbb{R}\) = maternal vulnerability index: a PCA-based index summarizing mother’s diagnoses (+), level of education (-), age (-), and number of children (-)
  • sex.girl = \(H_2\in\{0,1\}\) = child’s sex at birth, female = 1
  • med.pre6 = \(H_3\in\{0,1\}\) = prescription for ADHD medication before grade 6
  • pre.diag = \(H_4\in\{0,1\}\) = diagnoses for comorbid disorders (ADHD-related, internalizing or externalizing) registered before grade 6
  • ntr.grd5 = \(H_5\in\mathbb{R}\) = raw score obtained at the national test for grade 5 (average math and reading)
  • reg.gpsp = \(H_6\in\mathbb{N}\) = number of registrations for health services (GP and specialist) before grade 6
  • med.6to8 = \(A\) = prescription for ADHD medication between grades 6-8

2. Mediator variables and the outcome:

  • ptr.diag = \(M_1\in\{0,1\}\) = diagnoses for comorbid disorders (ADHD-related, internalizing or externalizing) registered between grades 6-8
  • ptr.gpsp = \(M_2\in\mathbb{N}\) = number of registrations for health services (GP and specialist) between grades 6-8
  • ntr.grd8 = \(Y\) = raw score obtained at the national test for grade 8 (math)

3. Counterfactual variables:

  • ptr.diag.cntr = \(M_1^{1-A}\) = diagnoses for comorbid disorders (ADHD-related, internalizing or externalizing) registered between grades 6-8, had the individual taken the opposite treatment
  • ptr.gpsp.cntr = \(M_2^{1-A}\) = number of registrations for health services (GP and specialist) between grades 6-8, had the individual taken the opposite treatment
  • ntr.grd8.cntr = \(Y^{1-A}\) = raw score obtained at the national test for grade 5 (math)
  • ITE = \(Y^{1}-Y^{0}\) = individual treatment effect

A dataset (synth.data) of \(N=5\,000\) samples was generated:

Code
# Load packages --------------------------------------------------------------
library(data.table)   # Processes dataframes
library(dplyr)        # Processes dataframes
library(kableExtra)   # Styles tables
library(speedglm)     # Performs fast fitting for GLM
library(ranger)       # Performs random forest learning
library(nnls)         # Performs non-negative least squares
library(Rsolnp)       # Augmented Lagrange optimizer
library(sl3)          # Performs super-learning
library(tmle3)        # Performs TMLE
library(tmle3mediate) # Performs TMLE for mediation analysis
library(ggplot2)      # Plots

# Read data -------------------------------------------------------------------
load(file="syntheticADHDdata.RData")

# Glimpse of the data ---------------------------------------------------------
synth.data = data.table(synth.data)
synth.data %>% head() %>% kable(caption = "Table 1: A glimpse at the synthetic data")
Table 1: A glimpse at the synthetic data
mom.vuln sex.girl med.pre6 pre.diag ntr.grd5 reg.gpsp med.6to8 ptr.diag ptr.gpsp ntr.grd8 ptr.diag.cntr ptr.gpsp.cntr ntr.grd8.cntr ITE
0.65 0 1 0 19.64 14 0 0 42 16.77 0 44 19.29 2.52
1.31 1 1 0 25.85 1 1 0 21 26.72 0 19 21.64 5.08
0.08 0 1 1 19.59 56 0 0 7 21.38 1 46 23.27 1.89
0.65 0 1 1 15.12 23 1 1 0 17.62 1 0 14.48 3.14
0.99 1 0 1 9.30 27 0 0 4 15.02 1 22 15.16 0.14
2.02 0 0 1 19.83 128 0 1 47 19.62 1 75 20.18 0.56

No missingness case

A known SCM allows the computation of unit-level counterfactuals: we can see both worlds at the same time. In particular, we can compute the oracle sample average treatment effect (\(^S\)ATE): the within-sample average of the individual treatment effects (unobserved in real-world data). \[ ^S\text{ATE} = N^{-1}\sum_{i=1}^N \text{ITE}^{(i)} = N^{-1}\sum_{i=1}^N (Y^{(i)}_{A=1}-Y^{(i)}_{A=0}) \] The \(^S\)ATE is itself a consistent (but impossible) estimator of the population ATE (\(^P\)ATE).

Likewise, we can compute the oracle \(^S\)CATE for \(X\)-specific effects. If \(X\) is categorical, we can partition units into strata: \(I(x):=\{i\in 1\dots N : X^{(i)}=x\}\). \[ ^S\text{CATE}(x) = |I(x)|^{-1}\sum_{i\in I(x)} \text{ITE}^{(i)} = |I(x)|^{-1}\sum_{i\in I(x)} (Y^{(i)}_{A=1}-Y^{(i)}_{A=0}) \]

Let us compare the oracle effects against the results from two approaches:

  • \(t\)-tests comparing the means of raw test scores between treated and non-treated. This estimator would be biased due to confounding

  • The TMLE estimator using super-learners (Phillips et al. 2023), combining:

    • flexible libraries of base-learners to model \(A\) and \(Y\) given confounders \(H\). We employ linear models and random forests (ground-truth DGP is in this model-space)

    • meta-learners: to ensemble the base-learners together, weighted by their out-of-sample (CV) predictive measures to avoid over-fitting (logit for \(A\) and NNLS for \(Y\))

Code
###############################################################################
# Solution via t-tests (biased by confounding)
###############################################################################

# Function to generate C.I. data from t-tests ---------------------------------
# Param: pre-treatment binary variable to stratify
# Return: data ready to plot as confidence intervals

get.ci.data = function(varname){
  
  # Temporal data.frame
  dplot1 = synth.data
  dplot1$dummy = dplot1[,get(varname)]
    
  # Two-sample t-test: biased! -------------------
  
  # all, med vs no-med
  ttest.pop = t.test(dplot1[med.6to8==1, ntr.grd8], # treated
                     dplot1[med.6to8==0, ntr.grd8]) # control
  # subpop 0, med vs no-med
  ttest.sp0 = t.test(dplot1[med.6to8==1 & dummy==0, ntr.grd8], # treated (sp 0)
                     dplot1[med.6to8==0 & dummy==0, ntr.grd8]) # control (sp 0)
  # subpop 1, med vs no-med
  ttest.sp1 = t.test(dplot1[med.6to8==1 & dummy==1, ntr.grd8], # treated (sp 1)
                     dplot1[med.6to8==0 & dummy==1, ntr.grd8]) # control (sp 1)

  # One-sample t-test: oracle! -------------------
  
  # oracle ITE, all
  otest.pop = t.test(dplot1[, ITE])            
  # oracle ITE, subpop 0
  otest.sp0 = t.test(dplot1[dummy==0, ITE])    
  # oracle ITE, subpop 1
  otest.sp1 = t.test(dplot1[dummy==1, ITE])    
  
  # data frame put together ----------------------
  
  dplot1 = data.table( 
    # label for type of analysis
    type = rep(c('t-test','oracle'), each=3),
    # label for supopulation
    subpop = rep(c('all','subpop 0','subpop 1'), 2),        
    # lower bound
    low = c(ttest.pop$conf.int[1], ttest.sp0$conf.int[1], ttest.sp1$conf.int[1],  
            otest.pop$conf.int[1], otest.sp0$conf.int[1], otest.sp1$conf.int[1]),
    # upper bound
    upr = c(ttest.pop$conf.int[2], ttest.sp0$conf.int[2], ttest.sp1$conf.int[2],  
            otest.pop$conf.int[2], otest.sp0$conf.int[2], otest.sp1$conf.int[2]))
  
  # add midpoint and return dataframe
  dplot1 = dplot1[, midp := (low+upr)/2] 
  return(dplot1)                          
}

###############################################################################
# Solution via adjusted regression / TMLE
###############################################################################

# Learning algorithms employed to learn treatment/outcome mechanisms -----------

# Linear model
lrnr_lm = make_learner(Lrnr_glm_fast)
# Random forest model
lrnr_rf = make_learner(Lrnr_ranger)      

# Meta-learners: to stack together predictions from the learners ---------------

# Combine Y predictions with non-negative least squares
meta_Y = make_learner(Lrnr_nnls)   

# Combine A predictions with logit likelihood (augmented Lagrange optimizer)
meta_A = make_learner(Lrnr_solnp,                     
  loss_function = loss_loglik_binomial,              
  learner_function = metalearner_logistic_binomial)   

# Super-learners: learners + meta-learners together ----------------------------

# Outcome super-learning
super_Y = Lrnr_sl$new(learners = list(lrnr_lm, lrnr_rf), 
                      metalearner = meta_Y)
# Exposure super-learning
super_A = Lrnr_sl$new(learners = list(lrnr_lm, lrnr_rf), 
                      metalearner = meta_A)
# Super-learners put together
super_list = list(A = super_A,
                  Y = super_Y)              

# Confounder set labels
H = colnames(synth.data[,1:6])

# Function to generate C.I. data from TMLE--------------------------------------
# Param: pre-treatment binary variable to stratify
# Param: data
# Return: data from confidence intervals

aux_tml_method = function(namevar, data=synth.data){
  tmle_CATE = tmle3(
  tmle_spec = tmle_stratified(tmle_ATE(1,0)), # Causal contrast CATE: Y1-Y0
  node_list = list(                           # Variables involved in the query
    W = setdiff(H,namevar),                   # Label for non-stratum confounders
    V = namevar,                              # Label for stratifying variable
    A = "med.6to8",                           # Label for exposure
    Y = "ntr.grd8"),                          # Label for outcome
  data = data,                                # Data
  learner_list = super_list)                  # Super-learners specified
  return(tmle_CATE$summary)                   # Return summary
}

###############################################################################
# Generate confidence intervals to plot
###############################################################################

# t-test on prior treatment
ttest.med = cbind.data.frame(get.ci.data('med.pre6'),  # C.I. data
                             strata='prior treatment') # Strata variable

# t-test on sex strata
ttest.sex = cbind.data.frame(get.ci.data('sex.girl'), # C.I. data
                             strata='sex at birth')   # Strata variable

# TMLE on prior treatment
tml.cate.med = aux_tml_method('med.pre6')                                       # TMLE results
tlme.med = cbind.data.frame(type='TMLE',                                        # method label
                            subpop=c('all','subpop 0','subpop 1'),              # subpop label
                            tml.cate.med[order(tml.cate.med$param),c(9,10,8)],  # C.I. data
                            strata='prior treatment')                           # Strata variable
# TMLE on sex strata
tml.cate.sex = aux_tml_method('sex.girl')                                       # TMLE results
tlme.sex = cbind.data.frame(type='TMLE',                                        # method label
                            subpop=c('all','subpop 0','subpop 1'),              # subpop label
                            tml.cate.sex[order(tml.cate.sex$param),c(9,10,8)],  # C.I. data
                            strata='sex at birth')                              # Strata variable

# Unify column names
colnames(tlme.med) = colnames(ttest.med)
colnames(tlme.sex) = colnames(ttest.sex)

# Put all together
ci.data.1 = rbind.data.frame(ttest.med, ttest.sex, tlme.med, tlme.sex)

The following plot presents the point-estimate and 95% confidence interval for the ATE/CATE resulting from oracle \(^S\)ATE/\(^S\)CATE, \(t\)-tests, and TMLE.

Code
# Plot confidence intervals ---------------------------------------------------
ci.data.1 %>% 
  ggplot(aes(x = subpop, y = midp, colour = type)) +                    
  geom_errorbar(aes(ymax = upr, ymin = low), position = "dodge") +
  geom_point(position = position_dodge(0.9)) +
  geom_hline(yintercept=0, linetype="dashed", color = "red") +
  labs(y='ATE/CATE estimate',
       x='Subpopulation (1 = yes prior treatment, 1 = girls)') +
  facet_wrap(~strata) + theme_classic()

From the previous plot we can infer:

1. There is some amount of confounding

  • Confidence intervals from \(t\)-tests (green) do not always cover the sample-based oracle value (red). This is probably due to confounding bias, as \(t\)-test do not adjust for covariation with \(H\)

  • In general, confounding bias is negative: unadjusted effects are below the oracle result

  • Confounding bias seems to be particularly problematic when stratifying on prior prescription for ADHD medication, which might be a strong predictor of the exposure

  • \(t\)-test results might lead to conclude that the treatment is actually deleterious for those who had not started it before grade 6 (because of late diagnosis?). In fact, the treatment has positive, but extremely small, average effect in such subpopulation

  • \(t\)-test results might also lead to conclude that the treatment effect is insignificant for boys. In fact, the effect on them is small but positive

2. TMLE with covariate adjustment

  • Under some mild condition including back-door admissibility of \(H\), TMLE produces consistent estimators. This is, the amount of confounding bias becomes negligible in large finite samples. We can see that confidence intervals from TMLE (blue) cover the sample-based oracle value (red) in all cases

  • TMLE results lead to correct interpretation of treatment effect for boys and the subpopulation who did not start treatment before grade 6

3. In general, the effects of ADHD treatment on school performance are positive but small

  • The estimate ATE is 1.7, on an outcome with mean 21.0 and standard deviation 7.0

Outcome-missingness mechanisms

Let us simulate the \(Y\)-missingness mechanism as a binary random variable \(R_Y\), childless in the associated \(m\)-graph \(\mathcal{G}\). For a particular individual \(i\), \(R_Y^{(i)}=1\) indicates that its respective outcome realization \(Y^{(i)}\) is observed, while \(R_Y^{(i)}=0\) signifies \(Y^{(i)}\) is missing. This mechanism has a probit specification:

\[ R_Y = \mathbb{I}[\theta_0 + \theta_H^\top H + \theta_A A + \theta_M^\top M + U_R > 0],\quad U_R\sim N(0,1) \]

Fixing the parameters \(\theta\) to certain value (unknown in the analysis), a realization of \(U_R\) produces 751 (15%) missing outcomes:

Code
###############################################################################
# Building an endogenous selection/missingness mechanism
###############################################################################

# Parameter for the glm-probit specification, only main effects
set.seed(9)
noise.R = rnorm(n=5e3, 0, 1)
theta.i = 3.77 
theta.c = c(-0.01, -1.97, -2.03, 0.18, -0.03, 0.01, -0.05, 0.64, -0.01)
lin.pred = as.matrix(synth.data[,1:9]) %*% theta.c + theta.i + noise.R

# Realization of missingness mechanism
synth.data$RY = ifelse(lin.pred > 0, 1, 0)

# Count how many selected / observed -------------------------------------------
table(synth.data$RY) %>% 
  kable(col.names = c("R(Y)","n"),
    caption = "Table 2: Missing/selected unit-outcomes",
    table.attr = "quarto-disable-processing=true") %>% 
  kable_styling(full_width = FALSE)
Table 2: Missing/selected unit-outcomes
R(Y) n
0 751
1 4249

Under the TMLE framework, let us first consider four options to recover the ATE/CATE:

  1. TMLE-cc: TMLE with complete cases. We discard entire units with \(R_Y=0\)

    • This approach would produce a consistent estimator under MCAR. However, as \(R_Y\) is MAR in our case, results would be biased
Code
###############################################################################
# Complete-case analysis for TMLE
###############################################################################

# TMLE on prior treatment
tml.cate.med.cc = aux_tml_method('med.pre6',                                            
                                 data=synth.data[RY==1,])                               # TMLE results
tlme.med.cc = cbind.data.frame(type='TMLE-cc',                                          # method label
                               subpop=c('all','subpop 0','subpop 1'),                   # subpop label
                               tml.cate.med.cc[order(tml.cate.med.cc$param),c(9,10,8)], # C.I. data
                               strata='prior treatment')                                # Strata variable
# TMLE on sex strata
tml.cate.sex.cc = aux_tml_method('sex.girl',
                                 data=synth.data[RY==1,])                               # TMLE results
tlme.sex.cc = cbind.data.frame(type='TMLE-cc',                                          # method label
                               subpop=c('all','subpop 0','subpop 1'),                   # subpop label
                               tml.cate.sex.cc[order(tml.cate.sex.cc$param),c(9,10,8)], # C.I. data
                               strata='sex at birth')                                   # Strata variable
  1. TMLE-im: TMLE with single median-inputation. We replace \(Y^*\) with the median of the observed \(Y\)

    • This approach would also produce biased results, as all missing slots are inputed with the same numeric value independen of \(H,A\). One preferable approach would be multiple model-based inputations (consistent under MAR). Yet, single median-inputation is the built-in preprocessing method in the tmle package

    \[ Y^* \leftarrow \text{med}(Y\mid R_Y=1) \]

Code
###############################################################################
# Median-based single imputation for missing outcome
###############################################################################

# Replace Y-values for NA whenever R(Y)=0
observed.data = copy(synth.data)
observed.data$ntr.grd8 = ifelse(observed.data$RY==0, NA, observed.data$ntr.grd8)

# Apply process_missing function on TMLE package
processed = process_missing(observed.data,                     # data
                            node_list = list(W = H,            # confounders
                                             A = "med.6to8",   # exposure
                                             Y = "ntr.grd8"),  # outcome
                            complete_nodes = c('W','A'))       # complete variables

# TMLE on prior treatment
tml.cate.med.im = aux_tml_method('med.pre6',
                                 data=processed$data)                                   # TMLE results
tlme.med.im = cbind.data.frame(type='TMLE-im',                                          # method label
                               subpop=c('all','subpop 0','subpop 1'),                   # subpop label
                               tml.cate.med.im[order(tml.cate.med.im$param),c(9,10,8)], # C.I. data
                               strata='prior treatment')                                # Strata variable
# TMLE on sex strata
tml.cate.sex.im = aux_tml_method('sex.girl',
                                 data=processed$data)                                   # TMLE results
tlme.sex.im = cbind.data.frame(type='TMLE-im',                                          # method label
                               subpop=c('all','subpop 0','subpop 1'),                   # subpop label
                               tml.cate.sex.im[order(tml.cate.sex.im$param),c(9,10,8)], # C.I. data
                               strata='sex at birth')                                   # Strata variable

# Unify column names
colnames(tlme.med.cc) = colnames(ttest.med)
colnames(tlme.sex.cc) = colnames(ttest.sex)
colnames(tlme.med.im) = colnames(ttest.med)
colnames(tlme.sex.im) = colnames(ttest.sex)

# Put all together
ci.data.2 = rbind.data.frame(ttest.med[ttest.med$type=='oracle',],
                             ttest.sex[ttest.sex$type=='oracle',],
                             tlme.med.cc, tlme.sex.cc,
                             tlme.med.im, tlme.sex.im)
  1. TMLE-mm: TMLE with explicit model for \(M\) (only considering ptr.diag)

    • Let us assume we are sure that post-treatment diagnoses (ptr.diag) have a direct effect on attrition, but post-treatment visits (ptr.gpsp) do not. Since the variable ptr.diag is binary, we can pose an explicit model using a superlearner based on logistic regression. If such assumption is sound, this approach should remove most of the selection bias.

    \[ \begin{aligned} &\text{Outcome model on selected } & Q_1(H,A,M) = \mathbb{E}[Y\mid H,A,M,R_Y=1] & \\ &\text{M-model (\texttt{ptr.diag})} & m(H,A) = \mathbb{P}[M=1\mid H,A] &\\ &\text{M-averaged expected outcome} & Q_2(H,a) = m(H,a)Q_1(H,a,1) +(1-m(H,a))Q_1(H,a,0) &\\ &\text{Initial ATE} & \psi_0 = \mathbb{E}_H\Delta_a Q_2(H,a)&\\ &\text{Propensity score (nuisance)} & e(H) = \mathbb{P}[A=1\mid H] &\\ &\text{Selection score (nuisance)} & s(H,A,M) = \mathbb{P}[R_Y=1\mid H,A,M] &\\ \end{aligned} \]

Code
###############################################################################
# Recovering causal effect via regression and explicit M-model
###############################################################################

# The library of base learners: GAM and random forests
stack = Stack$new(Lrnr_gam$new(), Lrnr_ranger$new())

# Super-learner for continuous outcome: NNLS
superl = Lrnr_sl$new(learners = stack, metalearner = Lrnr_nnls$new())

# Super-learner for binary variables: default logit
superb = Lrnr_sl$new(learners = stack)

# Function to generate C.I. data from TMLE--------------------------------------
# Param: pre-treatment binary variable to stratify
# Param: data
# Return: data from confidence intervals

aux_tml_method_mmod = function(namevar, input.data){
  
  # Remove namevar from adjustment set
  H = setdiff(H, namevar)
  
  # Learning the Q-function ---------------------------------------------------
  
  # Training task for Q, fit superlearner 
  # on selected-outcome data, including M1 as predictors
  train.Q = make_sl3_Task(
    data = input.data[RY==1,],
    outcome = 'ntr.grd8',
    covariates = c(H,'med.6to8','ptr.diag')
  )
  Q_fit = superl$train(task = train.Q)
  
  # Training task for M, all data
  train.M = make_sl3_Task(
    data = input.data,
    outcome = 'ptr.diag',
    covariates = c(H,'med.6to8')
  )
  M_fit = superb$train(task = train.M)
  
  # Case A=1 ------------------------------------------------------------------
  
  # Data to make counterfactual predictions 
  temp.dt = copy(input.data)
  
  # Prediction task for p(M=1 |A=1)
  temp.dt$med.6to8 = 1
  pred.M = make_sl3_Task(
    data = temp.dt,
    outcome = 'ptr.diag',
    covariates = c(H,'med.6to8')
  )
  pM = M_fit$predict(task = pred.M)
  
  # Prediction task for E(Y |A=1, M=1)
  temp.dt$ptr.diag = 1
  pred.Q1 = make_sl3_Task(
    data = temp.dt,
    outcome = 'ntr.grd8',
    covariates = c(H,'med.6to8','ptr.diag')
  )
  
  # Prediction task for E(Y |A=1, M=0)
  temp.dt$ptr.diag = 0
  pred.Q0 = make_sl3_Task(
    data = temp.dt,
    outcome = 'ntr.grd8',
    covariates = c(H,'med.6to8','ptr.diag')
  )
  
  # Save INTEGRAL dP(M |A=1) E(Y |A=1, M)
  input.data$Q_1m = pM * Q_fit$predict(task = pred.Q1) +
    (1-pM) * Q_fit$predict(task = pred.Q0)
  
  # Case A=0 ------------------------------------------------------------------
  
  # Prediction task for p(M=1 |A=0)
  temp.dt$med.6to8 = 0
  pred.M = make_sl3_Task(
    data = temp.dt,
    outcome = 'ptr.diag',
    covariates = c(H,'med.6to8')
  )
  pM = M_fit$predict(task = pred.M)
  
  # Prediction task for E(Y |A=0, M=1)
  temp.dt$ptr.diag = 1
  pred.Q1 = make_sl3_Task(
    data = temp.dt,
    outcome = 'ntr.grd8',
    covariates = c(H,'med.6to8','ptr.diag')
  )
  
  # Prediction task for E(Y |A=0, M=0)
  temp.dt$ptr.diag = 0
  pred.Q0 = make_sl3_Task(
    data = temp.dt,
    outcome = 'ntr.grd8',
    covariates = c(H,'med.6to8','ptr.diag')
  )
  
  # Save INTEGRAL dP(M |A=0) E(Y |A=0, M)
  input.data$Q_0m = pM * Q_fit$predict(task = pred.Q1) +
    (1-pM) * Q_fit$predict(task = pred.Q0)
  
  # Save Qm as the M-average of Q1 and Q0
  input.data[,Qm := med.6to8*Q_1m + (1-med.6to8)*Q_0m]
  
  # Nuisance parameters: A  ----------------------------------------------------
  
  # Training task for A, fit superlearner 
  train.A = make_sl3_Task(
    data = input.data,
    outcome = 'med.6to8',
    covariates = H
  )
  A_fit = superb$train(task = train.A)
  input.data$p.score = A_fit$predict(task = train.A)
  
  # Nuisance parameters: R(Y)  -------------------------------------------------
  
  # Training task for R(Y), fit superlearner 
  train.R = make_sl3_Task(
    data = input.data,
    outcome = 'RY',
    covariates = c(H,'med.6to8','ptr.diag')
  )
  R_fit = superb$train(task = train.R)
  input.data$p.select = R_fit$predict(task = train.R)
  
  # Count parameters  ---------------------------------------------------------
  N = nrow(input.data)
  n = nrow(input.data[RY==1,])
  r = n/N
  
  # Nuisance parameters: R(Y)  ---------------------------------------------------
  input.data[, H:= r/(p.select) * (med.6to8/p.score - (1-med.6to8)/(1-p.score)) ]
  
  # Updating model  ---------------------------------------------------
  updating = lm(ntr.grd8 ~ -1 + offset(Qm) + H, data=input.data)
  eps = as.numeric(coef(updating))
  
  # Update Q-values
  input.data[, Q_1f:= Q_1m + eps * r/(p.select*p.score)]
  input.data[, Q_0f:= Q_0m - eps * r/(p.select*(1-p.score))]
  input.data[, D:= H*(ntr.grd8 - Qm) + Q_1m - Q_0m]
  
  # Compute ATE, standard errors and CI
  ate = mean(input.data$Q_1f - input.data$Q_0f)
  ate.se = sd(input.data$D, na.rm=T)/sqrt(n)
  ate.lo = ate - qnorm(0.975)*ate.se
  ate.up = ate + qnorm(0.975)*ate.se
  
  # Return CI
  return(c(ate.lo, ate.up, ate))
}

# Generate CI data
ate = aux_tml_method_mmod('', observed.data)
cate.nop =  aux_tml_method_mmod('med.pre6', observed.data[med.pre6==0,])
cate.yes =  aux_tml_method_mmod('med.pre6', observed.data[med.pre6==1,])
cate.boy =  aux_tml_method_mmod('sex.girl', observed.data[sex.girl==0,])
cate.grl =  aux_tml_method_mmod('sex.girl', observed.data[sex.girl==1,])

# Put all together
ci.data.3 = ci.data[13:18,]
ci.data.3$type = 'TMLE-mm'
ci.data.3$midp = c(ate[3], cate.nop[3], cate.yes[3], ate[3], cate.boy[3], cate.grl[3])
ci.data.3$low = c(ate[1], cate.nop[1], cate.yes[1], ate[1], cate.boy[1], cate.grl[1])
ci.data.3$upr = c(ate[2], cate.nop[2], cate.yes[2], ate[2], cate.boy[2], cate.grl[2])
  1. TMLE-nr: TMLE with nested regressions.

    • Let us assume we are sure that post-treatment diagnoses (ptr.diag) have a direct effect on attrition, but post-treatment visits (ptr.gpsp) do not. Since the variable ptr.diag is binary, we can pose an explicit model using a superlearner based on logistic regression. If such assumption is sound, this approach should remove most of the selection bias.

    \[ \begin{aligned} &\text{Outcome model on selected } & Q_1(H,A,M) = \mathbb{E}[Y\mid H,A,M,R_Y=1] & \\ &\text{High-level regression} & Q_2(H,a) = \mathbb{E}[\hat{Q}_1\mid H,A=a] &\\ &\text{Initial ATE} & \psi_0 = \mathbb{E}_H\Delta_a Q_2(H,a)&\\ &\text{Propensity score (nuisance)} & e(H) = \mathbb{P}[A=1\mid H] &\\ &\text{Selection score (nuisance)} & s(H,A,M) = \mathbb{P}[R_Y=1\mid H,A,M] &\\ \end{aligned} \]

Code
###############################################################################
# Recovering causal effect via nested regressions
###############################################################################

# Function to generate C.I. data from TMLE--------------------------------------
# Param: pre-treatment binary variable to stratify
# Param: data
# Return: data from confidence intervals

aux_tml_method_nest = function(namevar, input.data){
  
  # Remove namevar from adjustment set
  H = setdiff(H, namevar)
  
  # Learning the Q-function ---------------------------------------------------
  
  # Training task for Q, fit superlearner 
  # on selected-outcome data, including M1 as predictors
  train.Q = make_sl3_Task(
    data = input.data[RY==1,],
    outcome = 'ntr.grd8',
    covariates = c(H,'med.6to8','ptr.diag','ptr.gpsp')
  )
  
  Q_fit = superl$train(task = train.Q)
  
  # Data to make counterfactual predictions 
  temp.dt = copy(input.data)
  
  # Prediction task for E(Y |A=1, M)
  temp.dt$med.6to8 = 1
  pred.Q1 = make_sl3_Task(
    data = temp.dt,
    outcome = 'ntr.grd8',
    covariates = c(H,'med.6to8','ptr.diag','ptr.gpsp')
  )
  
  input.data$Q1.pred = Q_fit$predict(task = pred.Q1)
  
  # Prediction task for E(Y |A=0, M)
  temp.dt$med.6to8 = 0
  pred.Q0 = make_sl3_Task(
    data = temp.dt,
    outcome = 'ntr.grd8',
    covariates = c(H,'med.6to8','ptr.diag','ptr.gpsp')
  )
  
  input.data$Q0.pred = Q_fit$predict(task = pred.Q0)
  
  # Training task for Q21
  train.Q21 = make_sl3_Task(
    data = input.data,
    outcome = 'Q1.pred',
    covariates = H # No treatment in here, its fixed
  )
  Q21_fit = superl$train(task = train.Q21)
  input.data$Q1.fin = Q21_fit$predict(task = train.Q21)
  
  # Training task for Q20
  train.Q20 = make_sl3_Task(
    data = input.data,
    outcome = 'Q0.pred',
    covariates = H # No treatment in here, its fixed
  )
  Q20_fit = superl$train(task = train.Q20)
  input.data$Q0.fin = Q20_fit$predict(task = train.Q20)
  
  # m-Average
  input.data[, Qa:= med.6to8*Q1.fin + (1-med.6to8)*Q0.fin]
  
  # Nuisance parameters: A  ----------------------------------------------------
  
  # Training task for A, fit superlearner 
  train.A = make_sl3_Task(
    data = input.data,
    outcome = 'med.6to8',
    covariates = H
  )
  A_fit = superb$train(task = train.A)
  input.data$p.score = A_fit$predict(task = train.A)
  
  # Nuisance parameters: R(Y)  -------------------------------------------------
  
  # Training task for R(Y), fit superlearner 
  train.R = make_sl3_Task(
    data = input.data,
    outcome = 'RY',
    covariates = c(H,'med.6to8','ptr.diag')
  )
  R_fit = superb$train(task = train.R)
  input.data$p.select = R_fit$predict(task = train.R)
  
  # Count parameters  ---------------------------------------------------------
  N = nrow(input.data)
  n = nrow(input.data[RY==1,])
  r = n/N
  
  # Nuisance parameters: R(Y)  ---------------------------------------------------
  input.data[, H:= r/(p.select) * (med.6to8/p.score - (1-med.6to8)/(1-p.score)) ]
  
  # Updating model  ---------------------------------------------------
  updating = lm(ntr.grd8 ~ -1 + offset(Qa) + H, data=input.data)
  eps = as.numeric(coef(updating))
  
  # Update Q-values
  input.data[, Q_1f:= Q1.fin + eps * r/(p.select*p.score)]
  input.data[, Q_0f:= Q0.fin - eps * r/(p.select*(1-p.score))]
  input.data[, D:= H*(ntr.grd8 - Qa) + Q1.fin - Q0.fin]
  
  # Compute ATE, standard errors and CI
  ate = mean(input.data$Q_1f - input.data$Q_0f)
  ate.se = sd(input.data$D, na.rm=T)/sqrt(n)
  ate.lo = ate - qnorm(0.975)*ate.se
  ate.up = ate + qnorm(0.975)*ate.se
  
  # Return CI
  return(c(ate.lo, ate.up, ate))
}

# Generate CI data
ate = aux_tml_method_nest('', observed.data)
cate.nop =  aux_tml_method_nest('med.pre6', observed.data[med.pre6==0,])
cate.yes =  aux_tml_method_nest('med.pre6', observed.data[med.pre6==1,])
cate.boy =  aux_tml_method_nest('sex.girl', observed.data[sex.girl==0,])
cate.grl =  aux_tml_method_nest('sex.girl', observed.data[sex.girl==1,])

# Put all together
ci.data.4 = ci.data[13:18,]
ci.data.4$type = 'TMLE-nr'
ci.data.4$midp = c(ate[3], cate.nop[3], cate.yes[3], ate[3], cate.boy[3], cate.grl[3])
ci.data.4$low = c(ate[1], cate.nop[1], cate.yes[1], ate[1], cate.boy[1], cate.grl[1])
ci.data.4$upr = c(ate[2], cate.nop[2], cate.yes[2], ate[2], cate.boy[2], cate.grl[2])

# All methods together
ci.data = rbind(ci.data.2, ci.data.3, ci.data.4)

The following plot presents the point-estimate and 95% confidence interval from the approaches presented above:

Code
# Plot confidence intervals ---------------------------------------------------
rbind(ci.data) %>% 
  ggplot(aes(x = subpop, y = midp, colour = type)) +                    
  geom_errorbar(aes(ymax = upr, ymin = low), position = "dodge") +
  geom_point(position = position_dodge(0.9)) +
  geom_hline(yintercept=0, linetype="dashed", color = "red") +
  labs(y='ATE/CATE estimate',
       x='Subpopulation (1 = yes prior treatment, 1 = girls)') +
  facet_wrap(~strata) + theme_classic()

Regression adjustment works

  • Confidence intervals from complete-case analysis (brown) and single median-inputations (green) fail to cover the oracle SATE (red) in most cases, particularly for the population ATE.
  • A regression-based solution via an explicit model for ptr.diag (blue) does a better job, and only struggle with small subpopulations in the data (no prior treatment or girls, about 30%).
  • A nested-regressions approach (purple) performs well for the ATE, but struggle a little with small subpopulations.

All methods above via TMLE employ only one fluctuation parameter (on \(Q_2\)). Better performance can be obtained by working with a fluctuation parameter on \(Q_1\) or \(p(M\mid H,A)\) as well, and by incorporating smart cross-fitting schemes.

References

Pearl, Judea. 2009. Causality: Models, Reasoning, and Inference. 2nd ed. Cambridge, UK: Cambridge University Press. https://doi.org/10.1017/CBO9780511803161.
Phillips, Rachael V, Mark J van der Laan, Hana Lee, and Susan Gruber. 2023. Practical considerations for specifying a super learner.” International Journal of Epidemiology 52 (4): 1276–85. https://doi.org/10.1093/ije/dyad023.
van der Laan, M. J., and S. Rose. 2011. Targeted Learning: Causal Inference for Observational and Experimental Data. Springer Series in Statistics. Springer New York. https://books.google.no/books?id=RGnSX5aCAgQC.