This paper was converted on www.awesomepapers.org from LaTeX by an anonymous user.
Want to know more? Visit the Converter page.

Decision-Focused Learning in Restless Multi-Armed Bandits with Application to Maternal and Child Care Domain

Kai Wang    Shresth Verma    Aditya Mate    Sanket Shah    Aparna Taneja    Neha Madhiwala    Aparna Hegde    Milind Tambe
Abstract

This paper studies restless multi-armed bandit (RMAB) problems with unknown arm transition dynamics but with known correlated arm features. The goal is to learn a model to predict transition dynamics given features, where the Whittle index policy solves the RMAB problems using predicted transitions. However, prior works often learn the model by maximizing the predictive accuracy instead of final RMAB solution quality, causing a mismatch between training and evaluation objectives. To address this shortcoming we propose a novel approach for decision-focused learning in RMAB that directly trains the predictive model to maximize the Whittle index solution quality. We present three key contributions: (i) we establish the differentiability of the Whittle index policy to support decision-focused learning; (ii) we significantly improve the scalability of previous decision-focused learning approaches in sequential problems; (iii) we apply our algorithm to the service call scheduling problem on a real-world maternal and child health domain. Our algorithm is the first for decision-focused learning in RMAB that scales to large-scale real-world problems.

Machine Learning, ICML

1 Introduction

Restless multi-armed bandits (RMABs) (weber1990index; tekin2012online) are composed of a set of heterogeneous arms and a planner who can pull multiple arms under budget constraint at each time step to collect rewards. Different from the classic stochastic multi-armed bandits (gittins2011multi; bubeck2012regret), the state of each arm in an RMAB can change even when the arm is not pulled, where each arm follows a Markovian process to transition between different states with transition probabilities dependent on arms and the pulling decision. Rewards are associated with different arm states, where the planner’s goal is to plan a sequential pulling policy to maximize the total reward received from all arms. RMABs are commonly used to model sequential scheduling problems where limited resources must be strategically assigned to different tasks sequentially to maximize performance. Examples include machine maintenance (glazebrook2006some), cognitive radio sensing problem (bagheri2015restless), and healthcare (mate2022field).

In this paper, we study offline RMAB problems with unknown transition dynamics but with given arm features. The goal is to learn a mapping from arm features to transition dynamics, which can be used to infer the dynamics of unseen RMAB problems to plan accordingly. Prior works (mate2022field; sun2018cell) often learn the transition dynamics from the historical pulling data by maximizing the predictive accuracy. However, RMAB performance is evaluated by its solution quality derived from the predicted transition dynamics, which leads to a mismatch in the training objective and the evaluation objective. Previously, decision-focused learning (wilder2019melding) has been proposed to directly optimize the solution quality rather than predictive accuracy, by integrating the one-shot optimization problem (donti2017task; perrault2020end) or sequential problems (wang2021learning; futoma2020popcorn) as a differentiable layer in the training pipeline. Unfortunately, while decision-focused learning can successfully optimize the evaluation objective, it is computationally extremely expensive due to the presence of the optimization problems in the training process. Specifically, for RMAB problems, the computation cost of decision-focused learning arises from the complexity of the sequential problems formulated as Markov decision processes (MDPs), which limits the applicability to RMAB problems due to the PSPACE hardness of finding the optimal solution (papadimitriou1994complexity).

Our main contribution is a novel and scalable approach for decision-focused learning in RMAB problems using Whittle index policy, a commonly used approximate solution in RMABs. Our three key contributions are (i) we establish the differentiability of Whittle index policy to support decision-focused learning to directly optimize the RMAB solution quality; (ii) we show that our approach of differentiating through Whittle index policy improves the scalability of decision-focused learning in RMAB; (iii) we apply our algorithm to an anonymized maternal and child health RMAB dataset previously collected by  armman to evaluate the performance of our algorithm in simulation.

We establish the differentiability of Whittle index by showing that Whittle index can be expressed as a solution to a full-rank linear system reduced from Bellman equations with transition dynamics as entries, which allows us to compute the derivative of Whittle index with respect to transition dynamics. On the other hand, to execute Whittle index policy, the standard selection process of choosing arms with top-k Whittle indices to pull is non-differentiable. We relax this non-differentiable process by using a differentiable soft top-k selection to establish differentiability. Our differentiable Whittle index policy enables decision-focused learning in RMAB problems to backpropagate from final policy performance to the predictive model. We significantly improve the scalability of decision-focused learning, where the computation cost of our algorithm O(NMω+1)O(NM^{\omega+1}) scales linearly in the number of arms NN and polynomially in the number of states MM with ω2.373\omega\approx 2.373, while previous work scales exponentially O(MωN)O(M^{\omega N}). This significant reduction in computation cost is crucial for extending decision-focused learning to RMAB problems with large number of arms.

In our experiments, we apply decision-focused learning to RMAB problems to optimize importance sampling-based evaluation on synthetic datasets as well as an anonymized RMAB dataset about a maternal and child health program previously collected by (armman) – these datasets are the basis of comparing different methods in simulation. We compare decision-focused learning with the two-stage method that trains to minimize the predictive loss. The two-stage method achieves the best predictive loss but significantly degraded solution quality. In contrast, decision-focused learning reaches a slightly worse predictive loss but with a much better importance sampling-based solution quality evaluation and the improvement generalizes to the simulation-based evaluation that is built from the data. Lastly, the scalability improvement is the crux of applying decision-focused learning to real-world RMAB problems: our algorithm can run decision-focused learning on the maternal and child health dataset with hundreds of arms, whereas state of the art is a 100-fold slower even with 20 arms and grows exponentially worse.

Related Work

Restless multi-armed bandits with given transition dynamics

This line of research primarily focuses on solving RMAB problems to get a sequential policy. The complexity of solving RMAB problems optimally is known to be PSPACE hard (papadimitriou1994complexity). One approximate solution is proposed by whittle1988restless, where they use Lagrangian relaxation to decompose arms and compute the associated Whittle indices to define a policy. Specifically, the indexability condition (akbarzadeh2019restless; wang2019opportunistic) guarantees this Whittle index policy to be asymptotically optimal (weber1990index). In practice, Whittle index policy usually provides a near-optimal solution to RMAB problems.

Restless multi-armed bandits with missing transition dynamics

When the transition dynamics are unknown in RMAB problems but an interactive environment is available, prior works (tekin2012online; liu2012learning; oksanen2015order; dai2011non) consider this as an online learning problem that aims to maximize the expected reward. However, these approaches become infeasible when interacting with the environment is expensive, e.g., healthcare problems (mate2022field). In this work, we consider the offline RMAB problem, and each arm comes with an arm feature that is correlated to the transition dynamics and can be learned from the past data.

Decision-focused learning

The predict-then-optimize framework (elmachtoub2021smart) is composed of a predictive problem that makes predictions on the parameters of the later optimization problem, and an optimization problem that uses the predicted parameters to come up with a solution, where the overall objective is the solution quality of the proposed solution. Standard two-stage learning method solves the predictive and optimization problems separately, leading to a mismatch of the predictive loss and the evaluation metric (huang2019addressing; lambert2020objective; johnson2019survey). In contrast, decision-focused learning (wilder2019melding; mandi2020smart; elmachtoub2020decision) learns the predictive model to directly optimize the solution quality by integrating the optimization problem as a differentiable layer (amos2017optnet; agrawal2019differentiable) in the training pipeline. Our offline RMAB problem is a predict-then-optimize problem, where we first (offline) learn a mapping from arm features to transition dynamics from the historical data (mate2022field; sun2018cell), and the RMAB problem is solved using the predicted transition dynamics accordingly. Prior work (mate2022field) is limited to using two-stage learning to solve the offline RMAB problems. While decision-focused learning in sequential problems were primarily studied in the context of MDPs (wang2021learning; futoma2020popcorn) they come with an expensive computation cost that immediately becomes infeasible in large RMAB problems.

2 Model: Restless Multi-armed Bandit

An instance of the restless multi-armed bandit (RMAB) problem is composed of a set of NN arms, each is modeled as an independent Markov decision process (MDP). The ii-th arm in a RMAB problem is defined by a tuple (𝒮,𝒜,Ri,Pi)(\mathcal{S},\mathcal{A},R_{i},P_{i}). 𝒮\mathcal{S} and 𝒜\mathcal{A} are the identical state and action spaces across all arms. Ri,Pi:𝒮×𝒜×𝒮R_{i},P_{i}:\mathcal{S}\times\mathcal{A}\times\mathcal{S}\rightarrow\mathbb{R} are the reward and transition functions associated to arm ii. We consider finite state space with |𝒮|=M|\mathcal{S}|=M fully observable states and action set 𝒜={0,1}\mathcal{A}=\{0,1\} corresponding to not pulling or pulling the arm, respectively. For each arm ii, the reward is denoted by Ri(si,ai,si)=R(si)R_{i}(s_{i},a_{i},s^{\prime}_{i})=R(s_{i}), i.e., the reward R(si)R(s_{i}) only depends on the current state sis_{i}, where R:𝒮R:\mathcal{S}\rightarrow\mathbb{R} is a vector of size MM. Given the state sis_{i} and action aia_{i}, Pi(si,ai)=[Pi(si,ai,si)]si𝒮P_{i}(s_{i},a_{i})=[P_{i}(s_{i},a_{i},s^{\prime}_{i})]_{s^{\prime}_{i}\in\mathcal{S}} defines the probability distribution of transitioning to all possible next states si𝒮s^{\prime}_{i}\in\mathcal{S}.

In a RMAB problem, at each time step t[T]t\in[T], the learner observes 𝒔t=[st,i]i[N]𝒮N\boldsymbol{s}_{t}=[s_{t,i}]_{i\in[N]}\in\mathcal{S}^{N}, the states of all arms. The learner then chooses action 𝒂t=[at,i]i[N]𝒜N{\boldsymbol{a}}_{t}=[a_{t,i}]_{i\in[N]}\in\mathcal{A}^{N} denoting the pulling actions of all arms, which has to satisfy a budget constraint i[N]at,iK\sum\nolimits_{i\in[N]}a_{t,i}\leq K, i.e., the learner can pull at most KK arms at each time step. Once the action is chosen, arms receive action 𝒂t\boldsymbol{a}_{t} and transitions under PP with rewards 𝒓t=[rt,i]i[N]\boldsymbol{r}_{t}=[r_{t,i}]_{i\in[N]} accordingly. We denote a full trajectory by τ=(𝒔1,𝒂1,𝒓1,,𝒔T,𝒂T,𝒓T)\tau=(\boldsymbol{s}_{1},\boldsymbol{a}_{1},\boldsymbol{r}_{1},\cdots,\boldsymbol{s}_{T},\boldsymbol{a}_{T},\boldsymbol{r}_{T}). The total reward is defined by the summation of the discounted reward across TT time steps and NN arms, i.e., t=1Tγt1i[N]rt,i\sum\nolimits_{t=1}^{T}\gamma^{t-1}\sum\nolimits_{i\in[N]}r_{t,i}, where 0<γ10<\gamma\leq 1 is the discount factor.

A policy is denoted by π\pi, where π(𝒂𝒔)\pi(\boldsymbol{a}\mid\boldsymbol{s}) is the probability of choosing action 𝒂\boldsymbol{a} given state 𝒔\boldsymbol{s}. Additionally, we define π(ai=1𝒔)\pi(a_{i}=1\mid\boldsymbol{s}) to be the marginal probability of pulling arm ii given state 𝒔\boldsymbol{s}, where π(𝒔)=[π(ai=1𝒔)]i[N]\pi(\boldsymbol{s})=[\pi(a_{i}=1\mid\boldsymbol{s})]_{i\in[N]} is a vector of arm pulling probabilities. Specifically, we use π\pi^{*} to denote the optimal policy that optimizes the cumulative reward, while πsolver\pi^{\text{solver}} to denote a near-optimal policy solver.

3 Problem Statement

This paper studies the RMAB problem where we do not know the transition probabilities P={Pi}i[N]P=\{P_{i}\}_{i\in[N]} in advance. Instead, we are given a set of features 𝒙={xi𝒳}i[N]\boldsymbol{x}=\{x_{i}\in\mathcal{X}\}_{i\in[N]}, each corresponding to one arm. The goal is to learn a mapping fw:𝒳𝒫f_{w}:\mathcal{X}\rightarrow\mathcal{P}, parameterized by weights ww, to make predictions on the transition probabilities P=fw(𝒙){fw(xi)}i[N]P=f_{w}(\boldsymbol{x})\coloneqq\{f_{w}(x_{i})\}_{i\in[N]}. The predicted transition probabilities are later used to solve the RMAB problem to derive a policy π=πsolver(fw(𝒙))\pi=\pi^{\text{solver}}(f_{w}(\boldsymbol{x})). The performance of the model ff is evaluated by the performance of the proposed policy π\pi.

3.1 Training and Testing Datasets

To learn the mapping fwf_{w}, we are given a set of RMAB instances as training examples 𝒟train={(𝒙,𝒯)}\mathcal{D}_{\text{train}}=\{(\boldsymbol{x},\mathcal{T})\}, where each instance is composed of a RMAB problem with feature 𝒙\boldsymbol{x} that is correlated to the unknown transition probabilities PP, and a set of realized trajectories 𝒯={τ(j)}jJ\mathcal{T}=\{\tau^{(j)}\}_{j\in J} generated from a given behavior policy πbeh\pi_{\text{beh}} that determined how to pull arms in the past. The testing set 𝒟test\mathcal{D}_{\text{test}} is defined similarly but hidden at training time.

3.2 Evaluation Metrics

Predictive loss

To measure the correctness of transition probabilities P={Pi}i[N]P=\{P_{i}\}_{i\in[N]}, we define the predictive loss as the average negative log-likelihood of seeing the given trajectories 𝒯\mathcal{T}, i.e., (P,𝒯)logPr(𝒯P)=𝔼τ𝒯t[T]logP(𝒔t,𝒂t,𝒔t+1)\mathcal{L}(P,\mathcal{T})\coloneqq-\log\Pr(\mathcal{T}\mid P)=-\mathop{\mathbb{E}}\nolimits_{\tau\sim\mathcal{T}}\sum\nolimits_{t\in[T]}\log P(\boldsymbol{s}_{t},\boldsymbol{a}_{t},\boldsymbol{s}_{t+1}). Therefore, we can define the predictive loss of a model fwf_{w} on dataset 𝒟\mathcal{D} by:

𝔼(𝒙,𝒯)𝒟(fw(𝒙),𝒯)\displaystyle\mathop{\mathbb{E}}\nolimits_{(\boldsymbol{x},\mathcal{T})\sim\mathcal{D}}\mathcal{L}(f_{w}(\boldsymbol{x}),\mathcal{T}) (1)

Policy evaluation

On the other hand, given transition probabilities PP, we can solve the RMAB problem to derive a policy πsolver(P)\pi^{\text{solver}}(P). We can use the historical trajectories 𝒯\mathcal{T} to evaluate how good the policy performs, denoted by Eval(πsolver(P),𝒯)\text{Eval}(\pi^{\text{solver}}(P),\mathcal{T}). Given dataset 𝒟\mathcal{D}, we can evaluate the predictive model fwf_{w} on dataset 𝒟\mathcal{D} by:

𝔼(𝒙,𝒯)𝒟Eval(πsolver(fw(𝒙)),𝒯)\displaystyle\mathop{\mathbb{E}}\nolimits_{(\boldsymbol{x},\mathcal{T})\sim\mathcal{D}}\text{Eval}(\pi^{\text{solver}}(f_{w}(\boldsymbol{x})),\mathcal{T}) (2)

Two common types of policy evaluation are importance sampling-based off-policy policy evaluation and simulation-based evaluation, which will be discussed in Section 5.

3.3 Learning Methods

Refer to caption
Figure 1: This flowchart visualizes different methods of learning the predictive model. Two-stage learning directly compares the predicted transition probabilities with the given data to define a predictive loss to run gradient descent. Decision-focused learning instead goes through a policy solver using Whittle index policy to estimate the final evaluation and run gradient ascent.

Two-stage learning

To learn the predictive model fwf_{w}, we can minimize Equation 1 by computing gradient d(fw(𝒙),𝒯)dw\frac{d\mathcal{L}(f_{w}(\boldsymbol{x}),\mathcal{T})}{dw} to run gradient descent. However, this training objective (Equation 1) differs from the evaluation objective (Equation 2), which often leads to suboptimal performance.

Decision-focused learning

In contrast, we can directly run gradient ascent to maximize Equation 2 by computing the gradient dEval(πsolver(fw(𝒙)),𝒯)dw\frac{d\text{Eval}(\pi^{\text{solver}}(f_{w}(\boldsymbol{x})),\mathcal{T})}{dw}. However, in order to compute the gradient, we need to differentiate through the policy solver πsolver\pi^{\text{solver}} and the corresponding optimal solution. Unfortunately, finding the optimal policy in RMABs is expensive and the policy is high-dimensional. Both of these challenges prevent us from computing the gradient to achieve decision-focused learning.

4 Decision-focused Learning in RMABs

In this paper, instead of grappling with the optimal policy, we consider the Whittle index policy (whittle1988restless) – the dominant solution paradigm used to solve the RMAB problem. Whittle index policy is easier to compute and has been shown to perform well in practice. In this section we establish that it is also possible to backpropagate through the Whittle index policy. This differentiability of Whittle index policy allows us to run decision-focused learning to directly maximize the performance in the RMAB problem.

4.1 Whittle Index and Whittle Index Policy

Informally, the Whittle index of an arm captures the added value derived from pulling that arm. The key idea is to determine the Whittle indices of all arms and to pull the arms with the highest values of the index.

To evaluate the value of pulling an arm ii, we consider the notion of ‘passive subsidy’, which is a hypothetical exogenous compensation mm rewarded for not pulling the arm (i.e. for choosing action a=0a=0). Whittle index is defined as the smallest subsidy necessary to make pulling as rewarding as not pulling, assuming indexability (liu2010indexability):

Definition 4.1 (Whittle index).

Given state u𝒮u\in\mathcal{S}, we define the Whittle index associated to state uu by:

Wi(u)\displaystyle W_{i}(u) infm{Qim(u;a=0)=Qim(u;a=1)}\displaystyle\coloneqq\inf\nolimits_{m}\{Q_{i}^{m}(u;a=0)=Q_{i}^{m}(u;a=1)\} (3)

where the value functions are defined by the following Bellman equations, augmented with subsidy mm for action a=0a=0.

Vim(s)\displaystyle V^{m}_{i}(s) =maxaQim(s;a)\displaystyle=\max\nolimits_{a}Q_{i}^{m}(s;a) (4)
Qim(s;a)\displaystyle Q_{i}^{m}(s;a)\! =m𝟏a=0+R(s)+γsPi(s,a,s)Vim(s)\displaystyle=\!m\boldsymbol{1}_{a=0}\!+\!R(s)\!+\!\gamma\!\sum\nolimits_{s^{\prime}}\!P_{i}(s,a,s^{\prime})V^{m}_{i}(s^{\prime})\! (5)

Given the Whittle indices of all arms and all states W=[Wi(u)]i[N],u𝒮W=[W_{i}(u)]_{i\in[N],u\in\mathcal{S}}, the Whittle index policy is denoted by πwhittle:𝒮N[0,1]N\pi^{\text{whittle}}:\mathcal{S}^{N}\longrightarrow[0,1]^{N}, which takes the states of all arms as input to compute their Whittle indices and output the probabilities of pulling arms. This policy repeats for every time step to pull arms based on the index values.

4.2 Decision-focused Learning Using Whittle Index Policy

Instead of using the optimal policy π\pi^{*} to run decision-focused learning with expensive computation cost, we use Whittle index policy πwhittle\pi^{\text{whittle}} to determine how to pull arms as an approximate solution. In this case, in order to run decision-focused learning, we need to compute the derivative of the evaluation metric by chain rule:

dEval(πwhittle,𝒯)dw=dEval(πwhittle,𝒯)dπwhittledπwhittledWdWdPdPdw\displaystyle\frac{d\text{Eval}(\pi^{\text{whittle}},\mathcal{T})}{dw}=\frac{d\text{Eval}(\pi^{\text{whittle}},\mathcal{T})}{d\pi^{\text{whittle}}}\frac{d\pi^{\text{whittle}}}{dW}\frac{dW}{dP}\frac{dP}{dw} (6)

where WW is the Whittle indices of all states under the predicted transition probabilities PP. The policy πwhittle\pi^{\text{whittle}} is the Whittle index policy induced by WW. The flowchart is illustrated in Figure 1.

The term dEval(πwhittle,𝒯)dπwhittle\frac{d\text{Eval}(\pi^{\text{whittle}},\mathcal{T})}{d\pi^{\text{whittle}}} can be computed via policy gradient theorem (sutton1998introduction), and the term dPdw\frac{dP}{dw} can be computed using auto-differentiation. However, there are still two challenges remaining: (i) how to differentiate through Whittle index policy to get dπwhittledW\frac{d\pi^{\text{whittle}}}{dW} (ii) how to differentiate through Whittle index computation to derive dWdP\frac{dW}{dP}.

4.3 Differentiability of Whittle Index Policy

A common choice of Whittle index policy is defined by:

Definition 4.2 (Strict Whittle index policy).
πWstrict(𝒔)=𝟏top-k([Wi(si)]i[N]){0,1}N\displaystyle\pi^{\text{strict}}_{W}(\boldsymbol{s})=\boldsymbol{1}_{\text{top-k}([W_{i}(s_{i})]_{i\in[N]})}\in\{0,1\}^{N} (7)

which selects arms with the top-k Whittle indices to pull.

However, the strict top-k operation in the strict Whittle index policy is non-differentiable, which prevents us from computing a meaningful estimate of dπwhittledW\frac{d\pi^{\text{whittle}}}{dW} in Equation 6. We circumvent this issue by relaxing the top-k selection to a soft-top-k selection (xie2020differentiable), which can be expressed as an optimal transport problem with regularization, making it differentiable. We apply soft-top-k to define a new differentiable soft Whittle index policy:

Definition 4.3 (Soft Whittle index policy).
πWsoft(𝒔)=soft-top-k([Wj(si)]i[N])[0,1]N\displaystyle\pi^{\text{soft}}_{W}(\boldsymbol{s})=\text{soft-top-k}([W_{j}(s_{i})]_{i\in[N]})\in[0,1]^{N} (8)

Using the soft Whittle index policy, the policy becomes differentiable and we can compute dπwhittledW\frac{d\pi^{\text{whittle}}}{dW}.

4.4 Differentiability of Whittle Index

The second challenge is the differentiability of Whittle index. Whittle indices are often computed using value iteration and binary search (qian2016restless; mate2020collapsing) or mixed integer linear program. However, these operations are not differentiable and we cannot compute the derivative dWdP\frac{dW}{dP} in Equation 6 directly.

Refer to caption
Figure 2: We establish the differentiability of Whittle index policy using a soft top-k selection to construct a soft Whittle index policy, and the differentiability of Whittle index by expressing Whittle index as a solution to a linear system in Equation 11.

Main idea

After computing the Whittle indices and the value functions of each arm ii, the key idea is to construct linear equations that link the Whittle index with the transition matrix PiP_{i}. Specifically, we achieve this by resolving the max\max operator in Equation 4 of Definition 4.1 by determining the optimal actions aa from the pre-computed value functions. Plugging back in Equation 5 and manipulating as shown below yields linear equations in the Whittle index Wi(u)W_{i}(u) and transition matrix PiP_{i}, which can be expressed as a full-rank linear system in PiP_{i}, with the Whittle index as a solution. This makes the Whittle index differentiable in PiP_{i}.

Selecting Bellman equation

Let uu and arm ii be the target state and target arm to compute the Whittle index. Assume we have precomputed the Whittle index m=Wi(u)m=W_{i}(u) for state uu and the corresponding value functions [Vim(s)]s𝒮[V^{m}_{i}(s)]_{s\in\mathcal{S}} for all states under the same passive subsidy m=Wi(u)m=W_{i}(u). Equation 5 can be combined with Equation 4 to get:

Vim(s){m+R(s)+γs𝒮Pi(s,a=0,s)Vim(s)R(s)+γs𝒮Pi(s,a=1,s)Vim(s)\displaystyle V^{m}_{i}(s)\geq\begin{cases}m+R(s)+\gamma\sum\nolimits_{s^{\prime}\in\mathcal{S}}P_{i}(s,a=0,s^{\prime})V^{m}_{i}(s^{\prime})\\ R(s)+\gamma\sum\nolimits_{s^{\prime}\in\mathcal{S}}P_{i}(s,a=1,s^{\prime})V^{m}_{i}(s^{\prime})\end{cases} (9)

where m=Wi(u)m=W_{i}(u).

For each sSs\in S, at least one of the equalities in Equation 9 holds because one of the actions must be optimal and match the state value function Vim(s)V^{m}_{i}(s). We can identify which equality holds by simply plugging in values of precomputed value functions [Vim(s)]s𝒮[V^{m}_{i}(s)]_{s\in\mathcal{S}}. Furthermore, for the target state uu, both equalities must hold because by the definition of Whittle index, the passive subsidy m=Wi(u)m=W_{i}(u) makes both actions equally optimal, i.e. in Equation 3, Vim(u)=Qim(u,a=0)=Qim(u,a=1)V^{m}_{i}(u)=Q^{m}_{i}(u,a=0)=Q^{m}_{i}(u,a=1) for m=Wi(u)m=W_{i}(u).

Thus Equation 9 can be written in matrix form:

[𝑽im𝑽im][𝟏Mγ𝑷i(𝒮,a=0,𝒮)𝟎Mγ𝑷i(𝒮,a=1,𝒮)][m𝑽im]+[𝑹(S)𝑹(S)]\displaystyle\begin{bmatrix}{\boldsymbol{V}}^{m}_{i}\\ {\boldsymbol{V}}^{m}_{i}\end{bmatrix}\!\geq\!\begin{bmatrix}\boldsymbol{1}_{M}\!&\!\gamma\boldsymbol{P}_{i}(\mathcal{S},a\!=\!0,\mathcal{S})\\ \boldsymbol{0}_{M}\!&\!\gamma\boldsymbol{P}_{i}(\mathcal{S},a\!=\!1,\mathcal{S})\end{bmatrix}\begin{bmatrix}m\\ {\boldsymbol{V}}^{m}_{i}\end{bmatrix}\!+\!\begin{bmatrix}\boldsymbol{R}(S)\\ \boldsymbol{R}(S)\end{bmatrix} (10)

where 𝑽im[Vim(s)]s𝒮\boldsymbol{V}^{m}_{i}\coloneqq[V^{m}_{i}(s)]_{s\in\mathcal{S}}, 𝑹(𝒮)=[R(s)]s𝒮\boldsymbol{R}(\mathcal{S})=[R(s)]_{s\in\mathcal{S}}, and 𝑷i(𝒮,a,𝒮)[Pi(s,a,s)]s,s𝒮M×M\boldsymbol{P}_{i}(\mathcal{S},a,\mathcal{S})\coloneqq[P_{i}(s,a,s^{\prime})]_{s,s^{\prime}\in\mathcal{S}}\in\mathbb{R}^{M\times M}.

By the aforementioned discussion, we know that there are at least M+1M+1 equalities in Equation 10 while there are also only M+1M+1 variables (mm\in\mathbb{R} and 𝑽imM\boldsymbol{V}^{m}_{i}\in\mathbb{R}^{M}). Therefore, we rearrange Equation 10 and pick only the rows where equalities hold to get:

A[𝟏Mγ𝑷i(𝒮,a=0,𝒮)IM𝟎Mγ𝑷i(𝒮,a=1,𝒮)IM][m𝑽im]=A[𝑹(S)𝑹(S)]\displaystyle A\begin{bmatrix}\boldsymbol{1}_{M}&\gamma\boldsymbol{P}_{i}(\mathcal{S},a=0,\mathcal{S})-I_{M}\\ \boldsymbol{0}_{M}&\gamma\boldsymbol{P}_{i}(\mathcal{S},a=1,\mathcal{S})-I_{M}\end{bmatrix}\begin{bmatrix}m\\ {\boldsymbol{V}}^{m}_{i}\end{bmatrix}=A\begin{bmatrix}-\boldsymbol{R}(S)\\ -\boldsymbol{R}(S)\end{bmatrix} (11)

where we use a binary matrix A{0,1}(M+1)×2MA\in\{0,1\}^{(M+1)\times 2M} with a single 11 per row to extract the equality. For example, we can set Aij=1A_{ij}=1 if the jj-th row in Equation 10 corresponds to the equality in Equation 9 with the ii-th state in the state space SS for i[M]i\in[M], and the last row A(M+1),j=1A_{(M+1),j}=1 to mark the additional equality matched by the Whittle index definition (see Appendix 16 for more details). Matrix AA picks M+1M+1 equalities out from Equation 10 to form Equation 11.

Equation 11 is a full-rank linear system with m=Wi(u)m=W_{i}(u) as a solution. This expresses Wi(u)W_{i}(u) as an implicit function of 𝑷\boldsymbol{P}, allowing for computation of dWi(u)d𝑷\frac{dW_{i}(u)}{d\boldsymbol{P}} via autodifferentiation, thus achieving differentiability of the Whittle index. We repeat this process for every arm i[N]i\in[N] and every state uu. Figure 2 summarizes the differentiable Whittle index policy and the algorithm is shown in Algorithm 1.

4.5 Computation Cost and Backpropagation

It is well studied that Whittle index policy can be computed more efficiently than solving the RMAB problem as a large MDP problem. Here, we show that the use of Whittle index policy also demonstrates a large speed up in terms of backpropagating the gradient in decision-focused learning.

In order to use Equation 11 to compute the gradient of Whittle indices, we need to invert the left-hand-side of Equation 11 with dimensionality M+1M+1, which takes O(Mω)O(M^{\omega}) where ω2.373\omega\approx 2.373 (alman2021refined) is the best known matrix inversion constant. Therefore, the overall computation of all NN arms and MM states is O(NMω+1)O(NM^{\omega+1}) per gradient step.

In contrast, the standard decision-focused learning differentiates through the optimal policy using the full Bellman equation with O(MN)O(M^{N}) variables, where inverting the large Bellman equation requires O(MωN)O(M^{\omega N}) cost per gradient step. Thus, our algorithm significantly reduces the computation cost to a linear dependency on the number of arms NN. This significantly improves the scalability of decision-focused learning.

4.6 Extension to Partially Observable RMAB

For partially observable RMAB problem, we focus on a subclass of RMAB problem known as collapsing bandits (mate2020collapsing). In collapsing bandits, belief states (monahan1982state) are used to represent the posterior belief of the unobservable states. Specifically, for each arm ii, we use bi=Δ(𝒮)[0,1]Mb_{i}\in\mathcal{B}=\Delta(\mathcal{S})\subset[0,1]^{M} to denote the posterior belief of an arm, where each entry bi(si)b_{i}(s_{i}) denotes the probability that the true state is si𝒮s_{i}\in\mathcal{S}. When arm ii is pulled, the current true state sibis_{i}\sim b_{i} is revealed and drawn from the posterior belief with expected reward biRb_{i}^{\top}R, where we can define the transition probability on the belief states. This process reduces partially observable states to fully observable belief states with in total MTMT states since the maximal horizon is TT. Therefore, we can use the same technique to differentiate through Whittle indices of partially observable states.

5 Policy Evaluation Metrics

In this paper, we use two different variants of evaluation metric: importance sampling-based evaluation (sutton1998introduction) and simulation-based (model-based) evaluation.

Importance sampling-based Evaluation

We adopt Consistent Weighted Per-Decision Importance Sampling (CWPDIS) (thomas2015safe) as our importance sampling-based evaluation. Given target policy π\pi and a trajectory τ={s1,a1,r1,,sT,aT,rT}\tau=\{s_{1},a_{1},r_{1},\cdots,s_{T},a_{T},r_{T}\} executed by the behavior policy πbeh\pi_{\text{beh}}, the importance sampling weight is defined by ρti=t=1tπ(at,ist)πbeh(at,ist)\rho_{ti}=\prod\nolimits_{t^{\prime}=1}^{t}\frac{\pi(a_{t^{\prime},i}\mid s_{t^{\prime}})}{\pi_{\text{beh}}(a_{t^{\prime},i}\mid s_{t^{\prime}})}. We evaluate the policy π\pi by:

EvalIS(π,𝒯)=t[T],i[N]γt1𝔼τ𝒯[rt,iρti(τ)]𝔼τ𝒯[ρti(τ)]\displaystyle\text{Eval}_{\text{IS}}(\pi,\mathcal{T})=\sum\nolimits_{t\in[T],i\in[N]}\gamma^{t-1}\frac{\mathop{\mathbb{E}}_{\tau\sim\mathcal{T}}\left[r_{t,i}\rho_{ti}(\tau)\right]}{\mathop{\mathbb{E}}_{\tau\sim\mathcal{T}}\left[\rho_{ti}(\tau)\right]} (12)

Importance sampling-based evaluations are often unbiased but with a larger variance due to the unstable importance sampling weights. CWPDIS normalizes the importance sampling weights to achieve a consistent estimate.

Simulation-based Evaluation

An alternative way is to use the given trajectories to construct an empirical transition probability P¯\bar{P} to build a simulator and evaluate the target policy π\pi. The variance of simulation-based evaluation is small, but it may require additional assumptions on the missing transition when the empirical transition P¯\bar{P} is not fully reconstructed.

Algorithm 1 Decision-focused Learning in RMAB
1:  Input: training set 𝒟train\mathcal{D}_{\text{train}}, learning rate rr, model fwf_{w}
2:  for epoch =1,2,=1,2,\cdots and (x,𝒯)𝒟train(x,\mathcal{T})\in\mathcal{D}_{\text{train}} do
3:     Predict P=fw(x)P=f_{w}(x) and compute Whittle indices W(P)W(P).
4:     Let πwhittle=πWsoft\pi^{\text{whittle}}=\pi^{\text{soft}}_{W} and compute Eval(πwhittle,𝒯)\text{Eval}(\pi^{\text{whittle}},\mathcal{T}).
5:     Update w=w+rdEval(πwhittle,𝒯)dπwhittledπwhittledWdWdPdPdww=w+r\frac{d\text{Eval}(\pi^{\text{whittle}},\mathcal{T})}{d\pi^{\text{whittle}}}\frac{d\pi^{\text{whittle}}}{dW}\frac{dW}{dP}\frac{dP}{dw}, where dWdP\frac{dW}{dP} is computed from Equation 11.
6:  end for
7:  Return: predictive model fwf_{w}

6 Experiments

Refer to caption
Refer to caption
(a) Predictive loss
Refer to caption
(b) IS-based evaluation
Refer to caption
(c) Simulation-based evaluation
Figure 3: Comparison of predictive loss, importance sampling-based evaluation, and simulation-based evaluation on all synthetic domains and the real ARMMAN dataset. For the evaluation metrics, we plot the improvement against the no-action baseline that does not pull any arm. Although two-stage method achieves the smallest predictive loss, decision-focused learning consistently outperforms two-stage method in both solution quality evaluation metrics across all domains.

We compare two-stage learning (TS) with our decision-focused learning (DF-Whittle) that optimizes importance sampling-based evaluation directly. We consider three different evaluation metrics including predictive loss, importance sampling evaluation, and simulation-based evaluation to evaluate all learning methods. We perform experiments on three synthetic datasets including 22-state fully observable, 55-state fully observable, and 22-state partially observable RMAB problems. We also perform experiments on a real dataset on maternal and child health problem modelled as a 22-state fully observable RMAB problem with real features and historical trajectories. For each dataset, we use 70%70\%, 10%10\%, 20%20\% of the RMAB problems as the training, validation, and testing sets, respectively. All experiments are averaged over 5050 independent runs.

Synthetic datasets

We consider RMAB problems composed of N=100N=100 arms, MM states, budget K=20K=20, and time horizon T=10T=10 with a discount rate of γ=0.99\gamma=0.99. The reward function is given by R=[i1M1]i[M]R=[\frac{i-1}{M-1}]_{i\in[M]}, while the transition probabilities are generated uniformly at random but with a constraint that pulling the arm (a=1a=1) is strictly better than not pulling the arm (a=0a=0) to ensure the benefit of pulling. To generate the arm features, we feed the transition probability of each arm to a randomly initialized neural network to generate fixed-length correlated features with size 1616 per arm. The historical trajectories 𝒯\mathcal{T} with |𝒯|=10|\mathcal{T}|=10 are produced by running a random behavior policy πbeh\pi_{\text{beh}}. The goal is to predict transition probabilities from the arm features and the training trajectories.

Real dataset

The Maternal and Child Healthcare Mobile Health program operated by armman aims to improve dissemination of health information to pregnant women and mothers with an aim to reduce maternal, neonatal and child mortality and morbidity. ARMMAN serves expectant/new mothers in disadvantaged communities with median daily family income of $3.22 per day which is seen to be below the world bank poverty line (world2020poverty). The program is composed of multiple enrolled beneficiaries and a planner who schedules service calls to improve the overall engagement of beneficiaries; engagement is measured in terms of total number of automated voice (health related) messages that the beneficiary engaged with. More precisely, this problem is modelled as a M=2M=2-state fully observable RMAB problem where each beneficiary’s behavior is governed by an MDP with two states - Engaging and Non-Engaging state; engagement is determined by whether the beneficiary listens to an automated voice message (average length 115 seconds) for more than 30 seconds. The planner’s task is to recommend a subset of beneficiaries every week to receive service calls from health workers to further improve their engagement behavior. We do not know the transition dynamics, but we are given beneficiaries’ socio-demographic features to predict transition dynamics.

We use a subset of data from the large-scale anonymized quality improvement study performed by ARMMAN for T=7T=7 weeks, obtained from mate2022field, with beneficiary consent. In the study, a cohort of beneficiaries received Round-Robin policy, scheduling service calls in a fixed order, with a single trajectory |𝒯|=1|\mathcal{T}|=1 per beneficiary that documents the calling decisions and the engagement behavior in the past. We randomly split the cohort into 88 training groups, 11 validation group, and 33 testing groups each with N=639N=639 beneficiaries and K=18K=18 budget formulated as an RMAB problem. The demographic features of beneficiaries are used to infer the missing transition dynamics.

Data usage

All the datasets are anonymized. The experiments are secondary analysis using different evaluation metrics with approval from the ARMMAN ethics board. There is no actual deployment of the proposed algorithm at ARMMAN. For more details about the dataset, consent of data collection, please refer to Appendix 10 and 11.

Refer to caption
Figure 4: Performance improvement of decision-focused v.s. two-stage method with varying number of trajectories.
Refer to caption
(a) Comparing out algorithm to decision-focused baselines.
Refer to caption
(b) Computation cost with varying number of arms NN.
Figure 5: We compare the computation cost of our decision-focused learning with other baselines and the theoretical complexity O(NMω+1)O(NM^{\omega+1}) with varying number of arms NN.

7 Experimental Results

Performance improvement and justification of objective mismatch

In Figure 3, we show the performance of random policy, two-stage, and decision-focused learning (DF-Whittle) on three evaluation metrics - predictive loss, importance sampling-based evaluation and simulation-based evaluation for all domains. For the evaluation metrics, we plot the improvement against the no-action baseline that does not pull any arms throughout the entire RMAB problem. We observe that two-stage learning consistently converges to a smaller predictive loss, while DF-Whittle outperforms two-stage on all solution quality evaluation metrics significantly (p-value << 0.050.05) by alleviating the objective mismatch issue. This result also provides evidence of aforementioned objective mismatch, where the advantage of two-stage in the predictive loss does not translate to solution quality.

Significance in maternal and child care domain

In the ARMMAN data in Figure 3, we assume limited resources that we can only select 1818 out of 638638 beneficiaries to make service call per week. Both random and two-stage method lead to around 1515 more (IS-based evaluation) listening to automated voice messages among all beneficiaries throughout the 7-week program by 18×7=12618\times 7=126 service calls, when compared to not scheduling any service call; this low improvement also reflects the hardness of maximizing the effectiveness of service calls. In contrast, decision-focused learning achieves an increase of beneficiaries listening to 50 more voice messages overall; DF-whittle achieves a much higher increase by strategically assigning the limited service calls using the right objective in the learning method. The improvement is statistically significant (p-value << 0.050.05).

In the testing set, we examine the difference between those selected for service call in two-stage and DF-Whittle. We observe that there are some interesting differences. For example, DF-Whittle chooses to do service calls to expectant mothers earlier in gestational age (22% vs 37%), and to a lower proportion of those who have already given birth (2.8% vs 13%) compared to two-stage. In terms of the income level, there is no statistic significance between two-stage and DFL (p-value = 0.20 see Appendix 10). In particular, 94% of the mothers selected by both methods are below the poverty line (world2020poverty).

Impact of Limited Data

Figure 4 shows the improvement between decision-focused learning and two-stage method with varying number of trajectories given to evaluate the impact of limited data. We notice that a larger improvement between decision-focused and two-stage learning is observed when fewer trajectories are available. We hypothesize that less samples implies larger predictive error and more discrepancy between the loss metric and the evaluation metric.

Computation cost comparison

Figure 5(a), compares the computation cost per gradient step of our Whittle index-based decision-focused learning and other baselines in decision-focused learning (wang2021learning; futoma2020popcorn) by changing NN (the number of arms) in M=2M=2-state RMAB problem. The other baselines fail to run with N=30N=30 arms and do not scale to larger problems like maternal and child care with more than 600600 people enrolled, while our approach is 100x faster than the baselines as shown in Figure 5(a) and with a linear dependency on the number of arms NN.

In Figure 5(b), we compare the empirical computation cost of our algorithm with the theoretical computation complexity O(NMω+1)O(NM^{\omega+1}) in NN arms and MM states RMAB problems. The empirical computation cost matches with the linear trend in NN. Our computation cost significantly improves the computation cost O(MωN)O(M^{\omega N}) of previous work as discussed in Section 4.5.

8 Conclusion

This paper presents the first decision-focused learning in RMAB problems that is scalable for large real-world datasets. We establish the differentiability of Whittle index policy in RMAB by providing new method to differentiate through Whittle index and using soft-top-k to relax the arm selection process. Our algorithm significantly improves the performance and scalability of decision-focused learning, and is scalable to real-world RMAB problem sizes.

Acknowledgments

Aditya Mate was supported by the ARO and was accomplished under Grant Number W911NF-17-1-0370. Sanket Shah and Kai Wang were also supported by W911NF-17-1-0370 and ARO Grant Number W911NF-18-1-0208. The views and conclusions contained in this document are those of the authors and should not be interpreted as representing the official policies, either expressed or implied, of ARO or the U.S. Government. The U.S. Government is authorized to reproduce and distribute reprints for Government purposes notwithstanding any copyright notation herein. Kai Wang was additionally supported by Siebel Scholars.

Appendix

9 Hyperparameter Setting and Computation Infrastructure

We run both Decision Focused Learning and Two-Stage Learning for 50 epochs in 2-state and 5-state synthetic domain problems, 30 epochs in ARMMAN domain and 18 epochs in 2-state partially observable setting. The learning rate rr is kept at 0.010.01 and γ=0.59\gamma=0.59 is used in all experiments. All the experiments are performed on an Intel Xeon CPU with 64 cores and 128 GB memory.

Neural Network Structure

The predictive model fwf_{w} we use to predict the transition probability is a neural network with an intermediate layer of size 6464 with ReLU activation function, and an output layer of size of the transition probability followed by a softmax layer to match probability distribution. Dropout layers are added to avoid overfitting. The same neural network structure is applied to all domains and all training methods.

In the synthetic datasets, given the generated transition probabilities, we feed the transition probability of each arm into a randomly initialized neural network with two intermediate layers each with 6464 neurons, and an output dimension size 1616 to generate a feature vector of size 1616. The randomly initiated neural network uses ReLU layers as nonlinearity followed by a linear layer in the end.

10 Real ARMMAN Dataset

The large-scale quality improvement study conducted by armman contains 7668 beneficiries in the Round Robin Group. Over a duration of 7 weeks, 20%20\% of the beneficiaries receive at least one active action (LIVE service call). We randomly split the 7668 beneficiaries into 12 groups while preserving the proportion of beneficiaries who received at least one active action. There are 43 features available for every beneficiary which describe characteristics such as age, income, education level, call slot preference, language preference, phone ownership etc.

10.1 Protected and Sensitive Features

ARMMAN’s mobile voice call program has long been working with socially disadvantaged populations. ARMMAN does not collect or include constitutionally protected and particularly sensitive categories such as caste and religion. Despite such categories not being available, in pursuit of ensuring fairness, we worked with public health and field experts to ensure indicators such as education, and income levels that signify markers of socio-economic marginalization were measured and evaluated for fairness testing.

10.2 Feature List

We provide the full list of 43 features used for predicting transition probability:

  • Enroll gestation age, age (split into 5 categories), income (8 categories), education level (7 categories), language (5 categories), phone ownership (3 categories), call slot preference (5 categories), enrollment channel (3 categories), stage of pregnancy, days since first call, gravidity, parity, stillbirths, live births

10.3 Feature Evaluation

Feature Two-stage Decision-focused learning p-value
age (year) 25.57 24.9 0.06
gestation age (week) 24.28 17.21 0.00
Table 1: Feature analysis of continuous features. This table summarizes the average feature values of the beneficiaries selected to schedule service calls by different learning methods. The p-value of the continuous features is analyzed using t-test for difference in mean.
Feature Two-stage Decision-focused learning p-value
income (rupee, averaged over multiple categories) 10560.0 11190.0 0.20
education (categorical) 3.32 3.16 0.21
stage of pregnancy 0.13 0.03 0.00
language
language (hindi) 0.53 0.6 0.04
language (marathi) 0.45 0.4 0.08
phone ownership
phone ownership (women) 0.86 0.82 0.04
phone ownership (husband) 0.12 0.16 0.03
phone ownership (family) 0.02 0.02 1.00
enrollment channel
channel type (community) 0.7 0.47 0.00
channel type (hospital) 0.3 0.53 0.00
Table 2: Feature analysis of categorical features. This table summarizes the average feature values of the beneficiaries selected to schedule service calls by different learning methods. The p-value of the categorical values is analyzed using chi-square test for different proportions.

In our simulation, we further analyze the demographic features of participants who are selected to schedule service calls by either two-stage learning method and decision-focused learning method. The following tables show the average value of each individual feature over the selected participants with scheduled service calls under the two-stage or decision-focused learning method. The p-value of the continuous features is analyzed using t-test for difference in mean; the p-value of the categorical values is analyzed using chi-square test for different proportions.

In Table 1 and Table 2, we can see that there is no statistical significance (p-value >0.05>0.05) between the average feature values of income and education, meaning that there is no obvious difference in these feature values between the population selected by two different methods. We see statistical significance in some other features, e.g., gestation age, stage of maternal event, language, phone ownership, and channel type, which may be further analyzed to understand the benefit of decision-focused learning, but they do not appear to directly bear upon socio-economic marginalization; these features are more related to the health status of the beneficiaries.

11 Consent for Data Collection and Analysis

In this section, we provide information about consent related to data collection, analyzing data, data usage and sharing.

11.1 Secondary Analysis and Data Usage

This study falls into the category of secondary analysis of the aforementioned dataset. We use the previously collected engagement trajectories of different beneficiaries participating in the service call program to train the predictive model and evaluate the performance. The evaluation of the proposed algorithm is evaluated via different off-policy policy evaluations, including an importance sampling-based method and a simulation-based method discussed in Section 5. This paper does not involve deployment of the proposed algorithm or any other baselines to the service call program.As noted earlier, the experiments are secondary analysis using different evaluation metrics with approval from the ARMMAN ethics board.

11.2 Consent for Data Collection and Sharing

The consent for collecting data is obtained from each of the participants of the service call program. The data collection process is carefully explained to the participants to seek their consent before collecting the data. The data is anonymized before sharing with us to ensure anonymity. Data exchange and use was regulated through clearly defined exchange protocols including anonymization, read-access only to researchers, restricted use of the data for research purposes only, and approval by ARMMAN’s ethics review committee.

11.3 Universal Accessibility of Health Information

To allay further concerns: this simulation study focuses on improving quality of service calls. Even in the intended future application, all participants will receive the same weekly health information by automated message regardless of whether they are scheduled to receive service calls or not. The service call program does not withhold any information from the participants nor conduct any experimentation on the health information. The health information is always available to all participants, and participants can always request service calls via a free missed call service. In the intended future application our algorithm may only help schedule *additional* service calls to help beneficiaries who are likely to drop out of the program.

12 Societal Impacts and Limitations

12.1 Societal Impacts

The improvement shown in the real dataset directly reflects the number of engagements improved by our algorithm under different evaluation metrics. On the other hand, because of the use of demographic features to predict the engagement behavior, we must carefully compare the models learned by standard two-stage approach and our decision-focused learning to further examine whether there is any bias or discrimination concern.

Specifically, the data is collected by ARMMAN, an India non-government organization, to help mothers during their pregnancy. The ARMMAN dataset we use in the paper does not contain information related to race, religion, caste or other sensitive features; this information is not available to the machine learning algorithm. Furthermore, examination by ARMMAN staff of the mothers selected for service calls by our algorithm did not reveal any specific bias related to these features. In particular, the program run by ARMMAN targets mothers in economically disadvantaged communities; the majority of the participants (94%) are below the international poverty line determined by The World Bank (world2020poverty). To compare the models learned by two-stage and DF-Whittle approach, we further examine the difference between those mothers who are selected for service call in two-stage and DF-Whittle, respectively. We observe that there are some interesting differences. For example, DF-Whittle chooses to do service calls to expectant mothers earlier in gestational age (22% vs 37%), and to a lower proportion of those who have already given birth (2.8% vs 13%) compared to two-stage, but in terms of the income level, 94% of the mothers selected by both methods are below the poverty line. This suggests that our approach is not biased based on income level, especially when the entire population is coming from economically disadvantaged communities. Our model can identify other features of mothers who are actually in need of service calls.

12.2 Limitations

Impact of limited data and the strength of decision-focused learning

As shown in Section 7 and Figure 4, we notice a smaller improvement between decision-focused learning and two-stage approach when there is sufficient data available in the training set. This is because the data is sufficient enough to train a predictive model with small predictive loss, which implies that the predicted transition probabilities and the true transition probabilities are also close enough with similar Whittle indices and Whittle index policy. In this case with sufficient data, there is less discrepancy between predictive loss and the evaluation metrics, which suggests less improvement led by fixing the discrepancy using decision-focused learning. Compared to two-stage approach, decision-focused learning is still more expensive to run. Therefore, when data is sufficient, two-stage may be sufficient to achieve comparable performance while maintaining a low training cost.

On the other hand, we notice a larger improvement between decision-focused learning and two-stage approach when data is limited. When data is limited, predictive loss is less representative with a larger mismatch compared to the evaluation metrics. Therefore, fixing the objective mismatch issue using decision-focused learning becomes more prominent. Therefore, decision-focused learning may be adopted in the limited data case to significantly improve the performance.

Computation cost

As we have shown in Section 4.5, our approach improves the computation cost of decision-focused learning from O(MωN)O(M^{\omega N}) to O(NMω+1)O(NM^{\omega+1}), where NN is the number of arms and MM is the number of states. This computation cost is linear in the number of arms NN, allowing us to scale up to large real-world deployment of RMAB applications with larger number of arms involved in the problem. Nonetheless, the extension in terms of the number of states MM is not cheap. The computation cost still grows between cubic and biquadratic as shown in Figure 6. This is particularly significant when working on partially observable RMAB problems, where the partially observable problems are reduced to fully observable problems with larger number of states. There is room for improving the computation cost in terms of the number of states to make decision-focused learning more scalable to real-world applications.

13 Computation Cost Analysis of Decision-focused Learning

We have shown the computation cost of backpropagating through Whittle indices in Section 4.5. This section covers the remaining computation cost associated to other components, including the computation cost of Whittle indices in the forward pass, and the computation cost of constructing soft Whittle index policy using soft-top-k operator.

13.1 Solving Whittle Index (Forward Pass)

In this section, we discuss the cost of computing Whittle index in the forward pass. In the work by qian2016restless, they propose to use value iteration and binary search to solve the Bellman equation with MM states. Therefore, every value iteration requires updating the current value functions of MM states by considering all the possible M2M^{2} transitions between states, which results in a computation cost of O(M2)O(M^{2}) per value iteration. The value iteration is run for a constant number of iterations, and the binary search is run for O(log1ϵ)O(\log\frac{1}{\epsilon}) iterations to get a precision of order ϵ\epsilon. In total, the computation cost is of order O(M2log1ϵ)=O(M2)O(M^{2}\log\frac{1}{\epsilon})=O(M^{2}) where we simply use a fixed precision to ignore the dependency on ϵ\epsilon.

On the other hand, there is a faster way to compute the value function by solving linear program with MM variables directly. The Bellman equation can be expressed as a linear program where all the MM variables are the value functions. The best known complexity of solving a linear program with MM variables is O(M2+118)O(M^{2+\frac{1}{18}}) by jiang2020faster. Notice that this complexity is slightly larger than the one in value iteration because (i) value iteration does not guarantee convergence in a constant iterations (ii) the constant associated to the number of value iterations is large.

In total, we need to compute the Whittle index of NN arms and for MM possible states in 𝒮\mathcal{S}. The total complexity of value iteration and linear program are O(NM3)O(NM^{3}) with a large constant and O(NM3+118)O(NM^{3+\frac{1}{18}}), respectively. In any cases, the cost of computing all Whittle indices in the forward pass is still smaller than O(NM1+ω)O(NM^{1+\omega}), the cost of backpropagating through all the Whittle indices in the backward pass. Therefore, the backward pass is the bottleneck of the entire process.

13.2 Soft-top-k Operators

In Section 13.1 and Section 4.5, we analyze the cost of computing and backpropagating through Whittle indices of all states and all arms. In this section, we discuss the cost of computing the soft Whittle index policy from the given Whittle indices using soft-top-k operators.

Soft-top-k operators

xie2020differentiable reduces top-k selection problem to an optimal transport problem that transports a uniform distribution across all input elements with size NN to a distribution where the elements with the highest-k values are assigned probability 11 and all the others are assigned 0.

This optimal transport problem with NN elements can be efficiently solved by using Bregman projections (benamou2015iterative) with complexity O(LN)O(LN), where LL is the number of iterations used to run Bregman projections. In the backward pass, xie2020differentiable shows that the technique of differentiating through the fixed point equation (bai2019deep; amos2017optnet) also applies, but the naive implementation requires computation cost O(N2)O(N^{2}). Therefore,  xie2020differentiable provides a faster computation approach by leveraging the associate rule in matrix multiplication to lower the backward complexity to O(N)O(N).

In summary, a single soft-top-k operator requires O(LN)O(LN) to compute the result in the forward pass, and O(N)O(N) to compute the derivative in the backward pass. In our case, we need to apply one soft-top-k operator for every time step in TT and for every trajectory in 𝒯\mathcal{T}. Therefore, the total computation cost of computing a soft Whittle index policy and the associated importance sampling-based evaluation metric is bounded by O(LNT|𝒯|)O(LNT|\mathcal{T}|), which is linear in the number of arms NN, but still significantly smaller than O(NMω+1)O(NM^{\omega+1}), the cost of backpropagating through all Whittle indices as shown in Section 4.5. Therefore, we just need to concern the computation cost of Whittle indices in decision-focused learning.

13.3 Computation Cost Dependency on the Number of States

Figure 6 compares the computation cost of our algorithm, DF-Whittle, and the theoretical computation cost O(NMω+1O(NM^{\omega+1}. We vary the number of states MM in Figure 6 and we can see that the computation cost of our algorithm matches the theoretical guarantee on the computation cost. In contrast to the prior work with computation cost O(MωN)O(M^{\omega N}), our algorithm significantly improves the computation cost of running decision-focused learning on RMAB problems.

Refer to caption
Figure 6: Computation cost comparison to the theoretical guarantee with varying number of states MM.

14 Importance Sampling-based Evaluations for ARMMAN Dataset with Single Trajectory

Unlike the synthetic datasets that we can produce multiple trajectories of an RMAB problem, in the real problem of service call scheduling problem operated by ARMMAN, there is only one trajectory available to us for every RMAB problem. Due to the specialty of the maternal and child health domain, it is unlikely to have the exactly same set of the pregnant mothers participating in the service call scheduling program at different times and under the same engagement behavior.

Given this restriction, we must evaluate the performance of a newly proposed policy using the only available trajectory. Unfortunately, the standard CWPDIS in Equation 12 does not work because the CWPDIS estimator is canceled out when there is only one trajectory:

EvalIS(π,𝒯)=t[T],i[N]γt1𝔼τ𝒯[rt,iρti(τ)]𝔼τ𝒯[ρti(τ)]=t[T],i[N]γt1rt,iρti(τ)ρti(τ)=t[T],i[N]γt1rt,i\displaystyle\text{Eval}_{\text{IS}}(\pi,\mathcal{T})=\sum\limits_{t\in[T],i\in[N]}\gamma^{t-1}\frac{\mathop{\mathbb{E}}_{\tau\sim\mathcal{T}}\left[r_{t,i}\rho_{ti}(\tau)\right]}{\mathop{\mathbb{E}}_{\tau\sim\mathcal{T}}\left[\rho_{ti}(\tau)\right]}=\sum\limits_{t\in[T],i\in[N]}\gamma^{t-1}\frac{r_{t,i}\rho_{ti}(\tau)}{\rho_{ti}(\tau)}=\sum\limits_{t\in[T],i\in[N]}\gamma^{t-1}r_{t,i} (13)

which is fixed regardless what target policy π\pi is used and the associated importance sampling weights π(at,ist)πbeh(at,ist)\frac{\pi(a_{t,i}\mid s_{t})}{\pi_{\text{beh}}(a_{t,i}\mid s_{t})} and ρti=t=1tπ(at,ist)πbeh(at,ist)\rho_{ti}=\prod\nolimits_{t^{\prime}=1}^{t}\frac{\pi(a_{t^{\prime},i}\mid s_{t^{\prime}})}{\pi_{\text{beh}}(a_{t^{\prime},i}\mid s_{t^{\prime}})}. This implies that we cannot use CWPDIS to evaluate the target policy when there is only one trajectory.

Accordingly, we use the following variant to evaluate the performance:

EvalIS(π,𝒯)=i[N],t[T]γt1rt,iρti(τ)𝔼t[T][ρti(τ)]\displaystyle\text{Eval}_{\text{IS}}(\pi,\mathcal{T})=\sum\limits_{i\in[N],t\in[T]}\gamma^{t-1}\frac{r_{t,i}\rho^{\prime}_{ti}(\tau)}{\mathop{\mathbb{E}}\nolimits_{t^{\prime}\in[T]}\left[\rho^{\prime}_{t^{\prime}i}(\tau)\right]} (14)

where the new importance sampling weights are defined by ρt,i(τ)=π(at,ist)πbeh(at,ist)\rho^{\prime}_{t,i}(\tau)=\frac{\pi(a_{t,i}\mid s_{t})}{\pi_{\text{beh}}(a_{t,i}\mid s_{t})}, which is not multiplicative compared to the original ones.

The main motivation of this new evaluation metric is to segment the given trajectory into a set of length-1 trajectories. We can apply CWPDIS to the newly generated length-1 trajectories to compute a meaningful estimate because we have more than one trajectory now. The OPE formulation with segmentation is under the assumption that we can decompose the total reward into the contribution of multiple segments using the idea of trajectory segmentation (krishnan2017transition; ranchod2015nonparametric). This assumption holds when all segments start with the same state distribution. In our ARMMAN dataset, the data is composed of trajectories of the participants who have enrolled in the system a few weeks ago, which have (almost) reached a stationary distribution. Therefore, the state distribution under the behavior policy, which is a uniform random policy, does not change over time. Our assumption of identical distribution is satisfied and we can decompose the trajectories into smaller segments to perform evaluation. Empirically, we noticed that this temporal decomposition helps define a meaningful importance sampling-based evaluation with the consistency benefit brought by CWPDIS.

15 Additional Experimental Results

We provide the learning curves of fully observable 2-state RMAB, fully observable 5-state RMAB, partially observable 2-state RMAB, and the real ARMMAN fully observable 2-state RMAB problems in Figure 7,8910, respectively. Across all domains, two-stage method consistently converges to a lower predictive loss faster than decision-focused learning in Figure 7(a),8(a)9(a)10(a). However, the learned model does not produce a policy with good performance in the importance sampling-based evaluation metric in Figure 7(b),8(b)9(b)10(b), and similarly in the simulation-based evaluation metric in Figure 7(c),8(c)9(c)10(c).

Refer to caption
(a) Testing predictive loss v.s. training epoch
Refer to caption
(b) Testing IS-based evaluation
Refer to caption
(c) Testing simulation-based evaluation
Figure 7: Comparison between two-stage and decision-focused in the synthetic fully observable 2-state RMAB problems.
Refer to caption
(a) Testing predictive loss v.s. training epoch
Refer to caption
(b) Testing IS-based evaluation
Refer to caption
(c) Testing simulation-based evaluation
Figure 8: Comparison between two-stage and decision-focused learning for fully observable 5-state RMAB problems.
Refer to caption
(a) Testing predictive loss v.s. training epoch
Refer to caption
(b) Testing IS-based evaluation
Refer to caption
(c) Testing simulation-based evaluation
Figure 9: Comparison between two-stage and decision-focused learning for 2-state partially observable RMAB problems.
Refer to caption
(a) Testing predictive loss
Refer to caption
(b) Testing IS-based evaluation
Refer to caption
(c) Testing simulation-based evaluation
Figure 10: Comparison between two-stage and decision-focused learning in the real ARMMAN service call scheduling problem. The pulling action in the real dataset is much sparser, leading to a larger mismatch between predictive loss and evaluation metrics. Two-stage overfits to the predictive loss drastically with no improvement in evaluation metrics. In contrast, decision-focused learning can directly optimize the evaluation metric to avoid the objective mismatch issue.

16 Solving for and Differentiating Through the Whittle Index Computation

To solve for the Whittle index for some state uSu\in S, you have to solve the following set of equations:

V(u)\displaystyle V(u) =R(s)+mu+γs𝒮P(s,0,s)V(s)\displaystyle=R(s)+m_{u}+\gamma\sum_{\mathclap{s^{\prime}\in\mathcal{S}}}P(s,0,s^{\prime})\cdot V(s^{\prime})
V(u)\displaystyle V(u) =R(s)+γs𝒮P(s,1,s)V(s)\displaystyle=R(s)+\gamma\sum_{\mathclap{s^{\prime}\in\mathcal{S}}}P(s,1,s^{\prime})\cdot V(s^{\prime}) (15)
V(s)\displaystyle V(s) =maxa{0,1}[R(s)+(1a)mu+γs𝒮P(s,a,s)V(s)],s𝒮u\displaystyle=\max_{a\in\{0,1\}}[R(s)+(1-a)m_{u}+\gamma\sum_{\mathclap{s^{\prime}\in\mathcal{S}}}P(s,a,s^{\prime})\cdot V(s^{\prime})],\quad\quad\forall s\in\mathcal{S}-u (16)

Here:

  • 𝒮\mathcal{S}

    is the set of all states

  • R(s)R(s)

    is the reward for being in state ss

  • P(s,a,s)P(s,a,s^{\prime})

    is the probability of transitioning to state ss^{\prime} when you begin in state ss and take action aa

  • V(s)V(s)

    is the expected value of being in state ss

  • msm_{s}

    is the whittle index for state ss

One way to interpret these equations is to view them as the Bellman Optimality Equations associated with a modified MDP in which the reward function is changed to Ru(s,a)=R(s)+(1a)muR^{\prime}_{u}(s,a)=R(s)+(1-a)m_{u}, i.e., you are given a ‘subsidy’ for not acting (Equation 16). Then, to find the whittle index for state uu, you have to find the minimum subsidy for which the value of not acting exceeds the value of acting (whittle1988restless). At this transition point, the value of not acting is equal to the value of acting in that state (Equation 15), leading to the set of equations above.

Now, this set of equations is typically hard to solve because of the max\max terms in Equation 16. Specifically, knowing whether argmaxa=0\arg\max_{a}=0 or argmaxa=1\arg\max_{a}=1 for some state ss is equivalent to knowing what the optimal policy is for this modified MDP; such equations are typically solved using Value Iteration or variations thereof. However, this problem is slightly more complicated than a standard MDP because one also has to determine the value of msm_{s}. The way that this problem is traditionally solved in the literature is the following:

  1. 1.

    One guesses a value for the subsidy msm_{s}.

  2. 2.

    Given this guess, one solves the Bellman Optimality Equations associated with the modified MDP.

  3. 3.

    Then, one checks the resultant policy. If it is more valuable to act than to not act in state ss, the value of the guess for the subsidy is increased. Else, it is decreased.

  4. 4.

    Go to Step 2 and repeat until convergence.

Given the monotonicity and the ability to bound the values of the whittle index, Step 3 above is typically solved using binary search. However, even with Binary Search, this process is quite time-consuming.

In this paper, we provide a much faster solution method for our application of interest. We leverage the small size of our state space to search over the space of policies rather than over the correct value of msm_{s}. Concretely, because S={0,1}S=\{0,1\} and A={0,1}A=\{0,1\}, the whittle index equations for state s=0s=0 above boil down to:

V(0)\displaystyle V(0) =R(0)+ms0+γs{0,1}P(0,0,s)V(s)\displaystyle=R(0)+m_{s_{0}}+\gamma\sum_{\mathclap{s^{\prime}\in\{0,1\}}}P(0,0,s^{\prime})\cdot V(s^{\prime})
V(0)\displaystyle V(0) =R(0)+γs{0,1}P(0,1,s)V(s)\displaystyle=R(0)+\gamma\sum_{\mathclap{s^{\prime}\in\{0,1\}}}P(0,1,s^{\prime})\cdot V(s^{\prime})
V(1)\displaystyle V(1) =maxa{0,1}[R(1)+(1a)ms0+γs{0,1}P(1,a,s)V(s)]\displaystyle=\max_{a\in\{0,1\}}[R(1)+(1-a)m_{s_{0}}+\gamma\sum_{\mathclap{s^{\prime}\in\{0,1\}}}P(1,a,s^{\prime})\cdot V(s^{\prime})] (17)

These are 3 equations in 3 unknowns (V(0),V(1),ms0V(0),V(1),m_{s_{0}}). The only hiccup here is that Equation 17 has a max\max term and so this set of equations can not be solved as normal linear equations would be. However, we can ‘unroll’ Equation 17 into 2 different equations:

V(1)\displaystyle V(1) =R(1)+ms0+γs{0,1}P(1,0,s)V(s),or\displaystyle=R(1)+m_{s_{0}}+\gamma\sum_{\mathclap{s^{\prime}\in\{0,1\}}}P(1,0,s^{\prime})\cdot V(s^{\prime}),\quad\quad\text{or} (18)
V(1)\displaystyle V(1) =R(1)+γs{0,1}P(1,1,s)V(s)\displaystyle=R(1)+\gamma\sum_{\mathclap{s^{\prime}\in\{0,1\}}}P(1,1,s^{\prime})\cdot V(s^{\prime}) (19)

Each of these corresponds to evaluating the value function associated with the partial policies s=1a=0s=1\rightarrow a=0 and s=1a=1s=1\rightarrow a=1. Then to get the optimal policy, we can just evaluate both of the policies and choose the better of the two policies, i.e., the policy with the higher expected value V(1)V(1). In practice, we pre-compute the Whittle index and value function using the binary search and value iteration approach studied by  qian2016restless. Therefore, to determine which equation is satisfied, we just use the pre-computed value functions to evaluate the expected future return of different actions, and use the one with higher value to form a set of linear equations.

This gives us a set of linear equations where Whittle index is a solution. We can therefore derive a closed-form expression of the Whittle index as a function of the transition probabilities, which is differentiable. This completes the differentiability of Whittle index. This technique is equivalent to saying that the policy does not change if we infinitesimally change the input probabilities.

16.1 Worked Example

s=0s=0
R(0)=0R(0)=0
s=1s=1
R(1)=1R(1)=1
0.2, 0.80.8, 0.20.5, 0.50.5, 0.5
Figure 11: An MDP with the probabilities associated with the passive action a=0a=0 in red and active action a=1a=1 in green.

Let us consider the concrete example in Figure 11 with γ=0.5\gamma=0.5. To calculate the whittle index for state s=0s=0, we have to solve the following set of linear equations:

V(0)\displaystyle V(0) =0+ms0+0.5[0.8V(0)+0.2V(1)]\displaystyle=0+m_{s_{0}}+0.5\cdot[0.8V(0)+0.2V(1)] V(0)\displaystyle V(0) =0+ms0+0.5[0.8V(0)+0.2V(1)]\displaystyle=0+m_{s_{0}}+0.5\cdot[0.8V(0)+0.2V(1)]
V(0)\displaystyle V(0) =0+0.5[0.2V(0)+0.8V(1)]\displaystyle=0+0.5\cdot[0.2V(0)+0.8V(1)] V(0)\displaystyle V(0) =0+0.5[0.2V(0)+0.8V(1)]\displaystyle=0+0.5\cdot[0.2V(0)+0.8V(1)]
V(1)\displaystyle{\color[rgb]{0,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,0}V(1)} =1+ms0+0.5[0.5V(0)+0.5V(1)]\displaystyle{\color[rgb]{0,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,0}=1+m_{s_{0}}+0.5\cdot[0.5V(0)+0.5V(1)]} V(1)\displaystyle{\color[rgb]{0,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,0}V(1)} =1+0.5[0.5V(0)+0.5V(1)]\displaystyle{\color[rgb]{0,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,0}=1+0.5\cdot[0.5V(0)+0.5V(1)]}
V(0)\displaystyle{\color[rgb]{0,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,0}V(0)} 0.65,V(1)1.45,ms00.25\displaystyle{\color[rgb]{0,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,0}\approx 0.65,V(1)\approx 1.45,m_{s_{\mathrlap{0}}}\approx 0.25} V(0)\displaystyle{\color[rgb]{0,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,0}V(0)} 0.52,V(1)1.18,ms00.20\displaystyle{\color[rgb]{0,0,0}\definecolor[named]{pgfstrokecolor}{rgb}{0,0,0}\approx 0.52,V(1)\approx 1.18,m_{s_{\mathrlap{0}}}\approx 0.20}

Here the left set of equations corresponds to taking action a=0a=0 in state s=1s=1 and the right corresponds to taking the action a=1a=1. As we can see in the above calculation, given subsidy ms0m_{s_{0}}, it is better to choose the passive action (a=0) on the left to obtain a higher expected future value V(1)V(1). On the other hand, this can also be verified by precomputing the Whittle index and the value function. Therefore, we know that the passive action in Equation 19 leads to a higher value, where the equality holds. Thus we can express the Whittle index as a solution to the following set of linear equations:

V(0)\displaystyle V(0) =R(0)+ms0+γs{0,1}P(0,0,s)V(s)\displaystyle=R(0)+m_{s_{0}}+\gamma\sum_{\mathclap{s^{\prime}\in\{0,1\}}}P(0,0,s^{\prime})\cdot V(s^{\prime})
V(0)\displaystyle V(0) =R(0)+γs{0,1}P(0,1,s)V(s)\displaystyle=R(0)+\gamma\sum_{\mathclap{s^{\prime}\in\{0,1\}}}P(0,1,s^{\prime})\cdot V(s^{\prime})
V(1)\displaystyle V(1) =R(1)+ms0+γs{0,1}P(1,0,s)V(s)\displaystyle=R(1)+m_{s_{0}}+\gamma\sum_{\mathclap{s^{\prime}\in\{0,1\}}}P(1,0,s^{\prime})\cdot V(s^{\prime})

By solving this set of linear equation, we can express the Whittle index ms0m_{s_{0}} as a function of the transition probabilities. Therefore, we can apply auto-differentiation to compute the derivative dms0dP\frac{dm_{s_{0}}}{dP}.