Distributionally Robust Deep Learning using Hardness Weighted Sampling

Lucas Fidon1, Michael Aertsen2, Thomas Deprest2, Doaa Emam3, Frédéric Guffens2, Nada Mufti4, Esther Van Elslander2, Ernst Schwartz5, Michael Ebner4, Daniela Prayer5, Gregor Kasprian5, Anna L David6, Andrew Melbourne4, Sébastien Ourselin4, Jan Deprest3, Georg Langs5, Tom Vercauteren4
1: Shool of Biomedical Engineering & Imaging Sciences, King's College London, 2: Department of Radiology, University Hospitals Leuven, 3: Department of Obstetrics and Gynaecology, University Hospitals Leuven, 4: School of Biomedical Engineering & Imaging Sciences, King’s College London, 5: Department of Biomedical Imaging and Image-guided Therapy, Medical University of Vienna, 6: Institute for Women’s Health, University College London
PIPPI 2021 special issue
Publication date: 2022/07/18
PDF · arXiv · Video · Code

Abstract

Limiting failures of machine learning systems is of paramount importance for safety-critical applications. In order to improve the robustness of machine learning systems, Distributionally Robust Optimization (DRO) has been proposed as a generalization of Empirical Risk Minimization (ERM). However, its use in deep learning has been severely restricted due to the relative inefficiency of the optimizers available for DRO in comparison to the wide-spread variants of Stochastic Gradient Descent (SGD) optimizers for ERM.
We propose SGD with hardness weighted sampling, a principled and efficient optimization method for DRO in machine learning that is particularly suited in the context of deep learning. Similar to a hard example mining strategy in practice, the proposed algorithm is straightforward to implement and computationally as efficient as SGD-based optimizers used for deep learning, requiring minimal overhead computation. In contrast to typical ad hoc hard mining approaches, we prove the convergence of our DRO algorithm for over-parameterized deep learning networks with ReLU activation and finite number of layers and parameters.
Our experiments on fetal brain 3D MRI segmentation and brain tumor segmentation in MRI demonstrate the feasibility and the usefulness of our approach. Using our hardness weighted sampling for training a state-of-the-art deep learning pipeline leads to improved robustness to anatomical variabilities in automatic fetal brain 3D MRI segmentation using deep learning and to improved robustness to the image protocol variations in brain tumor segmentation.a decrease of 2% of the interquartile range of the Dice scores for the enhanced tumor and the tumor core regions.
Our code is available at https://github.com/LucasFidon/HardnessWeightedSampler

Keywords

Machine Learning · Image Segmentation · Distributionally Robust Optimization

Bibtex @article{melba:2022:019:fidon, title = "Distributionally Robust Deep Learning using Hardness Weighted Sampling", author = "Fidon, Lucas and Aertsen, Michael and Deprest, Thomas and Emam, Doaa and Guffens, Frédéric and Mufti, Nada and Van Elslander, Esther and Schwartz, Ernst and Ebner, Michael and Prayer, Daniela and Kasprian, Gregor and David, Anna L and Melbourne, Andrew and Ourselin, Sébastien and Deprest, Jan and Langs, Georg and Vercauteren, Tom", journal = "Machine Learning for Biomedical Imaging", volume = "1", issue = "PIPPI 2021 special issue", year = "2022", pages = "1--61", issn = "2766-905X", url = "https://melba-journal.org/2022:019" }
RISTY - JOUR AU - Fidon, Lucas AU - Aertsen, Michael AU - Deprest, Thomas AU - Emam, Doaa AU - Guffens, Frédéric AU - Mufti, Nada AU - Van Elslander, Esther AU - Schwartz, Ernst AU - Ebner, Michael AU - Prayer, Daniela AU - Kasprian, Gregor AU - David, Anna L AU - Melbourne, Andrew AU - Ourselin, Sébastien AU - Deprest, Jan AU - Langs, Georg AU - Vercauteren, Tom PY - 2022 TI - Distributionally Robust Deep Learning using Hardness Weighted Sampling T2 - Machine Learning for Biomedical Imaging VL - 1 IS - PIPPI 2021 special issue SP - 1 EP - 61 SN - 2766-905X UR - https://melba-journal.org/2022:019 ER -

2022:019 cover


1 Introduction

Datasets used to train deep neural networks typically contain some underrepresented subsets of cases. These cases are not specifically dealt with by the training algorithms currently used for deep neural networks. This problem has been referred to as hidden stratification (Oakden-Rayner et al., 2020). Hidden stratification has been shown to lead to deep learning models with good average performance but poor performance on underrepresented but clinically relevant subsets of the population (Larrazabal et al., 2020; Oakden-Rayner et al., 2020; Puyol-Antón et al., 2021). In Figure 1 we give an example of hidden stratification in fetal brain MRI. The presence of abnormalities associated with diseases with low prevalence (Aertsen et al., 2019) exacerbates the anatomical variability of the fetal brain between 18 weeks and 38 weeks of gestation.

While uncovering the issue, the study of Oakden-Rayner et al. (2020) does not study the cause or propose a method to mitigate this problem. In addition, the work of Oakden-Rayner et al. (2020) is limited to classification. In standard deep learning pipelines, this hidden stratification is ignored and the model is trained to minimize the mean per-example loss, which corresponds to the standard Empirical Risk Minimization (ERM) problem. As a result, models trained with ERM are more likely to underperform on those examples from the underrepresented subdomains, seen as hard examples. This may lead to unfair AI systems (Larrazabal et al., 2020; Puyol-Antón et al., 2021). For example, state-of-the-art deep learning models for brain tumor segmentation (currently trained using ERM) underperform for cases with confounding effects, such as low grade gliomas, despite achieving good average and median performance (Bakas et al., 2018). For safety-critical systems, such as those used in healthcare, this greatly limits their usage as ethics guidelines of regulators such as European Commission (2019) require AI systems to be technically robust and fair prior to their deployment in hospitals.

Distributionally Robust Optimization (DRO) is a robust generalization of ERM that has been introduced in convex machine learning to model the uncertainty in the training data distribution (Chouzenoux et al., 2019; Duchi et al., 2016; Namkoong and Duchi, 2016; Rafique et al., 2018). Instead of minimizing the mean per-example loss on the training dataset, DRO seeks to optimize for the hardest weighted empirical training data distribution around the (uniform) empirical training data distribution. This suggests a link between DRO and Hard Example Mining. However, DRO as a generalization of ERM for machine learning still lacks optimization methods that are principled and computationally as efficient as SGD in the non-convex setting of deep learning. Previously proposed principled optimization methods for DRO consist in alternating between approximate maximization and minimization steps (Jin et al., 2019; Lin et al., 2019; Rafique et al., 2018). However, they differ from SGD methods for ERM by the introduction of additional hyperparameters for the optimizer such as a second learning rate and a ratio between the number of minimization and maximization steps. This makes DRO difficult to use as a drop-in replacement for ERM in practice.

In contrast, efficient weighted sampling methods, including Hard Example Mining (Chang et al., 2017; Loshchilov and Hutter, 2016; Shrivastava et al., 2016) and weighted sampling (Berger et al., 2018; Puyol-Antón et al., 2021), have been empirically shown to mitigate class imbalance issues and to improve deep embedding learning (Harwood et al., 2017; Suh et al., 2019; Wu et al., 2017). However, even though these works typically start from an ERM formulation, it is not clear how those heuristics formally relate to ERM in theory. This suggests that bridging the gap between DRO and weighted sampling methods could lead to a principled Hard Example Mining approach, or conversely to more efficient optimization methods for DRO in deep learning.

Given an efficient solver for the inner maximization problem in DRO, DRO could be addressed by maintaining a solution of the inner maximization problem and using a minimization scheme akin to the standard ERM but over an adaptively weighted empirical distribution. However, even in the case where a closed-form solution is available for the inner maximization problem, it would require performing a forward pass over the entire training dataset at each iteration. This cannot be done efficiently for large datasets. This suggests identifying an approximate, but practically usable, solution for the inner maximization problem based on a closed-form solution.

From a theoretical perspective, analysis of previous optimization methods for non-convex DRO (Jin et al., 2019; Lin et al., 2019; Rafique et al., 2018) made the assumption that the model is either smooth or weakly-convex, but none of those properties are true for deep neural networks with ReLUReLU\mathrm{ReLU}roman_ReLU activation functions that are typically used.

In this work, we propose SGD with hardness weighted sampling, a novel, principled optimization method for training deep neural networks with DRO and inspired by Hard Example Mining, that is computationally as efficient as SGD for ERM. Compared to SGD, our method only requires introducing an additional softmaxsoftmax\mathrm{softmax}roman_softmax layer and maintaining a stale per-example loss vector to compute sampling probabilities over the training data. This work is an extension of our previous preliminary work (Fidon et al., 2021b) in which we applied the proposed hardnes weighted sampler to distributionally robust fetal brain 3D MRI segmentation and studied the link between DRO and the minimization of percentiles of the per-example loss. In this extension, we formally introduce our hardness weighted sampler and we generalize recent results in the convergence theory of SGD with ERM and over-parameterized deep learning networks with ReLUReLU\mathrm{ReLU}roman_ReLU activation functions (Allen-Zhu et al., 2019b, a; Cao and Gu, 2020; Zou and Gu, 2019) to our SGD with hardness weighted sampling for DRO. This is, to the best of our knowledge, the first convergence result for deep learning networks with ReLUReLU\mathrm{ReLU}roman_ReLU trained with DRO. We also formally link DRO in our method with Hard Example Mining. As a result, our method can be seen as a principled Hard Example Mining approach. In terms of experiments, we have extended the evaluation on fetal brain 3D MRI with 69696969 additional fetal brain 3D MRIs. We have also added experiments on brain tumor segmentations and experiments on image classification with MNIST as a toy example. We show that our method outperforms plain SGD in the case of class imbalance, and improves the robustness of a state-of-the-art deep learning pipeline for fetal brain segentation and brain tumor segmentation. We evaluate the proposed methodology for the automatic segmentation of white matter, ventricles, and cerebellum based on fetal brain 3D T2w MRI. We used a total of 437437437437 fetal brain 3D MRIs including anatomically normal fetuses, fetuses with spina bifida aperta, and fetuses with other central nervous system pathologies for gestational ages ranging from 19191919 weeks to 40404040 weeks. Our empirical results suggest that the proposed training method based on distributionally robust optimization leads to better percentiles values for abnormal fetuses. In addition, qualitative results shows that distributionally robust optimization allows to reduce the number of clinically relevant failures of nnU-Net. For brain tumor segmentation our DRO-based method allows reducing the interquartile range of the Dice scores of 2%percent22\%2 % for the segmentation of the enhancing tumor and the tumor core regions.

1.1 Main Mathematical Notations

We summarize here the main mathematical notations. An extended list of notations can be found in Appendix A.

  • Training dataset: {(𝒙i,𝒚i)}i=1nsuperscriptsubscriptsubscript𝒙𝑖subscript𝒚𝑖𝑖1𝑛\{({\bm{x}}_{i},{\bm{y}}_{i})\}_{i=1}^{n}{ ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT.

  • Δn={(pi)i=1n[0,1]n,ipi=1}subscriptΔ𝑛formulae-sequencesuperscriptsubscriptsubscript𝑝𝑖𝑖1𝑛superscript01𝑛subscript𝑖subscript𝑝𝑖1\Delta_{n}=\left\{\left(p_{i}\right)_{i=1}^{n}\in[0,1]^{n},\,\,\sum_{i}p_{i}=1\right\}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { ( italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 } is a n𝑛nitalic_n-simplex.

  • Let 𝒒=(qi)Δn𝒒subscript𝑞𝑖subscriptΔ𝑛{\bm{q}}=(q_{i})\in\Delta_{n}bold_italic_q = ( italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, and f𝑓fitalic_f a function, we denote 𝔼𝒒[f(𝒙)]:=i=1nqif(𝒙i)assignsubscript𝔼𝒒delimited-[]𝑓𝒙superscriptsubscript𝑖1𝑛subscript𝑞𝑖𝑓subscript𝒙𝑖\mathbb{E}_{{\bm{q}}}[f({\bm{x}})]:=\sum_{i=1}^{n}q_{i}f({\bm{x}}_{i})blackboard_E start_POSTSUBSCRIPT bold_italic_q end_POSTSUBSCRIPT [ italic_f ( bold_italic_x ) ] := ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_f ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ).

  • Let 𝒒Δn𝒒subscriptΔ𝑛{\bm{q}}\in\Delta_{n}bold_italic_q ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, and f𝑓fitalic_f a function, we denote 𝕍𝒒[f(𝒙)]:=i=1nqif(𝒙i)𝔼q[f(𝒙)]2assignsubscript𝕍𝒒delimited-[]𝑓𝒙superscriptsubscript𝑖1𝑛subscript𝑞𝑖superscriptdelimited-∥∥𝑓subscript𝒙𝑖subscript𝔼𝑞delimited-[]𝑓𝒙2\mathbb{V}_{{\bm{q}}}[f({\bm{x}})]:=\sum_{i=1}^{n}q_{i}\left\lVert f({\bm{x}}_% {i})-\mathbb{E}_{q}[f({\bm{x}})]\right\rVert^{2}blackboard_V start_POSTSUBSCRIPT bold_italic_q end_POSTSUBSCRIPT [ italic_f ( bold_italic_x ) ] := ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ italic_f ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - blackboard_E start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT [ italic_f ( bold_italic_x ) ] ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

  • 𝒑trainsubscript𝒑train{\bm{p}}_{\rm{train}}bold_italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT is the uniform training data distribution, i.e. 𝒑train=(1n)i=1nΔnsubscript𝒑trainsuperscriptsubscript1𝑛𝑖1𝑛subscriptΔ𝑛{\bm{p}}_{\rm{train}}=\left(\frac{1}{n}\right)_{i=1}^{n}\in\Delta_{n}bold_italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT = ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.

  • \operatorname*{\mathcal{L}}caligraphic_L is the per-example loss function.

  • ERM is short for Empirical Risk Minimization.

  • DRO is short for Distributionally Robust Optimisation.

Refer to caption
Figure 1: Illustration of the anatomical variability in fetal brain across gestational ages and diagnostics. 1: Control (22 weeks); 2: Control (26 weeks); 3: Control (29 weeks); 4: Spina bifida (19 weeks); 5: Spina bifida (26 weeks); 6: Spina bifida (32 weeks); 7: Dandy-walker malformation with corpus callosum abnormality (23 weeks); 8: Dandy-walker malformation with ventriculomegaly and periventricular nodular heterotopia (27 weeks); 9: Aqueductal stenosis (34 weeks).

2 Related Works

An optimization method for group-DRO was proposed in (Sagawa et al., 2020). In contrast to the formulation of DRO that we study in this paper, their method requires additional labels allowing to identify the underrepresented group in the training dataset. However, those labels may not be available or may even be impossible to obtain in most applications. Sagawa et al. (2020) show that, when associated with strong regularization of the weights of the network, their group DRO method can tackle spurious correlations that are known a priori in some classification problems. It is worth noting that, in contrast, no regularization was necessary in our experiments with MNIST.

Biases of convolutional neural networks applied to medical image classification and segmentation has been studied in the literature. State-of-the-art deep neural networks for brain tumor segmentation underperform for cases with confounding effects, such as low grade gliomas (Bakas et al., 2018). It has been shown that scans coming from 15151515 different studies can be re-assigned with 73.3%percent73.373.3\%73.3 % accuracy to their source using a random forest classifier (Wachinger et al., 2019). A state-of-the-art deep neural networks for the diagnosis of 14141414 thoracic diseases using X-ray trained on a dataset with a gender bias underperform on X-ray of female patients (Larrazabal et al., 2020). And a state-of-the-art deep learning pipeline for cardiac MRI segmentation was found to underperform when evaluated on racial groups that were underrepresented in the training dataset (Puyol-Antón et al., 2021). To mitigate this problem, Puyol-Antón et al. (2021) proposed to use a stratified batch sampling approach during training that shares similarities with the group-DRO approach mentioned above (Sagawa et al., 2020). In contrast to our hardness weighted sampler, their stratified batch sampling approach requires additional labels, such as the racial group, that may not be available for training data. In addition, they do not study the formal relationship between the use of their stratified batch sampling approach and the training optimization problem.

In this work, we focus on DRO with a ϕitalic-ϕ\phiitalic_ϕ-divergence (Csiszár et al., 2004). In this case, the data distributions that are considered in the DRO problem (3) are restricted to sharing the support of the empirical training distribution. In other words, the weights assigned to the training data can change, but the training data itself remains unchanged. Another popular formulation is DRO with a Wasserstein distance (Chouzenoux et al., 2019; Duchi et al., 2016; Sinha et al., 2018; Staib and Jegelka, 2017). In contrast to ϕitalic-ϕ\phiitalic_ϕ-divergences, using a Wasserstein distance in DRO seeks to apply small data augmentation to the training data to make the deep learning model robust to small deformation of the data, but the sampling weights of the training data distribution typically remains unchanged. In this sense, DRO with a ϕitalic-ϕ\phiitalic_ϕ-divergence and DRO with a Wasserstein distance can be considered as orthogonal endeavours. While we show that DRO with ϕitalic-ϕ\phiitalic_ϕ-divergence can be seen as a principled Hard Exemple Mining method, it has been shown that DRO with a Wasserstein distance can be seen as a principled adversarial training method (Sinha et al., 2018; Staib and Jegelka, 2017).

The effect of multiplicative weighting during training, rather than weighted sampling used in our algorithm, has been studied empirically by (Byrd and Lipton, 2019) for image classification. They find that the effect of multiplicative weighting vanishes over training for classification tasks in which we can achieve zero loss on the training dataset. However, multiplicative weighting and weighted sampling affect the optimization dynamic in different ways. This may explain why we did not observe this vanishing effect in our experiments on classification and segmentation. Previous work have also studied empirical and convergence results of DRO for linear models (Hu and et al, 2018).

3 Methods

3.1 Background: Deep Learning with Distributionally Robust Optimization

Standard training procedures in machine learning are based on Empirical Risk Minimization (ERM) (Bottou et al., 2018). For a neural network hhitalic_h with parameters 𝜽𝜽{\bm{\theta}}bold_italic_θ, a per-example loss function \operatorname*{\mathcal{L}}caligraphic_L, and a training dataset {(𝒙i,𝒚i)}i=1nsuperscriptsubscriptsubscript𝒙𝑖subscript𝒚𝑖𝑖1𝑛\left\{({\bm{x}}_{i},{\bm{y}}_{i})\right\}_{i=1}^{n}{ ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, where 𝒙isubscript𝒙𝑖{\bm{x}}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are the inputs and 𝒚isubscript𝒚𝑖{\bm{y}}_{i}bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT are the labels, the ERM problem corresponds to

min𝜽{𝔼𝐩train[(h(𝒙;𝜽),𝒚)]=1ni=1n(h(𝒙i;𝜽),𝒚i)}subscript𝜽subscript𝔼subscript𝐩traindelimited-[]𝒙𝜽𝒚1𝑛superscriptsubscript𝑖1𝑛subscript𝒙𝑖𝜽subscript𝒚𝑖\min_{{\bm{\theta}}}\left\{\mathbb{E}_{\textbf{p}_{\rm{train}}}\left[% \operatorname*{\mathcal{L}}\left(h({\bm{x}};{\bm{\theta}}),{\bm{y}}\right)% \right]=\frac{1}{n}\sum_{i=1}^{n}\operatorname*{\mathcal{L}}\left(h({\bm{x}}_{% i};{\bm{\theta}}),{\bm{y}}_{i}\right)\right\}roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT { blackboard_E start_POSTSUBSCRIPT p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ caligraphic_L ( italic_h ( bold_italic_x ; bold_italic_θ ) , bold_italic_y ) ] = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) }(1)

where 𝐩trainsubscript𝐩train\textbf{p}_{\rm{train}}p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT is the empirical uniform distribution on the training dataset and 𝔼𝐩trainsubscript𝔼subscript𝐩train\mathbb{E}_{\textbf{p}_{\rm{train}}}blackboard_E start_POSTSUBSCRIPT p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT end_POSTSUBSCRIPT is the expected value operator as defined in section 1.1. When data augmentation is used, the number of samples n𝑛nitalic_n can become infinite. For our theoretical results, we suppose that 𝐩trainsubscript𝐩train\textbf{p}_{\rm{train}}p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT contains a finite number of examples. The extension of our Algorithm 1 to an infinite number of data augmentations using importance sampling is presented in section 3.2.2. Optionally, \operatorname*{\mathcal{L}}caligraphic_L can contain a parameter regularization term that is only a function of 𝜽𝜽{\bm{\theta}}bold_italic_θ.

The ERM training formulation assumes that 𝐩trainsubscript𝐩train\textbf{p}_{\rm{train}}p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT is an unbiased approximation of the true data distribution. However, this is generally impossible in domains such as medical image computing. This makes models trained with ERM at risk of underperforming on images from parts of the data distribution that are underrepresented in the training dataset.

In contrast, Distributionally Robust Optimization (DRO) is a family of generalization of ERM in which the uncertainty in the training data distribution is modelled by minimizing the worst-case expected loss over an uncertainty set of training data distributions (Rahimian and Mehrotra, 2019).

In this paper, we consider training deep neural networks with DRO based on a ϕitalic-ϕ\phiitalic_ϕ-divergence. We denote Δn:={(pi)i=1n[0,1]n|i=1npi=1}assignsubscriptΔ𝑛conditional-setsuperscriptsubscriptsubscript𝑝𝑖𝑖1𝑛superscript01𝑛superscriptsubscript𝑖1𝑛subscript𝑝𝑖1\Delta_{n}:=\left\{\left(p_{i}\right)_{i=1}^{n}\in[0,1]^{n}\,\,|\,\,\sum_{i=1}% ^{n}p_{i}=1\right\}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT := { ( italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT | ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 } the set of empirical training data probabilities vectors under consideration (i.e. the uncertainty set). The different probabilities vectors in ΔnsubscriptΔ𝑛\Delta_{n}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT correspond to all the possible weighting of the training dataset. Every 𝐩=(pi)i=1n𝐩superscriptsubscriptsubscript𝑝𝑖𝑖1𝑛\textbf{p}=(p_{i})_{i=1}^{n}p = ( italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT in ΔnsubscriptΔ𝑛\Delta_{n}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT gives a weight to each training example but keep the examples the same. We use the following definition of ϕitalic-ϕ\phiitalic_ϕ-divergence in the remainder of the paper.

Definition 1 (Strong Convexity)

Let f:Ωnormal-:𝑓normal-→normal-Ωf:\Omega\rightarrow{\mathbb{R}}italic_f : roman_Ω → blackboard_R be differentiable on Ωnormal-Ω\Omegaroman_Ω, a convex subset of {\mathbb{R}}blackboard_R and fsuperscript𝑓normal-′f^{\prime}italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT be the first derivative of f𝑓fitalic_f. Let ρ>0𝜌0\rho>0italic_ρ > 0, f𝑓fitalic_f is ρ𝜌\rhoitalic_ρ-strongly convex if for all x,yΩ,𝑥𝑦normal-Ωx,y\in\Omega,italic_x , italic_y ∈ roman_Ω ,
ϕ(y)ϕ(x)+ϕ(x)(yx)+ρ2(yx)2italic-ϕ𝑦italic-ϕ𝑥superscriptitalic-ϕnormal-′𝑥𝑦𝑥𝜌2superscript𝑦𝑥2\phi(y)\geq\phi(x)+\phi^{\prime}(x)(y-x)+\frac{\rho}{2}(y-x)^{2}italic_ϕ ( italic_y ) ≥ italic_ϕ ( italic_x ) + italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) ( italic_y - italic_x ) + divide start_ARG italic_ρ end_ARG start_ARG 2 end_ARG ( italic_y - italic_x ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

Definition 2 (ϕitalic-ϕ\phiitalic_ϕ-divergence)

Let ϕ:+normal-:italic-ϕnormal-→subscript\phi:{\mathbb{R}}_{+}\rightarrow{\mathbb{R}}italic_ϕ : blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT → blackboard_R be two times continuously differentiable on [0,n]0𝑛[0,n][ 0 , italic_n ], ρ𝜌\rhoitalic_ρ-strongly convex on [0,n]0𝑛[0,n][ 0 , italic_n ] with ρ>0𝜌0\rho>0italic_ρ > 0, and satisfying z,ϕ(z)ϕ(1)=0,ϕ(1)=0formulae-sequenceformulae-sequencefor-all𝑧italic-ϕ𝑧italic-ϕ10superscriptitalic-ϕnormal-′10\forall z\in{\mathbb{R}},\,\,\phi(z)\geq\phi(1)=0,\,\,\phi^{\prime}(1)=0∀ italic_z ∈ blackboard_R , italic_ϕ ( italic_z ) ≥ italic_ϕ ( 1 ) = 0 , italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( 1 ) = 0. The ϕitalic-ϕ\phiitalic_ϕ-divergence Dϕsubscript𝐷italic-ϕD_{\phi}italic_D start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT is defined as, for all 𝐩=(pi)i=1n,𝐪=(qi)i=1nΔnformulae-sequence𝐩superscriptsubscriptsubscript𝑝𝑖𝑖1𝑛𝐪superscriptsubscriptsubscript𝑞𝑖𝑖1𝑛subscriptnormal-Δ𝑛\textbf{p}=(p_{i})_{i=1}^{n},\textbf{q}=(q_{i})_{i=1}^{n}\in\Delta_{n}p = ( italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , q = ( italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT,

Dϕ(𝒒𝒑)=i=1npiϕ(qipi)subscript𝐷italic-ϕconditional𝒒𝒑superscriptsubscript𝑖1𝑛subscript𝑝𝑖italic-ϕsubscript𝑞𝑖subscript𝑝𝑖D_{\phi}\left(\textbf{q}\|\textbf{p}\right)=\sum_{i=1}^{n}p_{i}\phi\left(\frac% {q_{i}}{p_{i}}\right)italic_D start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( q ∥ p ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_ϕ ( divide start_ARG italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG )(2)

We refer to our example 1 on page 1 to highlight that the KL divergence is indeed a ϕitalic-ϕ\phiitalic_ϕ-divergence.

The DRO problem for which we propose an optimizer for training deep neural networks can be formally defined as

min𝜽{R(𝑳(h(𝜽))):=max𝒒Δn(𝔼𝒒[(h(𝒙;𝜽),𝒚)]1βDϕ(𝒒𝐩train))}subscript𝜽assign𝑅𝑳𝜽subscript𝒒subscriptΔ𝑛subscript𝔼𝒒delimited-[]𝒙𝜽𝒚1𝛽subscript𝐷italic-ϕconditional𝒒subscript𝐩train\min_{{\bm{\theta}}}\,\left\{R\left({\bm{L}}(h({\bm{\theta}}))\right):=\max_{{% \bm{q}}\in\Delta_{n}}\left(\mathbb{E}_{{\bm{q}}}\left[\operatorname*{\mathcal{% L}}\left(h({\bm{x}};{\bm{\theta}}),{\bm{y}}\right)\right]-\frac{1}{{\beta}}D_{% \phi}\left({\bm{q}}\|\textbf{p}_{\rm{train}}\right)\right)\right\}roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT { italic_R ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) := roman_max start_POSTSUBSCRIPT bold_italic_q ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( blackboard_E start_POSTSUBSCRIPT bold_italic_q end_POSTSUBSCRIPT [ caligraphic_L ( italic_h ( bold_italic_x ; bold_italic_θ ) , bold_italic_y ) ] - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG italic_D start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( bold_italic_q ∥ p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT ) ) }(3)

where 𝐩trainsubscript𝐩train\textbf{p}_{\rm{train}}p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT is the uniform empirical distribution, and β>0𝛽0{\beta}>0italic_β > 0 an hyperparameter. The choice of β𝛽{\beta}italic_β and ϕitalic-ϕ\phiitalic_ϕ controls how the unknown training data distribution q𝑞qitalic_q is allowed to differ from 𝐩trainsubscript𝐩train\textbf{p}_{\rm{train}}p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT. Here and thereafter, we use the notation 𝑳(h(𝜽)):=((h(𝒙i;𝜽),𝒚i))i=1nassign𝑳𝜽superscriptsubscriptsubscript𝒙𝑖𝜽subscript𝒚𝑖𝑖1𝑛{\bm{L}}(h({\bm{\theta}})):=\left(\operatorname*{\mathcal{L}}(h({\bm{x}}_{i};{% \bm{\theta}}),{\bm{y}}_{i})\right)_{i=1}^{n}bold_italic_L ( italic_h ( bold_italic_θ ) ) := ( caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT to refer to the vector of loss values of the n𝑛nitalic_n training samples for the value 𝜽𝜽{\bm{\theta}}bold_italic_θ of the parameters of the neural network hhitalic_h. In the remainder of the paper, we will refer to R𝑅Ritalic_R as the distributionally robust loss.

Our analysis of the properties of R𝑅Ritalic_R in the next sections relies on the Fenchel duality (Moreau, 1965) and the notion of Fenchel conjugate (Fenchel, 1949).

Definition 3 (Fenchel Conjugate Function)

Let f:m{+}normal-:𝑓normal-→superscript𝑚f:{\mathbb{R}}^{m}\rightarrow{\mathbb{R}}\cup\{+\infty\}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT → blackboard_R ∪ { + ∞ } be a proper function. The Fenchel conjugate of f𝑓fitalic_f is defined as 𝐯m,f*(𝐯)=sup𝐱m𝐯,𝐱f(𝐱)formulae-sequencefor-all𝐯superscript𝑚superscript𝑓𝐯subscriptsupremum𝐱superscript𝑚𝐯𝐱𝑓𝐱\forall{\bm{v}}\in{\mathbb{R}}^{m},\,\,f^{*}({\bm{v}})=\sup_{{\bm{x}}\in{% \mathbb{R}}^{m}}\langle{\bm{v}},{\bm{x}}\rangle-f({\bm{x}})∀ bold_italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT , italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_v ) = roman_sup start_POSTSUBSCRIPT bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ⟨ bold_italic_v , bold_italic_x ⟩ - italic_f ( bold_italic_x ) where ,normal-⋅normal-⋅\langle\cdot,\cdot\rangle⟨ ⋅ , ⋅ ⟩ is the inner product.

3.2 Hardness Weighted Sampling for Distributionally Robust Deep Learning

In the case where hhitalic_h is a non-convex predictor (such as a deep neural network), existing optimization methods for the DRO problem (3) alternate between approximate minimization and maximization steps (Jin et al., 2019; Lin et al., 2019; Rafique et al., 2018), requiring the introduction of additional hyperparameters compared to SGD. However, these are difficult to tune in practice and convergence has not been proven for non-smooth deep neural networks such as those with ReLUReLU\mathrm{ReLU}roman_ReLU activation functions.

In this section, we present an SGD-like optimization method for training a deep learning model hhitalic_h with the DRO problem (3). We first highlight, in Section 3.2.1, mathematical properties that allow us to link DRO with stochastic gradient descent (SGD) combined with an adaptive sampling that we refer to as hardness weighted sampling. In Section 3.2.2, we present our Algorithm 1 for distributionally robust deep learning. Then, in Section 3.3, we present theoretical convergence results for our hardness weighted sampling.

3.2.1 A sampling approach to Distributionally Robust Optimization

The goal of this subsection is to show that a stochastic approximation of the gradient of the distributionally robust loss can be obtained by using a weighted sampler. This result is a first step towards our Algorithm 1 for efficient training with the distributionally robust loss presented in the next subsection.

To reformulate R𝑅Ritalic_R as an unconstrained optimization problem over nsuperscript𝑛{\mathbb{R}}^{n}blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT (rather than constraining it to the n𝑛nitalic_n-simplex ΔnsubscriptΔ𝑛\Delta_{n}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT), we define

𝐩n,G(𝐩)=1βDϕ(𝐩𝐩train)+δΔn(𝐩)formulae-sequencefor-all𝐩superscript𝑛𝐺𝐩1𝛽subscript𝐷italic-ϕconditional𝐩subscript𝐩trainsubscript𝛿subscriptΔ𝑛𝐩\displaystyle\forall\textbf{p}\in{\mathbb{R}}^{n},\quad G(\textbf{p})=\frac{1}% {{\beta}}D_{\phi}(\textbf{p}\|\textbf{p}_{\rm{train}})+\delta_{\Delta_{n}}(% \textbf{p})∀ p ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , italic_G ( p ) = divide start_ARG 1 end_ARG start_ARG italic_β end_ARG italic_D start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( p ∥ p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT ) + italic_δ start_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( p )(4)

where δΔnsubscript𝛿subscriptΔ𝑛\delta_{\Delta_{n}}italic_δ start_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT is the characteristic function of the to the n𝑛nitalic_n-simplex ΔnsubscriptΔ𝑛\Delta_{n}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT which is a closed convex set, i.e.

𝐩n,δΔn(𝐩)={0if 𝐩Δn+otherwiseformulae-sequencefor-all𝐩superscript𝑛subscript𝛿subscriptΔ𝑛𝐩cases0if 𝐩subscriptΔ𝑛otherwise\forall\textbf{p}\in{\mathbb{R}}^{n},\quad\delta_{\Delta_{n}}(\textbf{p})=% \left\{\begin{array}[]{cl}0&\text{if }\textbf{p}\in\Delta_{n}\\ +\infty&\text{otherwise}\end{array}\right.∀ p ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , italic_δ start_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( p ) = { start_ARRAY start_ROW start_CELL 0 end_CELL start_CELL if bold_p ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL + ∞ end_CELL start_CELL otherwise end_CELL end_ROW end_ARRAY(5)

The distributionally robust loss R𝑅Ritalic_R in (3) can now be rewritten using the Fenchel conjugate function G*superscript𝐺G^{*}italic_G start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT of G𝐺Gitalic_G. This allows us to obtain regularity properties for R𝑅Ritalic_R.

Lemma 4 (Regularity of R𝑅Ritalic_R)

If ϕitalic-ϕ\phiitalic_ϕ satisfies Definition 2 (i.e. can be used for a ϕitalic-ϕ\phiitalic_ϕ-divergence), then G𝐺Gitalic_G and R𝑅Ritalic_R satisfy the following:

G is(nρβ)-strongly convex𝐺 is𝑛𝜌𝛽-strongly convexG\text{ is}\left(\frac{n\rho}{{\beta}}\right)\text{-strongly convex}italic_G is ( divide start_ARG italic_n italic_ρ end_ARG start_ARG italic_β end_ARG ) -strongly convex(6)
𝜽,R(𝑳(h(𝜽)))=max𝒒n(𝑳(h(𝜽)),𝒒G(𝒒))=G*(𝑳(h(𝜽)))for-all𝜽𝑅𝑳𝜽subscript𝒒superscript𝑛𝑳𝜽𝒒𝐺𝒒superscript𝐺𝑳𝜽\forall{\bm{\theta}},\quad R({\bm{L}}(h({\bm{\theta}})))=\max_{\textbf{q}\in{% \mathbb{R}}^{n}}\Big{(}\langle{\bm{L}}(h({\bm{\theta}})),\textbf{q}\rangle-G(% \textbf{q})\Big{)}=G^{*}\left({\bm{L}}(h({\bm{\theta}}))\right)∀ bold_italic_θ , italic_R ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) = roman_max start_POSTSUBSCRIPT q ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( ⟨ bold_italic_L ( italic_h ( bold_italic_θ ) ) , q ⟩ - italic_G ( q ) ) = italic_G start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_L ( italic_h ( bold_italic_θ ) ) )(7)
R is (βnρ)-gradient Lipschitz continuous.𝑅 is 𝛽𝑛𝜌-gradient Lipschitz continuous.R\text{ is }\left(\frac{{\beta}}{n\rho}\right)\text{-gradient Lipschitz % continuous.}italic_R is ( divide start_ARG italic_β end_ARG start_ARG italic_n italic_ρ end_ARG ) -gradient Lipschitz continuous.(8)

Equation (7) follows from Definition 3. Proofs of (6) and (8) can be found in Appendix E. According to (6), the optimization problem (7) is strictly convex and admits a unique solution in ΔnsubscriptΔ𝑛\Delta_{n}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, which we denote as

𝐩¯(𝑳(h(𝜽)))=argmax𝐪n(𝑳(h(𝜽)),𝐪G(𝐪))¯𝐩𝑳𝜽subscriptargmax𝐪superscript𝑛𝑳𝜽𝐪𝐺𝐪\bar{\textbf{p}}({\bm{L}}(h({\bm{\theta}})))=\operatorname*{arg\,max}_{\textbf% {q}\in{\mathbb{R}}^{n}}\left(\langle{\bm{L}}(h({\bm{\theta}})),\textbf{q}% \rangle-G(\textbf{q})\right)over¯ start_ARG p end_ARG ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) = start_OPERATOR roman_arg roman_max end_OPERATOR start_POSTSUBSCRIPT q ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( ⟨ bold_italic_L ( italic_h ( bold_italic_θ ) ) , q ⟩ - italic_G ( q ) )(9)

Thanks to those properties, we can now show the following lemma that is essential for the theoretical foundation of our Algorithm 1. Equation (10) states that the gradient of the distributionally robust loss R𝑅Ritalic_R is a weighted sum of the the gradients of the per-example losses (i.e. the gradients computed by the backpropagation algorithm in deep learning) with the weights given by the empirical distribution 𝐩¯(𝑳(𝒉(𝜽)))¯𝐩𝑳𝒉𝜽\bar{\textbf{p}}({\bm{L}}({\bm{h}}({\bm{\theta}})))over¯ start_ARG p end_ARG ( bold_italic_L ( bold_italic_h ( bold_italic_θ ) ) ). We further show that straightforward analytical formulas exist for 𝐩¯¯𝐩\bar{\textbf{p}}over¯ start_ARG p end_ARG, and give an example of such probability distribution for the Kullback-Leibler (KL) divergence.

Lemma 5 (Stochastic Gradient of the Distributionally Robust Loss)

For all 𝛉𝛉{\bm{\theta}}bold_italic_θ, we have

𝜽(R𝑳h)(𝜽)subscript𝜽𝑅𝑳𝜽\displaystyle\nabla_{\bm{\theta}}(R\circ{\bm{L}}\circ~{}h)({\bm{\theta}})∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ )=𝔼𝒑¯(𝑳(h(𝜽)))[𝜽(h(𝒙;𝜽),𝒚)]absentsubscript𝔼¯𝒑𝑳𝜽delimited-[]subscript𝜽𝒙𝜽𝒚\displaystyle=\mathbb{E}_{{\color[rgb]{0,0,0}\bar{\textbf{p}}({\bm{L}}(h({\bm{% \theta}})))}}\left[\nabla_{\bm{\theta}}\operatorname*{\mathcal{L}}\left(h({\bm% {x}};{\bm{\theta}}),{\bm{y}}\right)\right]= blackboard_E start_POSTSUBSCRIPT over¯ start_ARG p end_ARG ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_h ( bold_italic_x ; bold_italic_θ ) , bold_italic_y ) ](10)

The proof is found in Appendix F. We now provide a closed-form formula for 𝐩¯¯𝐩\bar{\textbf{p}}over¯ start_ARG p end_ARG given (h(𝜽))𝜽\operatorname*{\mathcal{L}}(h({\bm{\theta}}))caligraphic_L ( italic_h ( bold_italic_θ ) ) for the KL divergence as the choice of ϕitalic-ϕ\phiitalic_ϕ-divergence.

Example 1

For ϕ:zzlog(z)z+1normal-:italic-ϕmaps-to𝑧𝑧𝑧𝑧1\phi:z\mapsto z\log(z)-z+1italic_ϕ : italic_z ↦ italic_z roman_log ( italic_z ) - italic_z + 1, Dϕsubscript𝐷italic-ϕD_{\phi}italic_D start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT is the Kullback-Leibler (KL) divergence:

Dϕ(𝒒𝒑)=DKL(𝒒𝒑)=i=1nqilog(qipi)subscript𝐷italic-ϕconditional𝒒𝒑subscript𝐷KLconditional𝒒𝒑superscriptsubscript𝑖1𝑛subscript𝑞𝑖subscript𝑞𝑖subscript𝑝𝑖D_{\phi}(\textbf{q}\|\textbf{p})=D_{\mathrm{KL}}(\textbf{q}\|\textbf{p})=\sum_% {i=1}^{n}q_{i}\log\left(\frac{q_{i}}{p_{i}}\right)italic_D start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( q ∥ p ) = italic_D start_POSTSUBSCRIPT roman_KL end_POSTSUBSCRIPT ( q ∥ p ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( divide start_ARG italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG )(11)

In this case, we have (see Appendix D for a proof)

𝒑¯(𝑳(h(𝜽)))=softmax(β𝑳(h(𝜽)))¯𝒑𝑳𝜽softmax𝛽𝑳𝜽\bar{\textbf{p}}({\bm{L}}(h({\bm{\theta}})))=\mathrm{softmax}\left({\beta}{\bm% {L}}(h({\bm{\theta}}))\right)over¯ start_ARG p end_ARG ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) = roman_softmax ( italic_β bold_italic_L ( italic_h ( bold_italic_θ ) ) )(12)

3.2.2 Proposed Efficient Algorithm for Distributionally Robust Deep Learning

We now describe our algorithm for training deep neural networks with DRO using our hardness weighted sampling.

Algorithm 1 Training procedure for DRO with Hardness Weighted Sampling. Additional operations as compared to standard training algorithms are highlighted in blue.
1:{(𝒙i,𝒚i)}i=1nsuperscriptsubscriptsubscript𝒙𝑖subscript𝒚𝑖𝑖1𝑛\left\{({\bm{x}}_{i},{\bm{y}}_{i})\right\}_{i=1}^{n}{ ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT: training dataset with n>0𝑛0n>0italic_n > 0 the number of training samples.
2:b{1,,n}𝑏1𝑛b\in\{1,\ldots,n\}italic_b ∈ { 1 , … , italic_n }: batch size.
3:\operatorname*{\mathcal{L}}caligraphic_L: (any) smooth per-example loss function (e.g. cross entropy loss, Dice loss).
4:β>0𝛽0{\beta}>0italic_β > 0: robustness parameter defining the distributionally robust optimization problem.
5:𝜽0subscript𝜽0{\bm{\theta}}_{0}bold_italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT: initial parameter vector for the model hhitalic_h to train.
6:𝑳initsubscript𝑳𝑖𝑛𝑖𝑡{\bm{L}}_{init}bold_italic_L start_POSTSUBSCRIPT italic_i italic_n italic_i italic_t end_POSTSUBSCRIPT: initial stale per-example loss values vector.
7:t0𝑡0t\leftarrow 0italic_t ← 0\triangleright initialize the time step
8:𝑳𝑳init𝑳subscript𝑳𝑖𝑛𝑖𝑡{\bm{L}}\leftarrow{\bm{L}}_{init}bold_italic_L ← bold_italic_L start_POSTSUBSCRIPT italic_i italic_n italic_i italic_t end_POSTSUBSCRIPT\triangleright initialize the vector of stale loss values
9:while 𝜽tsubscript𝜽𝑡{\bm{\theta}}_{t}bold_italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT has not converged do
10:     𝒑tsoftmax(β𝑳)subscript𝒑𝑡softmax𝛽𝑳{\bm{p}}_{t}\leftarrow\mathrm{softmax}({\beta}{\bm{L}})bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← roman_softmax ( italic_β bold_italic_L )\triangleright online estimation of the hardness weights
11:     I𝒑tsimilar-to𝐼subscript𝒑𝑡I\sim{\bm{p}}_{t}italic_I ∼ bold_italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT\triangleright hardness weighted sampling
12:     if importance sampling is not used then
13:          iI,wi=1formulae-sequencefor-all𝑖𝐼subscript𝑤𝑖1\forall i\in I,\,\,w_{i}=1∀ italic_i ∈ italic_I , italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1
14:     else
15:          iI,wiexp(β((h(𝒙i;𝜽),𝒚i)Li))formulae-sequencefor-all𝑖𝐼subscript𝑤𝑖𝛽subscript𝒙𝑖𝜽subscript𝒚𝑖subscript𝐿𝑖\forall i\in I,\,\,w_{i}\leftarrow\exp\left({\beta}(\operatorname*{\mathcal{L}% }(h({\bm{x}}_{i};{\bm{\theta}}),{\bm{y}}_{i})-L_{i})\right)∀ italic_i ∈ italic_I , italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ← roman_exp ( italic_β ( caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )\triangleright importance sampling weights
16:          iI,wiclip(wi,[wmin,wmax])formulae-sequencefor-all𝑖𝐼subscript𝑤𝑖clipsubscript𝑤𝑖subscript𝑤𝑚𝑖𝑛subscript𝑤𝑚𝑎𝑥\forall i\in I,\,\,w_{i}\leftarrow\textup{clip}\left(w_{i},[w_{min},w_{max}]\right)∀ italic_i ∈ italic_I , italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ← clip ( italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , [ italic_w start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT , italic_w start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT ] )\triangleright clip the weights for stability      
17:      iI,Li(h(𝒙i;𝜽),𝒚i)formulae-sequencefor-all𝑖𝐼subscript𝐿𝑖subscript𝒙𝑖𝜽subscript𝒚𝑖\forall i\in I,\,\,L_{i}\leftarrow\operatorname*{\mathcal{L}}(h({\bm{x}}_{i};{% \bm{\theta}}),{\bm{y}}_{i})∀ italic_i ∈ italic_I , italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ← caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )\triangleright update the vector of stale loss values
18:     𝒈t1biIwi𝜽(h(𝒙i;𝜽t),𝒚i)subscript𝒈𝑡1𝑏subscript𝑖𝐼subscript𝑤𝑖subscript𝜽subscript𝒙𝑖subscript𝜽𝑡subscript𝒚𝑖{\bm{g}}_{t}\leftarrow\frac{1}{b}\sum_{i\in I}{\color[rgb]{0,0,1}w_{i}}\nabla_% {{\bm{\theta}}}\operatorname*{\mathcal{L}}(h({\bm{x}}_{i};{\bm{\theta}}_{t}),{% \bm{y}}_{i})bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ← divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_I end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
19:     𝜽t+1𝜽tη𝒈tsubscript𝜽𝑡1subscript𝜽𝑡𝜂subscript𝒈𝑡{\bm{\theta}}_{t+1}\leftarrow{\bm{\theta}}_{t}-\eta\,{\bm{g}}_{t}bold_italic_θ start_POSTSUBSCRIPT italic_t + 1 end_POSTSUBSCRIPT ← bold_italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_η bold_italic_g start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT\triangleright SGD step or any other optimizer (e.g. SGD momentum, Adam)
20:Output: 𝜽tsubscript𝜽𝑡{\bm{\theta}}_{t}bold_italic_θ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT

Equation (10) implies that 𝜽(h(𝒙i;𝜽),𝒚i)subscript𝜽subscript𝒙𝑖𝜽subscript𝒚𝑖\nabla_{\bm{\theta}}\operatorname*{\mathcal{L}}(h({\bm{x}}_{i};{\bm{\theta}}),% {\bm{y}}_{i})∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is an unbiased estimator of the gradient of the distributionally robust loss gradient when i𝑖iitalic_i is sampled with respect to 𝐩¯(𝑳(h(𝜽)))¯𝐩𝑳𝜽\bar{\textbf{p}}({\bm{L}}(h({\bm{\theta}})))over¯ start_ARG p end_ARG ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ). This suggests that the distributionally robust loss can be minimized efficiently by SGD by sampling mini-batches with respect to 𝐩¯(𝑳(h(𝜽)))¯𝐩𝑳𝜽\bar{\textbf{p}}({\bm{L}}(h({\bm{\theta}})))over¯ start_ARG p end_ARG ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) at each iteration. However, even though closed-form formulas were provided in Example 1 for 𝐩¯¯𝐩\bar{\textbf{p}}over¯ start_ARG p end_ARG, evaluating exactly 𝑳(h(𝜽))𝑳𝜽{\bm{L}}(h({\bm{\theta}}))bold_italic_L ( italic_h ( bold_italic_θ ) ), i.e. doing one forward pass on the whole training dataset at each iteration, is computationally prohibitive for large training datasets.

In practice, we propose to use a stale version of the vector of per-example loss values by maintaining an online history of the loss values of the examples seen during training ((h(𝒙i;𝜽(ti)),𝒚i))i=1nsuperscriptsubscriptsubscript𝒙𝑖superscript𝜽subscript𝑡𝑖subscript𝒚𝑖𝑖1𝑛\left(\operatorname*{\mathcal{L}}(h({\bm{x}}_{i};{\bm{\theta}}^{(t_{i})}),{\bm% {y}}_{i})\right)_{i=1}^{n}( caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ start_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, where for all i𝑖iitalic_i, tisubscript𝑡𝑖t_{i}italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the last iteration at which the per-example loss of example i𝑖iitalic_i has been computed. Using the Kullback-Leibler divergence as ϕitalic-ϕ\phiitalic_ϕ-divergence, this leads to the SGD with hardness weighted sampling algorithm proposed in Algorithm 1.

When data augmentation is used, an infinite number of training examples is virtually available. In this case, we keep one stale loss value per example irrespective of any augmentation as an approximation of the loss for this example under any augmentation.

Importance sampling is often used when sampling with respect to a desired distribution cannot be done exactly (Kahn and Marshall, 1953). In Algorithm 1, an up-to-date estimation of the per-example losses (or equivalently the hardness weights) in a batch is only available after sampling and evaluation through the network. Importance sampling can be used to compensate for the difference between the initial and the updated stale losses within this batch. We propose to use importance sampling in steps 9-10 of Algorithm 1 and highlight that this is especially useful to deal with data augmentation. Indeed, in this case, the stale losses for the examples in the batch are expected to be less accurate as they were estimated under a different augmentation. For efficiency, we use the following approximation wi=pinewpioldexp(β((h(𝒙i;𝜽),𝒚i)Li))subscript𝑤𝑖superscriptsubscript𝑝𝑖𝑛𝑒𝑤superscriptsubscript𝑝𝑖𝑜𝑙𝑑𝛽subscript𝒙𝑖𝜽subscript𝒚𝑖subscript𝐿𝑖w_{i}=\frac{p_{i}^{new}}{p_{i}^{old}}\approx\exp\left({\beta}(\operatorname*{% \mathcal{L}}(h({\bm{x}}_{i};{\bm{\theta}}),{\bm{y}}_{i})-L_{i})\right)italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = divide start_ARG italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n italic_e italic_w end_POSTSUPERSCRIPT end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_o italic_l italic_d end_POSTSUPERSCRIPT end_ARG ≈ roman_exp ( italic_β ( caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) where we have neglected the typically small change in the denominator of the softmaxsoftmax\mathrm{softmax}roman_softmax. More details are given in Appendix C. To tackle the typical instabilities that can arise when using importance sampling (Owen and Zhou, 2000), the importance weights are clipped.

Compared to standard SGD-based training optimizers for the mean loss, our algorithm requires only an additional softmaxsoftmax\mathrm{softmax}roman_softmax operation per iteration and to store an additional vector of scalars of size n𝑛nitalic_n (number of training examples), thereby making it well suited for deep learning applications. The computational time and memory overheads are studied in section 4.3.

For the convergence theorem, the stopping criteria is 𝜽(R𝑳h)(𝜽)ϵdelimited-∥∥subscript𝜽𝑅𝑳𝜽italic-ϵ\left\lVert\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}})\right% \rVert\leq\epsilon∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ ) ∥ ≤ italic_ϵ. However, in our experiments, a fixed number of iterations is used as implemented in the state-of-the-art method nnU-Net Isensee et al. (2021).

3.3 Overview of Theoretical Results

In this section, we present convergence guarantees for Algorithm 1 in the framework of over-parameterized deep learning. We further demonstrate properties of our hardness weighted sampling that allow to clarify its link with Hard Example Mining and with the minimization of percentiles of the per-sample loss on the training data distribution.

3.3.1 Convergence of SGD with Hardness Weighted Sampling for Over-parameterized Deep Neural Networks with ReLUReLU\mathrm{ReLU}roman_ReLU

Convergence results for over-parameterized deep learning have recently been proposed in (Allen-Zhu et al., 2019a). Their work gives convergence guarantees for deep neural networks hhitalic_h with any activation functions (including ReLUReLU\mathrm{ReLU}roman_ReLU), and with any (finite) number of layers L𝐿Litalic_L and parameters m𝑚mitalic_m, under the assumption that m𝑚mitalic_m is large enough. In our work, we extend the convergence theory developed by (Allen-Zhu et al., 2019a) for ERM and SGD to DRO using the proposed SGD with hardness weighted sampling and stale per-example loss vector (as stated in Algorithm 1). The proof in Appendix I.4 deals with the challenges raised by the non-linearity of R𝑅Ritalic_R with respect to the per-sample stale loss and the non-uniform dynamic sampling used in Algorithm 1.

Theorem 6 (Convergence of Algorithm 1 for neural networks with ReLUnormal-ReLU\mathrm{ReLU}roman_ReLU)

Let \operatorname*{\mathcal{L}}caligraphic_L be a smooth per-example loss function, b{1,,n}𝑏1normal-…𝑛b\in\{1,\ldots,n\}italic_b ∈ { 1 , … , italic_n } be the batch size, and ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0. If the number of parameters m𝑚mitalic_m is large enough, and the learning rate is small enough, then, with high probability over the randomness of the initialization and the mini-batches, Algorithm 1 (without importance sampling) guarantees 𝛉(R𝐋h)(𝛉)ϵdelimited-∥∥subscriptnormal-∇𝛉𝑅𝐋𝛉italic-ϵ\left\lVert\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}})\right% \rVert\leq\epsilon∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ ) ∥ ≤ italic_ϵ after a finite number of iterations.

A detailed description of the assumption for this theorem is described in Appendix 12 and its proof can be found in Appendix I.4. Our proof does not cover the case where importance sampling is used. However, our empirical results suggest that convergence guarantees still hold with importance sampling.

3.3.2 Link between Hardness Weighted Sampling and Hard Example Mining

In this section, we discuss the relationship between the proposed hardness weighted sampling for DRO and Hard Example Mining. The following result shows that using the proposed hardness weighted sampler the hard training examples, those training examples with relatively high values of the loss, are sampled with higher probability.

Theorem 7

Let a ϕitalic-ϕ\phiitalic_ϕ-divergence that satisfies Definition 2, and 𝐋=(Li)i=1nn𝐋superscriptsubscriptsubscript𝐿𝑖𝑖1𝑛superscript𝑛\textbf{L}=\left(L_{i}\right)_{i=1}^{n}\in{\mathbb{R}}^{n}L = ( italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT a vector of loss values for the examples {𝐱1,,𝐱n}subscript𝐱1normal-…subscript𝐱𝑛\{{\bm{x}}_{1},\ldots,{\bm{x}}_{n}\}{ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT }. The proposed hardness weighted sampling probabilities vector 𝐩¯(𝐋)=(p¯i(𝐋))i=1nnormal-¯𝐩𝐋superscriptsubscriptsubscriptnormal-¯𝑝𝑖𝐋𝑖1𝑛\bar{\textbf{p}}\left(\textbf{L}\right)=\left(\bar{p}_{i}\left(\textbf{L}% \right)\right)_{i=1}^{n}over¯ start_ARG p end_ARG ( L ) = ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( L ) ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT defined as in (9) verifies:

  1. 1.

    For all i{1,,n}𝑖1𝑛i\in\{1,\ldots,n\}italic_i ∈ { 1 , … , italic_n }, p¯isubscript¯𝑝𝑖\bar{p}_{i}over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is an increasing function of Lisubscript𝐿𝑖L_{i}italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT.

  2. 2.

    For all i{1,,n}𝑖1𝑛i\in\{1,\ldots,n\}italic_i ∈ { 1 , … , italic_n }, p¯isubscript¯𝑝𝑖\bar{p}_{i}over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is an non-increasing function of any Ljsubscript𝐿𝑗L_{j}italic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT for ji𝑗𝑖j\neq iitalic_j ≠ italic_i.

See Appendix G for the proof. The second part of Theorem 7 implies that as the loss of an example diminishes, the sampling probabilities of all the other examples increase. As a result, the proposed SGD with hardness weighted sampling balances exploitation (i.e. sampling the identified hard examples) and exploration (i.e. sampling any example to keep the record of hard examples up to date). Heuristics to enforce this trade-off are often used in Hard Example Mining methods (Berger et al., 2018; Harwood et al., 2017; Wu et al., 2017).

3.3.3 Link between DRO and the Minimization of a Loss Percentile

In this section, we show that the DRO problem (3) using the KL divergence is equivalent to a relaxation of the minimization of the per-example loss percentile shown thereafter in equation (13).

Instead of the average per-example loss (1), for robustness, one might be more interested in minimizing the percentile lαsubscript𝑙𝛼l_{\alpha}italic_l start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT at α𝛼\alphaitalic_α (e.g. 5%) of the per-example loss function. Formally, this corresponds to the minimization problem

min𝜽,lαlαsuch thatptrain((h(𝒙;𝜽),𝒚)lα)αsubscript𝜽subscript𝑙𝛼subscript𝑙𝛼such thatsubscript𝑝train𝒙𝜽𝒚subscript𝑙𝛼𝛼\min_{{\bm{\theta}},\,l_{\alpha}}\quad l_{\alpha}\qquad\textrm{such that}% \qquad p_{\rm{train}}\left(\operatorname*{\mathcal{L}}\left(h({\bm{x}};{\bm{% \theta}}),{\bm{y}}\right)\geq l_{\alpha}\right)\leq\alpharoman_min start_POSTSUBSCRIPT bold_italic_θ , italic_l start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_l start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT such that italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT ( caligraphic_L ( italic_h ( bold_italic_x ; bold_italic_θ ) , bold_italic_y ) ≥ italic_l start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) ≤ italic_α(13)

where ptrainsubscript𝑝trainp_{\rm{train}}italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT is the empirical distribution defined by the training dataset. In other words, if α=0.05𝛼0.05\alpha=0.05italic_α = 0.05, the optimal lα*(𝜽)superscriptsubscript𝑙𝛼𝜽l_{\alpha}^{*}({\bm{\theta}})italic_l start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_θ ) of (13) for a given value set of parameters 𝜽𝜽{\bm{\theta}}bold_italic_θ is the value of the loss such that the per-example loss function is worse than lα*(𝜽)superscriptsubscript𝑙𝛼𝜽l_{\alpha}^{*}({\bm{\theta}})italic_l start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_θ ) 5%percent55\%5 % of the time. As a result, training the deep neural network using (13) corresponds to minimizing the percentile of the per-example loss function lα*(𝜽)superscriptsubscript𝑙𝛼𝜽l_{\alpha}^{*}({\bm{\theta}})italic_l start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_θ ).

Unfortunately, the minimization problem (13) cannot be solved directly using stochastic gradient descent to train a deep neural network. We now propose a tractable upper bound for lα*(𝜽)superscriptsubscript𝑙𝛼𝜽l_{\alpha}^{*}({\bm{\theta}})italic_l start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_θ ) and show that it can be solved in practice using distributionally robust optimization.

The Chernoff bound (Chernoff et al., 1952) applied to the per-example loss function and the empirical training data distribution states that for all lαsubscript𝑙𝛼l_{\alpha}italic_l start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT and β>0𝛽0{\beta}>0italic_β > 0

ptrain((h(𝒙;𝜽),𝒚)lα)exp(βlα)ni=1nexp(β(h(𝒙i;𝜽),𝒚i))subscript𝑝train𝒙𝜽𝒚subscript𝑙𝛼𝛽subscript𝑙𝛼𝑛superscriptsubscript𝑖1𝑛𝛽subscript𝒙𝑖𝜽subscript𝒚𝑖p_{\rm{train}}\left(\operatorname*{\mathcal{L}}\left(h({\bm{x}};{\bm{\theta}})% ,{\bm{y}}\right)\geq l_{\alpha}\right)\leq\frac{\exp\left(-{\beta}l_{\alpha}% \right)}{n}\sum_{i=1}^{n}\exp\left({\beta}\operatorname*{\mathcal{L}}\left(h({% \bm{x}}_{i};{\bm{\theta}}),{\bm{y}}_{i}\right)\right)italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT ( caligraphic_L ( italic_h ( bold_italic_x ; bold_italic_θ ) , bold_italic_y ) ≥ italic_l start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) ≤ divide start_ARG roman_exp ( - italic_β italic_l start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ) end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_exp ( italic_β caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )(14)

To link this inequality to the minimization problem (13), we set β>0𝛽0{\beta}>0italic_β > 0 and

l^α(𝜽)subscript^𝑙𝛼𝜽\displaystyle\hat{l}_{\alpha}({\bm{\theta}})over^ start_ARG italic_l end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_italic_θ )=1βlog(1αni=1nexp(β(h(𝒙i;𝜽),𝒚i)))absent1𝛽1𝛼𝑛superscriptsubscript𝑖1𝑛𝛽subscript𝒙𝑖𝜽subscript𝒚𝑖\displaystyle=\frac{1}{{\beta}}\log\left(\frac{1}{\alpha n}\sum_{i=1}^{n}\exp% \left({\beta}\operatorname*{\mathcal{L}}\left(h({\bm{x}}_{i};{\bm{\theta}}),{% \bm{y}}_{i}\right)\right)\right)= divide start_ARG 1 end_ARG start_ARG italic_β end_ARG roman_log ( divide start_ARG 1 end_ARG start_ARG italic_α italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_exp ( italic_β caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) )(15)

In this case, we have

ptrain((h(𝒙;𝜽),𝒚)l^α(𝜽))α=exp(βl^α(𝜽))ni=1nexp(β(h(𝒙i;𝜽),𝒚i))subscript𝑝train𝒙𝜽𝒚subscript^𝑙𝛼𝜽𝛼𝛽subscript^𝑙𝛼𝜽𝑛superscriptsubscript𝑖1𝑛𝛽subscript𝒙𝑖𝜽subscript𝒚𝑖p_{\rm{train}}\left(\operatorname*{\mathcal{L}}\left(h({\bm{x}};{\bm{\theta}})% ,{\bm{y}}\right)\geq\hat{l}_{\alpha}({\bm{\theta}})\right)\leq\alpha=\frac{% \exp\left(-{\beta}\hat{l}_{\alpha}({\bm{\theta}})\right)}{n}\sum_{i=1}^{n}\exp% \left({\beta}\operatorname*{\mathcal{L}}\left(h({\bm{x}}_{i};{\bm{\theta}}),{% \bm{y}}_{i}\right)\right)italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT ( caligraphic_L ( italic_h ( bold_italic_x ; bold_italic_θ ) , bold_italic_y ) ≥ over^ start_ARG italic_l end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_italic_θ ) ) ≤ italic_α = divide start_ARG roman_exp ( - italic_β over^ start_ARG italic_l end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_italic_θ ) ) end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_exp ( italic_β caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )(16)

l^α(𝜽)subscript^𝑙𝛼𝜽\hat{l}_{\alpha}({\bm{\theta}})over^ start_ARG italic_l end_ARG start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_italic_θ ) is therefore an upper bound for the optimal lα*(𝜽)subscriptsuperscript𝑙𝛼𝜽l^{*}_{\alpha}({\bm{\theta}})italic_l start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_α end_POSTSUBSCRIPT ( bold_italic_θ ) in equation (13), independently to the value of 𝜽𝜽{\bm{\theta}}bold_italic_θ. Equation (13) can therefore be relaxed by

min𝜽1βlog(i=1nexp(β(h(𝒙i;𝜽),𝒚i)))subscript𝜽1𝛽superscriptsubscript𝑖1𝑛𝛽subscript𝒙𝑖𝜽subscript𝒚𝑖\min_{{\bm{\theta}}}\frac{1}{{\beta}}\log\left(\sum_{i=1}^{n}\exp\left({\beta}% \operatorname*{\mathcal{L}}\left(h({\bm{x}}_{i};{\bm{\theta}}),{\bm{y}}_{i}% \right)\right)\right)roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_β end_ARG roman_log ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_exp ( italic_β caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) )(17)

where β>0𝛽0{\beta}>0italic_β > 0 is a hyperparameter, and where the term 1βlog(1αn)1𝛽1𝛼𝑛\frac{1}{{\beta}}\log\left(\frac{1}{\alpha n}\right)divide start_ARG 1 end_ARG start_ARG italic_β end_ARG roman_log ( divide start_ARG 1 end_ARG start_ARG italic_α italic_n end_ARG ) was dropped as being independent of 𝜽𝜽{\bm{\theta}}bold_italic_θ. While in (17), α𝛼\alphaitalic_α does not appear in the optimization problem directly anymore, β𝛽{\beta}italic_β essentially acts as a substitute for α𝛼\alphaitalic_α. The higher the value of β𝛽{\beta}italic_β, the higher weights the per-example losses with a high value will have in (17).

We give a proof in Appendix H that (17) is equivalent to solving the following DRO problem

min𝜽max𝒒Δn(i=1nqi(h(𝒙i;𝜽),𝒚i)1βDKL(𝒒𝒑train))\min_{{\bm{\theta}}}\,\max_{{\bm{q}}\in\Delta_{n}}\left(\sum_{i=1}^{n}q_{i}% \operatorname*{\mathcal{L}}\left(h({\bm{x}}_{i};{\bm{\theta}}),{\bm{y}}_{i}% \right)-\frac{1}{{\beta}}D_{KL}\left({\bm{q}}\,\biggr{\|}\,{\color[rgb]{0,0,0}% {\bm{p}}}_{\rm{train}}\right)\right)roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT roman_max start_POSTSUBSCRIPT bold_italic_q ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( bold_italic_q ∥ bold_italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT ) )(18)

This is a special case of the DRO problem (3) where ϕitalic-ϕ\phiitalic_ϕ is chosen as the KL-divergence and it corresponds to the setting of Algorithm 1.

4 Experiments

In this section, we experiments with the proposed hardness weighted sampler for DRO as implemented in the proposed Algorithm 1. In the subsection 4.1, we give a toy example with the task of automatic classification of digits in the case where the digit 3333 is underrepresented in the training dataset. And in subsection 4.2, we report the results of our experiments on two medical image segmentation tasks: fetal brain segmentation using 3D MRI, and brain tumor segmentation using 3D MRI.

4.1 Toy Example: MNIST Classification with a Class Imbalance

The goal of this subsection is to illustrate key benefits of training a deep neural network using DRO in comparison to ERM when a part of the sample distribution is underrepresented in the training dataset. We take the MNIST dataset (LeCun, 1998) as a toy example, in which the task is to automatically classify images representing digits between 00 and 9999. In addition, we verify the ability of our Algorithm 1 to train a deep neural network for DRO and illustrates the behaviour of SGD with hardness weighted sampling for different values of β𝛽{\beta}italic_β.

Material:

We create a bias between training and testing data distribution of MNIST (LeCun, 1998) by keeping only 1%percent11\%1 % of the digits 3333 in the training dataset, while the testing dataset remains unchanged.

For our experiments on MNIST, we used a Wide Residual Network (WRN) (Zagoruyko and Komodakis, 2016). The family of WRN models has proved to be very efficient and flexible, achieving state-of-the-art accuracy on several dataset. More specifically, we used WRN-16161616-1111 (Zagoruyko and Komodakis, 2016, section 2.3). For the optimization we used a learning rate of 0.010.010.010.01. No momentum or weight decay were used. No data augmentation was used. For DRO no importance sampling was used. We used a GPU NVIDIA GeForce GTX 1070 with 8GB of memory for the experiments on MNIST.

Refer to caption
Refer to caption
Figure 2: Experiments on MNIST. We compare the learning curves at testing (top panels) and at training (bottom panels) for ERM with SGD (blue) and DRO with our SGD with hardness weighted sampling for different values of β𝛽{\beta}italic_β (β=0.1𝛽0.1{\beta}=0.1italic_β = 0.1, β=1𝛽1{\beta}=1italic_β = 1, β=10𝛽10{\beta}=10italic_β = 10, β=100𝛽100{\beta}=100italic_β = 100). The models are trained on an imbalanced MNIST dataset (only 1%percent11\%1 % of the digits 3333 kept for training) and evaluated on the original MNIST testing dataset.
Results:

Our experiment suggests that DRO and ERM lead to different optima. Indeed, DRO for β=10𝛽10\beta=10italic_β = 10 outperforms ERM by more than 15%percent1515\%15 % of accuracy on the underrepresented class, as illustrated in Figure 2. This suggests that DRO is more robust than ERM to domain gaps between the training and the testing dataset. In addition, Figure 2 suggests that DRO with our SGD with hardness weighted sampling can converge faster than ERM with SGD.

Furthermore, the variations of learning curves with β𝛽{\beta}italic_β shown in Figure 2 are consistent with our theoretical insight. As β𝛽{\beta}italic_β decreases to 00, the learning curve of DRO with our Algorithm 1 converges to the learning curve of ERM with SGD.

For large values of β𝛽{\beta}italic_β (here β10𝛽10{\beta}\geq 10italic_β ≥ 10), instabilities appear before convergence in the testing learning curves, as illustrated in the top panels of Figure 2. However, the bottom left panel of Figure 2 shows that the training loss curves for β10𝛽10{\beta}\geq 10italic_β ≥ 10 were stable there. We also observe that during iterations where instabilities appear on the testing set, the standard deviation of the per-example loss on the training set is relatively high (i.e. the hardness weighted probability is further away from the uniform distribution). This suggests that the apparent instabilities on the testing set are related to differences between the distributionally robust loss and the mean loss.

4.2 Medical Image Segmentation

In this section, we illustrate the application of Algorithm 1 to improve the robustness of deep learning methods for medical image segmentation. We first discuss the specificities of applying the proposed hardness weighted sampling to medical image segmentation in relation to the use of patch-based sampling. We evaluated the proposed method on two applications: fetal brain 3D MRI segmentation using the FeTA dataset and a private dataset, and brain tumor multi-sequence MRI segmentation using the BraTS 2019 dataset (Bakas et al., 2017a, b).

4.2.1 Hardness Weighted Sampler with Large Images

In medical image segmentation, the image used as input of the deep neural network are typically large 3D volumes. For this reason, state-of-the-art deep learning pipelines use patch-based sampling rather than full-volume sampling during training with ERM  (Isensee et al., 2021) as described in subsection 4.2.2.

This raised the question of what is the training distribution ptrainsubscript𝑝trainp_{\rm{train}}italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT in the ERM (1) and DRO (3) optimization problems. Here, since the patches are large enough to cover most of the brains, we consider that patches are good approximation of the whole volumes and ptrainsubscript𝑝trainp_{\rm{train}}italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT is the distribution of the full volumes. Therefore, in the hardness weighted sampler of Algorithm 1, we have only one weight per full volume.

In the case the full volumes are too large to be well covered by the patches, one can divide each full volume into a finite number of subvolumes prior to training. For example, for chest CT, one can divide the volumes into left and right lungs (Tilborghs et al., 2020).

4.2.2 Material

Fetal Brain Dataset.
Table 1: Training and Testing Fetal Drain 3D MRI Dataset Details. Other Abn: brain structural abnormalities other than spina bifida. There is no overlap of subjects between training and testing.

Train/Test

Origin

Condition

Volumes

Gestational age (in weeks)

Training

Atlas

Control

18

[21, 38]

Training

FeTA

Control

5

[22, 28]

Training

UHL

Control

116

[20, 35]

Training

UHL

Spina Bifida

28

[22, 34]

Training

UHL

Other Abn

10

[23, 35]

Testing

FeTA

Control

31

[20, 34]

Testing

FeTA

Spina Bifida

38

[21, 31]

Testing

FeTA

Other Abn

16

[20, 34]

Testing

UHL

Control

76

[22, 37]

Testing

UHL and MUV

Spina Bifida

74

[19, 35]

Testing

UHL

Other Abn

25

[21, 40]

A total of 177177177177 (resp. 260260260260) fetal brain 3D MRIs were used for training (resp. testing). Origin, condition, and gestational ages for the training and testing datasets are summarized in Table 1.

We used the 18 control fetal brain 3D MRIs of the spatio-temporal fetal brain atlas111http://crl.med.harvard.edu/research/fetal_brain_atlas/ (Gholipour et al., 2017) for gestational ages ranging from 21212121 weeks to 38383838 weeks. We also used 80808080 volumes from the publicly available FeTA MICCAI challenge dataset222DOI: 10.7303/syn25649159 (Payette et al., 2021, 2022) and the 10101010 3D MRIs from the testing set of the first release of the FeTA dataset for which manual segmentations are not publicly available. For those 3D MRIs, manual segmentations and corrections of the segmentations were performed by authors MA and LF to reduce the variability against the published segmentation guidelines that was released with the FeTA dataset (Payette et al., 2021). Part of those corrections were performed as part of our previous work (Fidon et al., 2021a, c) and are publicly available333DOI: 10.5281/zenodo.5148611. Brain masks for the FeTA data were obtained via affine registration using two fetal brain atlases444DOI: 10.7303/syn25887675 (Fidon et al., 2021d; Gholipour et al., 2017).

In addition, we used 329329329329 3D MRIs from a private dataset. All images in the private dataset were part of routine clinical care and were acquired at University Hospital Leuven (UHL) and Medical University of Vienna (MUW) due to congenital malformations seen on ultrasound. In total, 102102102102 cases with spina bifida aperta, 35353535 cases with other central nervous system pathologies, and 192192192192 cases with other malformations, though with normal brain, and referred as controls, were included. The gestational age at MRI ranged from 19191919 weeks to 40404040 weeks. Some of those 3D MRIs and their manual segmentations were used in previous studies (Emam et al., 2021; Fidon et al., 2021d, a; Mufti et al., 2021). We have started to make fetal brain T2w 3D MRIs publicly available555https://www.cir.meduniwien.ac.at/research/fetal/. For each study, at least three orthogonal T2-weighted HASTE series of the fetal brain were collected on a 1.51.51.51.5T scanner using an echo time of 133133133133ms, a repetition time of 1000100010001000ms, with no slice overlap nor gap, pixel size 0.390.390.390.39mm to 1.481.481.481.48mm, and slice thickness 2.502.502.502.50mm to 4.404.404.404.40mm. A radiologist attended all the acquisitions for quality control.

The reconstructed fetal brain 3D MRIs were obtained using NiftyMIC (Ebner et al., 2020) a state-of-the-art super resolution and reconstruction algorithm. The volumes were all reconstructed to a resolution of 0.80.80.80.8 mm isotropic and registered to a fetal brain atlas (Gholipour et al., 2017). The 2D MRIs were also corrected for image intensity bias field as implemented in NiftyMIC. Our pre-processing improves the resolution, and removes motion between neighboring slices and motion artefacts present in the original 2D slices (Ebner et al., 2020). It also facilitates the manual delineation of the fetal brain structures compared to the original 2D slices. We used volumetric brain masks to mask the tissues outside the fetal brain. Those brain masks were obtained using the automatic segmentation methods described in (Ebner et al., 2020; Ranzini et al., 2021).

The labelling protocol used for white matter, intra-axial CSF, and cerebellum is the same as in (Payette et al., 2021). We use the term intra-axial CSF rather than ventricular system because in addition to the lateral ventricles, third ventricle, and forth ventricle, it also contains the cavum septum pellucidum and the cavum vergae that are not part of the ventricular system (Tubbs et al., 2011). The three tissue types were segmented for our private dataset by DE, EVE, FG, LF, MA, NM, and TD under the supervision of MA a paediatric radiologist specialized in fetal brain anatomy, who quality controlled and corrected all manual segmentations.

Brain Tumor Dataset.

We have used the BraTS 2019 dataset because it is the last edition of the BraTS challenge for which information about the image acquisition center is available at the time of writing. The dataset contains the same four MRI sequences (T1, ceT1, T2, and FLAIR) for 448 cases, corresponding to patients with either a high-grade Gliomas or a low-grade Gliomas. All the cases were manually segmented for peritumoral edema, enhancing tumor, and non-enhancing tumor core using the same labeling protocol (Menze et al., 2014; Bakas et al., 2018, 2017c). We split the 323 cases of the BraTS 2019 training dataset into 268 for training and 67 for validation. In addition, the BraTS 2019 validation dataset that contains 125 cases was used for testing.

Refer to caption
Refer to caption
Figure 3: Qualitative Results for Fetal Brain 3D MRI Segmentation using DRO. We have highlighted in white areas with severe violation of the anatomy by nnU-Net-ERM. Most of them are avoided by our nnU-Net-DRO. nnU-Net-ERM and nnU-Net-DRO differ only by the use of the hardness weighted sampler for the latter. a) Fetus with aqueductal stenosis (34 weeks). b) Fetus with spina bifida aperta (27 weeks). c) Fetus with Blake’s pouch cyst (29 weeks). d) Fetus with tuberous sclerosis complex (34 weeks). e) Fetus with spina bifida aperta (22 weeks). f) Fetus with spina bifida aperta (31 weeks). g) Fetus with spina bifida aperta (28 weeks). For cases a) and b), nnU-Net-ERM (Isensee et al., 2021) misses completely the cerebellum and achieves poor segmentation for the white matter and the ventricles. For case c), a large part of the Blake’s pouch cyst is wrongly included in the ventricular system segmentation by nnU-Net-ERM. This is not the case for the proposed nnU-Net-DRO. For case d), nnU-Net-ERM fails to segment the cerebellum correctly and a large part of the cerebellum is segmented as part of the white matter. In contrast, our nnU-Net-DRO correctly segment cerebellum and white matter for this case. For cases e) f) and g), nnU-Net-ERM wrongly included parts of the brainstem in the cerebellum segmentation. nnU-Net-DRO does not make this mistake. We emphasise that the segmentation of the cerebellum for spina bifida aperta is essential for studying and evaluating the effect of surgery in-utero.
Deep Learning Pipeline.

The deep learning pipeline used was based on nnU-Net (Isensee et al., 2021), which is a generic deep learning pipeline for medical image segmentation, that has been shown to outperform other deep learning pipelines on 23 public datasets without the need to manually tune the loss function or the deep neural network architecture. Specifically, we used nnU-Net version 2 in 3D-full-resolution mode which is the recommended mode for isotropic 3D MRI data and the code is publicly available at https://github.com/MIC-DKFZ/nnUNet.

Like most deep learning pipelines in the literature, nnU-Net is based on ERM. For clarity, in the following we will sometimes refer to the unmodified nnU-Net as nnU-Net-ERM.

The meta-parameters used for the deep learning pipeline used were determined automatically using the heuristics developed in nnU-Net (Isensee et al., 2021). The 3D CNN selected for the brain tumor data is based on 3D U-Net (Çiçek et al., 2016) with 5 (resp. 6) levels for fetal brain segmentation (resp. brain tumor segmentation) and 32 features after the first convolution that are multiplied by 2 at each level with a maximum set at 320. The 3D CNN uses leaky ReLUReLU\mathrm{ReLU}roman_ReLU activation, instance normalization (Ulyanov et al., 2016), max-pooling downsampling operations and linear upsampling with learnable parameters. In addition, the network is trained using the addition of the mean Dice loss and the cross entropy, and deep supervision (Lee et al., 2015). The default optimization step is SGD with a momentum of 0.990.990.990.99 and Nesterov update, a batch size of 4 (resp. 2) for fetal brain segmentation (resp. brain tumor segmentation), and a decreasing learning rate defined for each epoch t𝑡titalic_t as

ηt=0.01×(1ttmax)0.9subscript𝜂𝑡0.01superscript1𝑡subscript𝑡𝑚𝑎𝑥0.9\eta_{t}=0.01\times\left(1-\frac{t}{t_{max}}\right)^{0.9}italic_η start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 0.01 × ( 1 - divide start_ARG italic_t end_ARG start_ARG italic_t start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 0.9 end_POSTSUPERSCRIPT

where tmaxsubscript𝑡𝑚𝑎𝑥t_{max}italic_t start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT is the maximum number of epochs fixed as 1000100010001000. Note that in nnU-Net, one epoch is defined as equal to 250 batches, irrespective of the size of the training dataset. A patch size of 96×112×96961129696\times 112\times 9696 × 112 × 96 (resp. 128×192×128128192128128\times 192\times 128128 × 192 × 128) was selected for fetal brain segmentation (resp. brain tumor segmentation), which is not sufficient to fit the whole brain of all the cases. As a result, a patch-based approach is used as often in medical image segmentation applications. A large number of data augmentation methods are used: random cropping of a patch, random zoom, gamma intensity augmentation, multiplicative brightness, random rotations, random mirroring along all axes, contrast augmentation, additive Gaussian noise, Gaussian blurring and simulation of low resolution. nnU-Net automatically splits the training data into 5 folds 80%percent8080\%80 % training/20%percent2020\%20 % validation. For the experiments on brain tumor segmentation, only the networks corresponding to the first fold were trained. For the experiments on fetal brain segmentation, 5 models were trained, one for each fold, and the predicted class probability maps of the 5 models are averaged at inference to improve robustness (Isensee et al., 2021). GPUs NVIDIA Tesla V100-SXM2 with 16GB of memory were used for the experiments. Training each network took from 4 to 6 days.

Our only modifications of the nnU-Net pipeline is the addition of our hardness weighted sampling when "DRO" is indicated and for some experiments we modified the optimization update rule as indicated in Table 2. Our implementation of the nnU-Net-DRO training procedure is publicly available at https://github.com/LucasFidon/HardnessWeightedSampler. If "ERM" is indicated and nothing is indicated about the optimization update rule, it means that nnU-Net (Isensee et al., 2021) is used without any modification.

Table 2: Evaluation of Distribution Robustness with Respect to the Pathology (260 3D MRIs).nnU-Net-ERM is the unmodified nnU-Net pipeline (Isensee et al., 2021) in which Empirical Risk Minimization (ERM) is used. nnU-Net-DRO is the nnU-Net pipeline modified to use the proposed hardness weighted sampler and in which Distributionally Robust Optimization (DRO) is therefore used. WM: White matter, In-CSF: Intra-axial CSF, Cer: Cerebellum. IQR: interquartile range, 𝐩Xsubscript𝐩𝑋\textbf{p}_{X}p start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT: Xthsuperscript𝑋thX^{\textrm{th}}italic_X start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT percentile of the Dice score distribution in percentage. Best values are in bold and improvements of at least 5555 points of percentage are highlighted.
Dice Score (%percent\%%)
CNSMethodROI

Mean

Median

IQR

𝐩25subscript𝐩25\textbf{p}_{25}p start_POSTSUBSCRIPT 25 end_POSTSUBSCRIPT

𝐩10subscript𝐩10\textbf{p}_{10}p start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT

𝐩5subscript𝐩5\textbf{p}_{5}p start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT

ControlsnnU-Net-ERMWM

94.494.4\bf 94.4bold_94.4

95.2

2.893.391.590.6
(107 volumes)(baseline)In-CSF

90.3

92.4

6.4

87.8

80.7

79.0

Cer95.7

97.0

3.4

94.2

91.3

90.4
nnU-Net-DROWM94.495.3

3.0

93.2

91.1

90.1

(ours)In-CSF90.492.76.287.981.779.1
Cer95.797.13.394.291.4

90.1

Spina BifidannU-Net-ERMWM

89.6

92.1

4.1

89.5

80.6

73.8

(112 volumes)(baseline)In-CSF

91.4

93.9

6.4

89.6

86.983.7
Cer

76.8

87.8

11.1

80.4

15.8

0.0
nnU-Net-DROWM90.192.24.089.981.674.8
(ours)In-CSF

91.691.6\bf 91.6bold_91.6

94.16.490.0

86.7

83.6

Cer77.887.99.782.043.30.0
Other Abn.nnU-Net-ERMWM

90.3

92.64.6

90.1

88.0

71.6

(41 volumes)(baseline)In-CSF

87.4

87.9

10.4

82.7

77.7

75.9

Cer

90.4

92.8

5.490.787.5

81.4

nnU-Net-DROWM90.492.6

4.7

90.288.273.5
(ours)In-CSF87.988.19.583.380.477.7
Cer91.393.0

5.5

90.787.582.7
Hyper-parameters of the Hardness Weighted Sampler.

For brain tumor segmentation, we tried the values {10,100,1000}101001000\{10,100,1000\}{ 10 , 100 , 1000 } of β𝛽{\beta}italic_β with or without importance sampling. Using β=100𝛽100{\beta}=100italic_β = 100 with importance sampling lead to the best mean dice score on the validation split of the training dataset. For fetal brain segmentation, we tried only β=100𝛽100\beta=100italic_β = 100 with importance sampling. When importance sampling is used, the clipping values wmin=0.1subscript𝑤𝑚𝑖𝑛0.1w_{min}=0.1italic_w start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT = 0.1 and wmax=10subscript𝑤𝑚𝑎𝑥10w_{max}=10italic_w start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT = 10 are always used. No other values of wmaxsubscript𝑤𝑚𝑎𝑥w_{max}italic_w start_POSTSUBSCRIPT italic_m italic_a italic_x end_POSTSUBSCRIPT and wminsubscript𝑤𝑚𝑖𝑛w_{min}italic_w start_POSTSUBSCRIPT italic_m italic_i italic_n end_POSTSUBSCRIPT have been tested.

Metrics.

We evaluate the quality of the automatic segmentations using the Dice score (Dice, 1945; Fidon et al., 2017). We are particularly interested in measuring the statistical risk of the results as a way to evaluate the robustness of the different methods.

In the BraTS challenge, this is usually measured using the interquartile range (IQR) which is the difference between the percentiles at 75%percent7575\%75 % and 25%percent2525\%25 % of the the metric values (Bakas et al., 2018). We therefore reported the mean, the median and the IQR of the Dice score in Table 3. For fetal brain segmentation, in addition to the mean, median, and IQR, we also report the percentiles of the Dice score at 25%percent2525\%25 %, 10%percent1010\%10 %, and 5%percent55\%5 %. In Table 2, we report those quantities for the Dice scores of the three tissue types white matter, intra-axial CSF, and cerebellum.

For each method, nnU-Net is trained 5 times using different train/validation splits and different random initializations. The 5 same splits, computed randomly, are used for the two methods. The results for fetal brain 3D MRI segmentation in Table 2 are for the ensemble of the 5 3D U-Nets. Ensembling is known to increase the robustness of deep learning methods for segmentation (Isensee et al., 2021). It also makes the evaluation less sensitive to the random initialization and to the stochastic optimization.

Table 3: Dice Score Evaluation on the BraTS 2019 Online Validation Set (125 cases). Metrics were computed using the BraTS online evaluation platform (https://ipp.cbica.upenn.edu/). ERM: Empirical Risk Minimization, DRO: Distributionally Robust Optimization, SGD: plain SGD (no momentum used), Nesterov: SGD with Nesterov momentum, IQR: Interquartile range. The best values overall are in bold and improvements of at least 5555 points of percentage when comparing ERM and DRO for the same optimizer are highlighted.
Optim.Optim.Enhancing TumorWhole TumorTumor Core
problemupdate

Mean

Median

IQR

Mean

Median

IQR

Mean

Median

IQR

ERMSGD

71.3

86.0

20.9

90.4

92.3

6.1

80.5

88.8

17.5

DROSGD72.387.219.1

90.5

92.6

6.0

82.189.715.2
ERMNesterov

73.0

87.1

15.6

90.792.65.4

83.9

90.5

14.3

DRONesterov74.587.313.8

90.6

92.6

5.9

84.1

90.0

12.5
Refer to caption
Figure 4: Dice scores distribution on the BraTS 2019 validation dataset for cases from a center of TCIA (76 cases) and cases from other centers (49 cases). This shows that the lower interquartile range of DRO for the enhancing tumor comes specifically from a lower number of poor segmentations on cases coming from The Cancer Imaging Archive (TCIA). This suggests that DRO can deal with some of the confounding biases present in the training dataset, and lead to a model that is more fair.
Results.

The quantitative comparison of nnU-Net-ERM and nnU-Net-DRO on fetal brain 3D MRI segmentation for the three different central nervous system conditions control, spina bifida, and other abnormalities can be found in Table 2.

For spina bifida and other brain abnormalities, the proposed nnU-Net-DRO achieves same or higher mean Dice scores than nnU-Net-ERM (Isensee et al., 2021) with +0.50.5+0.5+ 0.5 percentage points (pp) for white matter and +11+1+ 1pp for the cerebellum of spina bifida cases and +0.90.9+0.9+ 0.9pp for the cerebellum for other abnormalities. In addition, nnU-Net-DRO achieves comparable (at most 0.10.10.10.1pp of difference) or lower IQR than nnU-Net-ERM with 1.41.4-1.4- 1.4pp for the cerebellum of spina bifida cases and 0.90.9-0.9- 0.9pp for the intra-axial CSF of cases with other abnormalities. For controls, the mean, median, and IQR of the Dice scores of nnU-Net-DRO and nnU-Net-ERM differ by less than 0.20.20.20.2pp for the three tissue types. This suggests that nnU-Net-DRO is more robust to anatomical variabilities associated with abnormal brains, while retaining the same segmentation performance on neurotypical cases.

In terms of median Dice score, nnU-Net-DRO and nnU-Net-ERM differ by less than 0.30.30.30.3pp for all tissue types and conditions. Therefore the differences in terms of mean Dice scores mentioned above are not due to improved segmentation in the middle of the Dice score performance distribution.

The comparison of the percentiles at 25%percent2525\%25 %, 10%percent1010\%10 %, and 5%percent55\%5 % of the Dice score allows us to compare methods at the tail of the Dice scores distribution where segmentation methods reach their worst-case performance. For spina bifida, nnU-Net-DRO achieves higher values of percentiles than nnU-Net-ERM for the white matter (+1.01.0+1.0+ 1.0pp for 𝐩10subscript𝐩10\textbf{p}_{10}p start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT and +1.01.0+1.0+ 1.0pp for 𝐩5subscript𝐩5\textbf{p}_{5}p start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT), and for the cerebellum (+1.61.6+1.6+ 1.6pp for 𝐩25subscript𝐩25\textbf{p}_{25}p start_POSTSUBSCRIPT 25 end_POSTSUBSCRIPT and +27.527.5+27.5+ 27.5pp for 𝐩10subscript𝐩10\textbf{p}_{10}p start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT). And for other brain abnormalities, nnU-Net-DRO achieves higher values of percentiles than nnU-Net-ERM for the white matter (+1.91.9+1.9+ 1.9pp for 𝐩5subscript𝐩5\textbf{p}_{5}p start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT), for the intra-axial CSF (+0.60.6+0.6+ 0.6pp for 𝐩25subscript𝐩25\textbf{p}_{25}p start_POSTSUBSCRIPT 25 end_POSTSUBSCRIPT, +2.32.3+2.3+ 2.3pp for 𝐩10subscript𝐩10\textbf{p}_{10}p start_POSTSUBSCRIPT 10 end_POSTSUBSCRIPT and +1.81.8+1.8+ 1.8pp for 𝐩5subscript𝐩5\textbf{p}_{5}p start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT), and for the cerebellum (+1.31.3+1.3+ 1.3pp for 𝐩5subscript𝐩5\textbf{p}_{5}p start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT). All the other percentile values differ by less than 0.50.50.50.5pp of Dice score between the two methods. This suggests that nnU-Net-DRO achieves better worst case performance than nnU-Net-ERM for abnormal cases. However, both methods have a percentile at 5%percent55\%5 % of the Dice score equal to 00 for the cerebellum of spina bifida cases. This indicates that both methods completely miss the cerebellum for spina bifida cases in 5%percent55\%5 % of the cases.

As can be seen in the qualitative results of Figure 3, there are cases for which nnU-Net-ERM predicts an empty cerebellum segmentation while nnU-Net-DRO achieves satisfactory cerebellum segmentation. There were no cases for which the converse was true. However, there were also spina bifida cases for which both methods failed to predict the cerebellum. Robust segmentation of the cerebellum for spina bifida is particularly relevant for the evaluation of fetal brain surgery for spina bifida aperta (Aertsen et al., 2019; Danzer et al., 2020; Sacco et al., 2019). All the spina bifida 3D MRIs with missing cerebellum in the automatic segmentations were 3D MRIs from the FeTA dataset Payette et al. (2021) and represented brains of fetuses with spina bifida before they were operated on. The cerebellum is more difficult to detect using MRI before surgery as compared to early or late after surgery (Aertsen et al., 2019; Danzer et al., 2007). No 3D MRI with the combination of those two factors were present in the training dataset (Table. 1). This might explain why DRO did not help improving the segmentation quality for those cases. DRO aims at improving the performance on subgroups that were underrepresented in the training dataset, not subgroups that were not represented at all.

In Table 2, it is worth noting that overall the Dice score values decrease for the white matter and the cerebellum between controls and spina bifida and abnormal cases. It was expected due to the higher anatomical variability in pathological cases. However, the Dice score values for the ventricular system tend to be higher for spina bifida cases than for controls. This can be attributed to the large proportion of spina bifida cases with enlarged ventricles because the Dice score values tend to be higher for larger regions of interest.

For our experiments on brain tumor segmentation, Table 3 summarizes the performance of training nnU-Net using ERM or using DRO. Here, we experiment with two SGD-based optimizers. For both ERM and DRO, the optimization update rule used was either plain SGD without momentum (SGD), or SGD with a Nesterov momentum equal to 0.990.990.990.99 (Nesterov). Especially, for the latter, this implies that step 12 of Algorithm 1 is modified to use SGD with Nesterov momentum. It was also the case for our experiments on fetal brain 3D MRI segmentation. For DRO, the results presented here are for β=100𝛽100{\beta}=100italic_β = 100 and using importance sampling (step 6 of Algorithm 1).

As illustrated in Table 3, for both ERM and DRO, the use of SGD with Nesterov momentum outperforms plain-SGD for all metrics and all regions of interest. This result was expected for ERM, for which it is common practice in the deep learning literature to use SGD with a momentum. Our results here suggest that the benefit of using a momentum with SGD is retained for DRO.

For both optimizers, DRO outperforms ERM in terms of IQR for the enhancing tumor and the tumor core by approximately 2222pp of Dice score, and in terms of mean Dice score for the enhancing tumor by 1111pp for the plain-SGD and 1.51.51.51.5pp for SGD with Nesterov momentum. For plain-SGD, DRO also outpermforms ERM in terms of mean Dice score for the tumor core by 1.61.61.61.6pp. The IQR is the global statistic used in the BraTS challenge to measure the level of robustness of a method (Bakas et al., 2018). In addition, Figure 4 shows that the lower IQR of DRO for the enhancing tumor comes specifically from a lower number of poor segmentations on cases coming from The Cancer Imaging Archive (TCIA). This suggests that DRO can deal with some of the confounding biases present in the training dataset, and lead to a model that is more fair with respect to the acquisition center of the MRI.

Since the same improvements are observed independently of the optimization update rule used. This suggests that in practice Algorithm 1 still converges when a momentum is used, even if Theorem 6 was only demonstrated to hold for plain-SGD.

The value β=100𝛽100{\beta}=100italic_β = 100 and the use of importance sampling was selected based on the mean Dice score on the validation split of the training dataset. Results for β{10,100,1000}𝛽101001000{\beta}\in\{10,100,1000\}italic_β ∈ { 10 , 100 , 1000 } with Nesterov momentum and with or without importance sampling can be found in Appendix B Table 5. The tendency described previously still holds true for the enhancing tumor for β𝛽{\beta}italic_β equal to 10101010 or 100100100100 with and without importance sampling. The mean Dice score is improved by 0.40.40.40.4pp to 2.32.32.32.3pp and the IQR is reduced by 1.31.31.31.3pp to 2.32.32.32.3pp for the four DRO models as compared to the ERM model. For the tumor core with β=100𝛽100{\beta}=100italic_β = 100 mean and IQR are improved over ERM with and without importance sampling. However, for β=10𝛽10{\beta}=10italic_β = 10 with importance sampling there was a loss of performance as compared to ERM for the whole tumor. This problem was not observed with β=10𝛽10{\beta}=10italic_β = 10 without importance sampling. For the other models with β𝛽{\beta}italic_β equal to 10101010 or 100100100100 similar Dice score performance similar to the one ERM was observed for the whole tumor. This suggests that overall the use of ERM or DRO does not affect the segmentation performance of the whole tumor. One possible explanation of this is that Dice scores for the whole tumor are already high for almost all cases when ERM is used with a low IQR. In addition, DRO and the hardness weighted sampler are sensitive to the loss function, here the mean-class Dice loss plus cross entropy loss. In the case of brain tumor segmentation, we hypothesise that the loss function is more sensitive to the segmentation performance for the tumor core and the enhancing tumor than for the whole tumor.

When β𝛽{\beta}italic_β becomes too large (β=1000𝛽1000{\beta}=1000italic_β = 1000) a decrease of the mean and median Dice score for all regions is observed as compared to ERM. In this case, DRO tends towards the maximization of the worst-case example only which appears to be unstable using our Algorithm 1. For all values of β𝛽{\beta}italic_β the use of importance sampling, as described in steps 6-8 of Algorithm 1, improves the IQR of the Dice scores for the enhancing tumor and the tumor core. We therefore recommend to use Algorithm 1 with importance sampling.

4.3 Computational Time and Memory Overhead of Algorithm 1

Table 4: Estimated Computational Time and Memory Overhead of the hardness weighted sampler in Algorithm 1. The times (in seconds) are estimated using a batch size of 2222 and β=100𝛽100{\beta}=100italic_β = 100 and by taking the average sampling time over 10,0001000010,00010 , 000 sampling operations for each number of samples. It is worth noting that the sampling operations are computed on the CPUs as in most deep learning pipeline. The time and memory overhead of the proposed hardness weighted sampler is negligible for training datasets with up to 1 million samples.
# Samples

102superscript10210^{2}10 start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

103superscript10310^{3}10 start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT

104superscript10410^{4}10 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT

105superscript10510^{5}10 start_POSTSUPERSCRIPT 5 end_POSTSUPERSCRIPT

106superscript10610^{6}10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT

107superscript10710^{7}10 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT

Time (in sec)

1.3×1041.3superscript1041.3\times 10^{-4}1.3 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT

1.5×1041.5superscript1041.5\times 10^{-4}1.5 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT

2.6×1042.6superscript1042.6\times 10^{-4}2.6 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT

2.4×1032.4superscript1032.4\times 10^{-3}2.4 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT

2.1×1022.1superscript1022.1\times 10^{-2}2.1 × 10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT

1.8×1011.8superscript1011.8\times 10^{-1}1.8 × 10 start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT

Memory (in MB)

7.6×1047.6superscript1047.6\times 10^{-4}7.6 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT

7.6×1037.6superscript1037.6\times 10^{-3}7.6 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT

7.6×1027.6superscript1027.6\times 10^{-2}7.6 × 10 start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT

7.6×1017.6superscript1017.6\times 10^{-1}7.6 × 10 start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT

7.6

76.3

The main additional computational cost in Algorithm 1 is due to the hardness weighted sampling in steps 4 and 5 that is dependent on the number n𝑛nitalic_n of training examples. In Table 4, we have computed the computational time and memory overhead of the hardness weighted sampler for different sizes of the training dataset. We have computed that additional time required is less than 0.50.50.50.5 second and the additional memory less than 100100100100 MB for up to n=107𝑛superscript107n=10^{7}italic_n = 10 start_POSTSUPERSCRIPT 7 end_POSTSUPERSCRIPT using a batch size of 2222 and the function random.choice of Numpy version 1.21.11.21.11.21.11.21.1. The times were estimated using 12121212 Intel(R) Core(TM) i7-8750H CPU @ 2.20GHz. The additional time and memory that occurs due to the proposed hardness weighted sampling is therefore negligible for all the datasets used in practice in medical image segmentation. For our brain tumor segmentation training set of n=268 volumes and a batch size of 2, the additional memory usage of Algorithm 1 is only 2144 bytes of memory (one float array of size n) and the additional computational time is approximately 104superscript10410^{-4}10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT seconds per iteration using the Python library numpy, i.e. approximately 0.005%percent0.0050.005\%0.005 % of the total duration of an iteration. The size of the training dataset for fetal brain 3D MRI segmentation being lower, the additional memory usage and the additional computational time are even lower than for brain tumor segmentation. We have made available a python script in our GitHub repository that allows to easily compute the additional time and memory occurring because of the hardness weighted sampler for any number of samples and batch size.

5 Discussion and Conclusion

In this paper, we have shown that efficient training of deep neural networks with Distributionally Robust Optimization (DRO) with a ϕitalic-ϕ\phiitalic_ϕ-divergence is possible.

The proposed hardness weighted sampler for training a deep neural network with Stochastic Gradient Descent (SGD) for DRO is as straightforward to implement, and as computationally efficient as SGD for Empirical Risk Minimization (ERM). It can be used for deep neural networks with any activation function (including ReLUReLU\mathrm{ReLU}roman_ReLU), and with any per-example loss function. We have shown that the proposed approach can formally be described as a principled Hard Example Mining strategy (Theorem 7) and is related to minimizing the percentile of the per-example loss distribution (13). In addition, we prove the convergence of our method for over-parameterized deep neural networks (Theorem 6). Thereby, extending the convergence theory of deep learning of Allen-Zhu et al. (2019a). This is, to the best of our knowledge, the first convergence result for training a deep neural network based on DRO.

In practice, we have shown that our hardness weighted sampling method can be easily integrated in a state-of-the-art deep learning framework for medical image segmentation. Interestingly, the proposed algorithm remains stable when SGD with momentum is used. The hardness weighted sampling has one hyperparameter β>0𝛽0{\beta}>0italic_β > 0. Our experiments suggest that similar values of β𝛽{\beta}italic_β lead to improve robustness in different applications. We hypothesize that good values of β𝛽{\beta}italic_β are of the order of the inverse of the standard deviation of the vector of per-volume (stale) losses during the training epochs that precede convergence.

The high anatomical variability of the developing fetal brain across gestational ages and pathologies hampers the robustness of deep neural networks trained by maximizing the average per-volume performance. Specifically, it limits the generalization of deep neural networks to abnormal cases for which few cases are available during training. In this paper, we propose to mitigate this problem by training deep neural networks using Distributionally Robust Optimization (DRO) with the proposed hardness weighted sampling. We have validated the proposed training method on a multi-centric dataset of 437437437437 fetal brain T2w 3D MRIs with various diagnostics. nnU-Net trained with DRO achieved improved segmentation results for pathological cases as compared to the unmodified nnU-Net, while achieving similar segmentation performance for the neurotypical cases. Those results suggest that nnU-Net trained with DRO is more robust to anatomical variabilities than the original nnU-Net that is trained with ERM. In addition, we have performed experiments on the open-source multiclass brain tumor segmentation dataset BraTS (Bakas et al., 2018). Our results on BraTS suggests that DRO can help improving the robustness of deep neural network for segmentation to variations in the acquisition protocol of the images.

However, we have also found in our experiments that all deep learning models, either trained with ERM or DRO, failed in some cases. For example, the models evaluated all missed the cerebellum in at least 5%percent55\%5 % of the spina bifida aperta cases. As a result, while our results do suggest that DRO with our method can improve the robustness of deep neural networks for segmentation, they also show that DRO alone with our method does not provide a guarantee of robustness. DRO with a ϕitalic-ϕ\phiitalic_ϕ-divergence reweights the examples in the training dataset but cannot account for subsets of the true distribution that are not represented at all in the training dataset. We investigate this problem in our following work (Fidon et al., 2022).

We have shown that the additional computational cost of the proposed hardness weighted sampling is small enough to be negligible in practice and requires less than one second for up to n=108𝑛superscript108n=10^{8}italic_n = 10 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT examples. The proposed Algorithm 1 is therefore as computationally efficient as state-of-the-art deep learning pipeline for medical image segmentation. However, when data augmentation is used, an infinite number of training examples is virtually available. We mitigate this problem using importance sampling and only one probability per non-augmented example. We found that importance sampling led to improved segmentation results.

We have also illustrated in our experiments that reporting the mean and standard deviation of the Dice score is not enough to evaluate the robustness of deep neural networks for medical image segmentation. A stratification of the evaluation is required to assess for which subgroups of the population and for which image protocols a deep learning model for segmentation can be safely used. In addition, not all improvements of the mean and standard deviation of the Dice score are equally relevant as they can result from improvements of either the best or the worst segmentation cases. Regarding the robustness of automatic segmentation methods across various conditions, one is interested in improvements of segmentation metrics in the tail of the distribution that corresponds to the worst segmentation cases. To this end, one can report the interquartile range (IQR) and measures of risk such as percentiles.


Acknowledgments

This project has received funding from the European Union’s Horizon 2020 research and innovation program under the Marie Skłodowska-Curie grant agreement TRABIT No 765148; Wellcome [203148/Z/16/Z; WT101957], EPSRC [NS/A000049/1; NS/A000027/1]. Tom Vercauteren is supported by a Medtronic / RAEng Research Chair [RCSRF1819\7\34]. Data used in this publication were obtained as part of the RSNA-ASNR-MICCAI Brain Tumor Segmentation (BraTS) Challenge project through Synapse ID (syn25829067).


Ethical Standards

The work follows appropriate ethical standards in conducting research and writing the manuscript, following all applicable laws and regulations regarding treatment of human subjects.


Conflicts of Interest

Sébastien Ourselin is co-founder of Brainminer and non-executive director at Hypervision Surgical. Tom Vercauteren is chief scientific officer at Hypervision Surgical. Michael Ebner is chief executive officer at Hypervision Surgical. Georg Langs is chief scientist and co-founder at Contextflow.

References

  • Aertsen et al. (2019) M Aertsen, J Verduyckt, F De Keyzer, T Vercauteren, F Van Calenbergh, L De Catte, S Dymarkowski, P Demaerel, and J Deprest. Reliability of MR imaging–based posterior fossa and brain stem measurements in open spinal dysraphism in the era of fetal surgery. American Journal of Neuroradiology, 40(1):191–198, 2019.
  • Allen-Zhu et al. (2019a) Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. A convergence theory for deep learning via over-parameterization. In ICML, pages 242–252, 2019a.
  • Allen-Zhu et al. (2019b) Zeyuan Allen-Zhu, Yuanzhi Li, and Zhao Song. On the convergence rate of training recurrent neural networks. In Advances in Neural Information Processing Systems 32, pages 6676–6688. Curran Associates, Inc., 2019b.
  • Bakas et al. (2017a) Spyridon Bakas, Hamed Akbari, Aristeidis Sotiras, Michel Bilello, Martin Rozycki, Justin S Kirby, John B Freymann, Keyvan Farahani, and Christos Davatzikos. Segmentation labels and radiomic features for the pre-operative scans of the TCGA-GBM collection. The Cancer Imaging Archive, 2017a. doi: 10.7937/K9/TCIA.2017.KLXWJJ1Q.
  • Bakas et al. (2017b) Spyridon Bakas, Hamed Akbari, Aristeidis Sotiras, Michel Bilello, Martin Rozycki, Justin S Kirby, John B Freymann, Keyvan Farahani, and Christos Davatzikos. Segmentation labels and radiomic features for the pre-operative scans of the TCGA-LGG collection. The Cancer Imaging Archive, 2017b. doi: 10.7937/K9/TCIA.2017.GJQ7R0EF.
  • Bakas et al. (2017c) Spyridon Bakas, Hamed Akbari, Aristeidis Sotiras, Michel Bilello, Martin Rozycki, Justin S Kirby, John B Freymann, Keyvan Farahani, and Christos Davatzikos. Advancing the cancer genome atlas glioma MRI collections with expert segmentation labels and radiomic features. Scientific data, 4:170117, 2017c.
  • Bakas et al. (2018) Spyridon Bakas, Mauricio Reyes, Andras Jakab, Stefan Bauer, Markus Rempfler, Alessandro Crimi, Russell Takeshi Shinohara, Christoph Berger, Sung Min Ha, Martin Rozycki, et al. Identifying the best machine learning algorithms for brain tumor segmentation, progression assessment, and overall survival prediction in the BRATS challenge. arXiv preprint arXiv:1811.02629, 2018.
  • Berger et al. (2018) Lorenz Berger, Hyde Eoin, M Jorge Cardoso, and Sébastien Ourselin. An adaptive sampling scheme to efficiently train fully convolutional networks for semantic segmentation. In Annual Conference on Medical Image Understanding and Analysis, pages 277–286. Springer, 2018.
  • Bottou et al. (2018) Léon Bottou, Frank E Curtis, and Jorge Nocedal. Optimization methods for large-scale machine learning. Siam Review, 60(2):223–311, 2018.
  • Byrd and Lipton (2019) Jonathon Byrd and Zachary Lipton. What is the effect of importance weighting in deep learning? In ICML, pages 872–881, 2019.
  • Cao and Gu (2020) Yuan Cao and Quanquan Gu. Generalization error bounds of gradient descent for learning overparameterized deep relu networks. In AAAI, 2020.
  • Chang et al. (2017) Haw-Shiuan Chang, Erik Learned-Miller, and Andrew McCallum. Active bias: Training more accurate neural networks by emphasizing high variance samples. In Advances in Neural Information Processing Systems, pages 1002–1012, 2017.
  • Chernoff et al. (1952) Herman Chernoff et al. A measure of asymptotic efficiency for tests of a hypothesis based on the sum of observations. The Annals of Mathematical Statistics, 23(4):493–507, 1952.
  • Chouzenoux et al. (2019) Emilie Chouzenoux, Henri Gérard, and Jean-Christophe Pesquet. General risk measures for robust machine learning. Foundations of Data Science, 1:249, 2019.
  • Çiçek et al. (2016) Özgün Çiçek, Ahmed Abdulkadir, Soeren S Lienkamp, Thomas Brox, and Olaf Ronneberger. 3D U-Net: learning dense volumetric segmentation from sparse annotation. In International conference on medical image computing and computer-assisted intervention, pages 424–432. Springer, 2016.
  • Csiszár et al. (2004) Imre Csiszár, Paul C Shields, et al. Information theory and statistics: A tutorial. Foundations and Trends® in Communications and Information Theory, 1(4):417–528, 2004.
  • Danzer et al. (2007) Enrico Danzer, Mark P Johnson, Michael Bebbington, Erin M Simon, R Douglas Wilson, Larrissa T Bilaniuk, Leslie N Sutton, and N Scott Adzick. Fetal head biometry assessed by fetal magnetic resonance imaging following in utero myelomeningocele repair. Fetal diagnosis and therapy, 22(1):1–6, 2007.
  • Danzer et al. (2020) Enrico Danzer, Luc Joyeux, Alan W Flake, and Jan Deprest. Fetal surgical intervention for myelomeningocele: lessons learned, outcomes, and future implications. Developmental Medicine & Child Neurology, 62(4):417–425, 2020.
  • Dice (1945) Lee R Dice. Measures of the amount of ecologic association between species. Ecology, 26(3):297–302, 1945.
  • Duchi et al. (2016) John Duchi, Peter Glynn, and Hongseok Namkoong. Statistics of robust optimization: A generalized empirical likelihood approach. arXiv preprint arXiv:1610.03425, 2016.
  • Ebner et al. (2020) Michael Ebner, Guotai Wang, Wenqi Li, Michael Aertsen, Premal A Patel, Rosalind Aughwane, Andrew Melbourne, Tom Doel, Steven Dymarkowski, Paolo De Coppi, et al. An automated framework for localization, segmentation and super-resolution reconstruction of fetal brain MRI. NeuroImage, 206:116324, 2020.
  • Emam et al. (2021) Doaa Emam, Michael Aertsen, Lennart Van der Veeken, Lucas Fidon, Prachi Patkee, Vanessa Kyriakopoulou, Luc De Catte, Francesca Russo, Philippe Demaerel, Tom Vercauteren, et al. Longitudinal evaluation of brain development in fetuses with congenital diaphragmatic hernia on mri: an original research study. 2021.
  • European Commission (2019) European Commission. Ethics guidelines for trustworthy AI. Report, European Commission, 2019.
  • Fenchel (1949) Werner Fenchel. On conjugate convex functions. Canadian Journal of Mathematics, 1(1):73–77, 1949.
  • Fidon et al. (2017) Lucas Fidon, Wenqi Li, Luis C Garcia-Peraza-Herrera, Jinendra Ekanayake, Neil Kitchen, Sébastien Ourselin, and Tom Vercauteren. Generalised Wasserstein dice score for imbalanced multi-class segmentation using holistic convolutional networks. In International MICCAI Brainlesion Workshop, pages 64–76. Springer, 2017.
  • Fidon et al. (2021a) Lucas Fidon, Michael Aertsen, Doaa Emam, Nada Mufti, Frédéric Guffens, Thomas Deprest, Philippe Demaerel, Anna L David, Andrew Melbourne, Sébastien Ourselin, et al. Label-set loss functions for partial supervision: Application to fetal brain 3D MRI parcellation. arXiv preprint arXiv:2107.03846, 2021a.
  • Fidon et al. (2021b) Lucas Fidon, Michael Aertsen, Nada Mufti, Thomas Deprest, Doaa Emam, Frédéric Guffens, Ernst Schwartz, Michael Ebner, Daniela Prayer, Gregor Kasprian, et al. Distributionally robust segmentation of abnormal fetal brain 3D MRI. In Uncertainty for Safe Utilization of Machine Learning in Medical Imaging, and Perinatal Imaging, Placental and Preterm Image Analysis, pages 263–273. Springer, 2021b.
  • Fidon et al. (2021c) Lucas Fidon, Michael Aertsen, Suprosanna Shit, Philippe Demaerel, Sébastien Ourselin, Jan Deprest, and Tom Vercauteren. Partial supervision for the FeTA challenge 2021. arXiv preprint arXiv:2111.02408, 2021c.
  • Fidon et al. (2021d) Lucas Fidon, Elizabeth Viola, Nada Mufti, Anna David, Andrew Melbourne, Philippe Demaerel, Sebastien Ourselin, Tom Vercauteren, Jan Deprest, and Michael Aertsen. A spatio-temporal atlas of the developing fetal brain with spina bifida aperta. Open Research Europe, 2021d.
  • Fidon et al. (2022) Lucas Fidon, Michael Aertsen, Florian Kofler, Andrea Bink, Anna L David, Thomas Deprest, Doaa Emam, Frédéric Guffens, András Jakab, Gregor Kasprian, et al. A Dempster-Shafer approach to trustworthy AI with application to fetal brain MRI segmentation. arXiv preprint arXiv:2204.02779, 2022.
  • Gholipour et al. (2017) Ali Gholipour, Caitlin K Rollins, Clemente Velasco-Annis, Abdelhakim Ouaalam, Alireza Akhondi-Asl, Onur Afacan, Cynthia M Ortinau, Sean Clancy, Catherine Limperopoulos, Edward Yang, et al. A normative spatiotemporal MRI atlas of the fetal brain for automatic segmentation and analysis of early brain growth. Scientific reports, 7(1):1–13, 2017.
  • Harwood et al. (2017) Ben Harwood, BG Kumar, Gustavo Carneiro, Ian Reid, Tom Drummond, et al. Smart mining for deep metric learning. In Proceedings of the IEEE International Conference on Computer Vision, pages 2821–2829, 2017.
  • Hiriart-Urruty and Lemaréchal (2013) Jean-Baptiste Hiriart-Urruty and Claude Lemaréchal. Convex analysis and minimization algorithms I: Fundamentals, volume 305. Springer science & business media, 2013.
  • Hu and et al (2018) Weihua Hu and et al. Does distributionally robust supervised learning give robust classifiers? In ICML, 2018.
  • Isensee et al. (2021) Fabian Isensee, Paul F Jaeger, Simon AA Kohl, Jens Petersen, and Klaus H Maier-Hein. nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature Methods, 18(2):203–211, 2021.
  • Jin et al. (2019) Chi Jin, Praneeth Netrapalli, and Michael I Jordan. Minmax optimization: Stable limit points of gradient descent ascent are locally optimal. arXiv preprint arXiv:1902.00618, 2019.
  • Kahn and Marshall (1953) Herman Kahn and Andy W Marshall. Methods of reducing sample size in Monte Carlo computations. Journal of the Operations Research Society of America, 1(5):263–278, 1953.
  • Larrazabal et al. (2020) Agostina J Larrazabal, Nicolás Nieto, Victoria Peterson, Diego H Milone, and Enzo Ferrante. Gender imbalance in medical imaging datasets produces biased classifiers for computer-aided diagnosis. Proceedings of the National Academy of Sciences, 117(23):12592–12594, 2020.
  • LeCun (1998) Yann LeCun. The MNIST database of handwritten digits. http://yann. lecun. com/exdb/mnist/, 1998.
  • Lee et al. (2015) Chen-Yu Lee, Saining Xie, Patrick Gallagher, Zhengyou Zhang, and Zhuowen Tu. Deeply-supervised nets. In Artificial intelligence and statistics, pages 562–570, 2015.
  • Lin et al. (2019) Tianyi Lin, Chi Jin, and Michael I Jordan. On gradient descent ascent for nonconvex-concave minimax problems. arXiv preprint arXiv:1906.00331, 2019.
  • Loshchilov and Hutter (2016) Ilya Loshchilov and Frank Hutter. Online batch selection for faster training of neural networks. ICLR Workshop, 2016.
  • Menze et al. (2014) Bjoern H Menze, Andras Jakab, Stefan Bauer, Jayashree Kalpathy-Cramer, Keyvan Farahani, Justin Kirby, Yuliya Burren, Nicole Porz, Johannes Slotboom, Roland Wiest, et al. The multimodal brain tumor image segmentation benchmark (brats). IEEE transactions on medical imaging, 34(10):1993–2024, 2014.
  • Moreau (1965) Jean-Jacques Moreau. Proximité et dualité dans un espace hilbertien. Bulletin de la Société mathématique de France, 93:273–299, 1965.
  • Mufti et al. (2021) Nada Mufti, Michael Aertsen, Michael Ebner, Lucas Fidon, Premal Patel, Muhamad Bin Abdul Rahman, Yannick Brackenier, Gregor Ekart, Virginia Fernandez, Tom Vercauteren, et al. Cortical spectral matching and shape and volume analysis of the fetal brain pre-and post-fetal surgery for spina bifida: a retrospective study. Neuroradiology, pages 1–14, 2021.
  • Namkoong and Duchi (2016) Hongseok Namkoong and John C Duchi. Stochastic gradient methods for distributionally robust optimization with f-divergences. In Advances in Neural Information Processing Systems, pages 2208–2216, 2016.
  • Oakden-Rayner et al. (2020) Luke Oakden-Rayner, Jared Dunnmon, Gustavo Carneiro, and Christopher Ré. Hidden stratification causes clinically meaningful failures in machine learning for medical imaging. In Proceedings of the ACM conference on health, inference, and learning, pages 151–159, 2020.
  • Owen and Zhou (2000) Art Owen and Yi Zhou. Safe and effective importance sampling. Journal of the American Statistical Association, 95(449):135–143, 2000.
  • Payette et al. (2021) Kelly Payette, Priscille de Dumast, Hamza Kebiri, Ivan Ezhov, Johannes C Paetzold, Suprosanna Shit, Asim Iqbal, Romesa Khan, Raimund Kottke, Patrice Grehten, et al. An automatic multi-tissue human fetal brain segmentation benchmark using the fetal tissue annotation dataset. Scientific Data, 8(1):1–14, 2021.
  • Payette et al. (2022) Kelly Payette, Hongwei Li, Priscille de Dumast, Roxane Licandro, Hui Ji, Md Mahfuzur Rahman Siddiquee, Daguang Xu, Andriy Myronenko, Hao Liu, Yuchen Pei, et al. Fetal brain tissue annotation and segmentation challenge results. arXiv preprint arXiv:2204.09573, 2022.
  • Puyol-Antón et al. (2021) Esther Puyol-Antón, Bram Ruijsink, Stefan K Piechnik, Stefan Neubauer, Steffen E Petersen, Reza Razavi, and Andrew P King. Fairness in cardiac mr image analysis: An investigation of bias due to data imbalance in deep learning based segmentation. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 413–423. Springer, 2021.
  • Rafique et al. (2018) Hassan Rafique, Mingrui Liu, Qihang Lin, and Tianbao Yang. Non-convex min-max optimization: Provable algorithms and applications in machine learning. arXiv preprint arXiv:1810.02060, 2018.
  • Rahimian and Mehrotra (2019) Hamed Rahimian and Sanjay Mehrotra. Distributionally robust optimization: A review. arXiv preprint arXiv:1908.05659, 2019.
  • Ranzini et al. (2021) Marta Ranzini, Lucas Fidon, Sébastien Ourselin, Marc Modat, and Tom Vercauteren. MONAIfbs: MONAI-based fetal brain MRI deep learning segmentation. arXiv preprint arXiv:2103.13314, 2021.
  • Sacco et al. (2019) Adalina Sacco, Fred Ushakov, Dominic Thompson, Donald Peebles, Pranav Pandya, Paolo De Coppi, Ruwan Wimalasundera, George Attilakos, Anna Louise David, and Jan Deprest. Fetal surgery for open spina bifida. The Obstetrician & Gynaecologist, 21(4):271, 2019.
  • Sagawa et al. (2020) Shiori Sagawa, Pang Wei Koh, Tatsunori B Hashimoto, and Percy Liang. Distributionally robust neural networks for group shifts: On the importance of regularization for worst-case generalization. ICLR, 2020.
  • Shrivastava et al. (2016) Abhinav Shrivastava, Abhinav Gupta, and Ross Girshick. Training region-based object detectors with online hard example mining. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 761–769, 2016.
  • Sinha et al. (2018) Aman Sinha, Hongseok Namkoong, and John Duchi. Certifying some distributional robustness with principled adversarial training. ICLR, 2018.
  • Staib and Jegelka (2017) Matthew Staib and Stefanie Jegelka. Distributionally robust deep learning as a generalization of adversarial training. In NIPS workshop on Machine Learning and Computer Security, 2017.
  • Suh et al. (2019) Yumin Suh, Bohyung Han, Wonsik Kim, and Kyoung Mu Lee. Stochastic class-based hard example mining for deep metric learning. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition, pages 7251–7259, 2019.
  • Tilborghs et al. (2020) Sofie Tilborghs, Ine Dirks, Lucas Fidon, Siri Willems, Tom Eelbode, Jeroen Bertels, Bart Ilsen, Arne Brys, Adriana Dubbeldam, Nico Buls, et al. Comparative study of deep learning methods for the automatic segmentation of lung, lesion and lesion type in CT scans of COVID-19 patients. arXiv preprint arXiv:2007.15546, 2020.
  • Tubbs et al. (2011) R Shane Tubbs, Sanjay Krishnamurthy, Ketan Verma, Mohammadali M Shoja, Marios Loukas, Martin M Mortazavi, and Aaron A Cohen-Gadol. Cavum velum interpositum, cavum septum pellucidum, and cavum vergae: a review. Child’s Nervous System, 27(11):1927–1930, 2011.
  • Ulyanov et al. (2016) Dmitry Ulyanov, Andrea Vedaldi, and Victor Lempitsky. Instance normalization: The missing ingredient for fast stylization. arXiv preprint arXiv:1607.08022, 2016.
  • Wachinger et al. (2019) Christian Wachinger, Benjamin Gutierrez Becker, Anna Rieckmann, and Sebastian Pölsterl. Quantifying confounding bias in neuroimaging datasets with causal inference. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 484–492. Springer, 2019.
  • Wu et al. (2017) Chao-Yuan Wu, R Manmatha, Alexander J Smola, and Philipp Krahenbuhl. Sampling matters in deep embedding learning. In Proceedings of the IEEE International Conference on Computer Vision, pages 2840–2848, 2017.
  • Zagoruyko and Komodakis (2016) Sergey Zagoruyko and Nikos Komodakis. Wide residual networks. In Proceedings of the British Machine Vision Conference (BMVC), pages 87.1–87.12. BMVA Press, 2016.
  • Zou and Gu (2019) Difan Zou and Quanquan Gu. An improved analysis of training over-parameterized deep neural networks. In Advances in Neural Information Processing Systems 32, pages 2055–2064. Curran Associates, Inc., 2019.

A Summary of the Notations used in the Proofs

For the ease of reading the proofs we first summarize our notations.

A.1 Probability Theory Notations

  • Δn={(pi)i=1n[0,1]n,ipi=1}subscriptΔ𝑛formulae-sequencesuperscriptsubscriptsubscript𝑝𝑖𝑖1𝑛superscript01𝑛subscript𝑖subscript𝑝𝑖1\Delta_{n}=\left\{\left(p_{i}\right)_{i=1}^{n}\in[0,1]^{n},\,\,\sum_{i}p_{i}=1\right\}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT = { ( italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 }

  • Let 𝒒=(qi)Δn𝒒subscript𝑞𝑖subscriptΔ𝑛{\bm{q}}=(q_{i})\in\Delta_{n}bold_italic_q = ( italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, and f𝑓fitalic_f a function, we denote 𝔼𝒒[f(𝒙)]:=i=1nqif(𝒙i)assignsubscript𝔼𝒒delimited-[]𝑓𝒙superscriptsubscript𝑖1𝑛subscript𝑞𝑖𝑓subscript𝒙𝑖\mathbb{E}_{{\bm{q}}}[f({\bm{x}})]:=\sum_{i=1}^{n}q_{i}f({\bm{x}}_{i})blackboard_E start_POSTSUBSCRIPT bold_italic_q end_POSTSUBSCRIPT [ italic_f ( bold_italic_x ) ] := ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_f ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ).

  • Let 𝒒Δn𝒒subscriptΔ𝑛{\bm{q}}\in\Delta_{n}bold_italic_q ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, and f𝑓fitalic_f a function, we denote 𝕍𝒒[f(𝒙)]:=i=1nqif(𝒙i)𝔼q[f(𝒙)]2assignsubscript𝕍𝒒delimited-[]𝑓𝒙superscriptsubscript𝑖1𝑛subscript𝑞𝑖superscriptdelimited-∥∥𝑓subscript𝒙𝑖subscript𝔼𝑞delimited-[]𝑓𝒙2\mathbb{V}_{{\bm{q}}}[f({\bm{x}})]:=\sum_{i=1}^{n}q_{i}\left\lVert f({\bm{x}}_% {i})-\mathbb{E}_{q}[f({\bm{x}})]\right\rVert^{2}blackboard_V start_POSTSUBSCRIPT bold_italic_q end_POSTSUBSCRIPT [ italic_f ( bold_italic_x ) ] := ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ italic_f ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - blackboard_E start_POSTSUBSCRIPT italic_q end_POSTSUBSCRIPT [ italic_f ( bold_italic_x ) ] ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

  • 𝒑trainsubscript𝒑train{\bm{p}}_{\rm{train}}bold_italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT is the uniform training data distribution, i.e. 𝒑train=(1n)i=1nΔnsubscript𝒑trainsuperscriptsubscript1𝑛𝑖1𝑛subscriptΔ𝑛{\bm{p}}_{\rm{train}}=\left(\frac{1}{n}\right)_{i=1}^{n}\in\Delta_{n}bold_italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT = ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT.

A.2 Machine Learning Notations

  • n is the number of training examples.

  • d is the dimension of the output.

  • 𝔡𝔡\operatorname*{\mathfrak{d}}fraktur_d is the dimension of the input.

  • m is the number of nodes in each layer.

  • Training data: {(𝒙i,𝒚i)}i=1nsuperscriptsubscriptsubscript𝒙𝑖subscript𝒚𝑖𝑖1𝑛\{({\bm{x}}_{i},{\bm{y}}_{i})\}_{i=1}^{n}{ ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, where for all i{1,,n}𝑖1𝑛i\in\{1,\ldots,n\}italic_i ∈ { 1 , … , italic_n }, 𝒙i𝔡subscript𝒙𝑖superscript𝔡{\bm{x}}_{i}\in{\mathbb{R}}^{\operatorname*{\mathfrak{d}}}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT fraktur_d end_POSTSUPERSCRIPT and 𝒚idsubscript𝒚𝑖superscript𝑑{\bm{y}}_{i}\in{\mathbb{R}}^{d}bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT.

  • h:𝒙𝒚:maps-to𝒙𝒚h:{\bm{x}}\mapsto{\bm{y}}italic_h : bold_italic_x ↦ bold_italic_y is the predictor (deep neural network).

  • 𝜽𝜽{\bm{\theta}}bold_italic_θ is the set of parameters of the predictor.

  • For all i𝑖iitalic_i, hi:𝜽h(𝒙i;𝜽):subscript𝑖maps-to𝜽subscript𝒙𝑖𝜽h_{i}:{\bm{\theta}}\mapsto h({\bm{x}}_{i};{\bm{\theta}})italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT : bold_italic_θ ↦ italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) is the output of the network for example i𝑖iitalic_i as a function of 𝜽𝜽{\bm{\theta}}bold_italic_θ.

  • \mathcal{L}caligraphic_L is the per-example loss function.

  • i:𝒗(𝒗,𝒚i):subscript𝑖maps-to𝒗𝒗subscript𝒚𝑖\mathcal{L}_{i}:{\bm{v}}\mapsto\mathcal{L}({\bm{v}},{\bm{y}}_{i})caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT : bold_italic_v ↦ caligraphic_L ( bold_italic_v , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is the per-example loss function for example i𝑖iitalic_i.

  • We denote by 𝑳𝑳{\bm{L}}bold_italic_L the vector-valued function 𝑳:(𝒗i)i=1n(i(𝒗i))i=1n:𝑳maps-tosuperscriptsubscriptsubscript𝒗𝑖𝑖1𝑛superscriptsubscriptsubscript𝑖subscript𝒗𝑖𝑖1𝑛{\bm{L}}:({\bm{v}}_{i})_{i=1}^{n}\mapsto(\operatorname*{\mathcal{L}}_{i}({\bm{% v}}_{i}))_{i=1}^{n}bold_italic_L : ( bold_italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ↦ ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT.

  • b{1,,n}𝑏1𝑛b\in\{1,\ldots,n\}italic_b ∈ { 1 , … , italic_n } is the batch size.

  • η>0𝜂0\eta>0italic_η > 0 is the learning rate.

  • ERM is short for Empirical Risk Minimization.

A.3 Distributionally Robust Optimisation Notations

  • Forall 𝜽𝜽{\bm{\theta}}bold_italic_θ, R(𝑳(h(𝜽)))=max𝒒Δn𝔼𝒒[(h(𝒙;𝜽),𝒚)]1βDϕ(𝒒𝒑train)𝑅𝑳𝜽subscript𝒒subscriptΔ𝑛subscript𝔼𝒒delimited-[]𝒙𝜽𝒚1𝛽subscript𝐷italic-ϕconditional𝒒subscript𝒑trainR({\bm{L}}(h({\bm{\theta}})))=\max_{{\bm{q}}\in\Delta_{n}}\mathbb{E}_{{\bm{q}}% }\left[\operatorname*{\mathcal{L}}\left(h({\bm{x}};{\bm{\theta}}),{\bm{y}}% \right)\right]-\frac{1}{{\beta}}D_{\phi}({\bm{q}}\|{\bm{p}}_{\rm{train}})italic_R ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) = roman_max start_POSTSUBSCRIPT bold_italic_q ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_q end_POSTSUBSCRIPT [ caligraphic_L ( italic_h ( bold_italic_x ; bold_italic_θ ) , bold_italic_y ) ] - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG italic_D start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( bold_italic_q ∥ bold_italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT ) is the Distributionally Robust Loss evaluated at 𝜽𝜽{\bm{\theta}}bold_italic_θ, where β>0𝛽0{\beta}>0italic_β > 0 is the parameter that adjusts the distributionally robustness. For short, we also used the terms distributionally robust loss or just robust loss for R(𝑳(h(𝜽)))𝑅𝑳𝜽R({\bm{L}}(h({\bm{\theta}})))italic_R ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ).

  • DRO is short for Distributionally Robust Optimisation.

A.4 Miscellaneous

By abuse of notation, and similarly to (Allen-Zhu et al., 2019a), we use the Bachmann-Landau notations to hide constants that do not depend on our main hyper-parameters. Let f𝑓fitalic_f and g𝑔gitalic_g be two scalar functions, we note:

{fO(g)c>0 s.t. fcgfΩ(g)c>0 s.t. fcgf=Θ(g)c1>0 and c2>c1 s.t. c1gfc2gcases𝑓𝑂𝑔iff𝑐0 s.t. 𝑓𝑐𝑔𝑓Ω𝑔iff𝑐0 s.t. 𝑓𝑐𝑔𝑓Θ𝑔iffsubscript𝑐10 and subscript𝑐2subscript𝑐1 s.t. subscript𝑐1𝑔𝑓subscript𝑐2𝑔\left\{\begin{array}[]{ccccc}f\leq O(g)&\iff&\exists c>0&\textup{ s.t. }&f\leq cg% \\ f\geq\Omega(g)&\iff&\exists c>0&\textup{ s.t. }&f\geq cg\\ f=\Theta(g)&\iff&\exists c_{1}>0\textup{ and }\exists c_{2}>c_{1}&\textup{ s.t% . }&c_{1}g\leq f\leq c_{2}g\\ \end{array}\right.{ start_ARRAY start_ROW start_CELL italic_f ≤ italic_O ( italic_g ) end_CELL start_CELL ⇔ end_CELL start_CELL ∃ italic_c > 0 end_CELL start_CELL s.t. end_CELL start_CELL italic_f ≤ italic_c italic_g end_CELL end_ROW start_ROW start_CELL italic_f ≥ roman_Ω ( italic_g ) end_CELL start_CELL ⇔ end_CELL start_CELL ∃ italic_c > 0 end_CELL start_CELL s.t. end_CELL start_CELL italic_f ≥ italic_c italic_g end_CELL end_ROW start_ROW start_CELL italic_f = roman_Θ ( italic_g ) end_CELL start_CELL ⇔ end_CELL start_CELL ∃ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT > 0 and ∃ italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT > italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL s.t. end_CELL start_CELL italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_g ≤ italic_f ≤ italic_c start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_g end_CELL end_ROW end_ARRAY

B Evaluation of the Influence of β𝛽{\beta}italic_β on the Segmentation Performance for BraTS

Table 5: Detailed evaluation on the BraTS 2019 online validation set (125 cases). All the models in this table were trained using the default SGD with Nesterov momentum of nnU-Net (Isensee et al., 2021). Dice scores were computed using the BraTS online plateform for evaluation https://ipp.cbica.upenn.edu/. ERM: Empirical Risk Minimization, DRO: Distributionally Robust Optimization, IS: Importance Sampling is used, IQR: Interquartile range. The best values are in bold.
OptimizationEnhancing TumorWhole TumorTumor Core
problem

Mean

Median

IQR

Mean

Median

IQR

Mean

Median

IQR

ERM

73.0

87.1

15.6

90.7

92.6

5.4

83.9

90.5

14.3

DRO β=10𝛽10{\beta}=10italic_β = 10

74.6

86.8

14.1

90.893.0

5.9

83.4

90.7

14.5

DRO β=10𝛽10{\beta}=10italic_β = 10 IS75.3

86.0

13.3

90.0

91.9

7.0

82.8

89.1

14.3

DRO β=100𝛽100{\beta}=100italic_β = 100

73.4

86.7

14.3

90.6

92.6

6.2

84.590.9

13.7

DRO β=100𝛽100{\beta}=100italic_β = 100 IS

74.5

87.3

13.8

90.6

92.6

5.9

84.1

90.0

12.5
DRO β=1000𝛽1000{\beta}=1000italic_β = 1000

74.5

84.2

33.0

89.5

91.8

5.9

71.1

87.2

41.1

DRO β=1000𝛽1000{\beta}=1000italic_β = 1000 IS

72.2

85.7

15.0

90.3

92.2

6.3

81.1

89.4

15.1

C Importance Sampling Approximation in Algorithm 1

In this section, we give additional details about the approximation made in the computation of the importance weights (step 9 of Algorithm 1).

Let 𝜽𝜽{\bm{\theta}}bold_italic_θ be the parameters of the neural network hhitalic_h, 𝑳=(Li)i=1n𝑳superscriptsubscriptsubscript𝐿𝑖𝑖1𝑛{\bm{L}}=\left(L_{i}\right)_{i=1}^{n}bold_italic_L = ( italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT be the stale per-example loss vector, and let i𝑖iitalic_i be an index in the current batch I𝐼Iitalic_I.

We start from the definition of the importance weight wisubscript𝑤𝑖w_{i}italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for example i𝑖iitalic_i and use the formula for the hardness weighted sampling probabilities of Example 1.

wisubscript𝑤𝑖\displaystyle w_{i}italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT=pinewpioldabsentsuperscriptsubscript𝑝𝑖𝑛𝑒𝑤superscriptsubscript𝑝𝑖𝑜𝑙𝑑\displaystyle=\frac{p_{i}^{new}}{p_{i}^{old}}= divide start_ARG italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n italic_e italic_w end_POSTSUPERSCRIPT end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_o italic_l italic_d end_POSTSUPERSCRIPT end_ARG(19)
=exp(βLinew)exp(βLinew)+jiexp(βLjold)×j=1nexp(βLjold)exp(βLiold)absent𝛽superscriptsubscript𝐿𝑖𝑛𝑒𝑤𝛽superscriptsubscript𝐿𝑖𝑛𝑒𝑤subscript𝑗𝑖𝛽superscriptsubscript𝐿𝑗𝑜𝑙𝑑superscriptsubscript𝑗1𝑛𝛽superscriptsubscript𝐿𝑗𝑜𝑙𝑑𝛽superscriptsubscript𝐿𝑖𝑜𝑙𝑑\displaystyle=\frac{\exp\left({\beta}L_{i}^{new}\right)}{\exp\left({\beta}L_{i% }^{new}\right)+\sum_{j\neq i}\exp\left({\beta}L_{j}^{old}\right)}\times\frac{% \sum_{j=1}^{n}\exp\left({\beta}L_{j}^{old}\right)}{\exp\left({\beta}L_{i}^{old% }\right)}= divide start_ARG roman_exp ( italic_β italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n italic_e italic_w end_POSTSUPERSCRIPT ) end_ARG start_ARG roman_exp ( italic_β italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n italic_e italic_w end_POSTSUPERSCRIPT ) + ∑ start_POSTSUBSCRIPT italic_j ≠ italic_i end_POSTSUBSCRIPT roman_exp ( italic_β italic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_o italic_l italic_d end_POSTSUPERSCRIPT ) end_ARG × divide start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_exp ( italic_β italic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_o italic_l italic_d end_POSTSUPERSCRIPT ) end_ARG start_ARG roman_exp ( italic_β italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_o italic_l italic_d end_POSTSUPERSCRIPT ) end_ARG
exp(β(LinewLiold))absent𝛽superscriptsubscript𝐿𝑖𝑛𝑒𝑤superscriptsubscript𝐿𝑖𝑜𝑙𝑑\displaystyle\approx\exp\left({\beta}(L_{i}^{new}-L_{i}^{old})\right)≈ roman_exp ( italic_β ( italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n italic_e italic_w end_POSTSUPERSCRIPT - italic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_o italic_l italic_d end_POSTSUPERSCRIPT ) )

where we have assumed that the two sums of exponentials are approximately equal.

D Proof of Example 1: Formula of the Sampling Probabilities for the KL Divergence

We give here a simple proof of the formula of the sampling probabilities for the KL divergence as ϕitalic-ϕ\phiitalic_ϕ-divergence (i.e. ϕ:zzlog(z)z+1:italic-ϕmaps-to𝑧𝑧𝑧𝑧1\phi:z\mapsto z\log(z)-z+1italic_ϕ : italic_z ↦ italic_z roman_log ( italic_z ) - italic_z + 1)

𝜽,p¯(𝑳(h(𝜽)))=softmax(β𝑳(h(𝜽)))for-all𝜽¯𝑝𝑳𝜽softmax𝛽𝑳𝜽\forall{\bm{\theta}},\quad\bar{p}({\bm{L}}(h({\bm{\theta}})))=\mathrm{softmax}% \left({\beta}{\bm{L}}(h({\bm{\theta}}))\right)∀ bold_italic_θ , over¯ start_ARG italic_p end_ARG ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) = roman_softmax ( italic_β bold_italic_L ( italic_h ( bold_italic_θ ) ) )
Proof:

For any 𝜽𝜽{\bm{\theta}}bold_italic_θ, the distributionally robust loss for the KL divergence at 𝜽𝜽{\bm{\theta}}bold_italic_θ is given by

R𝑳h(𝜽)𝑅𝑳𝜽\displaystyle R\circ{\bm{L}}\circ h({\bm{\theta}})italic_R ∘ bold_italic_L ∘ italic_h ( bold_italic_θ )=max𝒒Δn(i=1nqiihi(𝜽)1βi=1nqilog(nqi))absentsubscript𝒒subscriptΔ𝑛superscriptsubscript𝑖1𝑛subscript𝑞𝑖subscript𝑖subscript𝑖𝜽1𝛽superscriptsubscript𝑖1𝑛subscript𝑞𝑖𝑛subscript𝑞𝑖\displaystyle=\max_{{\bm{q}}\in\Delta_{n}}\left(\sum_{i=1}^{n}q_{i}{\color[rgb% ]{0,0,0}\mathcal{L}_{i}}\circ~{}h_{i}({\bm{\theta}})-\frac{1}{{\beta}}\sum_{i=% 1}^{n}q_{i}\log\left(nq_{i}\right)\right)= roman_max start_POSTSUBSCRIPT bold_italic_q ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_n italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )
=max𝒒Δni=1n(qiihi(𝜽)1βqilog(nqi))absentsubscript𝒒subscriptΔ𝑛superscriptsubscript𝑖1𝑛subscript𝑞𝑖subscript𝑖subscript𝑖𝜽1𝛽subscript𝑞𝑖𝑛subscript𝑞𝑖\displaystyle=\max_{{\bm{q}}\in\Delta_{n}}\sum_{i=1}^{n}\left(q_{i}{\color[rgb% ]{0,0,0}\mathcal{L}_{i}}\circ~{}h_{i}({\bm{\theta}})-\frac{1}{{\beta}}q_{i}% \log\left(nq_{i}\right)\right)= roman_max start_POSTSUBSCRIPT bold_italic_q ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_n italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )

where we have used that 1ptrain,i=n1subscript𝑝traini𝑛\frac{1}{p_{\rm{train},i}}=ndivide start_ARG 1 end_ARG start_ARG italic_p start_POSTSUBSCRIPT roman_train , roman_i end_POSTSUBSCRIPT end_ARG = italic_n inside the log\logroman_log function. To simplify the notations, let us denote 𝒗=(vi)i=1n:=(ihi(𝜽))i=1n𝒗superscriptsubscriptsubscript𝑣𝑖𝑖1𝑛assignsuperscriptsubscriptsubscript𝑖subscript𝑖𝜽𝑖1𝑛{\bm{v}}=(v_{i})_{i=1}^{n}:=\left(\operatorname*{\mathcal{L}}_{i}\circ~{}h_{i}% ({\bm{\theta}})\right)_{i=1}^{n}bold_italic_v = ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT := ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, and 𝒑¯=(p¯i)i=1n:=𝐩¯(𝑳(h(𝜽)))¯𝒑superscriptsubscriptsubscript¯𝑝𝑖𝑖1𝑛assign¯𝐩𝑳𝜽\bar{{\bm{p}}}=(\bar{p}_{i})_{i=1}^{n}:=\bar{\textbf{p}}({\bm{L}}(h({\bm{% \theta}})))over¯ start_ARG bold_italic_p end_ARG = ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT := over¯ start_ARG p end_ARG ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ). Thus 𝐩¯(𝑳(𝒉(𝜽)))¯𝐩𝑳𝒉𝜽\bar{\textbf{p}}({\bm{L}}({\bm{h}}({\bm{\theta}})))over¯ start_ARG p end_ARG ( bold_italic_L ( bold_italic_h ( bold_italic_θ ) ) ) is, by definition, solution of the optimization problem

max𝒒Δni=1n(qivi1βqilog(nqi))subscript𝒒subscriptΔ𝑛superscriptsubscript𝑖1𝑛subscript𝑞𝑖subscript𝑣𝑖1𝛽subscript𝑞𝑖𝑛subscript𝑞𝑖{\color[rgb]{0,0,0}\max}_{{\bm{q}}\in\Delta_{n}}\sum_{i=1}^{n}\left(q_{i}v_{i}% -\frac{1}{{\beta}}q_{i}\log\left(nq_{i}\right)\right)roman_max start_POSTSUBSCRIPT bold_italic_q ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_n italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )(20)

First, let us remark that the function qi=1nqilog(nqi)maps-to𝑞superscriptsubscript𝑖1𝑛subscript𝑞𝑖𝑛subscript𝑞𝑖q\mapsto\sum_{i=1}^{n}q_{i}\log\left(nq_{i}\right)italic_q ↦ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_n italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) is strictly convex on the non empty closed convex set ΔnsubscriptΔ𝑛\Delta_{n}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT as a sum of strictly convex functions. This implies that the optimization (20) has a unique solution and as a result 𝐩¯(𝑳(h(𝜽)))¯𝐩𝑳𝜽{\color[rgb]{0,0,0}\bar{\textbf{p}}}({\bm{L}}(h({\bm{\theta}})))over¯ start_ARG p end_ARG ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) is well defined.

We now reformulate the optimization problem (20) as a convex smooth constrained optimization problem by writing the condition 𝒒Δn𝒒subscriptΔ𝑛{\bm{q}}\in\Delta_{n}bold_italic_q ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT as constraints.

max𝒒+nsubscript𝒒subscriptsuperscript𝑛\displaystyle{\color[rgb]{0,0,0}\max}_{{\bm{q}}\in{\mathbb{R}}^{n}_{+}}roman_max start_POSTSUBSCRIPT bold_italic_q ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT start_POSTSUBSCRIPT + end_POSTSUBSCRIPT end_POSTSUBSCRIPTi=1n(qivi1βqilog(nqi))superscriptsubscript𝑖1𝑛subscript𝑞𝑖subscript𝑣𝑖1𝛽subscript𝑞𝑖𝑛subscript𝑞𝑖\displaystyle\sum_{i=1}^{n}\left(q_{i}v_{i}-\frac{1}{{\beta}}q_{i}\log\left(nq% _{i}\right)\right)∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_n italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) )(21)
s.t.i=1nqi=1superscriptsubscript𝑖1𝑛subscript𝑞𝑖1\displaystyle\sum_{i=1}^{n}q_{i}=1∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1

There exists a Lagrange multiplier λ𝜆\lambda\in{\mathbb{R}}italic_λ ∈ blackboard_R, such that the solution p¯¯𝑝\bar{p}over¯ start_ARG italic_p end_ARG of (21) is characterized by

{i{1,,n},vi1β(log(np¯i)+1)+λ=0i=1np¯i=1\left\{\begin{aligned} \forall i\in\{1,\ldots,n\},\quad&v_{i}-\frac{1}{{\beta}% }\left(\log\left(n\bar{p}_{i}\right)+1\right)+\lambda=0\\ &\sum_{i=1}^{n}\bar{p}_{i}=1\end{aligned}\right.{ start_ROW start_CELL ∀ italic_i ∈ { 1 , … , italic_n } , end_CELL start_CELL italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG ( roman_log ( italic_n over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + 1 ) + italic_λ = 0 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 end_CELL end_ROW(22)

Which we can rewrite as

{i{1,,n},p¯i=1nexp(β(vi+λ)1)1ni=1nexp(β(vi+λ)1)=1\left\{\begin{aligned} \forall i\in\{1,\ldots,n\},\quad&\bar{p}_{i}=\frac{1}{n% }\exp\left({\beta}\left(v_{i}+\lambda\right)-1\right)\\ &\frac{1}{n}\sum_{i=1}^{n}\exp\left({\beta}\left(v_{i}+\lambda\right)-1\right)% =1\end{aligned}\right.{ start_ROW start_CELL ∀ italic_i ∈ { 1 , … , italic_n } , end_CELL start_CELL over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG roman_exp ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ ) - 1 ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_exp ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ ) - 1 ) = 1 end_CELL end_ROW(23)

The last equality gives

exp(βλ1)=ni=1nexp(βvi)𝛽𝜆1𝑛superscriptsubscript𝑖1𝑛𝛽subscript𝑣𝑖\exp\left({\beta}\lambda-1\right)=\frac{n}{\sum_{i=1}^{n}\exp\left({\beta}v_{i% }\right)}roman_exp ( italic_β italic_λ - 1 ) = divide start_ARG italic_n end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_exp ( italic_β italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG

And by replacing in the formula of the p¯isubscript¯𝑝𝑖\bar{p}_{i}over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT

i{1,,n},p¯i=\displaystyle\forall i\in\{1,\ldots,n\},\quad\bar{p}_{i}=∀ italic_i ∈ { 1 , … , italic_n } , over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT =1nexp(βvi)exp(βλ1)1𝑛𝛽subscript𝑣𝑖𝛽𝜆1\displaystyle\frac{1}{n}\exp\left({\beta}v_{i}\right)\exp\left({\beta}\lambda-% 1\right)divide start_ARG 1 end_ARG start_ARG italic_n end_ARG roman_exp ( italic_β italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_exp ( italic_β italic_λ - 1 )
=\displaystyle==exp(βvi)j=1nexp(βvj)𝛽subscript𝑣𝑖superscriptsubscript𝑗1𝑛𝛽subscript𝑣𝑗\displaystyle\frac{\exp\left({\beta}v_{i}\right)}{{\sum_{j=1}^{n}\exp\left({% \beta}v_{j}\right)}}divide start_ARG roman_exp ( italic_β italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_exp ( italic_β italic_v start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_ARG

Which corresponds to 𝒑¯=softmax(β𝒗)¯𝒑softmax𝛽𝒗\bar{{\bm{p}}}=\mathrm{softmax}\left({\beta}{\bm{v}}\right)\,\,\blacksquareover¯ start_ARG bold_italic_p end_ARG = roman_softmax ( italic_β bold_italic_v ) ■

E Proof of Lemma 4: Regularity Properties of R

For the ease of reading, let us first recall that given a ϕitalic-ϕ\phiitalic_ϕ-divergence Dϕsubscript𝐷italic-ϕD_{\phi}italic_D start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT that satisfies Definition 2, we have defined in (3)

R::𝑅absent\displaystyle R:\,\,italic_R :nsuperscript𝑛\displaystyle{\mathbb{R}}^{n}\rightarrow{\mathbb{R}}blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → blackboard_R(24)
𝒗max𝒒Δniqivi1βDϕ(𝒒𝒑train)maps-to𝒗subscript𝒒subscriptΔ𝑛subscript𝑖subscript𝑞𝑖subscript𝑣𝑖1𝛽subscript𝐷italic-ϕconditional𝒒subscript𝒑train\displaystyle{\bm{v}}\,\mapsto\max_{{\bm{q}}\in\Delta_{n}}\sum_{i}q_{i}v_{i}-% \frac{1}{{\beta}}D_{\phi}({\bm{q}}\|{\bm{p}}_{\rm{train}})bold_italic_v ↦ roman_max start_POSTSUBSCRIPT bold_italic_q ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG italic_D start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( bold_italic_q ∥ bold_italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT )

And in (4)

G::𝐺absent\displaystyle G:\,\,italic_G :nsuperscript𝑛\displaystyle{\mathbb{R}}^{n}\rightarrow{\mathbb{R}}blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → blackboard_R(25)
𝒑1βDϕ(𝒑𝒑train)+δΔn(𝒑)maps-to𝒑1𝛽subscript𝐷italic-ϕconditional𝒑subscript𝒑trainsubscript𝛿subscriptΔ𝑛𝒑\displaystyle{\bm{p}}\,\mapsto\frac{1}{{\beta}}D_{\phi}({\bm{p}}\|{\bm{p}}_{% \rm{train}})+\delta_{\Delta_{n}}({\bm{p}})bold_italic_p ↦ divide start_ARG 1 end_ARG start_ARG italic_β end_ARG italic_D start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT ( bold_italic_p ∥ bold_italic_p start_POSTSUBSCRIPT roman_train end_POSTSUBSCRIPT ) + italic_δ start_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_p )

where δΔnsubscript𝛿subscriptΔ𝑛\delta_{\Delta_{n}}italic_δ start_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT is the characteristic function of the to the n𝑛nitalic_n-simplex ΔnsubscriptΔ𝑛\Delta_{n}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT which is a closed convex set, i.e.

𝒑n,δΔn(𝒑)={0if 𝒑Δn+otherwiseformulae-sequencefor-all𝒑superscript𝑛subscript𝛿subscriptΔ𝑛𝒑cases0if 𝒑subscriptΔ𝑛otherwise\forall{\bm{p}}\in{\mathbb{R}}^{n},\,\,\delta_{\Delta_{n}}({\bm{p}})=\left\{% \begin{array}[]{cl}0&\text{if }{\bm{p}}\in\Delta_{n}\\ +\infty&\text{otherwise}\end{array}\right.∀ bold_italic_p ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , italic_δ start_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_p ) = { start_ARRAY start_ROW start_CELL 0 end_CELL start_CELL if bold_italic_p ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL + ∞ end_CELL start_CELL otherwise end_CELL end_ROW end_ARRAY(26)

We now prove Lemma 4 on the regularity of R𝑅Ritalic_R.

Lemma 8 (Regularity of R – Restated from Lemma 4)

Let ϕitalic-ϕ\phiitalic_ϕ that satisfies Definition 2, G𝐺Gitalic_G and R𝑅Ritalic_R satisfy

G is(nρβ)-strongly convex𝐺 is𝑛𝜌𝛽-strongly convexG\text{ is}\left(\frac{n\rho}{{\beta}}\right)~{}\text{-strongly convex}italic_G is ( divide start_ARG italic_n italic_ρ end_ARG start_ARG italic_β end_ARG ) -strongly convex(27)
R(𝑳(h(𝜽)))=max𝒒n(𝑳(h(𝜽)),𝒒G(𝒒))=G*(𝑳(h(𝜽)))𝑅𝑳𝜽subscript𝒒superscript𝑛𝑳𝜽𝒒𝐺𝒒superscript𝐺𝑳𝜽R({\bm{L}}(h({\bm{\theta}})))=\max_{{\bm{q}}\in{\mathbb{R}}^{n}}\left(\langle{% \bm{L}}(h({\bm{\theta}})),{\bm{q}}\rangle-G({\bm{q}})\right)=G^{*}\left({\bm{L% }}(h({\bm{\theta}}))\right)italic_R ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) = roman_max start_POSTSUBSCRIPT bold_italic_q ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( ⟨ bold_italic_L ( italic_h ( bold_italic_θ ) ) , bold_italic_q ⟩ - italic_G ( bold_italic_q ) ) = italic_G start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_L ( italic_h ( bold_italic_θ ) ) )(28)
R is (βnρ)-gradient Lipschitz continuous.𝑅 is 𝛽𝑛𝜌-gradient Lipschitz continuous.R\text{ is }\left(\frac{{\beta}}{n\rho}\right)~{}\text{-gradient Lipschitz % continuous.}italic_R is ( divide start_ARG italic_β end_ARG start_ARG italic_n italic_ρ end_ARG ) -gradient Lipschitz continuous.(29)
Proof:

ϕitalic-ϕ\phiitalic_ϕ is ρ𝜌\rhoitalic_ρ-strongly convex on [0,n]0𝑛[0,n][ 0 , italic_n ] so

x,y[0,n]2,λ[0,1],ϕ(λx+(1λ)y)λϕ(x)+(1λ)ϕ(y)ρλ(1λ)2|yx|2formulae-sequencefor-all𝑥𝑦superscript0𝑛2formulae-sequencefor-all𝜆01italic-ϕ𝜆𝑥1𝜆𝑦𝜆italic-ϕ𝑥1𝜆italic-ϕ𝑦𝜌𝜆1𝜆2superscript𝑦𝑥2\forall x,y\in[0,n]^{2},\forall\lambda\in[0,1],\phi\left(\lambda x+(1-\lambda)% y\right)\leq\lambda\phi(x)+(1-\lambda)\phi(y)-\frac{\rho\lambda(1-\lambda)}{2}% |y-x|^{2}∀ italic_x , italic_y ∈ [ 0 , italic_n ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , ∀ italic_λ ∈ [ 0 , 1 ] , italic_ϕ ( italic_λ italic_x + ( 1 - italic_λ ) italic_y ) ≤ italic_λ italic_ϕ ( italic_x ) + ( 1 - italic_λ ) italic_ϕ ( italic_y ) - divide start_ARG italic_ρ italic_λ ( 1 - italic_λ ) end_ARG start_ARG 2 end_ARG | italic_y - italic_x | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(30)

Let 𝒑=(pi)i=1n𝒑superscriptsubscriptsubscript𝑝𝑖𝑖1𝑛{\bm{p}}=\left(p_{i}\right)_{i=1}^{n}bold_italic_p = ( italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, 𝒒=(qi)i=1nΔn𝒒superscriptsubscriptsubscript𝑞𝑖𝑖1𝑛subscriptΔ𝑛{\bm{q}}=\left(q_{i}\right)_{i=1}^{n}\in\Delta_{n}bold_italic_q = ( italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, and λ[0,1]𝜆01\lambda\in[0,1]italic_λ ∈ [ 0 , 1 ], using (30) and the convexity of δΔnsubscript𝛿subscriptΔ𝑛\delta_{\Delta_{n}}italic_δ start_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT, we obtain:

G(λ𝒑+(1λ)𝒒)𝐺𝜆𝒑1𝜆𝒒\displaystyle G\left(\lambda{\bm{p}}+(1-\lambda){\bm{q}}\right)italic_G ( italic_λ bold_italic_p + ( 1 - italic_λ ) bold_italic_q )=1βni=1nϕ(nλpi+n(1λ)qi)+δΔn(λ𝒑+(1λ)𝒒)absent1𝛽𝑛superscriptsubscript𝑖1𝑛italic-ϕ𝑛𝜆subscript𝑝𝑖𝑛1𝜆subscript𝑞𝑖subscript𝛿subscriptΔ𝑛𝜆𝒑1𝜆𝒒\displaystyle=\frac{1}{{\beta}n}\sum_{i=1}^{n}\phi\left(n\lambda p_{i}+n(1-% \lambda)q_{i}\right)+\delta_{\Delta_{n}}\left(\lambda{\bm{p}}+(1-\lambda){\bm{% q}}\right)= divide start_ARG 1 end_ARG start_ARG italic_β italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_ϕ ( italic_n italic_λ italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_n ( 1 - italic_λ ) italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_δ start_POSTSUBSCRIPT roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_λ bold_italic_p + ( 1 - italic_λ ) bold_italic_q )(31)
λG(𝒑)+(1λ)G(𝒒)1βni=1nρλ(1λ)2|npinqi|2absent𝜆𝐺𝒑1𝜆𝐺𝒒1𝛽𝑛superscriptsubscript𝑖1𝑛𝜌𝜆1𝜆2superscript𝑛subscript𝑝𝑖𝑛subscript𝑞𝑖2\displaystyle\leq\lambda G({\bm{p}})+(1-\lambda)G({\bm{q}})-\frac{1}{{\beta}n}% \sum_{i=1}^{n}\frac{\rho\lambda(1-\lambda)}{2}|np_{i}-nq_{i}|^{2}≤ italic_λ italic_G ( bold_italic_p ) + ( 1 - italic_λ ) italic_G ( bold_italic_q ) - divide start_ARG 1 end_ARG start_ARG italic_β italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT divide start_ARG italic_ρ italic_λ ( 1 - italic_λ ) end_ARG start_ARG 2 end_ARG | italic_n italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_n italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
λG(𝒑)+(1λ)G(𝒒)nρβλ(1λ)2𝒑𝒒2absent𝜆𝐺𝒑1𝜆𝐺𝒒𝑛𝜌𝛽𝜆1𝜆2superscriptdelimited-∥∥𝒑𝒒2\displaystyle\leq\lambda G({\bm{p}})+(1-\lambda)G({\bm{q}})-\frac{n\rho}{{% \beta}}\frac{\lambda(1-\lambda)}{2}\left\lVert{\bm{p}}-{\bm{q}}\right\rVert^{2}≤ italic_λ italic_G ( bold_italic_p ) + ( 1 - italic_λ ) italic_G ( bold_italic_q ) - divide start_ARG italic_n italic_ρ end_ARG start_ARG italic_β end_ARG divide start_ARG italic_λ ( 1 - italic_λ ) end_ARG start_ARG 2 end_ARG ∥ bold_italic_p - bold_italic_q ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

This proves that G𝐺Gitalic_G is nρβ𝑛𝜌𝛽\frac{n\rho}{{\beta}}divide start_ARG italic_n italic_ρ end_ARG start_ARG italic_β end_ARG-strongly convex.

R=G*𝑅superscript𝐺R=G^{*}italic_R = italic_G start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT is convex, and since G𝐺Gitalic_G is closed and convex, R*=(G*)*=Gsuperscript𝑅superscriptsuperscript𝐺𝐺R^{*}=\left(G^{*}\right)^{*}=Gitalic_R start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT = ( italic_G start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT = italic_G (Hiriart-Urruty and Lemaréchal, 2013). We obtain (28) using Definition 3.

We now show that R𝑅Ritalic_R is Frechet differentiable on nsuperscript𝑛{\mathbb{R}}^{n}blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT. Let 𝒗n𝒗superscript𝑛{\bm{v}}\in{\mathbb{R}}^{n}bold_italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT.

G𝐺Gitalic_G is strongly-convex, so in particular G𝐺Gitalic_G is strictly convex. This implies that the following optimization problem has a unique solution that we denote 𝒑^(𝒗)^𝒑𝒗\hat{{\bm{p}}}({\bm{v}})over^ start_ARG bold_italic_p end_ARG ( bold_italic_v ).

𝒑^(𝒗):=argmax𝒒n(𝒗,𝒒G(𝒒))assign^𝒑𝒗subscriptargmax𝒒superscript𝑛𝒗𝒒𝐺𝒒\hat{{\bm{p}}}({\bm{v}}):=\operatorname*{arg\,max}_{{\bm{q}}\in{\mathbb{R}}^{n% }}\left(\langle{\bm{v}},{\bm{q}}\rangle-G({\bm{q}})\right)over^ start_ARG bold_italic_p end_ARG ( bold_italic_v ) := start_OPERATOR roman_arg roman_max end_OPERATOR start_POSTSUBSCRIPT bold_italic_q ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( ⟨ bold_italic_v , bold_italic_q ⟩ - italic_G ( bold_italic_q ) )(32)

In addition, using the notion of subderivative of convex functions (Hiriart-Urruty and Lemaréchal, 2013, Definition 4.1.5 p.39), we have

𝒑^Δn solution of (32)^𝒑subscriptΔ𝑛 solution of (32)\displaystyle\hat{{\bm{p}}}\in\Delta_{n}\text{ solution of \eqref{eq:max_pb_R_% grad_lip} }over^ start_ARG bold_italic_p end_ARG ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT solution of ( )0𝒗G(𝒑^)absent0𝒗𝐺^𝒑\displaystyle\Longleftrightarrow 0\in{\bm{v}}-\partial G(\hat{{\bm{p}}})⟺ 0 ∈ bold_italic_v - ∂ italic_G ( over^ start_ARG bold_italic_p end_ARG )
𝒗G(𝒑^)absent𝒗𝐺^𝒑\displaystyle\Longleftrightarrow{\bm{v}}\in\partial G(\hat{{\bm{p}}})⟺ bold_italic_v ∈ ∂ italic_G ( over^ start_ARG bold_italic_p end_ARG )
𝒑^G*(𝒗)absent^𝒑superscript𝐺𝒗\displaystyle\Longleftrightarrow\hat{{\bm{p}}}\in\partial G^{*}({\bm{v}})⟺ over^ start_ARG bold_italic_p end_ARG ∈ ∂ italic_G start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_v )
𝒑^R(𝒗)absent^𝒑𝑅𝒗\displaystyle\Longleftrightarrow\hat{{\bm{p}}}\in\partial R({\bm{v}})⟺ over^ start_ARG bold_italic_p end_ARG ∈ ∂ italic_R ( bold_italic_v )

where we have used (Hiriart-Urruty and Lemaréchal, 2013, Proposition 6.1.2 p.39) for the third equivalence, and (28) for the last equivalence.

As a result, R(𝒗)={𝒑^(𝒗)}𝑅𝒗^𝒑𝒗\partial R({\bm{v}})=\{\hat{{\bm{p}}}({\bm{v}})\}∂ italic_R ( bold_italic_v ) = { over^ start_ARG bold_italic_p end_ARG ( bold_italic_v ) }. This implies that R𝑅Ritalic_R admit a gradient at 𝒗𝒗{\bm{v}}bold_italic_v, and

𝒗R(𝒗)=𝒑^(𝒗)subscript𝒗𝑅𝒗^𝒑𝒗\nabla_{{\bm{v}}}R({\bm{v}})=\hat{{\bm{p}}}({\bm{v}})∇ start_POSTSUBSCRIPT bold_italic_v end_POSTSUBSCRIPT italic_R ( bold_italic_v ) = over^ start_ARG bold_italic_p end_ARG ( bold_italic_v )(33)

Since this holds for any 𝒗n𝒗superscript𝑛{\bm{v}}\in{\mathbb{R}}^{n}bold_italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, we deduce that R𝑅Ritalic_R is Fréchet differentiable on nsuperscript𝑛{\mathbb{R}}^{n}blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT. \blacksquare

We are now ready to show that R𝑅Ritalic_R is βnρ𝛽𝑛𝜌\frac{{\beta}}{n\rho}divide start_ARG italic_β end_ARG start_ARG italic_n italic_ρ end_ARG-gradient Lipchitz continuous by using the following lemma (Hiriart-Urruty and Lemaréchal, 2013, Theorem 6.1.2 p.280).

Lemma 9

A necessary and sufficient condition for a convex function f:nnormal-:𝑓normal-→superscript𝑛f\,:{\mathbb{R}}^{n}\rightarrow{\mathbb{R}}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → blackboard_R to be c𝑐citalic_c-strongly convex on a convex set C𝐶Citalic_C is that for all x1,x2Csubscript𝑥1subscript𝑥2𝐶x_{1},x_{2}\in Citalic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ italic_C

s2s1,x2x1cx2x12for all sif(xi),i=1,2.formulae-sequencesubscript𝑠2subscript𝑠1subscript𝑥2subscript𝑥1𝑐superscriptdelimited-∥∥subscript𝑥2subscript𝑥12formulae-sequencefor all subscript𝑠𝑖𝑓subscript𝑥𝑖𝑖12\langle s_{2}-s_{1},x_{2}-x_{1}\rangle\geq c\left\lVert x_{2}-x_{1}\right% \rVert^{2}\quad\text{for all }s_{i}\in\partial f(x_{i}),i=1,2.⟨ italic_s start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_s start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⟩ ≥ italic_c ∥ italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for all italic_s start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ ∂ italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_i = 1 , 2 .

Using this lemma for f=G𝑓𝐺f=Gitalic_f = italic_G, c=nρβ𝑐𝑛𝜌𝛽c=\frac{n\rho}{{\beta}}italic_c = divide start_ARG italic_n italic_ρ end_ARG start_ARG italic_β end_ARG, and C=Δn𝐶subscriptΔ𝑛C=\Delta_{n}italic_C = roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, we obtain:

For all 𝒑1,𝒑2Δnsubscript𝒑1subscript𝒑2subscriptΔ𝑛{\bm{p}}_{1},{\bm{p}}_{2}\in\Delta_{n}bold_italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, for all 𝒗1G(𝒑1)subscript𝒗1𝐺subscript𝒑1{\bm{v}}_{1}\in\partial G({\bm{p}}_{1})bold_italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ ∂ italic_G ( bold_italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ), 𝒗2G(𝒑2)subscript𝒗2𝐺subscript𝒑2{\bm{v}}_{2}\in\partial G({\bm{p}}_{2})bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ ∂ italic_G ( bold_italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ),

𝒗2𝒗1,𝒑2𝒑1nρβ𝒑2𝒑12subscript𝒗2subscript𝒗1subscript𝒑2subscript𝒑1𝑛𝜌𝛽superscriptdelimited-∥∥subscript𝒑2subscript𝒑12\langle{\bm{v}}_{2}-{\bm{v}}_{1},{\bm{p}}_{2}-{\bm{p}}_{1}\rangle\geq\frac{n% \rho}{{\beta}}\left\lVert{\bm{p}}_{2}-{\bm{p}}_{1}\right\rVert^{2}⟨ bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - bold_italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - bold_italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⟩ ≥ divide start_ARG italic_n italic_ρ end_ARG start_ARG italic_β end_ARG ∥ bold_italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - bold_italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

In addition, for i{1, 2}𝑖12i\in\{1,\,2\}italic_i ∈ { 1 , 2 }, 𝒗iG(𝒑i)𝒑iR(𝒗i)={𝒗R(𝒗i)}subscript𝒗𝑖𝐺subscript𝒑𝑖subscript𝒑𝑖𝑅subscript𝒗𝑖subscript𝒗𝑅subscript𝒗𝑖{\bm{v}}_{i}\in\partial G({\bm{p}}_{i})\Longleftrightarrow{\bm{p}}_{i}\in% \partial R({\bm{v}}_{i})=\{\nabla_{{\bm{v}}}R({\bm{v}}_{i})\}bold_italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ ∂ italic_G ( bold_italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ⟺ bold_italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ ∂ italic_R ( bold_italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = { ∇ start_POSTSUBSCRIPT bold_italic_v end_POSTSUBSCRIPT italic_R ( bold_italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) }.

And using Cauchy Schwarz inequality

𝒗2𝒗1𝒑2𝒑1𝒗2𝒗1,𝒑2𝒑1delimited-∥∥subscript𝒗2subscript𝒗1delimited-∥∥subscript𝒑2subscript𝒑1subscript𝒗2subscript𝒗1subscript𝒑2subscript𝒑1\left\lVert{\bm{v}}_{2}-{\bm{v}}_{1}\right\rVert\left\lVert{\bm{p}}_{2}-{\bm{p% }}_{1}\right\rVert\geq\langle{\bm{v}}_{2}-{\bm{v}}_{1},{\bm{p}}_{2}-{\bm{p}}_{% 1}\rangle∥ bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - bold_italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ ∥ bold_italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - bold_italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ ≥ ⟨ bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - bold_italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_p start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - bold_italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⟩

We conclude that

nρβ𝒗R(𝒗2)𝒗R(𝒗1)𝒗2𝒗1𝑛𝜌𝛽delimited-∥∥subscript𝒗𝑅subscript𝒗2subscript𝒗𝑅subscript𝒗1delimited-∥∥subscript𝒗2subscript𝒗1\frac{n\rho}{{\beta}}\left\lVert\nabla_{{\bm{v}}}R({\bm{v}}_{2})-\nabla_{{\bm{% v}}}R({\bm{v}}_{1})\right\rVert\leq\left\lVert{\bm{v}}_{2}-{\bm{v}}_{1}\right\rVertdivide start_ARG italic_n italic_ρ end_ARG start_ARG italic_β end_ARG ∥ ∇ start_POSTSUBSCRIPT bold_italic_v end_POSTSUBSCRIPT italic_R ( bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) - ∇ start_POSTSUBSCRIPT bold_italic_v end_POSTSUBSCRIPT italic_R ( bold_italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∥ ≤ ∥ bold_italic_v start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - bold_italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥

Which implies that R𝑅Ritalic_R is βnρ𝛽𝑛𝜌\frac{{\beta}}{n\rho}divide start_ARG italic_β end_ARG start_ARG italic_n italic_ρ end_ARG-gradient Lipchitz continuous. \blacksquare

F Proof of Lemma 5: Formula of the Distributionally Robust Loss Gradient

We prove Lemma 5 that we restate here for the ease of reading.

Lemma 10 (Stochastic Gradient of the DRO Loss – Restated from Lemma 5)

For all 𝛉𝛉{\bm{\theta}}bold_italic_θ, we have

p¯(𝑳(h(𝜽)))=𝒗R(𝑳(h(𝜽)))¯𝑝𝑳𝜽subscript𝒗𝑅𝑳𝜽\bar{p}({\bm{L}}(h({\bm{\theta}})))=\nabla_{{\bm{v}}}R({\bm{L}}(h({\bm{\theta}% })))over¯ start_ARG italic_p end_ARG ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) = ∇ start_POSTSUBSCRIPT bold_italic_v end_POSTSUBSCRIPT italic_R ( bold_italic_L ( italic_h ( bold_italic_θ ) ) )(34)
𝜽(R𝑳h)(𝜽)=𝔼𝒑¯((𝒉(𝜽)))[𝜽(h(x;𝜽),y)]subscript𝜽𝑅𝑳𝜽subscript𝔼¯𝒑𝒉𝜽delimited-[]subscript𝜽x𝜽𝑦\nabla_{\bm{\theta}}(R\circ{\bm{L}}\circ h)({\bm{\theta}})=\mathbb{E}_{\bar{{% \bm{p}}}(\operatorname*{\mathcal{L}}({\bm{h}}({\bm{\theta}})))}\left[\nabla_{% \bm{\theta}}\operatorname*{\mathcal{L}}\left(h({\textnormal{x}};{\bm{\theta}})% ,y\right)\right]∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ ) = blackboard_E start_POSTSUBSCRIPT over¯ start_ARG bold_italic_p end_ARG ( caligraphic_L ( bold_italic_h ( bold_italic_θ ) ) ) end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_h ( x ; bold_italic_θ ) , italic_y ) ](35)

where 𝒗Rsubscript𝒗𝑅\nabla_{\bm{v}}R∇ start_POSTSUBSCRIPT bold_italic_v end_POSTSUBSCRIPT italic_R is the gradient of R𝑅Ritalic_R with respect to its input.

Proof:

For a given 𝜽𝜽{\bm{\theta}}bold_italic_θ, equality (34) is a special case of (33) for 𝒗=(𝒉(𝜽))𝒗𝒉𝜽{\bm{v}}=\operatorname*{\mathcal{L}}({\bm{h}}({\bm{\theta}}))bold_italic_v = caligraphic_L ( bold_italic_h ( bold_italic_θ ) ).

Then using the chain rule and (34),

𝜽(R𝑳h)(𝜽)subscript𝜽𝑅𝑳𝜽\displaystyle\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}})∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ )=i=1nRvi(𝑳h(𝜽)))𝜽(ihi)(𝜽)\displaystyle=\sum_{i=1}^{n}\frac{\partial R}{\partial v_{i}}({\bm{L}}\circ h(% {\bm{\theta}})))\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}% \circ h_{i})({\bm{\theta}})= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT divide start_ARG ∂ italic_R end_ARG start_ARG ∂ italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG ( bold_italic_L ∘ italic_h ( bold_italic_θ ) ) ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ )
=i=1np¯i(𝑳(h(𝜽)))𝜽(ihi)(𝜽)absentsuperscriptsubscript𝑖1𝑛subscript¯𝑝𝑖𝑳𝜽subscript𝜽subscript𝑖subscript𝑖𝜽\displaystyle=\sum_{i=1}^{n}\bar{p}_{i}({\bm{L}}(h({\bm{\theta}})))\nabla_{{% \bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ )
=𝔼𝒑¯((𝒉(𝜽)))[𝜽(h(x;𝜽),y)]absentsubscript𝔼¯𝒑𝒉𝜽delimited-[]subscript𝜽x𝜽𝑦\displaystyle=\mathbb{E}_{\bar{{\bm{p}}}(\operatorname*{\mathcal{L}}({\bm{h}}(% {\bm{\theta}})))}\left[\nabla_{\bm{\theta}}\operatorname*{\mathcal{L}}\left(h(% {\textnormal{x}};{\bm{\theta}}),y\right)\right]= blackboard_E start_POSTSUBSCRIPT over¯ start_ARG bold_italic_p end_ARG ( caligraphic_L ( bold_italic_h ( bold_italic_θ ) ) ) end_POSTSUBSCRIPT [ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_h ( x ; bold_italic_θ ) , italic_y ) ]

Which concludes the proof. \blacksquare

G Proof of Theorem 7: Distributionally Robust Optimization as Principled Hard Example Mining

In this section, we demonstrate that the proposed hardness weighted sampling can be interpreted as a principled hard example mining method.

Let Dϕsubscript𝐷italic-ϕD_{\phi}italic_D start_POSTSUBSCRIPT italic_ϕ end_POSTSUBSCRIPT an ϕitalic-ϕ\phiitalic_ϕ-divergence satisfying Definition 2, and 𝒗=(vi)i=1nn𝒗superscriptsubscriptsubscript𝑣𝑖𝑖1𝑛superscript𝑛{\bm{v}}=\left(v_{i}\right)_{i=1}^{n}\in{\mathbb{R}}^{n}bold_italic_v = ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT. 𝒗𝒗{\bm{v}}bold_italic_v will play the role of a generic loss vector.

ϕitalic-ϕ\phiitalic_ϕ is strongly convex, and ΔnsubscriptΔ𝑛\Delta_{n}roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT is closed and convex, so the following optimization problem has one and only one solution

max𝒑=(pi)i=1nΔn𝒗,𝒑1βni=1nϕ(npi)subscript𝒑superscriptsubscriptsubscript𝑝𝑖𝑖1𝑛subscriptΔ𝑛𝒗𝒑1𝛽𝑛superscriptsubscript𝑖1𝑛italic-ϕ𝑛subscript𝑝𝑖\max_{{\bm{p}}=\left(p_{i}\right)_{i=1}^{n}\in\Delta_{n}}\langle{\bm{v}},{\bm{% p}}\rangle-\frac{1}{{\beta}n}\sum_{i=1}^{n}\phi(np_{i})roman_max start_POSTSUBSCRIPT bold_italic_p = ( italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ⟨ bold_italic_v , bold_italic_p ⟩ - divide start_ARG 1 end_ARG start_ARG italic_β italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_ϕ ( italic_n italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )(36)

Making the constraints associated with pΔn𝑝subscriptΔ𝑛p\in\Delta_{n}italic_p ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT explicit, this can be rewritten as

max𝒑=(pi)i=1nnsubscript𝒑superscriptsubscriptsubscript𝑝𝑖𝑖1𝑛superscript𝑛\displaystyle\max_{{\bm{p}}=\left(p_{i}\right)_{i=1}^{n}\in{\mathbb{R}}^{n}}roman_max start_POSTSUBSCRIPT bold_italic_p = ( italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_POSTSUBSCRIPT𝒗,𝒑1βni=1nϕ(npi)𝒗𝒑1𝛽𝑛superscriptsubscript𝑖1𝑛italic-ϕ𝑛subscript𝑝𝑖\displaystyle\langle{\bm{v}},{\bm{p}}\rangle-\frac{1}{{\beta}n}\sum_{i=1}^{n}% \phi(np_{i})⟨ bold_italic_v , bold_italic_p ⟩ - divide start_ARG 1 end_ARG start_ARG italic_β italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_ϕ ( italic_n italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )(37)
s.t.i{1,,n},pi0formulae-sequencefor-all𝑖1𝑛subscript𝑝𝑖0\displaystyle\forall i\in\{1,\ldots,n\},\,\,p_{i}\geq 0∀ italic_i ∈ { 1 , … , italic_n } , italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ 0
s.t.i=1npi=1superscriptsubscript𝑖1𝑛subscript𝑝𝑖1\displaystyle\sum_{i=1}^{n}p_{i}=1∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1

There exists KKT multipliers λ𝜆\lambda\in{\mathbb{R}}italic_λ ∈ blackboard_R and i,μi0for-all𝑖subscript𝜇𝑖0\forall i,\,\mu_{i}\geq 0∀ italic_i , italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ 0 such that the solution 𝒑¯=(p¯i)i=1n¯𝒑superscriptsubscriptsubscript¯𝑝𝑖𝑖1𝑛\bar{{\bm{p}}}=\left(\bar{p}_{i}\right)_{i=1}^{n}over¯ start_ARG bold_italic_p end_ARG = ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT satisfies

{i{1,,n},vi1βϕ(np¯i)+λμi=0i{1,,n},μipi=0i{1,,n},pi0i=1np¯i=1\left\{\begin{aligned} \forall i\in\{1,\ldots,n\},\quad&v_{i}-\frac{1}{{\beta}% }\phi^{\prime}(n\bar{p}_{i})+\lambda-\mu_{i}=0\\ \forall i\in\{1,\ldots,n\},\quad&\mu_{i}p_{i}=0\\ \forall i\in\{1,\ldots,n\},\quad&p_{i}\geq 0\\ &\sum_{i=1}^{n}\bar{p}_{i}=1\end{aligned}\right.{ start_ROW start_CELL ∀ italic_i ∈ { 1 , … , italic_n } , end_CELL start_CELL italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_n over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_λ - italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 end_CELL end_ROW start_ROW start_CELL ∀ italic_i ∈ { 1 , … , italic_n } , end_CELL start_CELL italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 end_CELL end_ROW start_ROW start_CELL ∀ italic_i ∈ { 1 , … , italic_n } , end_CELL start_CELL italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ 0 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 end_CELL end_ROW(38)

Since ϕitalic-ϕ\phiitalic_ϕ is continuously differentiable and strongly convex, we have (ϕ)1=(ϕ*)superscriptsuperscriptitalic-ϕ1superscriptsuperscriptitalic-ϕ\left(\phi^{\prime}\right)^{-1}=\left(\phi^{*}\right)^{\prime}( italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, where ϕ*superscriptitalic-ϕ\phi^{*}italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT is the Fenchel conjugate of ϕitalic-ϕ\phiitalic_ϕ (see Hiriart-Urruty and Lemaréchal, 2013, Proposition 6.1.2). As a result, (38) can be rewritten as

{i{1,,n},p¯i=1n(ϕ*)(β(vi+λμi))i{1,,n},μipi=0i{1,,n},pi01ni=1n(ϕ*)(β(vi+λμi))=1\left\{\begin{aligned} \forall i\in\{1,\ldots,n\},\quad&\bar{p}_{i}=\frac{1}{n% }\left(\phi^{*}\right)^{\prime}\left({\beta}(v_{i}+\lambda-\mu_{i})\right)\\ \forall i\in\{1,\ldots,n\},\quad&\mu_{i}p_{i}=0\\ \forall i\in\{1,\ldots,n\},\quad&p_{i}\geq 0\\ &\frac{1}{n}\sum_{i=1}^{n}\left(\phi^{*}\right)^{\prime}\left({\beta}(v_{i}+% \lambda-\mu_{i})\right)=1\end{aligned}\right.{ start_ROW start_CELL ∀ italic_i ∈ { 1 , … , italic_n } , end_CELL start_CELL over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ - italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL ∀ italic_i ∈ { 1 , … , italic_n } , end_CELL start_CELL italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 end_CELL end_ROW start_ROW start_CELL ∀ italic_i ∈ { 1 , … , italic_n } , end_CELL start_CELL italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ 0 end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ - italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) = 1 end_CELL end_ROW(39)

We now show that the KKT multipliers are uniquely defined.

The μisubscript𝜇𝑖\mu_{i}italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s are uniquely defined by v𝑣{\bm{v}}bold_italic_v and λ𝜆\lambdaitalic_λ:

Since i{1,,n},μipi=0,pi0formulae-sequencefor-all𝑖1𝑛formulae-sequencesubscript𝜇𝑖subscript𝑝𝑖0subscript𝑝𝑖0\forall i\in\{1,\ldots,n\},\,\,\mu_{i}p_{i}=0,\,\,p_{i}\geq 0∀ italic_i ∈ { 1 , … , italic_n } , italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 , italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ 0 and μi0subscript𝜇𝑖0\mu_{i}\geq 0italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ 0, for all i{1,,n}for-all𝑖1𝑛\forall i\in\{1,\ldots,n\}∀ italic_i ∈ { 1 , … , italic_n }, either pi=0subscript𝑝𝑖0p_{i}=0italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 or μi=0subscript𝜇𝑖0\mu_{i}=0italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0. In the case pi=0subscript𝑝𝑖0p_{i}=0italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 0 and using (39) it comes (ϕ*)(β(vi+λμi))=0superscriptsuperscriptitalic-ϕ𝛽subscript𝑣𝑖𝜆subscript𝜇𝑖0\left(\phi^{*}\right)^{\prime}\left({\beta}(v_{i}+\lambda-\mu_{i})\right)=0( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ - italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) = 0.

According to Definition 2, ϕitalic-ϕ\phiitalic_ϕ is strongly convex and continuously differentiable, so ϕsuperscriptitalic-ϕ\phi^{\prime}italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT and (ϕ*)=(ϕ)1superscriptsuperscriptitalic-ϕsuperscriptsuperscriptitalic-ϕ1(\phi^{*})^{\prime}=(\phi^{\prime})^{-1}( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = ( italic_ϕ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT are continuous and strictly increasing functions. As a result, it exists a unique μisubscript𝜇𝑖\mu_{i}italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (dependent to 𝒗𝒗{\bm{v}}bold_italic_v and λ𝜆\lambdaitalic_λ) such that:

(ϕ*)(β(vi+λμi))=0superscriptsuperscriptitalic-ϕ𝛽subscript𝑣𝑖𝜆subscript𝜇𝑖0\left(\phi^{*}\right)^{\prime}\left({\beta}(v_{i}+\lambda-\mu_{i})\right)=0( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ - italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) = 0

And (39) can be rewritten as

{i{1,,n},p¯i=ReLU(1n(ϕ*)(β(vi+λ)))=1nReLU((ϕ*)(β(vi+λ)))1ni=1nReLU((ϕ*)(β(vi+λ)))=1\left\{\begin{aligned} \forall i\in\{1,\ldots,n\},\quad&\bar{p}_{i}=\mathrm{% ReLU}\left(\frac{1}{n}\left(\phi^{*}\right)^{\prime}\left({\beta}(v_{i}+% \lambda)\right)\right)=\frac{1}{n}\mathrm{ReLU}\left(\left(\phi^{*}\right)^{% \prime}\left({\beta}(v_{i}+\lambda)\right)\right)\\ &\frac{1}{n}\sum_{i=1}^{n}\mathrm{ReLU}\left(\left(\phi^{*}\right)^{\prime}% \left({\beta}(v_{i}+\lambda)\right)\right)=1\\ \end{aligned}\right.{ start_ROW start_CELL ∀ italic_i ∈ { 1 , … , italic_n } , end_CELL start_CELL over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = roman_ReLU ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ ) ) ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG roman_ReLU ( ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ ) ) ) end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_ReLU ( ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ ) ) ) = 1 end_CELL end_ROW(40)
The KKT multiplier λ𝜆\lambdaitalic_λ is uniquely defined by v𝑣{\bm{v}}bold_italic_v and a continuous function of v𝑣{\bm{v}}bold_italic_v:

Let λ𝜆\lambda\in{\mathbb{R}}italic_λ ∈ blackboard_R that satisfies (40). We have 1ni=1nReLU((ϕ*)(β(vi+λ)))=11𝑛superscriptsubscript𝑖1𝑛ReLUsuperscriptsuperscriptitalic-ϕ𝛽subscript𝑣𝑖𝜆1\frac{1}{n}\sum_{i=1}^{n}\mathrm{ReLU}\left(\left(\phi^{*}\right)^{\prime}% \left({\beta}(v_{i}+\lambda)\right)\right)=1divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_ReLU ( ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ ) ) ) = 1. So there exists at least one index i0subscript𝑖0i_{0}italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT such that

ReLU((ϕ*)(β(vi0+λ)))=(ϕ*)(β(vi0+λ))1ReLUsuperscriptsuperscriptitalic-ϕ𝛽subscript𝑣subscript𝑖0𝜆superscriptsuperscriptitalic-ϕ𝛽subscript𝑣subscript𝑖0𝜆1\mathrm{ReLU}\left(\left(\phi^{*}\right)^{\prime}\left({\beta}(v_{i_{0}}+% \lambda)\right)\right)=\left(\phi^{*}\right)^{\prime}\left({\beta}(v_{i_{0}}+% \lambda)\right)\geq 1roman_ReLU ( ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT + italic_λ ) ) ) = ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT + italic_λ ) ) ≥ 1

Since (ϕ*)1superscriptsuperscriptitalic-ϕ1(\phi^{*})^{-1}( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT is continuous and striclty increasing, λReLU((ϕ*)(β(vi0+λ)))maps-tosuperscript𝜆ReLUsuperscriptsuperscriptitalic-ϕ𝛽subscript𝑣subscript𝑖0superscript𝜆\lambda^{\prime}\mapsto\mathrm{ReLU}\left(\left(\phi^{*}\right)^{\prime}\left(% {\beta}(v_{i_{0}}+\lambda^{\prime})\right)\right)italic_λ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ↦ roman_ReLU ( ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT + italic_λ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ) is continuous and strictly increasing on a neighborhood of λ𝜆\lambdaitalic_λ. In addition ReLUReLU\mathrm{ReLU}roman_ReLU is continuous and increasing, so for all i{1,,n}𝑖1𝑛i\in\{1,\ldots,n\}italic_i ∈ { 1 , … , italic_n }, λReLU((ϕ*)(β(vi+λ)))maps-tosuperscript𝜆ReLUsuperscriptsuperscriptitalic-ϕ𝛽subscript𝑣𝑖superscript𝜆\lambda^{\prime}\mapsto\mathrm{ReLU}\left(\left(\phi^{*}\right)^{\prime}\left(% {\beta}(v_{i}+\lambda^{\prime})\right)\right)italic_λ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ↦ roman_ReLU ( ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ) is a continuous and increasing function.

As a result, λ1ni=1nReLU((ϕ*)(β(vi+λ)))maps-tosuperscript𝜆1𝑛superscriptsubscript𝑖1𝑛ReLUsuperscriptsuperscriptitalic-ϕ𝛽subscript𝑣𝑖superscript𝜆\lambda^{\prime}\mapsto\frac{1}{n}\sum_{i=1}^{n}\mathrm{ReLU}\left(\left(\phi^% {*}\right)^{\prime}\left({\beta}(v_{i}+\lambda^{\prime})\right)\right)italic_λ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ↦ divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_ReLU ( ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ) is a continuous function that is increasing on {\mathbb{R}}blackboard_R, and strictly increasing on a neighborhood of λ𝜆\lambdaitalic_λ. This implies that λ𝜆\lambdaitalic_λ is uniquely defined by 𝒗𝒗{\bm{v}}bold_italic_v, and that 𝒗λ(𝒗)maps-to𝒗𝜆𝒗{\bm{v}}\mapsto\lambda({\bm{v}})bold_italic_v ↦ italic_λ ( bold_italic_v ) is continuous.

G.1 Link between Hard Weighted Sampling and Hard Example Mining

For any pseudo loss vector 𝒗=(vi)i=1nn𝒗superscriptsubscriptsubscript𝑣𝑖𝑖1𝑛superscript𝑛{\bm{v}}=\left(v_{i}\right)_{i=1}^{n}\in{\mathbb{R}}^{n}bold_italic_v = ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, there exists a unique KKT multiplier λ𝜆\lambdaitalic_λ and a unique 𝒑¯¯𝒑\bar{{\bm{p}}}over¯ start_ARG bold_italic_p end_ARG that satisfies (40), so we can define the mapping:

𝒑¯:n:¯𝒑superscript𝑛\displaystyle\bar{{\bm{p}}}:\,\,{\mathbb{R}}^{n}over¯ start_ARG bold_italic_p end_ARG : blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPTΔnabsentsubscriptΔ𝑛\displaystyle\rightarrow\Delta_{n}→ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT(41)
𝒗𝒗\displaystyle{\bm{v}}bold_italic_v𝒑¯(𝒗;λ(𝒗))maps-toabsent¯𝒑𝒗𝜆𝒗\displaystyle\mapsto\bar{{\bm{p}}}({\bm{v}};\lambda({\bm{v}}))↦ over¯ start_ARG bold_italic_p end_ARG ( bold_italic_v ; italic_λ ( bold_italic_v ) )

where for all 𝒗𝒗{\bm{v}}bold_italic_v, λ(𝒗)𝜆𝒗\lambda({\bm{v}})italic_λ ( bold_italic_v ) is the unique λ𝜆\lambda\in{\mathbb{R}}italic_λ ∈ blackboard_R satisfying (40).

We will now demonstrate that each 𝒑¯i0(𝒗)subscript¯𝒑subscript𝑖0𝒗\bar{{\bm{p}}}_{i_{0}}({\bm{v}})over¯ start_ARG bold_italic_p end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_v ) for i0{1,,n}subscript𝑖01𝑛i_{0}\in\{1,\ldots,n\}italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ { 1 , … , italic_n } is an increasing function of visubscript𝑣𝑖v_{i}italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and a decreasing function of the visubscript𝑣𝑖v_{i}italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for ii0𝑖subscript𝑖0i\neq i_{0}italic_i ≠ italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Without loss of generality we assume i0=1subscript𝑖01i_{0}=1italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = 1.

Let 𝒗=(vi)i=1nn𝒗superscriptsubscriptsubscript𝑣𝑖𝑖1𝑛superscript𝑛{\bm{v}}=\left(v_{i}\right)_{i=1}^{n}\in{\mathbb{R}}^{n}bold_italic_v = ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, and ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0. Let us define 𝒗=(vi)i=1nnsuperscript𝒗superscriptsubscriptsuperscriptsubscript𝑣𝑖𝑖1𝑛superscript𝑛{\bm{v}}^{\prime}=\left(v_{i}^{\prime}\right)_{i=1}^{n}\in{\mathbb{R}}^{n}bold_italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, such that v1=v1+ϵsuperscriptsubscript𝑣1subscript𝑣1italic-ϵv_{1}^{\prime}=v_{1}+\epsilonitalic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_v start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_ϵ and i{2,,n},vi=viformulae-sequencefor-all𝑖2𝑛superscriptsubscript𝑣𝑖subscript𝑣𝑖\forall i\in\{2,\ldots,n\},\,\,v_{i}^{\prime}=v_{i}∀ italic_i ∈ { 2 , … , italic_n } , italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. Similarly as in the proof of the uniqueness of λ𝜆\lambdaitalic_λ above, we can show that there exists η>0𝜂0\eta>0italic_η > 0 such that the function

F:λ1ni=1nReLU((ϕ*)(β(vi+λ))):𝐹maps-tosuperscript𝜆1𝑛superscriptsubscript𝑖1𝑛ReLUsuperscriptsuperscriptitalic-ϕ𝛽subscript𝑣𝑖superscript𝜆F:\lambda^{\prime}\mapsto\frac{1}{n}\sum_{i=1}^{n}\mathrm{ReLU}\left(\left(% \phi^{*}\right)^{\prime}\left({\beta}(v_{i}+\lambda^{\prime})\right)\right)italic_F : italic_λ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ↦ divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_ReLU ( ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) )

is continuous and strictly increasing on [λ(𝒗)η,λ(𝒗)+η]𝜆𝒗𝜂𝜆𝒗𝜂[\lambda({\bm{v}})-\eta,\lambda({\bm{v}})+\eta][ italic_λ ( bold_italic_v ) - italic_η , italic_λ ( bold_italic_v ) + italic_η ], and F(λ(𝒗))=1𝐹𝜆𝒗1F(\lambda({\bm{v}}))=1italic_F ( italic_λ ( bold_italic_v ) ) = 1.

vλ(𝒗)maps-to𝑣𝜆𝒗v\mapsto\lambda({\bm{v}})italic_v ↦ italic_λ ( bold_italic_v ) is continuous, so for ϵitalic-ϵ\epsilonitalic_ϵ small enough λ(𝒗)[λ(v)η,λ(𝒗)+η]𝜆superscript𝒗𝜆𝑣𝜂𝜆𝒗𝜂\lambda({\bm{v}}^{\prime})\in[\lambda(v)-\eta,\lambda({\bm{v}})+\eta]italic_λ ( bold_italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∈ [ italic_λ ( italic_v ) - italic_η , italic_λ ( bold_italic_v ) + italic_η ].

Let us now prove by contradiction that λ(𝒗)λ(𝒗)𝜆superscript𝒗𝜆𝒗\lambda({\bm{v}}^{\prime})\leq\lambda({\bm{v}})italic_λ ( bold_italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ≤ italic_λ ( bold_italic_v ). Therefore, let us assume that λ(𝒗)>λ(𝒗)𝜆superscript𝒗𝜆𝒗\lambda({\bm{v}}^{\prime})>\lambda({\bm{v}})italic_λ ( bold_italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) > italic_λ ( bold_italic_v ). Then, as ReLU(ϕ*)ReLUsuperscriptsuperscriptitalic-ϕ\mathrm{ReLU}\circ\left(\phi^{*}\right)^{\prime}roman_ReLU ∘ ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is an increasing function and F𝐹Fitalic_F is strictly increasing on [λ(𝒗)η,λ(𝒗)+η]𝜆𝒗𝜂𝜆𝒗𝜂[\lambda({\bm{v}})-\eta,\lambda({\bm{v}})+\eta][ italic_λ ( bold_italic_v ) - italic_η , italic_λ ( bold_italic_v ) + italic_η ], and ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0 we obtain

11\displaystyle 11=1ni=1nReLU((ϕ*)(β(vi+λ(𝒗))))absent1𝑛superscriptsubscript𝑖1𝑛ReLUsuperscriptsuperscriptitalic-ϕ𝛽superscriptsubscript𝑣𝑖𝜆superscript𝒗\displaystyle=\frac{1}{n}\sum_{i=1}^{n}\mathrm{ReLU}\left(\left(\phi^{*}\right% )^{\prime}\left({\beta}(v_{i}^{\prime}+\lambda({\bm{v}}^{\prime}))\right)\right)= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_ReLU ( ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + italic_λ ( bold_italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ) )
1ni=1nReLU((ϕ*)(β(vi+λ(𝒗))))absent1𝑛superscriptsubscript𝑖1𝑛ReLUsuperscriptsuperscriptitalic-ϕ𝛽subscript𝑣𝑖𝜆superscript𝒗\displaystyle\geq\frac{1}{n}\sum_{i=1}^{n}\mathrm{ReLU}\left(\left(\phi^{*}% \right)^{\prime}\left({\beta}(v_{i}+\lambda({\bm{v}}^{\prime}))\right)\right)≥ divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_ReLU ( ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ ( bold_italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ) )
F(λ(𝒗))absent𝐹𝜆superscript𝒗\displaystyle\geq F(\lambda({\bm{v}}^{\prime}))≥ italic_F ( italic_λ ( bold_italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) )
>F(λ(𝒗))absent𝐹𝜆𝒗\displaystyle>F(\lambda({\bm{v}}))> italic_F ( italic_λ ( bold_italic_v ) )
>1absent1\displaystyle>1> 1

which is a contradiction. As a result

λ(𝒗)λ(𝒗)𝜆superscript𝒗𝜆𝒗\lambda({\bm{v}}^{\prime})\leq\lambda({\bm{v}})italic_λ ( bold_italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ≤ italic_λ ( bold_italic_v )(42)

Using equations (40) and (42), and the fact that ReLU(ϕ*)ReLUsuperscriptsuperscriptitalic-ϕ\mathrm{ReLU}\circ\left(\phi^{*}\right)^{\prime}roman_ReLU ∘ ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is an increasing function, we obtain for all i{2,,n}𝑖2𝑛i\in\{2,\ldots,n\}italic_i ∈ { 2 , … , italic_n }

p¯i(𝒗)subscript¯𝑝𝑖superscript𝒗\displaystyle\bar{p}_{i}({\bm{v}}^{\prime})over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT )=1nReLU((ϕ*)(β(vi+λ(𝒗))))absent1𝑛ReLUsuperscriptsuperscriptitalic-ϕ𝛽superscriptsubscript𝑣𝑖𝜆superscript𝒗\displaystyle=\frac{1}{n}\mathrm{ReLU}\left(\left(\phi^{*}\right)^{\prime}% \left({\beta}(v_{i}^{\prime}+\lambda({\bm{v}}^{\prime}))\right)\right)= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG roman_ReLU ( ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT + italic_λ ( bold_italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ) )(43)
=1nReLU((ϕ*)(β(vi+λ(𝒗))))absent1𝑛ReLUsuperscriptsuperscriptitalic-ϕ𝛽subscript𝑣𝑖𝜆superscript𝒗\displaystyle=\frac{1}{n}\mathrm{ReLU}\left(\left(\phi^{*}\right)^{\prime}% \left({\beta}(v_{i}+\lambda({\bm{v}}^{\prime}))\right)\right)= divide start_ARG 1 end_ARG start_ARG italic_n end_ARG roman_ReLU ( ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ ( bold_italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ) )
1nReLU((ϕ*)(β(vi+λ(𝒗))))absent1𝑛ReLUsuperscriptsuperscriptitalic-ϕ𝛽subscript𝑣𝑖𝜆𝒗\displaystyle\leq\frac{1}{n}\mathrm{ReLU}\left(\left(\phi^{*}\right)^{\prime}% \left({\beta}(v_{i}+\lambda({\bm{v}}))\right)\right)≤ divide start_ARG 1 end_ARG start_ARG italic_n end_ARG roman_ReLU ( ( italic_ϕ start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_β ( italic_v start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_λ ( bold_italic_v ) ) ) )
p¯i(𝒗)absentsubscript¯𝑝𝑖𝒗\displaystyle\leq\bar{p}_{i}({\bm{v}})≤ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_v )

In addition

i=1np¯i(𝒗)=1=i=1np¯i(𝒗)superscriptsubscript𝑖1𝑛subscript¯𝑝𝑖superscript𝒗1superscriptsubscript𝑖1𝑛subscript¯𝑝𝑖𝒗\sum_{i=1}^{n}\bar{p}_{i}({\bm{v}}^{\prime})=1=\sum_{i=1}^{n}{\color[rgb]{% 0,0,0}\bar{p}}_{i}({\bm{v}})∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) = 1 = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_v )

So necessarily

p¯i0(𝒗)p¯i0(𝒗)subscript¯𝑝subscript𝑖0superscript𝒗subscript¯𝑝subscript𝑖0𝒗{\color[rgb]{0,0,0}\bar{p}_{i_{0}}}({\bm{v}}^{\prime})\geq{\color[rgb]{0,0,0}% \bar{p}_{i_{0}}}({\bm{v}})over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ≥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_v )(44)

This holds for any i0subscript𝑖0i_{0}italic_i start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and any 𝒗𝒗{\bm{v}}bold_italic_v, which concludes the proof. \blacksquare

H Proof of Equivalence between (17) and (18): Link between DRO and Percentile Loss

In the DRO optimization problem of equation (18), the optimal 𝒒𝒒{\bm{q}}bold_italic_q for any 𝜽𝜽{\bm{\theta}}bold_italic_θ has the closed-form formula as shown in Appendix D

𝜽,𝒒*(𝜽)=softmax((β(h(𝒙i;𝜽),𝒚i))i=1n)for-all𝜽superscript𝒒𝜽softmaxsuperscriptsubscript𝛽subscript𝒙𝑖𝜽subscript𝒚𝑖𝑖1𝑛\forall{\bm{\theta}},\quad{\bm{q}}^{*}\left({\bm{\theta}}\right)=\mathrm{% softmax}\left(\left({\beta}\operatorname*{\mathcal{L}}\left(h({\bm{x}}_{i};{% \bm{\theta}}),{\bm{y}}_{i}\right)\right)_{i=1}^{n}\right)∀ bold_italic_θ , bold_italic_q start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ( bold_italic_θ ) = roman_softmax ( ( italic_β caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT )

By injecting this in equation (18), we obtain

min𝜽subscript𝜽\displaystyle\min_{{\bm{\theta}}}\,roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPTmax𝒒Δn(i=1nqi(h(𝒙i;𝜽),𝒚i)1βDKL(𝒒1n𝟏))\displaystyle\max_{{\bm{q}}\in\Delta_{n}}\left(\sum_{i=1}^{n}q_{i}% \operatorname*{\mathcal{L}}\left(h({\bm{x}}_{i};{\bm{\theta}}),{\bm{y}}_{i}% \right)-\frac{1}{{\beta}}D_{KL}\left({\bm{q}}\,\biggr{\|}\,\frac{1}{n}\mathbf{% 1}\right)\right)roman_max start_POSTSUBSCRIPT bold_italic_q ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( bold_italic_q ∥ divide start_ARG 1 end_ARG start_ARG italic_n end_ARG bold_1 ) )
=min𝜽absentsubscript𝜽\displaystyle=\min_{{\bm{\theta}}}\,= roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT(i=1nqi*(𝜽)(h(𝒙i;𝜽),𝒚i)1βi=1nqi*(𝜽)log(exp(β(h(𝒙i;𝜽),𝒚i))1nj=1nexp(β(h(𝒙j;𝜽),𝒚j))))superscriptsubscript𝑖1𝑛subscriptsuperscript𝑞𝑖𝜽subscript𝒙𝑖𝜽subscript𝒚𝑖1𝛽superscriptsubscript𝑖1𝑛subscriptsuperscript𝑞𝑖𝜽𝛽subscript𝒙𝑖𝜽subscript𝒚𝑖1𝑛superscriptsubscript𝑗1𝑛𝛽subscript𝒙𝑗𝜽subscript𝒚𝑗\displaystyle\left(\sum_{i=1}^{n}q^{*}_{i}({\bm{\theta}})\operatorname*{% \mathcal{L}}\left(h({\bm{x}}_{i};{\bm{\theta}}),{\bm{y}}_{i}\right)-\frac{1}{{% \beta}}\sum_{i=1}^{n}q^{*}_{i}({\bm{\theta}})\log\left(\frac{\exp\left({\beta}% \operatorname*{\mathcal{L}}\left(h({\bm{x}}_{i};{\bm{\theta}}),{\bm{y}}_{i}% \right)\right)}{\frac{1}{n}\sum_{j=1}^{n}\exp\left({\beta}\operatorname*{% \mathcal{L}}\left(h({\bm{x}}_{j};{\bm{\theta}}),{\bm{y}}_{j}\right)\right)}% \right)\right)( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) roman_log ( divide start_ARG roman_exp ( italic_β caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) end_ARG start_ARG divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_exp ( italic_β caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ) end_ARG ) )
=min𝜽absentsubscript𝜽\displaystyle=\min_{{\bm{\theta}}}\,= roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT(i=1nqi*(𝜽)(h(𝒙i;𝜽),𝒚i)i=1nqi*(𝜽)1βlog(exp(β(h(𝒙i;𝜽),𝒚i)))\displaystyle\left(\sum_{i=1}^{n}q^{*}_{i}({\bm{\theta}})\operatorname*{% \mathcal{L}}\left(h({\bm{x}}_{i};{\bm{\theta}}),{\bm{y}}_{i}\right)-\sum_{i=1}% ^{n}q^{*}_{i}({\bm{\theta}})\frac{1}{{\beta}}\log\left(\exp\left({\beta}% \operatorname*{\mathcal{L}}\left(h({\bm{x}}_{i};{\bm{\theta}}),{\bm{y}}_{i}% \right)\right)\right)\right.( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) divide start_ARG 1 end_ARG start_ARG italic_β end_ARG roman_log ( roman_exp ( italic_β caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) )
+1β(i=1nqi*(𝜽))×log(1nj=1nexp(β(h(𝒙j;𝜽),𝒚j))))\displaystyle+\frac{1}{{\beta}}\left.\left(\sum_{i=1}^{n}q^{*}_{i}({\bm{\theta% }})\right)\times\log\left(\frac{1}{n}\sum_{j=1}^{n}\exp\left({\beta}% \operatorname*{\mathcal{L}}\left(h({\bm{x}}_{j};{\bm{\theta}}),{\bm{y}}_{j}% \right)\right)\right)\right)+ divide start_ARG 1 end_ARG start_ARG italic_β end_ARG ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) × roman_log ( divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_exp ( italic_β caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ) ) )

Since the first two terms cancel each other and i=1nqi*(𝜽)=1superscriptsubscript𝑖1𝑛subscriptsuperscript𝑞𝑖𝜽1\sum_{i=1}^{n}q^{*}_{i}({\bm{\theta}})=1∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) = 1, we obtain

min𝜽subscript𝜽\displaystyle\min_{{\bm{\theta}}}\,roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPTmax𝒒Δn(i=1nqi(h(𝒙i;𝜽),𝒚i)1βDKL(𝒒1n𝟏))\displaystyle\max_{{\bm{q}}\in\Delta_{n}}\left(\sum_{i=1}^{n}q_{i}% \operatorname*{\mathcal{L}}\left(h({\bm{x}}_{i};{\bm{\theta}}),{\bm{y}}_{i}% \right)-\frac{1}{{\beta}}D_{KL}\left({\bm{q}}\,\biggr{\|}\,\frac{1}{n}\mathbf{% 1}\right)\right)roman_max start_POSTSUBSCRIPT bold_italic_q ∈ roman_Δ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG italic_D start_POSTSUBSCRIPT italic_K italic_L end_POSTSUBSCRIPT ( bold_italic_q ∥ divide start_ARG 1 end_ARG start_ARG italic_n end_ARG bold_1 ) )
=min𝜽absentsubscript𝜽\displaystyle=\min_{{\bm{\theta}}}\,= roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT1βlog(j=1nexp(β(h(𝒙j;𝜽),𝒚j)))1βlog(n)1𝛽superscriptsubscript𝑗1𝑛𝛽subscript𝒙𝑗𝜽subscript𝒚𝑗1𝛽𝑛\displaystyle\frac{1}{{\beta}}\log\left(\sum_{j=1}^{n}\exp\left({\beta}% \operatorname*{\mathcal{L}}\left(h({\bm{x}}_{j};{\bm{\theta}}),{\bm{y}}_{j}% \right)\right)\right)-\frac{1}{{\beta}}\log\left(n\right)divide start_ARG 1 end_ARG start_ARG italic_β end_ARG roman_log ( ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_exp ( italic_β caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ) ) - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG roman_log ( italic_n )
=min𝜽absentsubscript𝜽\displaystyle=\min_{{\bm{\theta}}}\,= roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT1βlog(j=1nexp(β(h(𝒙j;𝜽),𝒚j)))1𝛽superscriptsubscript𝑗1𝑛𝛽subscript𝒙𝑗𝜽subscript𝒚𝑗\displaystyle\frac{1}{{\beta}}\log\left(\sum_{j=1}^{n}\exp\left({\beta}% \operatorname*{\mathcal{L}}\left(h({\bm{x}}_{j};{\bm{\theta}}),{\bm{y}}_{j}% \right)\right)\right)divide start_ARG 1 end_ARG start_ARG italic_β end_ARG roman_log ( ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_exp ( italic_β caligraphic_L ( italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ; bold_italic_θ ) , bold_italic_y start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ) )

which is equivalent to the optimization problem (17) because the term 1βlog(n)1𝛽𝑛\frac{1}{{\beta}}\log\left(n\right)divide start_ARG 1 end_ARG start_ARG italic_β end_ARG roman_log ( italic_n ) above and the term 1βlog(αn)1𝛽𝛼𝑛\frac{1}{{\beta}}\log\left(\alpha n\right)divide start_ARG 1 end_ARG start_ARG italic_β end_ARG roman_log ( italic_α italic_n ) in (17) are independent of 𝜽𝜽{\bm{\theta}}bold_italic_θ \blacksquare

I Proof of Theorem 6: convergence of SGD with Hardness Weighted Sampling for Over-parameterized Deep Neural Networks with ReLU

In this section, we provide the proof of Theorem 6. This generalizes the convergence of SGD for empirical risk minimization in (Allen-Zhu et al., 2019a, Theorem 2) to the convergence of SGD and our proposed hardness weighted sampler for distributionally robust optimization.

We start by describing in details the assumptions made for our convergence result in Section I.1.

In Section I.2, we restate Theorem 6 using the assumptions and notations previously introduced in Section A.

In Section I.3, we give the proof of the convergence theorem. We focus on providing theoretical tools that could be used to generalize any convergence result for ERM using SGD to DRO using SGD with hardness weighted sampling as described in Algorithm 1.

I.1 Assumptions

Our analysis is based on the results developed in (Allen-Zhu et al., 2019a) which is a simplified version of (Allen-Zhu et al., 2019b). Improving on those theoretical results would automatically improve our results as well.

In the following we state our assumptions on the neural network hhitalic_h, and the per-example loss function \operatorname*{\mathcal{L}}caligraphic_L.

Assumption I.1 (Deep Neural Network)

In this section, we use the following notations and assumptions similar to (Allen-Zhu et al., 2019a):

  • h is a fully connected neural network with L+2𝐿2L+2italic_L + 2 layers, ReLUReLU\mathrm{ReLU}roman_ReLU as activation functions, and m𝑚mitalic_m nodes in each hidden layer

  • For all i{1,,n}𝑖1𝑛i\in\{1,\ldots,n\}italic_i ∈ { 1 , … , italic_n }, we denote hi:𝜽hi(xi;𝜽):subscript𝑖maps-to𝜽subscript𝑖subscriptx𝑖𝜽h_{i}:{\bm{\theta}}\mapsto h_{i}({\textnormal{x}}_{i};{\bm{\theta}})italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT : bold_italic_θ ↦ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ; bold_italic_θ ) the d𝑑ditalic_d-dimensional output scores of hhitalic_h applied to example xisubscriptx𝑖{\textnormal{x}}_{i}x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT of dimension 𝔡𝔡\operatorname*{\mathfrak{d}}fraktur_d.

  • For all i{1,,n}𝑖1𝑛i\in\{1,\ldots,n\}italic_i ∈ { 1 , … , italic_n }, we denote i:h(h,yi):subscript𝑖maps-tosubscripty𝑖\operatorname*{\mathcal{L}}_{i}:h\mapsto\operatorname*{\mathcal{L}}\left(h,{% \textnormal{y}}_{i}\right)caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT : italic_h ↦ caligraphic_L ( italic_h , y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) where yisubscripty𝑖{\textnormal{y}}_{i}y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is the ground truth associated to example i𝑖iitalic_i.

  • 𝜽=(𝜽l)l=0L+1𝜽superscriptsubscriptsubscript𝜽𝑙𝑙0𝐿1{\bm{\theta}}=\left({\bm{\theta}}_{l}\right)_{l=0}^{L+1}bold_italic_θ = ( bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_l = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT is the set of parameters of the neural network h, where 𝜽lsubscript𝜽𝑙{\bm{\theta}}_{l}bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT is the set of weights for layer l𝑙litalic_l with 𝜽0𝔡×msubscript𝜽0superscript𝔡𝑚{\bm{\theta}}_{0}\in{\mathbb{R}}^{\operatorname*{\mathfrak{d}}\times m}bold_italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT fraktur_d × italic_m end_POSTSUPERSCRIPT, 𝜽L+1m×dsubscript𝜽𝐿1superscript𝑚𝑑{\bm{\theta}}_{L+1}\in{\mathbb{R}}^{m\times d}bold_italic_θ start_POSTSUBSCRIPT italic_L + 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_d end_POSTSUPERSCRIPT, and 𝜽lm×msubscript𝜽𝑙superscript𝑚𝑚{\bm{\theta}}_{l}\in{\mathbb{R}}^{m\times m}bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_m end_POSTSUPERSCRIPT for any other l𝑙litalic_l.

  • (Data separation) It exists δ>0𝛿0\delta>0italic_δ > 0 such that for all i,j{1,,n}𝑖𝑗1𝑛i,j\in\{1,\ldots,n\}italic_i , italic_j ∈ { 1 , … , italic_n }, if ij,xixjδformulae-sequence𝑖𝑗delimited-∥∥subscript𝑥𝑖subscript𝑥𝑗𝛿i\neq j,\left\lVert x_{i}-x_{j}\right\rVert\geq\deltaitalic_i ≠ italic_j , ∥ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ ≥ italic_δ.

  • We assume mΩ(d×poly(n,L,δ1))𝑚Ω𝑑poly𝑛𝐿superscript𝛿1m\geq\Omega(d\times\textup{poly}(n,L,\delta^{-1}))italic_m ≥ roman_Ω ( italic_d × poly ( italic_n , italic_L , italic_δ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) ) for some sufficiently large polynomial poly, and δO(1L)𝛿𝑂1𝐿\delta\geq O\left(\frac{1}{L}\right)italic_δ ≥ italic_O ( divide start_ARG 1 end_ARG start_ARG italic_L end_ARG ). We refer the reader to (Allen-Zhu et al., 2019a) for details about the polynomial poly.

  • The parameters 𝜽=(𝜽l)l=0L+1𝜽superscriptsubscriptsubscript𝜽𝑙𝑙0𝐿1{\bm{\theta}}=\left({\bm{\theta}}_{l}\right)_{l=0}^{L+1}bold_italic_θ = ( bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_l = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT are initialized at random such that:

    • [𝜽0]i,j𝒩(0,2m)similar-tosubscriptdelimited-[]subscript𝜽0𝑖𝑗𝒩02𝑚\left[{\bm{\theta}}_{0}\right]_{i,j}\sim\mathcal{N}\left(0,\frac{2}{m}\right)[ bold_italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , divide start_ARG 2 end_ARG start_ARG italic_m end_ARG ) for every (i,j){1,,m}×{1,,𝔡}𝑖𝑗1𝑚1𝔡(i,j)\in\{1,\ldots,m\}\times\{1,\ldots,\operatorname*{\mathfrak{d}}\}( italic_i , italic_j ) ∈ { 1 , … , italic_m } × { 1 , … , fraktur_d }

    • [𝜽l]i,j𝒩(0,2m)similar-tosubscriptdelimited-[]subscript𝜽𝑙𝑖𝑗𝒩02𝑚\left[{\bm{\theta}}_{l}\right]_{i,j}\sim\mathcal{N}\left(0,\frac{2}{m}\right)[ bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , divide start_ARG 2 end_ARG start_ARG italic_m end_ARG ) for every (i,j){1,,m}2𝑖𝑗superscript1𝑚2(i,j)\in\{1,\ldots,m\}^{2}( italic_i , italic_j ) ∈ { 1 , … , italic_m } start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and l{1,,L}𝑙1𝐿l\in\{1,\ldots,L\}italic_l ∈ { 1 , … , italic_L }

    • [𝜽L+1]i,j𝒩(0,1d)similar-tosubscriptdelimited-[]subscript𝜽𝐿1𝑖𝑗𝒩01𝑑\left[{\bm{\theta}}_{L+1}\right]_{i,j}\sim\mathcal{N}\left(0,\frac{1}{d}\right)[ bold_italic_θ start_POSTSUBSCRIPT italic_L + 1 end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , divide start_ARG 1 end_ARG start_ARG italic_d end_ARG ) for every (i,j){1,,d}×{1,,m}𝑖𝑗1𝑑1𝑚(i,j)\in\{1,\ldots,d\}\times\{1,\ldots,m\}( italic_i , italic_j ) ∈ { 1 , … , italic_d } × { 1 , … , italic_m }

Assumption I.2 (Regularity of \operatorname*{\mathcal{L}}caligraphic_L)

There exists C()>0𝐶normal-∇0C(\nabla\operatorname*{\mathcal{L}})>0italic_C ( ∇ caligraphic_L ) > 0 and C()>0𝐶0C(\operatorname*{\mathcal{L}})>0italic_C ( caligraphic_L ) > 0 such that for all i, isubscript𝑖\operatorname*{\mathcal{L}}_{i}caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is a C()𝐶normal-∇C(\nabla\operatorname*{\mathcal{L}})italic_C ( ∇ caligraphic_L )-gradient Lipschitz continuous, C()𝐶C(\operatorname*{\mathcal{L}})italic_C ( caligraphic_L )-Lipschitz continuous, and bounded (potentially non-convex) function. When the optimization is performed on a closed convex set, the existence of C()𝐶normal-∇C(\nabla\operatorname*{\mathcal{L}})italic_C ( ∇ caligraphic_L ) implies that there exists a constant A()>0𝐴normal-∇0A(\nabla\operatorname*{\mathcal{L}})>0italic_A ( ∇ caligraphic_L ) > 0 that bounds the gradients of isubscript𝑖\operatorname*{\mathcal{L}}_{i}caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT for all i.

I.2 Convergence theorem (restated)

In this section, we restate the convergence Theorem 6 for SGD with hardness weighted sampling and stale per-example loss vector.

As an intermediate step, we will first generalize the convergence of SGD in (Allen-Zhu et al., 2019a, Theorem 2) to the minimization of the distributionally robust loss using SGD and an exact hardness weighted sampling (10), i.e. with an exact per-example loss vector.

Theorem 11 (Convergence with exact per-example loss vector)

Let batch size 1bn1𝑏𝑛1\leq b\leq n1 ≤ italic_b ≤ italic_n, and ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0. Under assumption I.1 and assumption I.2, suppose there exists constants C1,C2,C3>0subscript𝐶1subscript𝐶2subscript𝐶30C_{1},\,C_{2},\,C_{3}>0italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT > 0 such that the number of hidden units satisfies mC1(dϵ1×poly(n,L,δ1))𝑚subscript𝐶1𝑑superscriptitalic-ϵ1poly𝑛𝐿superscript𝛿1m\geq C_{1}(d\epsilon^{-1}\times\textup{poly}(n,L,\delta^{-1}))italic_m ≥ italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_d italic_ϵ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT × poly ( italic_n , italic_L , italic_δ start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ) ), δ(C2L)𝛿subscript𝐶2𝐿\delta\geq\left(\frac{C_{2}}{L}\right)italic_δ ≥ ( divide start_ARG italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG italic_L end_ARG ), and the learning rate be ηexact=C3(min(1,αn2ρβC()2+2nρC())×bδdpoly(n,L)mlog2(m))subscript𝜂𝑒𝑥𝑎𝑐𝑡subscript𝐶31𝛼superscript𝑛2𝜌𝛽𝐶superscript22𝑛𝜌𝐶normal-∇𝑏𝛿𝑑poly𝑛𝐿𝑚superscript2𝑚\eta_{exact}=C_{3}\left(\min\left(1,\,\frac{\alpha n^{2}\rho}{{\beta}C(% \operatorname*{\mathcal{L}})^{2}+2n\rho C(\nabla\operatorname*{\mathcal{L}})}% \right)\times\frac{b\delta d}{\textup{poly}(n,L)m\log^{2}(m)}\right)italic_η start_POSTSUBSCRIPT italic_e italic_x italic_a italic_c italic_t end_POSTSUBSCRIPT = italic_C start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ( roman_min ( 1 , divide start_ARG italic_α italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_ρ end_ARG start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 italic_n italic_ρ italic_C ( ∇ caligraphic_L ) end_ARG ) × divide start_ARG italic_b italic_δ italic_d end_ARG start_ARG poly ( italic_n , italic_L ) italic_m roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) end_ARG ). There exists constants C4,C5>0subscript𝐶4subscript𝐶50C_{4},\,C_{5}>0italic_C start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT > 0 such that with probability at least 1exp(C4(log2(m)))1subscript𝐶4superscript2𝑚1-\exp\left(-C_{4}(\log^{2}(m))\right)1 - roman_exp ( - italic_C start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ( roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) ) ) over the randomness of the initialization and the mini-batches, SGD with hardness weighted sampling and exact per-example loss vector guarantees 𝛉(R𝐋h)(𝛉)ϵdelimited-∥∥subscriptnormal-∇𝛉𝑅𝐋𝛉italic-ϵ\left\lVert\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}})\right% \rVert\leq\epsilon∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ ) ∥ ≤ italic_ϵ after T=C5(Ln3ηexactδϵ2)𝑇subscript𝐶5𝐿superscript𝑛3subscript𝜂𝑒𝑥𝑎𝑐𝑡𝛿superscriptitalic-ϵ2T=C_{5}\left(\frac{Ln^{3}}{\eta_{exact}\delta\epsilon^{2}}\right)italic_T = italic_C start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT ( divide start_ARG italic_L italic_n start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT end_ARG start_ARG italic_η start_POSTSUBSCRIPT italic_e italic_x italic_a italic_c italic_t end_POSTSUBSCRIPT italic_δ italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) iterations.

The proof can be found in Appendix I.3.4.

α=min𝜽minip¯i(𝑳(𝜽))𝛼subscript𝜽subscript𝑖subscript¯𝑝𝑖𝑳𝜽\alpha=\min_{{\bm{\theta}}}\min_{i}\bar{p}_{i}({\bm{L}}({\bm{\theta}}))italic_α = roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_θ ) ) is a lower bound on the sampling probabilities. For the Kullback-Leibler ϕitalic-ϕ\phiitalic_ϕ-divergence, and for any ϕitalic-ϕ\phiitalic_ϕ-divergence satisfying Definition 2 with a robustness parameter β𝛽{\beta}italic_β small enough, we have α>0𝛼0\alpha>0italic_α > 0. We refer the reader to (Allen-Zhu et al., 2019a, Theorem 2) for the values of the constants C1,C2,C3,C4,C5subscript𝐶1subscript𝐶2subscript𝐶3subscript𝐶4subscript𝐶5C_{1},\,C_{2},\,C_{3},\,C_{4},\,C_{5}italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT and the definitions of the polynomials.

Compared to (Allen-Zhu et al., 2019a, Theorem 2) only the learning rate differs. The min(1,.)\min(1,\,.\,)roman_min ( 1 , . ) operation in the formula for ηexactsubscript𝜂𝑒𝑥𝑎𝑐𝑡\eta_{exact}italic_η start_POSTSUBSCRIPT italic_e italic_x italic_a italic_c italic_t end_POSTSUBSCRIPT allows us to guarantee that ηexactηsubscript𝜂𝑒𝑥𝑎𝑐𝑡superscript𝜂\eta_{exact}\leq\eta^{\prime}italic_η start_POSTSUBSCRIPT italic_e italic_x italic_a italic_c italic_t end_POSTSUBSCRIPT ≤ italic_η start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT where ηsuperscript𝜂\eta^{\prime}italic_η start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is the learning rate of (Allen-Zhu et al., 2019a, Theorem 2).

It is worth noting that for the KL ϕitalic-ϕ\phiitalic_ϕ-divergence, ρ=1n𝜌1𝑛\rho=\frac{1}{n}italic_ρ = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG. In addition, in the limit β0𝛽0{\beta}\rightarrow 0italic_β → 0, which corresponds to ERM, we have α1n𝛼1𝑛\alpha\rightarrow\frac{1}{n}italic_α → divide start_ARG 1 end_ARG start_ARG italic_n end_ARG. As a result, we recover exactly Theorem 2 of (Allen-Zhu et al., 2019a) as extended in their Appendix A for any smooth loss function \operatorname*{\mathcal{L}}caligraphic_L that satisfies assumption I.2 with C()=1𝐶1C(\nabla\operatorname*{\mathcal{L}})=1italic_C ( ∇ caligraphic_L ) = 1.

We now restate the convergence of SGD with hardness weighted sampling and a stale per-example loss vector as in Algorithm 1.

Theorem 12 (Convergence with a stale per-example loss vector)

Let batch size 1bn1𝑏𝑛1\leq b\leq n1 ≤ italic_b ≤ italic_n, and ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0. Under the conditions of Theorem 11, the same notations, and with the learning rate ηstale=C6min(1,αρd3/2δblog(11α)βC()A()Lm3/2n3/2log2(m))×ηexactsubscript𝜂𝑠𝑡𝑎𝑙𝑒subscript𝐶61𝛼𝜌superscript𝑑32𝛿𝑏11𝛼𝛽𝐶𝐴normal-∇𝐿superscript𝑚32superscript𝑛32superscript2𝑚subscript𝜂𝑒𝑥𝑎𝑐𝑡\eta_{stale}=C_{6}\min\left(1,\,\frac{\alpha\rho d^{3/2}\delta b\log\left(% \frac{1}{1-\alpha}\right)}{{\beta}C(\operatorname*{\mathcal{L}})A(\nabla% \operatorname*{\mathcal{L}})Lm^{3/2}n^{3/2}\log^{2}(m)}\right)\times\eta_{exact}italic_η start_POSTSUBSCRIPT italic_s italic_t italic_a italic_l italic_e end_POSTSUBSCRIPT = italic_C start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT roman_min ( 1 , divide start_ARG italic_α italic_ρ italic_d start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT italic_δ italic_b roman_log ( divide start_ARG 1 end_ARG start_ARG 1 - italic_α end_ARG ) end_ARG start_ARG italic_β italic_C ( caligraphic_L ) italic_A ( ∇ caligraphic_L ) italic_L italic_m start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) end_ARG ) × italic_η start_POSTSUBSCRIPT italic_e italic_x italic_a italic_c italic_t end_POSTSUBSCRIPT for a constant C6>0subscript𝐶60C_{6}>0italic_C start_POSTSUBSCRIPT 6 end_POSTSUBSCRIPT > 0. With probability at least 1exp(C4(log2(m)))1subscript𝐶4superscript2𝑚1-\exp\left(-C_{4}(\log^{2}(m))\right)1 - roman_exp ( - italic_C start_POSTSUBSCRIPT 4 end_POSTSUBSCRIPT ( roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) ) ) over the randomness of the initialization and the mini-batches, SGD with hardness weighted sampling and stale per-example loss vector guarantees 𝛉(R𝐋h)(𝛉)ϵdelimited-∥∥subscriptnormal-∇𝛉𝑅𝐋𝛉italic-ϵ\left\lVert\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}})\right% \rVert\leq\epsilon∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ ) ∥ ≤ italic_ϵ after T=C5(Ln3ηstaleδϵ2)𝑇subscript𝐶5𝐿superscript𝑛3subscript𝜂𝑠𝑡𝑎𝑙𝑒𝛿superscriptitalic-ϵ2T=C_{5}\left(\frac{Ln^{3}}{\eta_{stale}\delta\epsilon^{2}}\right)italic_T = italic_C start_POSTSUBSCRIPT 5 end_POSTSUBSCRIPT ( divide start_ARG italic_L italic_n start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT end_ARG start_ARG italic_η start_POSTSUBSCRIPT italic_s italic_t italic_a italic_l italic_e end_POSTSUBSCRIPT italic_δ italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) iterations.

The proof can be found in Appendix I.4.

C()>0𝐶0C(\operatorname*{\mathcal{L}})>0italic_C ( caligraphic_L ) > 0 is a constant such that \operatorname*{\mathcal{L}}caligraphic_L is C()𝐶C(\operatorname*{\mathcal{L}})italic_C ( caligraphic_L )-Lipschitz continuous, and A()>0𝐴0A(\nabla\operatorname*{\mathcal{L}})>0italic_A ( ∇ caligraphic_L ) > 0 is a constant that bounds the gradient of \operatorname*{\mathcal{L}}caligraphic_L with respect to its input. C()𝐶C(\operatorname*{\mathcal{L}})italic_C ( caligraphic_L ) and A()𝐴A(\nabla\operatorname*{\mathcal{L}})italic_A ( ∇ caligraphic_L ) are guaranteed to exist under assumptions I.1.

Compared to Theorem 11 only the learning rate differs. Similarly to Theorem 11, when β𝛽{\beta}italic_β tends to zero we recover Theorem 2 of (Allen-Zhu et al., 2019a).

It is worth noting that when β𝛽{\beta}italic_β increases, αρd3/2δblog(11α)βC()A()Lm3/2n3/2log2(m)𝛼𝜌superscript𝑑32𝛿𝑏11𝛼𝛽𝐶𝐴𝐿superscript𝑚32superscript𝑛32superscript2𝑚\frac{\alpha\rho d^{3/2}\delta b\log\left(\frac{1}{1-\alpha}\right)}{{\beta}C(% \operatorname*{\mathcal{L}})A(\nabla\operatorname*{\mathcal{L}})Lm^{3/2}n^{3/2% }\log^{2}(m)}divide start_ARG italic_α italic_ρ italic_d start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT italic_δ italic_b roman_log ( divide start_ARG 1 end_ARG start_ARG 1 - italic_α end_ARG ) end_ARG start_ARG italic_β italic_C ( caligraphic_L ) italic_A ( ∇ caligraphic_L ) italic_L italic_m start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) end_ARG decreases. This implies that ηstalesubscript𝜂𝑠𝑡𝑎𝑙𝑒\eta_{stale}italic_η start_POSTSUBSCRIPT italic_s italic_t italic_a italic_l italic_e end_POSTSUBSCRIPT decreases faster than ηexactsubscript𝜂𝑒𝑥𝑎𝑐𝑡\eta_{exact}italic_η start_POSTSUBSCRIPT italic_e italic_x italic_a italic_c italic_t end_POSTSUBSCRIPT when β𝛽{\beta}italic_β increases. This was to be expected since the error that is made by using the stale per-example loss vector instead of the exact loss increases when β𝛽{\beta}italic_β increases.

I.3 Proofs of convergence

In this section, we prove the results of Therem 11 and 12.

For the ease of reading the proof, we remind here the chain rules for the distributionally robust loss that we are going to use intensively in the following proofs.

Chain rule for the derivative of RL𝑅𝐿R\circ{\bm{L}}italic_R ∘ bold_italic_L with respect to the network outputs hhitalic_h:

h(R𝑳)(h(𝜽))subscript𝑅𝑳𝜽\displaystyle\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}))∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) )=(hi(R𝑳)(h(𝜽)))i=1nabsentsuperscriptsubscriptsubscriptsubscript𝑖𝑅𝑳𝜽𝑖1𝑛\displaystyle=\left(\nabla_{h_{i}}(R\circ{\bm{L}})(h({\bm{\theta}}))\right)_{i% =1}^{n}= ( ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) ) ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT(45)
i{1,n},hi(R𝑳)(h(𝜽))for-all𝑖1𝑛subscriptsubscript𝑖𝑅𝑳𝜽\displaystyle\forall i\in\{1,\ldots n\},\quad\nabla_{h_{i}}(R\circ{\bm{L}})(h(% {\bm{\theta}}))∀ italic_i ∈ { 1 , … italic_n } , ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) )=j=1nRvj(𝑳(h(𝜽)))hij(hj(𝜽))absentsuperscriptsubscript𝑗1𝑛𝑅subscript𝑣𝑗𝑳𝜽subscriptsubscript𝑖subscript𝑗subscript𝑗𝜽\displaystyle=\sum_{j=1}^{n}\frac{\partial R}{\partial v_{j}}({\bm{L}}(h({\bm{% \theta}})))\nabla_{h_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{j}}(h_{j}({\bm{\theta% }}))= ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT divide start_ARG ∂ italic_R end_ARG start_ARG ∂ italic_v start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( bold_italic_θ ) )
=p¯i(𝑳(h(𝜽)))hii(hi(𝜽))absentsubscript¯𝑝𝑖𝑳𝜽subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽\displaystyle=\bar{p}_{i}({\bm{L}}(h({\bm{\theta}})))\nabla_{h_{i}}{\color[rgb% ]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}))= over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) )
Chain rule for the derivative of RLh𝑅𝐿R\circ{\bm{L}}\circ hitalic_R ∘ bold_italic_L ∘ italic_h with respect to the network parameters θ𝜃{\bm{\theta}}bold_italic_θ:

𝜽(R𝑳h)(𝜽)subscript𝜽𝑅𝑳𝜽\displaystyle\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}})∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ )=i=1nθhi(𝜽)hi(R𝑳)(h(𝜽))absentsuperscriptsubscript𝑖1𝑛subscript𝜃subscript𝑖𝜽subscriptsubscript𝑖𝑅𝑳𝜽\displaystyle=\sum_{i=1}^{n}\nabla_{\theta}h_{i}({\bm{\theta}})\nabla_{h_{i}}(% R\circ{\bm{L}})(h({\bm{\theta}}))= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) )(46)
=i=1np¯i(𝑳(h(𝜽)))θhi(𝜽)hii(hi(𝜽))absentsuperscriptsubscript𝑖1𝑛subscript¯𝑝𝑖𝑳𝜽subscript𝜃subscript𝑖𝜽subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽\displaystyle=\sum_{i=1}^{n}\bar{p}_{i}({\bm{L}}(h({\bm{\theta}})))\nabla_{% \theta}h_{i}({\bm{\theta}})\nabla_{h_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h% _{i}({\bm{\theta}}))= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) )
=i=1np¯i(𝑳(h(𝜽))𝜽(ihi)(𝜽))absentsuperscriptsubscript𝑖1𝑛subscript¯𝑝𝑖𝑳𝜽subscript𝜽subscript𝑖subscript𝑖𝜽\displaystyle=\sum_{i=1}^{n}\bar{p}_{i}({\bm{L}}(h({\bm{\theta}}))\nabla_{{\bm% {\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}))= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) )

where for all i{1,n}𝑖1𝑛i\in\{1,\ldots n\}italic_i ∈ { 1 , … italic_n }, θhi(𝜽)subscript𝜃subscript𝑖𝜽\nabla_{\theta}h_{i}({\bm{\theta}})∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) is the transpose of the Jacobian matrix of hisubscript𝑖h_{i}italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT as a function of 𝜽𝜽{\bm{\theta}}bold_italic_θ.

I.3.1 Proof that R o L is one-sided gradient Lipchitz

This property that R𝑳𝑅𝑳R\circ{\bm{L}}italic_R ∘ bold_italic_L is one-sided gradient Lipschitz is a key element for the proof of the semi-smoothness theorem for the distributionally robust loss Theorem 13.

Under Definition 2 for the ϕitalic-ϕ\phiitalic_ϕ-divergence, we have shown that R𝑅Ritalic_R is βnρ𝛽𝑛𝜌\frac{{\beta}}{n\rho}divide start_ARG italic_β end_ARG start_ARG italic_n italic_ρ end_ARG-gradient Lipchitz continuous (Lemma 4). And under assumption I.2, for all i𝑖iitalic_i, isubscript𝑖\operatorname*{\mathcal{L}}_{i}caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is C()𝐶C(\operatorname*{\mathcal{L}})italic_C ( caligraphic_L )-Lipschitz continuous and C()𝐶C(\nabla\operatorname*{\mathcal{L}})italic_C ( ∇ caligraphic_L )-gradient Lipschitz continuous.

Let 𝒛=(𝒛i)i=1n,𝒛=(𝒛i)i=1ndnformulae-sequence𝒛superscriptsubscriptsubscript𝒛𝑖𝑖1𝑛superscript𝒛superscriptsubscriptsuperscriptsubscript𝒛𝑖𝑖1𝑛superscript𝑑𝑛{\bm{z}}=({\bm{z}}_{i})_{i=1}^{n},{\bm{z}}^{\prime}=({\bm{z}}_{i}^{\prime})_{i% =1}^{n}\in{\mathbb{R}}^{dn}bold_italic_z = ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT , bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d italic_n end_POSTSUPERSCRIPT.

We want to show that R𝑳𝑅𝑳R\circ{\bm{L}}italic_R ∘ bold_italic_L is one-sided gradient Lipschitz, i.e. we want to prove the existence of a constant C>0𝐶0C>0italic_C > 0, independent to 𝒛𝒛{\bm{z}}bold_italic_z and 𝒛superscript𝒛{\bm{z}}^{\prime}bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, such that:

𝒛(R𝑳)(𝒛)z(R𝑳)(𝒛),𝒛𝒛C𝒛𝒛2subscript𝒛𝑅𝑳𝒛subscript𝑧𝑅𝑳superscript𝒛𝒛superscript𝒛𝐶superscriptdelimited-∥∥𝒛superscript𝒛2\langle\nabla_{{\bm{z}}}(R\circ{\bm{L}})({\bm{z}})-\nabla_{z}(R\circ{\bm{L}})(% {\bm{z}}^{\prime}),{\bm{z}}-{\bm{z}}^{\prime}\rangle\leq C\left\lVert{\bm{z}}-% {\bm{z}}^{\prime}\right\rVert^{2}⟨ ∇ start_POSTSUBSCRIPT bold_italic_z end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( bold_italic_z ) - ∇ start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , bold_italic_z - bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩ ≤ italic_C ∥ bold_italic_z - bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

We have

𝒛(R𝑳)(𝒛)𝒛(\displaystyle\langle\nabla_{{\bm{z}}}(R\circ{\bm{L}})({\bm{z}})-\nabla_{{\bm{z% }}}(⟨ ∇ start_POSTSUBSCRIPT bold_italic_z end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( bold_italic_z ) - ∇ start_POSTSUBSCRIPT bold_italic_z end_POSTSUBSCRIPT (R𝑳)(𝒛),𝒛𝒛\displaystyle R\circ{\bm{L}})({\bm{z}}^{\prime}),{\bm{z}}-{\bm{z}}^{\prime}\rangleitalic_R ∘ bold_italic_L ) ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , bold_italic_z - bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩(47)
=i=1n𝒛i(R𝑳)(𝒛)𝒛i(R𝑳)(𝒛),𝒛i𝒛iabsentsuperscriptsubscript𝑖1𝑛subscriptsubscript𝒛𝑖𝑅𝑳𝒛subscriptsubscript𝒛𝑖𝑅𝑳superscript𝒛subscript𝒛𝑖superscriptsubscript𝒛𝑖\displaystyle=\sum_{i=1}^{n}\langle\nabla_{{\bm{z}}_{i}}(R\circ{\bm{L}})({\bm{% z}})-\nabla_{{\bm{z}}_{i}}(R\circ{\bm{L}})({\bm{z}}^{\prime}),{\bm{z}}_{i}-{% \bm{z}}_{i}^{\prime}\rangle= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ⟨ ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( bold_italic_z ) - ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩
=i=1np¯i(𝑳(𝒛))𝒛ii(𝒛i)p¯i(𝑳(𝒛))𝒛ii(𝒛i),𝒛i𝒛iabsentsuperscriptsubscript𝑖1𝑛subscript¯𝑝𝑖𝑳𝒛subscriptsubscript𝒛𝑖subscript𝑖subscript𝒛𝑖subscript¯𝑝𝑖𝑳superscript𝒛subscriptsubscript𝒛𝑖subscript𝑖superscriptsubscript𝒛𝑖subscript𝒛𝑖superscriptsubscript𝒛𝑖\displaystyle=\sum_{i=1}^{n}\langle\bar{p}_{i}({\bm{L}}({\bm{z}}))\nabla_{{\bm% {z}}_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}({\bm{z}}_{i})-\bar{p}_{i}({\bm{L}% }({\bm{z}}^{\prime}))\nabla_{{\bm{z}}_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(% {\bm{z}}_{i}^{\prime}),{\bm{z}}_{i}-{\bm{z}}_{i}^{\prime}\rangle= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ⟨ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_z ) ) ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩
=i=1np¯i(𝑳(𝒛))𝒛ii(𝒛i)𝒛ii(𝒛i),𝒛i𝒛iabsentsuperscriptsubscript𝑖1𝑛subscript¯𝑝𝑖𝑳𝒛subscriptsubscript𝒛𝑖subscript𝑖subscript𝒛𝑖subscriptsubscript𝒛𝑖subscript𝑖superscriptsubscript𝒛𝑖subscript𝒛𝑖superscriptsubscript𝒛𝑖\displaystyle=\sum_{i=1}^{n}\bar{p}_{i}({\bm{L}}({\bm{z}}))\langle\nabla_{{\bm% {z}}_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}({\bm{z}}_{i})-\nabla_{{\bm{z}}_{i% }}{\color[rgb]{0,0,0}\mathcal{L}_{i}}({\bm{z}}_{i}^{\prime}),{\bm{z}}_{i}-{\bm% {z}}_{i}^{\prime}\rangle= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_z ) ) ⟨ ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩
+i=1n(p¯i(𝑳(𝒛))p¯i(𝑳(𝒛)))𝒛ii(𝒛i),𝒛i𝒛isuperscriptsubscript𝑖1𝑛subscript¯𝑝𝑖𝑳𝒛subscript¯𝑝𝑖𝑳superscript𝒛subscriptsubscript𝒛𝑖subscript𝑖superscriptsubscript𝒛𝑖subscript𝒛𝑖superscriptsubscript𝒛𝑖\displaystyle\quad+\sum_{i=1}^{n}\left(\bar{p}_{i}({\bm{L}}({\bm{z}}))-\bar{p}% _{i}({\bm{L}}({\bm{z}}^{\prime}))\right)\langle\nabla_{{\bm{z}}_{i}}{\color[% rgb]{0,0,0}\mathcal{L}_{i}}({\bm{z}}_{i}^{\prime}),{\bm{z}}_{i}-{\bm{z}}_{i}^{% \prime}\rangle+ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_z ) ) - over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ) ⟨ ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩

Where for all i{1,,n}𝑖1𝑛i\in\{1,\ldots,n\}italic_i ∈ { 1 , … , italic_n } we have used the chain rule

𝒛i(R𝑳)(𝒛)=j=1nR𝒗j((𝒛))𝒛ij(𝒛j)=p¯i(𝑳(𝒛))𝒛ii(𝒛i)subscriptsubscript𝒛𝑖𝑅𝑳𝒛superscriptsubscript𝑗1𝑛𝑅subscript𝒗𝑗𝒛subscriptsubscript𝒛𝑖subscript𝑗subscript𝒛𝑗subscript¯𝑝𝑖𝑳𝒛subscriptsubscript𝒛𝑖subscript𝑖subscript𝒛𝑖\nabla_{{\bm{z}}_{i}}(R\circ{\bm{L}})({\bm{z}})=\sum_{j=1}^{n}\frac{\partial R% }{\partial{\bm{v}}_{j}}(\operatorname*{\mathcal{L}}({\bm{z}}))\nabla_{{\bm{z}}% _{i}}{\color[rgb]{0,0,0}\mathcal{L}_{j}}({\bm{z}}_{j})=\bar{p}_{i}({\bm{L}}({% \bm{z}}))\nabla_{{\bm{z}}_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}({\bm{z}}_{i})∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( bold_italic_z ) = ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT divide start_ARG ∂ italic_R end_ARG start_ARG ∂ bold_italic_v start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_ARG ( caligraphic_L ( bold_italic_z ) ) ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_z ) ) ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )

Let

A=|i=1np¯i(𝑳(𝒛))𝒛ii(𝒛i)𝒛ii(𝒛i),𝒛i𝒛i|𝐴superscriptsubscript𝑖1𝑛subscript¯𝑝𝑖𝑳𝒛subscriptsubscript𝒛𝑖subscript𝑖subscript𝒛𝑖subscriptsubscript𝒛𝑖subscript𝑖superscriptsubscript𝒛𝑖subscript𝒛𝑖superscriptsubscript𝒛𝑖A=\left|\sum_{i=1}^{n}\bar{p}_{i}({\bm{L}}({\bm{z}}))\langle\nabla_{{\bm{z}}_{% i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}({\bm{z}}_{i})-\nabla_{{\bm{z}}_{i}}{% \color[rgb]{0,0,0}\mathcal{L}_{i}}({\bm{z}}_{i}^{\prime}),{\bm{z}}_{i}-{\bm{z}% }_{i}^{\prime}\rangle\right|italic_A = | ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_z ) ) ⟨ ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩ |

For all i𝑖iitalic_i, isubscript𝑖\operatorname*{\mathcal{L}}_{i}caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is C()𝐶C(\nabla\operatorname*{\mathcal{L}})italic_C ( ∇ caligraphic_L )-gradient Lipchitz continuous, so using Cauchy-Schwarz inequality

Ai=1nC()𝒛i𝒛i2=C()𝒛𝒛2𝐴superscriptsubscript𝑖1𝑛𝐶superscriptdelimited-∥∥subscript𝒛𝑖superscriptsubscript𝒛𝑖2𝐶superscriptdelimited-∥∥𝒛superscript𝒛2A\leq\sum_{i=1}^{n}C(\nabla\operatorname*{\mathcal{L}})\left\lVert{\bm{z}}_{i}% -{\bm{z}}_{i}^{\prime}\right\rVert^{2}=C(\nabla\operatorname*{\mathcal{L}})% \left\lVert{\bm{z}}-{\bm{z}}^{\prime}\right\rVert^{2}italic_A ≤ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_C ( ∇ caligraphic_L ) ∥ bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_C ( ∇ caligraphic_L ) ∥ bold_italic_z - bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(48)

Let

B=|i=1n(p¯i(𝑳(𝒛))p¯i(𝑳(𝒛)))𝒛ii(𝒛i),𝒛i𝒛i|𝐵superscriptsubscript𝑖1𝑛subscript¯𝑝𝑖𝑳𝒛subscript¯𝑝𝑖𝑳superscript𝒛subscriptsubscript𝒛𝑖subscript𝑖superscriptsubscript𝒛𝑖subscript𝒛𝑖superscriptsubscript𝒛𝑖B=\left|\sum_{i=1}^{n}\left(\bar{p}_{i}({\bm{L}}({\bm{z}}))-\bar{p}_{i}({\bm{L% }}({\bm{z}}^{\prime}))\right)\langle\nabla_{{\bm{z}}_{i}}{\color[rgb]{0,0,0}% \mathcal{L}_{i}}({\bm{z}}_{i}^{\prime}),{\bm{z}}_{i}-{\bm{z}}_{i}^{\prime}% \rangle\right|italic_B = | ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_z ) ) - over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ) ⟨ ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩ |

Using the triangular inequality:

B𝐵\displaystyle Bitalic_B|i=1n(p¯i(𝑳(𝒛))p¯i(𝑳(𝒛)))(i(𝒛i)i(𝒛i)|\displaystyle\leq\left|\sum_{i=1}^{n}\left(\bar{p}_{i}({\bm{L}}({\bm{z}}))-% \bar{p}_{i}({\bm{L}}({\bm{z}}^{\prime}))\right)({\color[rgb]{0,0,0}\mathcal{L}% _{i}}({\bm{z}}_{i})-{\color[rgb]{0,0,0}\mathcal{L}_{i}}({\bm{z}}_{i}^{\prime})\right|≤ | ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_z ) ) - over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ) ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) |(49)
+|i=1n(p¯i(𝑳(𝒛))p¯i(𝑳(𝒛)))(i(𝒛i)+𝒛ii(𝒛i),𝒛i𝒛ii(𝒛i)|\displaystyle\quad+\left|\sum_{i=1}^{n}\left(\bar{p}_{i}({\bm{L}}({\bm{z}}))-% \bar{p}_{i}({\bm{L}}({\bm{z}}^{\prime}))\right)({\color[rgb]{0,0,0}\mathcal{L}% _{i}}({\bm{z}}_{i}^{\prime})+\langle\nabla_{{\bm{z}}_{i}}{\color[rgb]{0,0,0}% \mathcal{L}_{i}}({\bm{z}}_{i}^{\prime}),{\bm{z}}_{i}-{\bm{z}}_{i}^{\prime}% \rangle-{\color[rgb]{0,0,0}\mathcal{L}_{i}}({\bm{z}}_{i})\right|+ | ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_z ) ) - over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ) ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) + ⟨ ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩ - caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) |
𝒗R(𝑳(𝒛))𝒗R(𝑳(𝒛)),(𝒛)(𝒛)absentsubscript𝒗𝑅𝑳𝒛subscript𝒗𝑅𝑳superscript𝒛𝒛superscript𝒛\displaystyle\leq\left\langle\nabla_{{\color[rgb]{0,0,0}{\bm{v}}}}R({\bm{L}}({% \bm{z}}))-\nabla_{{\color[rgb]{0,0,0}{\bm{v}}}}R({\bm{L}}({\bm{z}}^{\prime})),% \operatorname*{\mathcal{L}}({\bm{z}})-\operatorname*{\mathcal{L}}({\bm{z}}^{% \prime})\right\rangle≤ ⟨ ∇ start_POSTSUBSCRIPT bold_italic_v end_POSTSUBSCRIPT italic_R ( bold_italic_L ( bold_italic_z ) ) - ∇ start_POSTSUBSCRIPT bold_italic_v end_POSTSUBSCRIPT italic_R ( bold_italic_L ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) , caligraphic_L ( bold_italic_z ) - caligraphic_L ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ⟩
+2i=1n|i(𝒛i)+𝒛ii(𝒛i),𝒛i𝒛ii(𝒛i)|2superscriptsubscript𝑖1𝑛subscript𝑖superscriptsubscript𝒛𝑖subscriptsubscript𝒛𝑖subscript𝑖superscriptsubscript𝒛𝑖subscript𝒛𝑖superscriptsubscript𝒛𝑖subscript𝑖subscript𝒛𝑖\displaystyle\quad+2\sum_{i=1}^{n}\left|{\color[rgb]{0,0,0}\mathcal{L}_{i}}({% \bm{z}}_{i}^{\prime})+\langle\nabla_{{\bm{z}}_{i}}{\color[rgb]{0,0,0}\mathcal{% L}_{i}}({\bm{z}}_{i}^{\prime}),{\bm{z}}_{i}-{\bm{z}}_{i}^{\prime}\rangle-% \operatorname*{\mathcal{L}}_{i}({\bm{z}}_{i})\right|+ 2 ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT | caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) + ⟨ ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩ - caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) |
βnρ𝑳(𝒛)𝑳(𝒛)2+2C()2𝒛𝒛2absent𝛽𝑛𝜌superscriptdelimited-∥∥𝑳𝒛𝑳superscript𝒛22𝐶2superscriptdelimited-∥∥𝒛superscript𝒛2\displaystyle\leq\frac{{\beta}}{n\rho}\left\lVert{\bm{L}}({\bm{z}})-{\bm{L}}({% \bm{z}}^{\prime})\right\rVert^{2}+2\frac{C(\nabla\operatorname*{\mathcal{L}})}% {2}\left\lVert{\bm{z}}-{\bm{z}}^{\prime}\right\rVert^{2}≤ divide start_ARG italic_β end_ARG start_ARG italic_n italic_ρ end_ARG ∥ bold_italic_L ( bold_italic_z ) - bold_italic_L ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 divide start_ARG italic_C ( ∇ caligraphic_L ) end_ARG start_ARG 2 end_ARG ∥ bold_italic_z - bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
(βC()2nρ+C())𝒛𝒛2absent𝛽𝐶superscript2𝑛𝜌𝐶superscriptdelimited-∥∥𝒛superscript𝒛2\displaystyle\leq\left(\frac{{\beta}C(\operatorname*{\mathcal{L}})^{2}}{n\rho}% +C(\nabla\operatorname*{\mathcal{L}})\right)\left\lVert{\bm{z}}-{\bm{z}}^{% \prime}\right\rVert^{2}≤ ( divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + italic_C ( ∇ caligraphic_L ) ) ∥ bold_italic_z - bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

Combining equations (47), (48) and (49) we finally obtain

𝒛(R𝑳)(𝒛)𝒛(R𝑳)(𝒛),𝒛𝒛(βC()2nρ+2C())𝒛𝒛2subscript𝒛𝑅𝑳𝒛subscript𝒛𝑅𝑳superscript𝒛𝒛superscript𝒛𝛽𝐶superscript2𝑛𝜌2𝐶superscriptdelimited-∥∥𝒛superscript𝒛2\langle\nabla_{{\bm{z}}}(R\circ{\bm{L}})({\bm{z}})-\nabla_{{\bm{z}}}(R\circ{% \bm{L}})({\bm{z}}^{\prime}),{\bm{z}}-{\bm{z}}^{\prime}\rangle\leq\left(\frac{{% \beta}C(\operatorname*{\mathcal{L}})^{2}}{n\rho}+2C(\nabla\operatorname*{% \mathcal{L}})\right)\left\lVert{\bm{z}}-{\bm{z}}^{\prime}\right\rVert^{2}⟨ ∇ start_POSTSUBSCRIPT bold_italic_z end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( bold_italic_z ) - ∇ start_POSTSUBSCRIPT bold_italic_z end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) , bold_italic_z - bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ⟩ ≤ ( divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + 2 italic_C ( ∇ caligraphic_L ) ) ∥ bold_italic_z - bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(50)

From there, we can obtain the following inequality that will be used for the proof of the semi-smoothness property in Theorem 13

R(𝑳(𝒛))R(𝑳(𝒛))𝒛(R𝑳)(𝒛),𝒛𝒛𝑅𝑳superscript𝒛𝑅𝑳𝒛subscript𝒛𝑅𝑳𝒛superscript𝒛𝒛\displaystyle R({\bm{L}}({\bm{z}}^{\prime}))-R({\bm{L}}({\bm{z}}))-\langle% \nabla_{{\bm{z}}}(R\circ{\bm{L}})({\bm{z}}),{\bm{z}}^{\prime}-{\bm{z}}\rangleitalic_R ( bold_italic_L ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) - italic_R ( bold_italic_L ( bold_italic_z ) ) - ⟨ ∇ start_POSTSUBSCRIPT bold_italic_z end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( bold_italic_z ) , bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_z ⟩(51)
=t=01𝒛(R𝑳)(𝒛+t(𝒛𝒛))𝒛(R𝑳)(𝒛),𝒛𝒛𝑑tabsentsuperscriptsubscript𝑡01subscript𝒛𝑅𝑳𝒛𝑡superscript𝒛𝒛subscript𝒛𝑅𝑳𝒛superscript𝒛𝒛differential-d𝑡\displaystyle\quad=\int_{t=0}^{1}\langle\nabla_{{\bm{z}}}(R\circ{\bm{L}})\left% ({\bm{z}}+t({\bm{z}}^{\prime}-{\bm{z}})\right)-\nabla_{{\bm{z}}}(R\circ{\bm{L}% })({\bm{z}}),{\bm{z}}^{\prime}-{\bm{z}}\rangle dt= ∫ start_POSTSUBSCRIPT italic_t = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ⟨ ∇ start_POSTSUBSCRIPT bold_italic_z end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( bold_italic_z + italic_t ( bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_z ) ) - ∇ start_POSTSUBSCRIPT bold_italic_z end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( bold_italic_z ) , bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_z ⟩ italic_d italic_t
12(βC()2nρ+2C())𝒛𝒛2absent12𝛽𝐶superscript2𝑛𝜌2𝐶superscriptdelimited-∥∥𝒛superscript𝒛2\displaystyle\quad\leq\frac{1}{2}\left(\frac{{\beta}C(\operatorname*{\mathcal{% L}})^{2}}{n\rho}+2C(\nabla\operatorname*{\mathcal{L}})\right)\left\lVert{\bm{z% }}-{\bm{z}}^{\prime}\right\rVert^{2}≤ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + 2 italic_C ( ∇ caligraphic_L ) ) ∥ bold_italic_z - bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

I.3.2 Semi-smoothness property of the distributionally robust loss

We prove the following lemma which is a generalization of Theorem 4 in (Allen-Zhu et al., 2019a) for the distributionally robust loss.

Theorem 13 (Semi-smoothness of the distributionally robust loss)

Let ω[Ω(d3/2m3/2L3/2log3/2(m)),O(1L4.5log3(m))]𝜔normal-Ωsuperscript𝑑32superscript𝑚32superscript𝐿32superscript32𝑚𝑂1superscript𝐿4.5superscript3𝑚\omega\in\left[\Omega\left(\frac{d^{3/2}}{m^{3/2}L^{3/2}\log^{3/2}(m)}\right),% O\left(\frac{1}{L^{4.5}\log^{3}(m)}\right)\right]italic_ω ∈ [ roman_Ω ( divide start_ARG italic_d start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_m start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT italic_L start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT roman_log start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT ( italic_m ) end_ARG ) , italic_O ( divide start_ARG 1 end_ARG start_ARG italic_L start_POSTSUPERSCRIPT 4.5 end_POSTSUPERSCRIPT roman_log start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ( italic_m ) end_ARG ) ], and the 𝛉(0)superscript𝛉0{\bm{\theta}}^{(0)}bold_italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT being initialized randomly as described in assumption I.1. With probability as least 1exp(Ω(mω3/2L))1normal-Ω𝑚superscript𝜔32𝐿1-\exp{(-\Omega(m\omega^{3/2}L))}1 - roman_exp ( - roman_Ω ( italic_m italic_ω start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT italic_L ) ) over the initialization, we have for all 𝛉,𝛉(m×m)L𝛉superscript𝛉normal-′superscriptsuperscript𝑚𝑚𝐿{\bm{\theta}},{\bm{\theta}}^{\prime}\in\left({\mathbb{R}}^{m\times m}\right)^{L}bold_italic_θ , bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ ( blackboard_R start_POSTSUPERSCRIPT italic_m × italic_m end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT with 𝛉𝛉(0)2ωsubscriptdelimited-∥∥𝛉superscript𝛉02𝜔\left\lVert{\bm{\theta}}-{\bm{\theta}}^{(0)}\right\rVert_{2}\leq\omega∥ bold_italic_θ - bold_italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_ω, and 𝛉𝛉2ωsubscriptdelimited-∥∥𝛉superscript𝛉normal-′2𝜔\left\lVert{\bm{\theta}}-{\bm{\theta}}^{\prime}\right\rVert_{2}\leq\omega∥ bold_italic_θ - bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_ω

R(𝑳(h(𝜽))\displaystyle R({\bm{L}}(h({\bm{\theta}}^{\prime}))italic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) )R(𝑳(h(𝜽))+𝜽(R𝑳h)(𝜽),𝜽𝜽\displaystyle\leq R({\bm{L}}(h({\bm{\theta}}))+\langle\nabla_{{\bm{\theta}}}(R% \circ{\bm{L}}\circ h)({\bm{\theta}}),{\bm{\theta}}^{\prime}-{\bm{\theta}}\rangle≤ italic_R ( bold_italic_L ( italic_h ( bold_italic_θ ) ) + ⟨ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ ) , bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_θ ⟩(52)
+h(R𝑳)(h(𝜽))2,1O(L2ω1/3mlog(m)d)𝜽𝜽2,subscriptdelimited-∥∥subscript𝑅𝑳𝜽21𝑂superscript𝐿2superscript𝜔13𝑚𝑚𝑑subscriptdelimited-∥∥superscript𝜽𝜽2\displaystyle+\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}))\right% \rVert_{2,1}O\left(\frac{L^{2}\omega^{1/3}\sqrt{m\log(m)}}{\sqrt{d}}\right)% \left\lVert{\bm{\theta}}^{\prime}-{\bm{\theta}}\right\rVert_{2,\infty}+ ∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 , 1 end_POSTSUBSCRIPT italic_O ( divide start_ARG italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_ω start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT square-root start_ARG italic_m roman_log ( italic_m ) end_ARG end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) ∥ bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_θ ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT
+O((βC()2nρ+2C())nL2md)𝜽𝜽2,2𝑂𝛽𝐶superscript2𝑛𝜌2𝐶𝑛superscript𝐿2𝑚𝑑superscriptsubscriptdelimited-∥∥superscript𝜽𝜽22\displaystyle+O\left(\left(\frac{{\beta}C(\operatorname*{\mathcal{L}})^{2}}{n% \rho}+2C(\nabla\operatorname*{\mathcal{L}})\right)\frac{nL^{2}m}{d}\right)% \left\lVert{\bm{\theta}}^{\prime}-{\bm{\theta}}\right\rVert_{2,\infty}^{2}+ italic_O ( ( divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + 2 italic_C ( ∇ caligraphic_L ) ) divide start_ARG italic_n italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_m end_ARG start_ARG italic_d end_ARG ) ∥ bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_θ ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

where for all layer l{1,,L}𝑙1𝐿l\in\{1,\ldots,L\}italic_l ∈ { 1 , … , italic_L }, 𝜽lsubscript𝜽𝑙{\bm{\theta}}_{l}bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT is the vector of parameters for layer l𝑙litalic_l, and

𝜽𝜽2,subscriptdelimited-∥∥superscript𝜽𝜽2\displaystyle\left\lVert{\bm{\theta}}^{\prime}-{\bm{\theta}}\right\rVert_{2,\infty}∥ bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_θ ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT=maxl𝜽l𝜽l2\displaystyle=\max_{l}\left\lVert{\bm{\theta}}_{l}^{\prime}-{\bm{\theta}}_{l}% \right\rVert_{2}= roman_max start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∥ bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
𝜽𝜽2,2superscriptsubscriptdelimited-∥∥superscript𝜽𝜽22\displaystyle\left\lVert{\bm{\theta}}^{\prime}-{\bm{\theta}}\right\rVert_{2,% \infty}^{2}∥ bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_θ ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT=(maxl𝜽l𝜽l22)2=maxl𝜽l𝜽l22\displaystyle=\left(\max_{l}\left\lVert{\bm{\theta}}_{l}^{\prime}-{\bm{\theta}% }_{l}\right\rVert_{2}^{2}\right)^{2}=\max_{l}\left\lVert{\bm{\theta}}_{l}^{% \prime}-{\bm{\theta}}_{l}\right\rVert_{2}^{2}= ( roman_max start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∥ bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = roman_max start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∥ bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
h(R𝑳)(h(𝜽))2,1subscriptdelimited-∥∥subscript𝑅𝑳𝜽21\displaystyle\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}))\right% \rVert_{2,1}∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 , 1 end_POSTSUBSCRIPT=i=1nhi(R𝑳)(h(𝜽))2absentsuperscriptsubscript𝑖1𝑛subscriptdelimited-∥∥subscriptsubscript𝑖𝑅𝑳𝜽2\displaystyle=\sum_{i=1}^{n}\left\lVert\nabla_{h_{i}}(R\circ{\bm{L}})(h({\bm{% \theta}}))\right\rVert_{2}= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
=i=1np¯i(𝑳(h(𝜽)))hii(hi(𝜽))2 (chain rule (45))absentsuperscriptsubscript𝑖1𝑛subscriptdelimited-∥∥subscript¯𝑝𝑖𝑳𝜽subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽2 (chain rule (45))\displaystyle=\sum_{i=1}^{n}\left\lVert\bar{p}_{i}({\bm{L}}(h({\bm{\theta}})))% \nabla_{h_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}))\right% \rVert_{2}\quad\text{ (chain rule \eqref{eq:reminder_chain_rule_h}) }= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( italic_h ( bold_italic_θ ) ) ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT (chain rule ( ))

To compare this semi-smoothness result to the one in (Allen-Zhu et al., 2019a, Theorem 4), let us first remark that

h(R𝑳)(h(𝜽))2,1subscriptdelimited-∥∥subscript𝑅𝑳𝜽21\displaystyle\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}))\right% \rVert_{2,1}∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 , 1 end_POSTSUBSCRIPTnh(R𝑳)(h(𝜽))2,2absent𝑛subscriptdelimited-∥∥subscript𝑅𝑳𝜽22\displaystyle\leq\sqrt{n}\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}% ))\right\rVert_{2,2}≤ square-root start_ARG italic_n end_ARG ∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT

As a result, our result is analogous to (Allen-Zhu et al., 2019a, Theorem 4), up to an additional multiplicative factor (βC()2nρ+2C())𝛽𝐶superscript2𝑛𝜌2𝐶\left(\frac{{\beta}C(\operatorname*{\mathcal{L}})^{2}}{n\rho}+2C(\nabla% \operatorname*{\mathcal{L}})\right)( divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + 2 italic_C ( ∇ caligraphic_L ) ) in the last term of the right-hand side. It is worth noting that there is also implicitly an additional multiplicative factor C()𝐶C(\nabla\operatorname*{\mathcal{L}})italic_C ( ∇ caligraphic_L ) in Theorem 3 of (Allen-Zhu et al., 2019a) since (Allen-Zhu et al., 2019a) make the assumption that C()=1𝐶1C(\nabla\operatorname*{\mathcal{L}})=1italic_C ( ∇ caligraphic_L ) = 1 (see Allen-Zhu et al., 2019a, Appendix A).

Let 𝜽,𝜽(m×m)L𝜽superscript𝜽superscriptsuperscript𝑚𝑚𝐿{\bm{\theta}},{\bm{\theta}}^{\prime}\in\left({\mathbb{R}}^{m\times m}\right)^{L}bold_italic_θ , bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ ( blackboard_R start_POSTSUPERSCRIPT italic_m × italic_m end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT verifying the conditions of Theorem 13.

Let A=R(𝑳(h(𝜽))R(𝑳(h(𝜽))𝜽(R𝑳h)(𝜽),𝜽𝜽A=R({\bm{L}}(h({\bm{\theta}}^{\prime}))-R({\bm{L}}(h({\bm{\theta}}))-\langle% \nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}}),{\bm{\theta}}^{% \prime}-{\bm{\theta}}\rangleitalic_A = italic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) - italic_R ( bold_italic_L ( italic_h ( bold_italic_θ ) ) - ⟨ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ ) , bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_θ ⟩ , the quantity we want to bound.

Using (51) for 𝒛=h(𝜽)𝒛𝜽{\bm{z}}=h({\bm{\theta}})bold_italic_z = italic_h ( bold_italic_θ ) and 𝒛=h(𝜽)superscript𝒛superscript𝜽{\bm{z}}^{\prime}=h({\bm{\theta}}^{\prime})bold_italic_z start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ), we obtain

A𝐴\displaystyle Aitalic_A12(βC()2nρ+2C())h(𝜽)h(𝜽)22absent12𝛽𝐶superscript2𝑛𝜌2𝐶superscriptsubscriptdelimited-∥∥superscript𝜽𝜽22\displaystyle\leq\frac{1}{2}\left(\frac{{\beta}C(\operatorname*{\mathcal{L}})^% {2}}{n\rho}+2C(\nabla\operatorname*{\mathcal{L}})\right)\left\lVert h({\bm{% \theta}}^{\prime})-h({\bm{\theta}})\right\rVert_{2}^{2}≤ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + 2 italic_C ( ∇ caligraphic_L ) ) ∥ italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) - italic_h ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(53)
+h(R𝑳)(h(𝜽)),h(𝜽)h(𝜽)subscript𝑅𝑳𝜽superscript𝜽𝜽\displaystyle\quad+\langle\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}})),h({\bm{% \theta}}^{\prime})-h({\bm{\theta}})\rangle+ ⟨ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) ) , italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) - italic_h ( bold_italic_θ ) ⟩
𝜽(R𝑳h)(𝜽),𝜽𝜽subscript𝜽𝑅𝑳𝜽superscript𝜽𝜽\displaystyle\quad-\langle\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{% \theta}}),{\bm{\theta}}^{\prime}-{\bm{\theta}}\rangle- ⟨ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ ) , bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_θ ⟩

Then using the chain rule (46)

A𝐴\displaystyle Aitalic_A12(βC()2nρ+2C())h(𝜽)h(𝜽)22absent12𝛽𝐶superscript2𝑛𝜌2𝐶superscriptsubscriptdelimited-∥∥superscript𝜽𝜽22\displaystyle\leq\frac{1}{2}\left(\frac{{\beta}C(\operatorname*{\mathcal{L}})^% {2}}{n\rho}+2C(\nabla\operatorname*{\mathcal{L}})\right)\left\lVert h({\bm{% \theta}}^{\prime})-h({\bm{\theta}})\right\rVert_{2}^{2}≤ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ( divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + 2 italic_C ( ∇ caligraphic_L ) ) ∥ italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) - italic_h ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(54)
+i=1nhi(R𝑳)(h(𝜽)),hi(𝜽)hi(𝜽)(𝜽hi(𝜽))T(𝜽𝜽)superscriptsubscript𝑖1𝑛subscriptsubscript𝑖𝑅𝑳𝜽subscript𝑖superscript𝜽subscript𝑖𝜽superscriptsubscript𝜽subscript𝑖𝜽𝑇superscript𝜽𝜽\displaystyle\quad+\sum_{i=1}^{n}\langle\nabla_{h_{i}}(R\circ{\bm{L}})(h({\bm{% \theta}})),h_{i}({\bm{\theta}}^{\prime})-h_{i}({\bm{\theta}})-\left(\nabla_{{% \bm{\theta}}}h_{i}({\bm{\theta}})\right)^{T}({\bm{\theta}}^{\prime}-{\bm{% \theta}})\rangle+ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ⟨ ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) ) , italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) - italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) - ( ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT - bold_italic_θ ) ⟩

For all i{1,,n}𝑖1𝑛i\in\{1,\ldots,n\}italic_i ∈ { 1 , … , italic_n }, let us denote loss˘i:=hi(R𝑳)(h(𝜽))assignsubscript˘𝑙𝑜𝑠𝑠𝑖subscriptsubscript𝑖𝑅𝑳𝜽\breve{loss}_{i}:=\nabla_{h_{i}}(R\circ{\bm{L}})(h({\bm{\theta}}))over˘ start_ARG italic_l italic_o italic_s italic_s end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT := ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) ) to match the notations used in (Allen-Zhu et al., 2019a) for the derivative of the loss with respect to the output of the network for example i of the training set.

With this notation, we obtain exactly equation (11.3) in (Allen-Zhu et al., 2019a) up to the multiplicative factor (βC()2nρ+2C())𝛽𝐶superscript2𝑛𝜌2𝐶\left(\frac{{\beta}C(\operatorname*{\mathcal{L}})^{2}}{n\rho}+2C(\nabla% \operatorname*{\mathcal{L}})\right)( divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + 2 italic_C ( ∇ caligraphic_L ) ) for the distributionally robust loss.

From there the proof of Theorem 4 in (Allen-Zhu et al., 2019a) being independent to the formula for loss˘isubscript˘𝑙𝑜𝑠𝑠𝑖\breve{loss}_{i}over˘ start_ARG italic_l italic_o italic_s italic_s end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, we can conclude the proof of our Theorem 13 as in (Allen-Zhu et al., 2019a, Appendix A).

I.3.3 Gradient bounds for the distributionally robust loss

We prove the following lemma which is a generalization of Theorem 3 in (Allen-Zhu et al., 2019a) for the distributionally robust loss.

Theorem 14 (Gradient Bounds for the Distributionally Robust Loss)

Let ωO(δ3/2n9/2L6log3(m))𝜔𝑂superscript𝛿32superscript𝑛92superscript𝐿6superscript3𝑚\omega\in O\left(\frac{\delta^{3/2}}{n^{9/2}L^{6}\log^{3}(m)}\right)italic_ω ∈ italic_O ( divide start_ARG italic_δ start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n start_POSTSUPERSCRIPT 9 / 2 end_POSTSUPERSCRIPT italic_L start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT roman_log start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ( italic_m ) end_ARG ), and 𝛉(0)superscript𝛉0{\bm{\theta}}^{(0)}bold_italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT being initialized randomly as described in assumption I.1. With probability as least 1exp(Ω(mω3/2L))1normal-Ω𝑚superscript𝜔32𝐿1-\exp{(-\Omega(m\omega^{3/2}L))}1 - roman_exp ( - roman_Ω ( italic_m italic_ω start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT italic_L ) ) over the initialization, we have for all 𝛉(m×m)L𝛉superscriptsuperscript𝑚𝑚𝐿{\bm{\theta}}\in\left({\mathbb{R}}^{m\times m}\right)^{L}bold_italic_θ ∈ ( blackboard_R start_POSTSUPERSCRIPT italic_m × italic_m end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT with 𝛉𝛉(0)2ωsubscriptdelimited-∥∥𝛉superscript𝛉02𝜔\left\lVert{\bm{\theta}}-{\bm{\theta}}^{(0)}\right\rVert_{2}\leq\omega∥ bold_italic_θ - bold_italic_θ start_POSTSUPERSCRIPT ( 0 ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_ω

i{1,,n},l{1,,L},𝑳^nformulae-sequencefor-all𝑖1𝑛formulae-sequencefor-all𝑙1𝐿for-all^𝑳superscript𝑛\displaystyle\forall i\in\{1,\ldots,n\},\,\,\forall l\in\{1,\ldots,L\},\,\,% \forall\hat{{\bm{L}}}\in{\mathbb{R}}^{n}∀ italic_i ∈ { 1 , … , italic_n } , ∀ italic_l ∈ { 1 , … , italic_L } , ∀ over^ start_ARG bold_italic_L end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT(55)
p¯i(𝑳^)𝜽l(ihi)(𝜽)22O(mdp¯i(𝑳^)hii(hi(𝜽))22)superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝜽𝑙subscript𝑖subscript𝑖𝜽22𝑂𝑚𝑑superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽22\displaystyle\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{{\bm{\theta}}_{l}}(% {\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})\right\rVert_{2}% ^{2}\leq O\left(\frac{m}{d}\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{i}% }{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}))\right\rVert_{2}^{2}\right)∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_O ( divide start_ARG italic_m end_ARG start_ARG italic_d end_ARG ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
l{1,,L},𝑳^nformulae-sequencefor-all𝑙1𝐿for-all^𝑳superscript𝑛\displaystyle\forall l\in\{1,\ldots,L\},\,\,\forall\hat{{\bm{L}}}\in{\mathbb{R% }}^{n}∀ italic_l ∈ { 1 , … , italic_L } , ∀ over^ start_ARG bold_italic_L end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT
i=1np¯i(𝑳^)𝜽l(ihi)(𝜽)22O(mndi=1np¯i(𝑳^)hii(hi(𝜽))22)superscriptsubscriptdelimited-∥∥superscriptsubscript𝑖1𝑛subscript¯𝑝𝑖^𝑳subscriptsubscript𝜽𝑙subscript𝑖subscript𝑖𝜽22𝑂𝑚𝑛𝑑superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽22\displaystyle\left\lVert\sum_{i=1}^{n}\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{{\bm{% \theta}}_{l}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})% \right\rVert_{2}^{2}\leq O\left(\frac{mn}{d}\sum_{i=1}^{n}\left\lVert\bar{p}_{% i}(\hat{{\bm{L}}})\nabla_{h_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm% {\theta}}))\right\rVert_{2}^{2}\right)∥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_O ( divide start_ARG italic_m italic_n end_ARG start_ARG italic_d end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
i=1np¯i(𝑳^)𝜽L(ihi)(𝜽)22Ω(mδdn2i=1np¯i(𝑳^)hii(hi(𝜽))22)superscriptsubscriptdelimited-∥∥superscriptsubscript𝑖1𝑛subscript¯𝑝𝑖^𝑳subscriptsubscript𝜽𝐿subscript𝑖subscript𝑖𝜽22Ω𝑚𝛿𝑑superscript𝑛2superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽22\displaystyle\left\lVert\sum_{i=1}^{n}\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{{\bm{% \theta}}_{L}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})% \right\rVert_{2}^{2}\geq\Omega\left(\frac{m\delta}{dn^{2}}\sum_{i=1}^{n}\left% \lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{i}}{\color[rgb]{0,0,0}\mathcal{L}_% {i}}(h_{i}({\bm{\theta}}))\right\rVert_{2}^{2}\right)∥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ roman_Ω ( divide start_ARG italic_m italic_δ end_ARG start_ARG italic_d italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )

It is worth noting that the loss vector 𝑳^^𝑳\hat{{\bm{L}}}over^ start_ARG bold_italic_L end_ARG used for computing the robust probabilities 𝒑¯(𝑳^)=(p¯i(𝑳^))i=1n¯𝒑^𝑳superscriptsubscriptsubscript¯𝑝𝑖^𝑳𝑖1𝑛\bar{{\bm{p}}}(\hat{{\bm{L}}})=\left(\bar{p}_{i}(\hat{{\bm{L}}})\right)_{i=1}^% {n}over¯ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) = ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT does not have to be equal to 𝑳(h(𝜽))𝑳𝜽{\bm{L}}(h({\bm{\theta}}))bold_italic_L ( italic_h ( bold_italic_θ ) ).

We will use this for the proof of the Robust SGD with stale per-example loss vector.

The adaptation of the proof of Theorem 3 in (Allen-Zhu et al., 2019a) is straightforward.

Let 𝜽(m×m)L𝜽superscriptsuperscript𝑚𝑚𝐿{\bm{\theta}}\in\left({\mathbb{R}}^{m\times m}\right)^{L}bold_italic_θ ∈ ( blackboard_R start_POSTSUPERSCRIPT italic_m × italic_m end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT satisfying the conditions of Theorem 14, and 𝑳^n^𝑳superscript𝑛\hat{{\bm{L}}}\in{\mathbb{R}}^{n}over^ start_ARG bold_italic_L end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT.

Let us denote 𝒗:=(p¯i(𝑳^)hii(hi(𝜽)))i=1nassign𝒗superscriptsubscriptsubscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽𝑖1𝑛{\bm{v}}:=\left(\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{i}}\operatorname*{% \mathcal{L}}_{i}(h_{i}({\bm{\theta}}))\right)_{i=1}^{n}bold_italic_v := ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, applying the proof of Theorem 3 in (Allen-Zhu et al., 2019a) to our 𝒗𝒗{\bm{v}}bold_italic_v gives:

i{1,,n},l{1,,L},formulae-sequencefor-all𝑖1𝑛for-all𝑙1𝐿\displaystyle\forall i\in\{1,\ldots,n\},\,\,\forall l\in\{1,\ldots,L\},∀ italic_i ∈ { 1 , … , italic_n } , ∀ italic_l ∈ { 1 , … , italic_L } ,
p¯i(𝑳^)𝜽l(ihi)(𝜽)22O(mdp¯i(𝑳^)hii(hi(𝜽))22)superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝜽𝑙subscript𝑖subscript𝑖𝜽22𝑂𝑚𝑑superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽22\displaystyle\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{{\bm{\theta}}_{l}}(% {\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})\right\rVert_{2}% ^{2}\leq O\left(\frac{m}{d}\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{i}% }{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}))\right\rVert_{2}^{2}\right)∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_O ( divide start_ARG italic_m end_ARG start_ARG italic_d end_ARG ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
l{1,,L},𝑳^nformulae-sequencefor-all𝑙1𝐿for-all^𝑳superscript𝑛\displaystyle\forall l\in\{1,\ldots,L\},\,\,\forall\hat{{\bm{L}}}\in{\mathbb{R% }}^{n}∀ italic_l ∈ { 1 , … , italic_L } , ∀ over^ start_ARG bold_italic_L end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT
i=1np¯i(𝑳^)𝜽l(ihi)(𝜽)22O(mndi=1np¯i(𝑳^)hii(hi(𝜽))22)superscriptsubscriptdelimited-∥∥superscriptsubscript𝑖1𝑛subscript¯𝑝𝑖^𝑳subscriptsubscript𝜽𝑙subscript𝑖subscript𝑖𝜽22𝑂𝑚𝑛𝑑superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽22\displaystyle\left\lVert\sum_{i=1}^{n}\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{{\bm{% \theta}}_{l}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})% \right\rVert_{2}^{2}\leq O\left(\frac{mn}{d}\sum_{i=1}^{n}\left\lVert\bar{p}_{% i}(\hat{{\bm{L}}})\nabla_{h_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm% {\theta}}))\right\rVert_{2}^{2}\right)∥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_O ( divide start_ARG italic_m italic_n end_ARG start_ARG italic_d end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
i=1np¯i(𝑳^)𝜽L(ihi)(𝜽)22Ω(mδdnmaxi(p¯i(𝑳^)hii(hi(𝜽))22))superscriptsubscriptdelimited-∥∥superscriptsubscript𝑖1𝑛subscript¯𝑝𝑖^𝑳subscriptsubscript𝜽𝐿subscript𝑖subscript𝑖𝜽22Ω𝑚𝛿𝑑𝑛subscript𝑖superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽22\displaystyle\left\lVert\sum_{i=1}^{n}\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{{\bm{% \theta}}_{L}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})% \right\rVert_{2}^{2}\geq\Omega\left(\frac{m\delta}{dn}\max_{i}\left(\left% \lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{i}}{\color[rgb]{0,0,0}\mathcal{L}_% {i}}(h_{i}({\bm{\theta}}))\right\rVert_{2}^{2}\right)\right)∥ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ roman_Ω ( divide start_ARG italic_m italic_δ end_ARG start_ARG italic_d italic_n end_ARG roman_max start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) )

In addition

maxi(p¯i(𝑳^)hii(hi(𝜽))22)1ni=1np¯i(𝑳^)hii(hi(𝜽))22subscript𝑖superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽221𝑛superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽22\max_{i}\left(\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{i}}{\color[rgb]% {0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}))\right\rVert_{2}^{2}\right)\geq% \frac{1}{n}\sum_{i=1}^{n}\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{i}}{% \color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}))\right\rVert_{2}^{2}roman_max start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ≥ divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

This allows us to conclude the proof of our Theorem 14. \blacksquare

I.3.4 Convergence of SGD with Hardness Weighted Sampling and exact per-example loss vector

We can now prove Theorem 11.

Similarly to the proof of the convergence of SGD for the mean loss (Theorem 2 in (Allen-Zhu et al., 2019a)), the convergence of SGD for the distributionally robust loss will mainly rely on the semi-smoothness property (Theorem 13) and the gradient bound (Theorem 14) that we have proved previously for the distributionally robust loss.

Let 𝜽(m×m)L𝜽superscriptsuperscript𝑚𝑚𝐿{\bm{\theta}}\in\left({\mathbb{R}}^{m\times m}\right)^{L}bold_italic_θ ∈ ( blackboard_R start_POSTSUPERSCRIPT italic_m × italic_m end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT satisfying the conditions of Theorem 11, and 𝑳^^𝑳\hat{{\bm{L}}}over^ start_ARG bold_italic_L end_ARG be the exact per-example loss vector at 𝜽𝜽{\bm{\theta}}bold_italic_θ, i.e.

𝑳^=(i(hi(𝜽)))i=1n^𝑳superscriptsubscriptsubscript𝑖subscript𝑖𝜽𝑖1𝑛\hat{{\bm{L}}}=\left({\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}))% \right)_{i=1}^{n}over^ start_ARG bold_italic_L end_ARG = ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT(56)

For the batch size b{1,,n}𝑏1𝑛b\in\{1,\ldots,n\}italic_b ∈ { 1 , … , italic_n }, let S={ij}j=1b𝑆superscriptsubscriptsubscript𝑖𝑗𝑗1𝑏S=\{i_{j}\}_{j=1}^{b}italic_S = { italic_i start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT a batch of indices drawn from 𝒑¯(𝑳^)¯𝒑^𝑳\bar{{\bm{p}}}(\hat{{\bm{L}}})over¯ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) without replacement, i.e.

j{1,b},iji.i.d.𝒑¯(𝑳^)for-all𝑗1𝑏subscript𝑖𝑗i.i.d.similar-to¯𝒑^𝑳\forall j\in\{1,\ldots b\},\,\,i_{j}\overset{\text{i.i.d.}}{\sim}\bar{{\bm{p}}% }(\hat{{\bm{L}}})∀ italic_j ∈ { 1 , … italic_b } , italic_i start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT overi.i.d. start_ARG ∼ end_ARG over¯ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG )(57)

Let 𝜽(m×m)Lsuperscript𝜽superscriptsuperscript𝑚𝑚𝐿{\bm{\theta}}^{\prime}\in\left({\mathbb{R}}^{m\times m}\right)^{L}bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ ( blackboard_R start_POSTSUPERSCRIPT italic_m × italic_m end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT be the values of the parameters after a stochastic gradient descent step at 𝜽𝜽{\bm{\theta}}bold_italic_θ for the batch S𝑆Sitalic_S, i.e.

𝜽=𝜽η1biS𝜽(ihi)(𝜽)superscript𝜽𝜽𝜂1𝑏subscript𝑖𝑆subscript𝜽subscript𝑖subscript𝑖𝜽{\bm{\theta}}^{\prime}={\bm{\theta}}-\eta\frac{1}{b}\sum_{i\in S}\nabla_{{\bm{% \theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = bold_italic_θ - italic_η divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ )(58)

where η>0𝜂0\eta>0italic_η > 0 is the learning rate.

Assuming that 𝜽𝜽{\bm{\theta}}bold_italic_θ and 𝜽superscript𝜽{\bm{\theta}}^{\prime}bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT satisfies the conditions of Theorem 13, we obtain

R(𝑳(h(𝜽))\displaystyle R({\bm{L}}(h({\bm{\theta}}^{\prime}))\leqitalic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ≤R(𝑳(h(𝜽))η𝜽(R𝑳h)(𝜽),1biS𝜽(ihi)(𝜽)\displaystyle R({\bm{L}}(h({\bm{\theta}}))-\eta\langle\nabla_{{\bm{\theta}}}(R% \circ{\bm{L}}\circ h)({\bm{\theta}}),\frac{1}{b}\sum_{i\in S}\nabla_{{\bm{% \theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})\rangleitalic_R ( bold_italic_L ( italic_h ( bold_italic_θ ) ) - italic_η ⟨ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ ) , divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ⟩(59)
+ηnh(R𝑳)(h(𝜽))2,2O(L2ω1/3mlog(m)d)1biS𝜽(ihi)(𝜽)2,𝜂𝑛subscriptdelimited-∥∥subscript𝑅𝑳𝜽22𝑂superscript𝐿2superscript𝜔13𝑚𝑚𝑑subscriptdelimited-∥∥1𝑏subscript𝑖𝑆subscript𝜽subscript𝑖subscript𝑖𝜽2\displaystyle+\eta\sqrt{n}\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}% }))\right\rVert_{2,2}O\left(\frac{L^{2}\omega^{1/3}\sqrt{m\log(m)}}{\sqrt{d}}% \right)\left\lVert\frac{1}{b}\sum_{i\in S}\nabla_{{\bm{\theta}}}({\color[rgb]{% 0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})\right\rVert_{2,\infty}+ italic_η square-root start_ARG italic_n end_ARG ∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT italic_O ( divide start_ARG italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_ω start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT square-root start_ARG italic_m roman_log ( italic_m ) end_ARG end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) ∥ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT
+η2O((βC()2nρ+2C())nL2md)1biS𝜽(ihi)(𝜽)2,2superscript𝜂2𝑂𝛽𝐶superscript2𝑛𝜌2𝐶𝑛superscript𝐿2𝑚𝑑superscriptsubscriptdelimited-∥∥1𝑏subscript𝑖𝑆subscript𝜽subscript𝑖subscript𝑖𝜽22\displaystyle+\eta^{2}O\left(\left(\frac{{\beta}C(\operatorname*{\mathcal{L}})% ^{2}}{n\rho}+2C(\nabla\operatorname*{\mathcal{L}})\right)\frac{nL^{2}m}{d}% \right)\left\lVert\frac{1}{b}\sum_{i\in S}\nabla_{{\bm{\theta}}}({\color[rgb]{% 0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})\right\rVert_{2,\infty}^{2}+ italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_O ( ( divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + 2 italic_C ( ∇ caligraphic_L ) ) divide start_ARG italic_n italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_m end_ARG start_ARG italic_d end_ARG ) ∥ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

where we refer to (46) for the form of 𝜽(R𝑳h)(𝜽)subscript𝜽𝑅𝑳𝜽\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}})∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ ) and to (45) for the form of h(R𝑳)(h(𝜽))subscript𝑅𝑳𝜽\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}))∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) ).

In addition, we make the assumption that for the set of values of 𝜽𝜽{\bm{\theta}}bold_italic_θ considered the hardness weighted sampling probabilities admit an upper-bound

α=min𝜽minip¯i(𝑳(𝜽))>0𝛼subscript𝜽subscript𝑖subscript¯𝑝𝑖𝑳𝜽0\alpha=\min_{{\bm{\theta}}}\min_{i}\bar{p}_{i}({\bm{L}}({\bm{\theta}}))>0italic_α = roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_θ ) ) > 0(60)

Which is always satisfied under assumption I.2 for Kullback-Leibler ϕitalic-ϕ\phiitalic_ϕ-divergence, and for any ϕitalic-ϕ\phiitalic_ϕ-divergence satisfying Definition 2 with a robustness parameter β𝛽{\beta}italic_β small enough.

Let 𝔼Ssubscript𝔼𝑆\mathbb{E}_{S}blackboard_E start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT be the expectation with respect to S𝑆Sitalic_S. Applying 𝔼Ssubscript𝔼𝑆\mathbb{E}_{S}blackboard_E start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT to (59), we obtain

𝔼Ssubscript𝔼𝑆\displaystyle\mathbb{E}_{S}blackboard_E start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT[R(𝑳(h(𝜽))]\displaystyle\left[R({\bm{L}}(h({\bm{\theta}}^{\prime}))\right][ italic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ](61)
\displaystyle\leqR(𝑳(h(𝜽))η𝜽(R𝑳h)(𝜽)2,22\displaystyle R({\bm{L}}(h({\bm{\theta}}))-\eta\left\lVert\nabla_{{\bm{\theta}% }}(R\circ{\bm{L}}\circ h)({\bm{\theta}})\right\rVert_{2,2}^{2}italic_R ( bold_italic_L ( italic_h ( bold_italic_θ ) ) - italic_η ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
+ηh(R𝑳)(h(𝜽))2,2O(nL2ω1/3mlog(m)d)i=1nmaxlp¯i(𝑳^)𝜽l(ihi)(𝜽)2\displaystyle+\eta\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}))% \right\rVert_{2,2}O\left(\frac{nL^{2}\omega^{1/3}\sqrt{m\log(m)}}{\sqrt{d}}% \right)\sqrt{\sum_{i=1}^{n}\max_{l}\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})% \nabla_{{\bm{\theta}}_{l}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({% \bm{\theta}})\right\rVert^{2}}+ italic_η ∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT italic_O ( divide start_ARG italic_n italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_ω start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT square-root start_ARG italic_m roman_log ( italic_m ) end_ARG end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_max start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
+η2O((βC()2nρ+2C())nL2md)1αi=1nmaxlp¯i(𝑳^)𝜽l(ihi)(𝜽)2\displaystyle+\eta^{2}O\left(\left(\frac{{\beta}C(\operatorname*{\mathcal{L}})% ^{2}}{n\rho}+2C(\nabla\operatorname*{\mathcal{L}})\right)\frac{nL^{2}m}{d}% \right)\frac{1}{\alpha}\sum_{i=1}^{n}\max_{l}\left\lVert\bar{p}_{i}(\hat{{\bm{% L}}})\nabla_{{\bm{\theta}}_{l}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i}% )({\bm{\theta}})\right\rVert^{2}+ italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_O ( ( divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + 2 italic_C ( ∇ caligraphic_L ) ) divide start_ARG italic_n italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_m end_ARG start_ARG italic_d end_ARG ) divide start_ARG 1 end_ARG start_ARG italic_α end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_max start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

where we have used the following results:

  • For any integer k1𝑘1k\geq 1italic_k ≥ 1, and all (𝒂i)i=1n(k)nsuperscriptsubscriptsubscript𝒂𝑖𝑖1𝑛superscriptsuperscript𝑘𝑛\left({\bm{a}}_{i}\right)_{i=1}^{n}\in\left({\mathbb{R}}^{k}\right)^{n}( bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ ( blackboard_R start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, we have (see the proof in I.3.5)

    𝔼S[1biS𝒂i]subscript𝔼𝑆delimited-[]1𝑏subscript𝑖𝑆subscript𝒂𝑖\displaystyle\mathbb{E}_{S}\left[\frac{1}{b}\sum_{i\in S}{\bm{a}}_{i}\right]blackboard_E start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ]=𝔼p¯(𝑳^)[𝒂i]absentsubscript𝔼¯𝑝^𝑳delimited-[]subscript𝒂𝑖\displaystyle=\mathbb{E}_{\bar{p}(\hat{{\bm{L}}})}\left[{\bm{a}}_{i}\right]= blackboard_E start_POSTSUBSCRIPT over¯ start_ARG italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) end_POSTSUBSCRIPT [ bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ](62)
  • Using (62) for (𝒂i)i=1n=(𝜽(ihi)(𝜽))i=1nsuperscriptsubscriptsubscript𝒂𝑖𝑖1𝑛superscriptsubscriptsubscript𝜽subscript𝑖subscript𝑖𝜽𝑖1𝑛\left({\bm{a}}_{i}\right)_{i=1}^{n}=\left(\nabla_{{\bm{\theta}}}(\operatorname% *{\mathcal{L}}_{i}\circ h_{i})({\bm{\theta}})\right)_{i=1}^{n}( bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT = ( ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, and the chain rule (46)

    𝔼S[1biS𝜽(ihi)(𝜽)]=i=1np¯i(𝑳^)𝜽(ihi)(𝜽)=𝜽(R𝑳h)(𝜽)subscript𝔼𝑆delimited-[]1𝑏subscript𝑖𝑆subscript𝜽subscript𝑖subscript𝑖𝜽superscriptsubscript𝑖1𝑛subscript¯𝑝𝑖^𝑳subscript𝜽subscript𝑖subscript𝑖𝜽subscript𝜽𝑅𝑳𝜽\mathbb{E}_{S}\left[\frac{1}{b}\sum_{i\in S}\nabla_{{\bm{\theta}}}({\color[rgb% ]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})\right]=\sum_{i=1}^{n}\bar{% p}_{i}(\hat{{\bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i% }}\circ h_{i})({\bm{\theta}})=\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({% \bm{\theta}})blackboard_E start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ] = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) = ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ )(63)
  • Using the triangular inequality

    1biS𝜽(ihi)(𝜽)2,subscriptdelimited-∥∥1𝑏subscript𝑖𝑆subscript𝜽subscript𝑖subscript𝑖𝜽2\displaystyle\left\lVert\frac{1}{b}\sum_{i\in S}\nabla_{{\bm{\theta}}}({\color% [rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})\right\rVert_{2,\infty}∥ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT1biS𝜽(ihi)(𝜽)2,absent1𝑏subscript𝑖𝑆subscriptdelimited-∥∥subscript𝜽subscript𝑖subscript𝑖𝜽2\displaystyle\leq\frac{1}{b}\sum_{i\in S}\left\lVert\nabla_{{\bm{\theta}}}({% \color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})\right\rVert_{2,\infty}≤ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT(64)

    And using (62) for (ai)i=1n=(𝜽(ihi)(𝜽)2,)i=1nsuperscriptsubscriptsubscript𝑎𝑖𝑖1𝑛superscriptsubscriptsubscriptdelimited-∥∥subscript𝜽subscript𝑖subscript𝑖𝜽2𝑖1𝑛\left(a_{i}\right)_{i=1}^{n}=\left(\left\lVert\nabla_{{\bm{\theta}}}(% \operatorname*{\mathcal{L}}_{i}\circ h_{i})({\bm{\theta}})\right\rVert_{2,% \infty}\right)_{i=1}^{n}( italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT = ( ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT,

    𝔼S[1biS𝜽(ihi)(𝜽)2,]subscript𝔼𝑆delimited-[]subscriptdelimited-∥∥1𝑏subscript𝑖𝑆subscript𝜽subscript𝑖subscript𝑖𝜽2\displaystyle\mathbb{E}_{S}\left[\left\lVert\frac{1}{b}\sum_{i\in S}\nabla_{{% \bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})% \right\rVert_{2,\infty}\right]blackboard_E start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT [ ∥ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT ]i=1np¯i(𝑳^)𝜽(ihi)(𝜽)2,absentsuperscriptsubscript𝑖1𝑛subscript¯𝑝𝑖^𝑳subscriptdelimited-∥∥subscript𝜽subscript𝑖subscript𝑖𝜽2\displaystyle\leq\sum_{i=1}^{n}\bar{p}_{i}(\hat{{\bm{L}}})\left\lVert\nabla_{{% \bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})% \right\rVert_{2,\infty}≤ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT(65)
    i=1nmaxl𝜽l(p¯i(𝑳^)ihi)(𝜽)2\displaystyle\leq\sum_{i=1}^{n}\max_{l}\left\lVert\nabla_{{\bm{\theta}}_{l}}(% \bar{p}_{i}(\hat{{\bm{L}}}){\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({% \bm{\theta}})\right\rVert_{2}≤ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_max start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT
    ni=1nmaxl𝜽l(p¯i(𝑳^)ihi)(𝜽)22\displaystyle\leq\sqrt{n}\sqrt{\sum_{i=1}^{n}\max_{l}\left\lVert\nabla_{{\bm{% \theta}}_{l}}(\bar{p}_{i}(\hat{{\bm{L}}}){\color[rgb]{0,0,0}\mathcal{L}_{i}}% \circ h_{i})({\bm{\theta}})\right\rVert_{2}^{2}}≤ square-root start_ARG italic_n end_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_max start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG

    where we have used Cauchy-Schwarz inequality for the last inequality.

  • Using (64) and the convexity of the function xx2maps-to𝑥superscript𝑥2x\mapsto x^{2}italic_x ↦ italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

    1biS𝜽(ihi)(𝜽)2,2superscriptsubscriptdelimited-∥∥1𝑏subscript𝑖𝑆subscript𝜽subscript𝑖subscript𝑖𝜽22\displaystyle\left\lVert\frac{1}{b}\sum_{i\in S}\nabla_{{\bm{\theta}}}({\color% [rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})\right\rVert_{2,\infty}% ^{2}∥ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT1biS𝜽(ihi)(𝜽)2,2absent1𝑏subscript𝑖𝑆superscriptsubscriptdelimited-∥∥subscript𝜽subscript𝑖subscript𝑖𝜽22\displaystyle\leq\frac{1}{b}\sum_{i\in S}\left\lVert\nabla_{{\bm{\theta}}}({% \color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})\right\rVert_{2,% \infty}^{2}≤ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(66)

    And using (62) for (ai)i=1n=(𝜽(ihi)(𝜽)2,2)i=1nsuperscriptsubscriptsubscript𝑎𝑖𝑖1𝑛superscriptsubscriptsuperscriptsubscriptdelimited-∥∥subscript𝜽subscript𝑖subscript𝑖𝜽22𝑖1𝑛\left(a_{i}\right)_{i=1}^{n}=\left(\left\lVert\nabla_{{\bm{\theta}}}(% \operatorname*{\mathcal{L}}_{i}\circ h_{i})({\bm{\theta}})\right\rVert_{2,% \infty}^{2}\right)_{i=1}^{n}( italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT = ( ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT,

    𝔼S[1biS𝜽(ihi)(𝜽)2,2]subscript𝔼𝑆delimited-[]superscriptsubscriptdelimited-∥∥1𝑏subscript𝑖𝑆subscript𝜽subscript𝑖subscript𝑖𝜽22\displaystyle\mathbb{E}_{S}\left[\left\lVert\frac{1}{b}\sum_{i\in S}\nabla_{{% \bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})% \right\rVert_{2,\infty}^{2}\right]blackboard_E start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT [ ∥ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]i=1np¯i(𝑳^)𝜽(ihi)(𝜽)2,2absentsuperscriptsubscript𝑖1𝑛subscript¯𝑝𝑖^𝑳superscriptsubscriptdelimited-∥∥subscript𝜽subscript𝑖subscript𝑖𝜽22\displaystyle\leq\sum_{i=1}^{n}\bar{p}_{i}(\hat{{\bm{L}}})\left\lVert\nabla_{{% \bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})% \right\rVert_{2,\infty}^{2}≤ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(67)
    i=1n1p¯i(𝑳^)maxl𝜽l(p¯i(𝑳^)ihi)(𝜽)22\displaystyle\leq\sum_{i=1}^{n}\frac{1}{\bar{p}_{i}(\hat{{\bm{L}}})}\max_{l}% \left\lVert\nabla_{{\bm{\theta}}_{l}}(\bar{p}_{i}(\hat{{\bm{L}}}){\color[rgb]{% 0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}})\right\rVert_{2}^{2}≤ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) end_ARG roman_max start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
    1αi=1nmaxl𝜽l(p¯i(𝑳^)ihi)(𝜽)22\displaystyle\leq\frac{1}{\alpha}\sum_{i=1}^{n}\max_{l}\left\lVert\nabla_{{\bm% {\theta}}_{l}}(\bar{p}_{i}(\hat{{\bm{L}}}){\color[rgb]{0,0,0}\mathcal{L}_{i}}% \circ h_{i})({\bm{\theta}})\right\rVert_{2}^{2}≤ divide start_ARG 1 end_ARG start_ARG italic_α end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_max start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
Important Remark:

It is worth noting in (67) the apparition of α𝛼\alphaitalic_α defined in (60). If we were using a uniform sampling as for ERM (i.e. for DRO in the limit β0𝛽0{\beta}\rightarrow 0italic_β → 0), we would have α=1n𝛼1𝑛\alpha=\frac{1}{n}italic_α = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG. So although our inequality (67) may seem crude, it is consistent with equation (13.2) in (Allen-Zhu et al., 2019a) and the corresponding inequality in the case of ERM.

The rest of the proof of convergence will consist in proving that η𝜽(R𝑳h)(𝜽)2,22𝜂superscriptsubscriptdelimited-∥∥subscript𝜽𝑅𝑳𝜽222\eta\left\lVert\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}})% \right\rVert_{2,2}^{2}italic_η ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT dominates the two last terms in (59). As a result, we can already state that either the robustness parameter β𝛽{\beta}italic_β, or the learning rate η𝜂\etaitalic_η will have to be small enough to control α𝛼\alphaitalic_α.

Indeed, combining (59) with the chain rule (46), and the gradient bound Theorem 14 where we use our 𝑳^^𝑳\hat{{\bm{L}}}over^ start_ARG bold_italic_L end_ARG defined in (56)

𝔼S[R(𝑳(h(𝜽))]\displaystyle\mathbb{E}_{S}\left[R({\bm{L}}(h({\bm{\theta}}^{\prime}))\right]blackboard_E start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT [ italic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ) ]R(𝑳(h(𝜽))Ω(ηmδdn2)i=1np¯i(𝑳^)hii(hi(𝜽))22\displaystyle\leq R({\bm{L}}(h({\bm{\theta}}))-\Omega\left(\frac{\eta m\delta}% {dn^{2}}\right)\sum_{i=1}^{n}\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{% i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}))\right\rVert_{2}^{2}≤ italic_R ( bold_italic_L ( italic_h ( bold_italic_θ ) ) - roman_Ω ( divide start_ARG italic_η italic_m italic_δ end_ARG start_ARG italic_d italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(68)
+ηO(nL2ω1/3mlog(m)d)O(md)i=1np¯i(𝑳^)hii(hi(𝜽))22𝜂𝑂𝑛superscript𝐿2superscript𝜔13𝑚𝑚𝑑𝑂𝑚𝑑superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽22\displaystyle+\eta O\left(\frac{nL^{2}\omega^{1/3}\sqrt{m\log(m)}}{\sqrt{d}}% \right)O\left(\sqrt{\frac{m}{d}}\right)\sum_{i=1}^{n}\left\lVert\bar{p}_{i}(% \hat{{\bm{L}}})\nabla_{h_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{% \theta}}))\right\rVert_{2}^{2}+ italic_η italic_O ( divide start_ARG italic_n italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_ω start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT square-root start_ARG italic_m roman_log ( italic_m ) end_ARG end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) italic_O ( square-root start_ARG divide start_ARG italic_m end_ARG start_ARG italic_d end_ARG end_ARG ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
+η2O((βC()2nρ+2C())nL2md)O(mdα)i=1np¯i(𝑳^)hii(hi(𝜽))22superscript𝜂2𝑂𝛽𝐶superscript2𝑛𝜌2𝐶𝑛superscript𝐿2𝑚𝑑𝑂𝑚𝑑𝛼superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽22\displaystyle+\eta^{2}O\left(\left(\frac{{\beta}C(\operatorname*{\mathcal{L}})% ^{2}}{n\rho}+2C(\nabla\operatorname*{\mathcal{L}})\right)\frac{nL^{2}m}{d}% \right)O\left(\frac{m}{d\alpha}\right)\sum_{i=1}^{n}\left\lVert\bar{p}_{i}(% \hat{{\bm{L}}})\nabla_{h_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{% \theta}}))\right\rVert_{2}^{2}+ italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_O ( ( divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + 2 italic_C ( ∇ caligraphic_L ) ) divide start_ARG italic_n italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_m end_ARG start_ARG italic_d end_ARG ) italic_O ( divide start_ARG italic_m end_ARG start_ARG italic_d italic_α end_ARG ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
R(𝑳(h(𝜽))Ω(ηmδdn2)i=1np¯i(𝑳^)hii(hi(𝜽))22\displaystyle\leq R({\bm{L}}(h({\bm{\theta}}))-\Omega\left(\frac{\eta m\delta}% {dn^{2}}\right)\sum_{i=1}^{n}\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{% i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}))\right\rVert_{2}^{2}≤ italic_R ( bold_italic_L ( italic_h ( bold_italic_θ ) ) - roman_Ω ( divide start_ARG italic_η italic_m italic_δ end_ARG start_ARG italic_d italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
+O(ηnL2mω1/3log(m)d+Kη2(n/α)L2m2d2)i=1np¯i(𝑳^)hii(hi(𝜽))22𝑂𝜂𝑛superscript𝐿2𝑚superscript𝜔13𝑚𝑑𝐾superscript𝜂2𝑛𝛼superscript𝐿2superscript𝑚2superscript𝑑2superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽22\displaystyle+O\left(\frac{\eta nL^{2}m\omega^{1/3}\sqrt{\log(m)}}{d}+K\frac{% \eta^{2}(n/\alpha)L^{2}m^{2}}{d^{2}}\right)\sum_{i=1}^{n}\left\lVert\bar{p}_{i% }(\hat{{\bm{L}}})\nabla_{h_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{% \theta}}))\right\rVert_{2}^{2}+ italic_O ( divide start_ARG italic_η italic_n italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_m italic_ω start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT square-root start_ARG roman_log ( italic_m ) end_ARG end_ARG start_ARG italic_d end_ARG + italic_K divide start_ARG italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_n / italic_α ) italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_m start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

where we have used

K:=βC()2nρ+2C()assign𝐾𝛽𝐶superscript2𝑛𝜌2𝐶K:=\frac{{\beta}C(\operatorname*{\mathcal{L}})^{2}}{n\rho}+2C(\nabla% \operatorname*{\mathcal{L}})italic_K := divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + 2 italic_C ( ∇ caligraphic_L )(69)

There are only two differences compared to equation (13.2) in (Allen-Zhu et al., 2019a):

  • in the last fraction we have n/α𝑛𝛼n/\alphaitalic_n / italic_α instead of n2superscript𝑛2n^{2}italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (see remark I.3.4 for more details), and an additional multiplicative term K𝐾Kitalic_K. So in total, this term differs by a multiplicative factor αnK𝛼𝑛𝐾\frac{\alpha n}{K}divide start_ARG italic_α italic_n end_ARG start_ARG italic_K end_ARG from the analogous term in the proof of (Allen-Zhu et al., 2019a).

  • we have i=1np¯i(𝑳^)hii(hi(𝜽))22superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖𝜽22\sum_{i=1}^{n}\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{i}}% \operatorname*{\mathcal{L}}_{i}(h_{i}({\bm{\theta}}))\right\rVert_{2}^{2}∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT instead of F(𝐖(t))𝐹superscript𝐖𝑡F(\mathbf{W}^{(t)})italic_F ( bold_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ). In fact they are analogous since in equation (13.2) in (Allen-Zhu et al., 2019a), F(𝐖(t))𝐹superscript𝐖𝑡F(\mathbf{W}^{(t)})italic_F ( bold_W start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) is the squared norm of the mean loss for the L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT loss. We don’t make such a strong assumption on the choice of \operatorname*{\mathcal{L}}caligraphic_L (see assumption I.2). It is worth noting that the same analogy is used in (Allen-Zhu et al., 2019a, Appendix A) where they extend their result to the mean loss with other objective function than the L2superscript𝐿2L^{2}italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT loss.

Our choice of learning rate in Theorem 12 can be rewritten as

ηexactsubscript𝜂𝑒𝑥𝑎𝑐𝑡\displaystyle\eta_{exact}italic_η start_POSTSUBSCRIPT italic_e italic_x italic_a italic_c italic_t end_POSTSUBSCRIPT=Θ(αn2ρβC()2+2nρC()×bδdpoly(n,L)mlog2(m))absentΘ𝛼superscript𝑛2𝜌𝛽𝐶superscript22𝑛𝜌𝐶𝑏𝛿𝑑poly𝑛𝐿𝑚superscript2𝑚\displaystyle=\Theta\left(\frac{\alpha n^{2}\rho}{{\beta}C(\operatorname*{% \mathcal{L}})^{2}+2n\rho C(\nabla\operatorname*{\mathcal{L}})}\times\frac{b% \delta d}{\textup{poly}(n,L)m\log^{2}(m)}\right)= roman_Θ ( divide start_ARG italic_α italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_ρ end_ARG start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 italic_n italic_ρ italic_C ( ∇ caligraphic_L ) end_ARG × divide start_ARG italic_b italic_δ italic_d end_ARG start_ARG poly ( italic_n , italic_L ) italic_m roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) end_ARG )(70)
=Θ(αnK×bδdpoly(n,L)mlog2(m))absentΘ𝛼𝑛𝐾𝑏𝛿𝑑poly𝑛𝐿𝑚superscript2𝑚\displaystyle=\Theta\left(\frac{\alpha n}{K}\times\frac{b\delta d}{\textup{% poly}(n,L)m\log^{2}(m)}\right)= roman_Θ ( divide start_ARG italic_α italic_n end_ARG start_ARG italic_K end_ARG × divide start_ARG italic_b italic_δ italic_d end_ARG start_ARG poly ( italic_n , italic_L ) italic_m roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) end_ARG )
αnK×ηabsent𝛼𝑛𝐾superscript𝜂\displaystyle\leq\frac{\alpha n}{K}\times\eta^{\prime}≤ divide start_ARG italic_α italic_n end_ARG start_ARG italic_K end_ARG × italic_η start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT

And we also have

ηexactηsubscript𝜂𝑒𝑥𝑎𝑐𝑡superscript𝜂\eta_{exact}\leq\eta^{\prime}italic_η start_POSTSUBSCRIPT italic_e italic_x italic_a italic_c italic_t end_POSTSUBSCRIPT ≤ italic_η start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT(71)

where ηsuperscript𝜂\eta^{\prime}italic_η start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is the learning rate chosen in the proof of Theorem 2 in (Allen-Zhu et al., 2019a). We refer the reader to (Allen-Zhu et al., 2019a) for the details of the constant in "ΘΘ\Thetaroman_Θ" and the exact form of the polynomial poly(n,L)poly𝑛𝐿\textup{poly}(n,L)poly ( italic_n , italic_L ).

As a result, for η=ηexact𝜂subscript𝜂𝑒𝑥𝑎𝑐𝑡\eta=\eta_{exact}italic_η = italic_η start_POSTSUBSCRIPT italic_e italic_x italic_a italic_c italic_t end_POSTSUBSCRIPT, the term Ω(ηmδdn2)Ω𝜂𝑚𝛿𝑑superscript𝑛2\Omega\left(\frac{\eta m\delta}{dn^{2}}\right)roman_Ω ( divide start_ARG italic_η italic_m italic_δ end_ARG start_ARG italic_d italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) dominates the other term of the right-hand side of inequality (68) as in the proof of Theorem 2 in (Allen-Zhu et al., 2019a).

This implies that the conditions of Theorem 14 are satisfied for all 𝜽(t)superscript𝜽𝑡{\bm{\theta}}^{(t)}bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT, and that we have for all iteration t>0𝑡0t>0italic_t > 0

𝔼St[R(𝑳(h(𝜽(t+1)))]R(𝑳(h(𝜽(t)))Ω(ηmδdn2)i=1np¯i(𝑳^)hii(hi(𝜽(t)))22\mathbb{E}_{S_{t}}\left[R({\bm{L}}(h({\bm{\theta}}^{(t+1)}))\right]\leq R({\bm% {L}}(h({\bm{\theta}}^{(t)}))-\Omega\left(\frac{\eta m\delta}{dn^{2}}\right)% \sum_{i=1}^{n}\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{i}}{\color[rgb]% {0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}^{(t)}))\right\rVert_{2}^{2}blackboard_E start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) ) ] ≤ italic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) - roman_Ω ( divide start_ARG italic_η italic_m italic_δ end_ARG start_ARG italic_d italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(72)

And using a result in Appendix A of (Allen-Zhu et al., 2019a), since under assumption I.2 the distributionally robust loss is non-convex and bounded, we obtain for all ϵ>0superscriptitalic-ϵ0\epsilon^{\prime}>0italic_ϵ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT > 0

h(R𝑳)(h(𝜽(T)))2,2ϵifT=O(dn2ηδmϵ2)formulae-sequencesubscriptdelimited-∥∥subscript𝑅𝑳superscript𝜽𝑇22superscriptitalic-ϵif𝑇𝑂𝑑superscript𝑛2𝜂𝛿𝑚superscriptitalic-ϵ2\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}^{(T)}))\right\rVert_{2,2% }\leq\epsilon^{\prime}\quad\textup{if}\quad T=O\left(\frac{dn^{2}}{\eta\delta m% \epsilon^{\prime 2}}\right)∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT ≤ italic_ϵ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT if italic_T = italic_O ( divide start_ARG italic_d italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_η italic_δ italic_m italic_ϵ start_POSTSUPERSCRIPT ′ 2 end_POSTSUPERSCRIPT end_ARG )(73)

where according to (45)

h(R𝑳)(h(𝜽(T)))2,2=i=1np¯i(𝑳^)hii(hi(𝜽(t)))22subscriptdelimited-∥∥subscript𝑅𝑳superscript𝜽𝑇22superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖superscript𝜽𝑡22\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}^{(T)}))\right\rVert_{2,2% }=\sum_{i=1}^{n}\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{i}}{\color[% rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}^{(t)}))\right\rVert_{2}^{2}∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(74)

However, we are interested in a bound on 𝜽(R𝑳h)(𝜽(T))2,2subscriptdelimited-∥∥subscript𝜽𝑅𝑳superscript𝜽𝑇22\left\lVert\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}}^{(T)})% \right\rVert_{2,2}∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT, rather than a bound on h(R𝑳)(h(𝜽(T)))2,2subscriptdelimited-∥∥subscript𝑅𝑳superscript𝜽𝑇22\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}^{(T)}))\right\rVert_{2,2}∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT. Using the gradient bound of Theorem 14 and the chain rules (46) and (45)

𝜽(R𝑳h)(𝜽(T))2,2c1Lmndh(R𝑳)(h(𝜽(T)))2,2subscriptdelimited-∥∥subscript𝜽𝑅𝑳superscript𝜽𝑇22subscript𝑐1𝐿𝑚𝑛𝑑subscriptdelimited-∥∥subscript𝑅𝑳superscript𝜽𝑇22\left\lVert\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}}^{(T)})% \right\rVert_{2,2}\leq c_{1}\sqrt{\frac{Lmn}{d}}\left\lVert\nabla_{h}(R\circ{% \bm{L}})(h({\bm{\theta}}^{(T)}))\right\rVert_{2,2}∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT ≤ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT square-root start_ARG divide start_ARG italic_L italic_m italic_n end_ARG start_ARG italic_d end_ARG end_ARG ∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT(75)

where c1>0subscript𝑐10c_{1}>0italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT > 0 is the constant hidden in O(Lmnd)𝑂𝐿𝑚𝑛𝑑O\left(\sqrt{\frac{Lmn}{d}}\right)italic_O ( square-root start_ARG divide start_ARG italic_L italic_m italic_n end_ARG start_ARG italic_d end_ARG end_ARG ).

So with ϵ=1c1dLmnϵsuperscriptitalic-ϵ1subscript𝑐1𝑑𝐿𝑚𝑛italic-ϵ\epsilon^{\prime}=\frac{1}{c_{1}}\sqrt{\frac{d}{Lmn}}\epsilonitalic_ϵ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG square-root start_ARG divide start_ARG italic_d end_ARG start_ARG italic_L italic_m italic_n end_ARG end_ARG italic_ϵ, we finally obtain

𝜽(R𝑳h)(𝜽(T))2,2subscriptdelimited-∥∥subscript𝜽𝑅𝑳superscript𝜽𝑇22\displaystyle\left\lVert\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{% \theta}}^{(T)})\right\rVert_{2,2}∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPTc1Lmndh(R𝑳)(h(𝜽(T)))2,2absentsubscript𝑐1𝐿𝑚𝑛𝑑subscriptdelimited-∥∥subscript𝑅𝑳superscript𝜽𝑇22\displaystyle\leq c_{1}\sqrt{\frac{Lmn}{d}}\left\lVert\nabla_{h}(R\circ{\bm{L}% })(h({\bm{\theta}}^{(T)}))\right\rVert_{2,2}≤ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT square-root start_ARG divide start_ARG italic_L italic_m italic_n end_ARG start_ARG italic_d end_ARG end_ARG ∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_T ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT(76)
c1Lmndϵabsentsubscript𝑐1𝐿𝑚𝑛𝑑superscriptitalic-ϵ\displaystyle\leq c_{1}\sqrt{\frac{Lmn}{d}}\epsilon^{\prime}≤ italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT square-root start_ARG divide start_ARG italic_L italic_m italic_n end_ARG start_ARG italic_d end_ARG end_ARG italic_ϵ start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT
ϵabsentitalic-ϵ\displaystyle\leq\epsilon≤ italic_ϵ

If

T=O(dn2ηδmϵ2)=O(dn2ηδmLmndϵ2)=O(Ln3ηδϵ2)𝑇𝑂𝑑superscript𝑛2𝜂𝛿𝑚superscriptitalic-ϵ2𝑂𝑑superscript𝑛2𝜂𝛿𝑚𝐿𝑚𝑛𝑑superscriptitalic-ϵ2𝑂𝐿superscript𝑛3𝜂𝛿superscriptitalic-ϵ2T=O\left(\frac{dn^{2}}{\eta\delta m\epsilon^{\prime 2}}\right)=O\left(\frac{dn% ^{2}}{\eta\delta m}\frac{Lmn}{d\epsilon^{2}}\right)=O\left(\frac{Ln^{3}}{\eta% \delta\epsilon^{2}}\right)italic_T = italic_O ( divide start_ARG italic_d italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_η italic_δ italic_m italic_ϵ start_POSTSUPERSCRIPT ′ 2 end_POSTSUPERSCRIPT end_ARG ) = italic_O ( divide start_ARG italic_d italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_η italic_δ italic_m end_ARG divide start_ARG italic_L italic_m italic_n end_ARG start_ARG italic_d italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) = italic_O ( divide start_ARG italic_L italic_n start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT end_ARG start_ARG italic_η italic_δ italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG )(77)

which concludes the proof. \blacksquare

I.3.5 Proof of technical lemma 1


For any integer k1𝑘1k\geq 1italic_k ≥ 1, and all (𝒂i)i=1n(k)nsuperscriptsubscriptsubscript𝒂𝑖𝑖1𝑛superscriptsuperscript𝑘𝑛\left({\bm{a}}_{i}\right)_{i=1}^{n}\in\left({\mathbb{R}}^{k}\right)^{n}( bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∈ ( blackboard_R start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, we have

𝔼S[1biS𝒂i]subscript𝔼𝑆delimited-[]1𝑏subscript𝑖𝑆subscript𝒂𝑖\displaystyle\mathbb{E}_{S}\left[\frac{1}{b}\sum_{i\in S}{\bm{a}}_{i}\right]blackboard_E start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ]=1i1,,ibn[(k=1np¯ik(𝑳^))1bj=1b𝒂ij]absentsubscriptformulae-sequence1subscript𝑖1subscript𝑖𝑏𝑛delimited-[]superscriptsubscriptproduct𝑘1𝑛subscript¯𝑝subscript𝑖𝑘^𝑳1𝑏superscriptsubscript𝑗1𝑏subscript𝒂subscript𝑖𝑗\displaystyle=\sum_{1\leq i_{1},\ldots,i_{b}\leq n}\left[\left(\prod_{k=1}^{n}% \bar{p}_{i_{k}}(\hat{{\bm{L}}})\right)\frac{1}{b}\sum_{j=1}^{b}{\bm{a}}_{i_{j}% }\right]= ∑ start_POSTSUBSCRIPT 1 ≤ italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_i start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ≤ italic_n end_POSTSUBSCRIPT [ ( ∏ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ) divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT bold_italic_a start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ](78)
=1b1i1,,ibn[j=1bp¯ij(𝑳^)𝒂ij(k=1kjnp¯ik(𝑳^))]absent1𝑏subscriptformulae-sequence1subscript𝑖1subscript𝑖𝑏𝑛delimited-[]superscriptsubscript𝑗1𝑏subscript¯𝑝subscript𝑖𝑗^𝑳subscript𝒂subscript𝑖𝑗superscriptsubscriptproduct𝑘1𝑘𝑗𝑛subscript¯𝑝subscript𝑖𝑘^𝑳\displaystyle=\frac{1}{b}\sum_{1\leq i_{1},\ldots,i_{b}\leq n}\left[\sum_{j=1}% ^{b}\bar{p}_{i_{j}}(\hat{{\bm{L}}})\,{\bm{a}}_{i_{j}}\left(\prod_{\begin{% subarray}{c}k=1\\ k\neq j\end{subarray}}^{n}\bar{p}_{i_{k}}(\hat{{\bm{L}}})\right)\right]= divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT 1 ≤ italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_i start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ≤ italic_n end_POSTSUBSCRIPT [ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) bold_italic_a start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( ∏ start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_k = 1 end_CELL end_ROW start_ROW start_CELL italic_k ≠ italic_j end_CELL end_ROW end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ) ]
=1bj=1b[1i1,,ibnp¯ij(𝑳^)𝒂ij(k=1kjnp¯ik(𝑳^))]absent1𝑏superscriptsubscript𝑗1𝑏delimited-[]subscriptformulae-sequence1subscript𝑖1subscript𝑖𝑏𝑛subscript¯𝑝subscript𝑖𝑗^𝑳subscript𝒂subscript𝑖𝑗superscriptsubscriptproduct𝑘1𝑘𝑗𝑛subscript¯𝑝subscript𝑖𝑘^𝑳\displaystyle=\frac{1}{b}\sum_{j=1}^{b}\left[\sum_{1\leq i_{1},\ldots,i_{b}% \leq n}\bar{p}_{i_{j}}(\hat{{\bm{L}}})\,{\bm{a}}_{i_{j}}\left(\prod_{\begin{% subarray}{c}k=1\\ k\neq j\end{subarray}}^{n}\bar{p}_{i_{k}}(\hat{{\bm{L}}})\right)\right]= divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT [ ∑ start_POSTSUBSCRIPT 1 ≤ italic_i start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_i start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ≤ italic_n end_POSTSUBSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) bold_italic_a start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( ∏ start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_k = 1 end_CELL end_ROW start_ROW start_CELL italic_k ≠ italic_j end_CELL end_ROW end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ) ]
=1bj=1b[(ij=1np¯ij(𝑳^)𝒂ij)k=1kjn(ik=1np¯ik(𝑳^))]absent1𝑏superscriptsubscript𝑗1𝑏delimited-[]superscriptsubscriptsubscript𝑖𝑗1𝑛subscript¯𝑝subscript𝑖𝑗^𝑳subscript𝒂subscript𝑖𝑗superscriptsubscriptproduct𝑘1𝑘𝑗𝑛superscriptsubscriptsubscript𝑖𝑘1𝑛subscript¯𝑝subscript𝑖𝑘^𝑳\displaystyle=\frac{1}{b}\sum_{j=1}^{b}\left[\left(\sum_{i_{j}=1}^{n}\bar{p}_{% i_{j}}(\hat{{\bm{L}}})\,{\bm{a}}_{i_{j}}\right)\prod_{\begin{subarray}{c}k=1\\ k\neq j\end{subarray}}^{n}\left(\sum_{i_{k}=1}^{n}\bar{p}_{i_{k}}(\hat{{\bm{L}% }})\right)\right]= divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT [ ( ∑ start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) bold_italic_a start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ∏ start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_k = 1 end_CELL end_ROW start_ROW start_CELL italic_k ≠ italic_j end_CELL end_ROW end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ) ]
=1bj=1b(i=1np¯i(𝑳^)𝒂i)absent1𝑏superscriptsubscript𝑗1𝑏superscriptsubscript𝑖1𝑛subscript¯𝑝𝑖^𝑳subscript𝒂𝑖\displaystyle=\frac{1}{b}\sum_{j=1}^{b}\left(\sum_{i=1}^{n}\bar{p}_{i}(\hat{{% \bm{L}}})\,{\bm{a}}_{i}\right)= divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
=i=1np¯i(𝑳^)𝒂iabsentsuperscriptsubscript𝑖1𝑛subscript¯𝑝𝑖^𝑳subscript𝒂𝑖\displaystyle=\sum_{i=1}^{n}\bar{p}_{i}(\hat{{\bm{L}}})\,{\bm{a}}_{i}= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
=𝔼𝒑¯(𝑳^)[𝒂i]absentsubscript𝔼¯𝒑^𝑳delimited-[]subscript𝒂𝑖\displaystyle=\mathbb{E}_{\bar{{\bm{p}}}(\hat{{\bm{L}}})}\left[{\bm{a}}_{i}\right]= blackboard_E start_POSTSUBSCRIPT over¯ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) end_POSTSUBSCRIPT [ bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ]

I.4 Convergence of SGD with Hardness Weighted Sampling and stale per-example loss vector

The proof of the convergence of Algorithm 1 under the conditions of Theorem 12 follows the same structure as the proof of the convergence of Robust SGD with exact per-example loss vector I.3.4. We will reuse the intermediate results of I.3.4 when possible and focus on the differences between the two proofs due to the inexactness of the per-example loss vector.

Let t𝑡titalic_t be the iteration number, and let 𝜽(t)(m×m)Lsuperscript𝜽𝑡superscriptsuperscript𝑚𝑚𝐿{\bm{\theta}}^{(t)}\in\left({\mathbb{R}}^{m\times m}\right)^{L}bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ∈ ( blackboard_R start_POSTSUPERSCRIPT italic_m × italic_m end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT be the parameters of the deep neural network at iteration t𝑡titalic_t. We define the stale per-example loss vector at iteration t𝑡titalic_t as

𝑳^=(i(hi(𝜽(ti(t)))))i=1n^𝑳superscriptsubscriptsubscript𝑖subscript𝑖superscript𝜽subscript𝑡𝑖𝑡𝑖1𝑛\hat{{\bm{L}}}=\left({\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}^{% (t_{i}(t))}))\right)_{i=1}^{n}over^ start_ARG bold_italic_L end_ARG = ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ) end_POSTSUPERSCRIPT ) ) ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT(79)

where for all i𝑖iitalic_i, ti(t)<tsubscript𝑡𝑖𝑡𝑡t_{i}(t)<titalic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) < italic_t corresponds to the latest iteration before t𝑡titalic_t at which the per-example loss value for example i𝑖iitalic_i has been updated. Or equivalently, it corresponds to the last iteration before t𝑡titalic_t when example i𝑖iitalic_i was drawn to be part of a mini-batch.

We also define the exact per-example loss vector that is unknown in Algorithm 1, as

𝑳˘=(i(hi(𝜽(t))))i=1n˘𝑳superscriptsubscriptsubscript𝑖subscript𝑖superscript𝜽𝑡𝑖1𝑛\breve{{\bm{L}}}=\left({\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}% ^{(t)}))\right)_{i=1}^{n}over˘ start_ARG bold_italic_L end_ARG = ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ) start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT(80)

Similarly to (58) we define

𝜽(t+1)=𝜽(t)η1biS𝜽(ihi)(𝜽(t))superscript𝜽𝑡1superscript𝜽𝑡𝜂1𝑏subscript𝑖𝑆subscript𝜽subscript𝑖subscript𝑖superscript𝜽𝑡{\bm{\theta}}^{(t+1)}={\bm{\theta}}^{(t)}-\eta\frac{1}{b}\sum_{i\in S}\nabla_{% {\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}^{% (t)})bold_italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT = bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - italic_η divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT )(81)

and using Theorem 13, similarly to (59), we obtain

R(𝑳(h(𝜽(t+1)))\displaystyle R({\bm{L}}(h({\bm{\theta}}^{(t+1)}))\leqitalic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) ) ≤R(𝑳(h(𝜽(t)))η𝜽(R𝑳h)(𝜽(t)),1biS𝜽(ihi)(𝜽(t))\displaystyle R({\bm{L}}(h({\bm{\theta}}^{(t)}))-\eta\langle\nabla_{{\bm{% \theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}}^{(t)}),\frac{1}{b}\sum_{i\in S}% \nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{% \theta}}^{(t)})\rangleitalic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) - italic_η ⟨ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ⟩(82)
+ηh(R𝑳)(h(𝜽(t)))1,2O(L2ω1/3mlog(m)d)1biS𝜽(ihi)(𝜽(t))2,𝜂subscriptdelimited-∥∥subscript𝑅𝑳superscript𝜽𝑡12𝑂superscript𝐿2superscript𝜔13𝑚𝑚𝑑subscriptdelimited-∥∥1𝑏subscript𝑖𝑆subscript𝜽subscript𝑖subscript𝑖superscript𝜽𝑡2\displaystyle+\eta\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}^{(t)})% )\right\rVert_{1,2}O\left(\frac{L^{2}\omega^{1/3}\sqrt{m\log(m)}}{\sqrt{d}}% \right)\left\lVert\frac{1}{b}\sum_{i\in S}\nabla_{{\bm{\theta}}}({\color[rgb]{% 0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}^{(t)})\right\rVert_{2,\infty}+ italic_η ∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT italic_O ( divide start_ARG italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_ω start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT square-root start_ARG italic_m roman_log ( italic_m ) end_ARG end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) ∥ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT
+η2O((βC()2nρ+2C())nL2md)1biS𝜽(ihi)(𝜽(t))2,2superscript𝜂2𝑂𝛽𝐶superscript2𝑛𝜌2𝐶𝑛superscript𝐿2𝑚𝑑superscriptsubscriptdelimited-∥∥1𝑏subscript𝑖𝑆subscript𝜽subscript𝑖subscript𝑖superscript𝜽𝑡22\displaystyle+\eta^{2}O\left(\left(\frac{{\beta}C(\operatorname*{\mathcal{L}})% ^{2}}{n\rho}+2C(\nabla\operatorname*{\mathcal{L}})\right)\frac{nL^{2}m}{d}% \right)\left\lVert\frac{1}{b}\sum_{i\in S}\nabla_{{\bm{\theta}}}({\color[rgb]{% 0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}^{(t)})\right\rVert_{2,\infty}% ^{2}+ italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_O ( ( divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + 2 italic_C ( ∇ caligraphic_L ) ) divide start_ARG italic_n italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_m end_ARG start_ARG italic_d end_ARG ) ∥ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_i ∈ italic_S end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 , ∞ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

We can still define α𝛼\alphaitalic_α as in (60)

α=min𝜽minip¯i(𝑳(𝜽))>0𝛼subscript𝜽subscript𝑖subscript¯𝑝𝑖𝑳𝜽0\alpha=\min_{{\bm{\theta}}}\min_{i}\bar{p}_{i}({\bm{L}}({\bm{\theta}}))>0italic_α = roman_min start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_L ( bold_italic_θ ) ) > 0(83)

where we are guaranteed that α>0𝛼0\alpha>0italic_α > 0 under assumptions I.1.

Since Theorem 14 is independent to the choice of 𝑳^^𝑳\hat{{\bm{L}}}over^ start_ARG bold_italic_L end_ARG, taking the expectation with respect to S𝑆Sitalic_S, similarly to (68), we obtain

𝔼S[R(𝑳(h(𝜽(t+1)))]\displaystyle\mathbb{E}_{S}\left[R({\bm{L}}(h({\bm{\theta}}^{(t+1)}))\right]blackboard_E start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT [ italic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) ) ]R(𝑳(h(𝜽(t)))η𝜽(R𝑳h)(𝜽(t)),i=1np¯i(𝑳^)𝜽(ihi)(𝜽(t)))\displaystyle\leq R({\bm{L}}(h({\bm{\theta}}^{(t)}))-\eta\langle\nabla_{{\bm{% \theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}}^{(t)}),\sum_{i=1}^{n}\bar{p}_{i% }(\hat{{\bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}% \circ h_{i})({\bm{\theta}}^{(t)}))\rangle≤ italic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) - italic_η ⟨ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) , ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ⟩(84)
+ηh(R𝑳)(h(𝜽(t)))1,2O(L2ω1/3nmlog(m)d)i=1np¯i(𝑳^)hii(hi(𝜽(t)))22𝜂subscriptdelimited-∥∥subscript𝑅𝑳superscript𝜽𝑡12𝑂superscript𝐿2superscript𝜔13𝑛𝑚𝑚𝑑superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖superscript𝜽𝑡22\displaystyle+\eta\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}^{(t)})% )\right\rVert_{1,2}O\left(\frac{L^{2}\omega^{1/3}\sqrt{nm\log(m)}}{\sqrt{d}}% \right)\sqrt{\sum_{i=1}^{n}\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{i}% }{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}^{(t)}))\right\rVert_{% 2}^{2}}+ italic_η ∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT italic_O ( divide start_ARG italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_ω start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT square-root start_ARG italic_n italic_m roman_log ( italic_m ) end_ARG end_ARG start_ARG square-root start_ARG italic_d end_ARG end_ARG ) square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
+η2O((βC()2nρ+2C())nL2md)O(mdα)i=1np¯i(𝑳^)hii(hi(𝜽(t)))22superscript𝜂2𝑂𝛽𝐶superscript2𝑛𝜌2𝐶𝑛superscript𝐿2𝑚𝑑𝑂𝑚𝑑𝛼superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖superscript𝜽𝑡22\displaystyle+\eta^{2}O\left(\left(\frac{{\beta}C(\operatorname*{\mathcal{L}})% ^{2}}{n\rho}+2C(\nabla\operatorname*{\mathcal{L}})\right)\frac{nL^{2}m}{d}% \right)O\left(\frac{m}{d\alpha}\right)\sum_{i=1}^{n}\left\lVert\bar{p}_{i}(% \hat{{\bm{L}}})\nabla_{h_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{% \theta}}^{(t)}))\right\rVert_{2}^{2}+ italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_O ( ( divide start_ARG italic_β italic_C ( caligraphic_L ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ end_ARG + 2 italic_C ( ∇ caligraphic_L ) ) divide start_ARG italic_n italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_m end_ARG start_ARG italic_d end_ARG ) italic_O ( divide start_ARG italic_m end_ARG start_ARG italic_d italic_α end_ARG ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

where the differences with respect to (68) comes from the fact that 𝑳^^𝑳\hat{{\bm{L}}}over^ start_ARG bold_italic_L end_ARG is not the exact per-example loss vector here, i.e. 𝑳^𝑳˘^𝑳˘𝑳\hat{{\bm{L}}}\neq\breve{{\bm{L}}}over^ start_ARG bold_italic_L end_ARG ≠ over˘ start_ARG bold_italic_L end_ARG, which leads to

𝜽(R𝑳h)(𝜽(t))subscript𝜽𝑅𝑳superscript𝜽𝑡\displaystyle\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{\theta}}^{(t)})∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT )=i=1np^i(𝑳˘)𝜽(ihi)(𝜽(t)))\displaystyle=\sum_{i=1}^{n}\hat{p}_{i}(\breve{{\bm{L}}})\nabla_{{\bm{\theta}}% }({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}^{(t)}))= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over˘ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) )(85)
i=1np¯i(𝑳^)𝜽(ihi)(𝜽(t)))\displaystyle\neq\sum_{i=1}^{n}\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{{\bm{\theta}% }}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}^{(t)}))≠ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) )

and

h(R𝑳)(h(𝜽(t)))1,2subscriptdelimited-∥∥subscript𝑅𝑳superscript𝜽𝑡12\displaystyle\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}^{(t)}))% \right\rVert_{1,2}∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT=i=1np^i(𝑳˘)hii(hi(𝜽(t))))2\displaystyle=\sum_{i=1}^{n}\left\lVert\hat{p}_{i}(\breve{{\bm{L}}})\nabla_{h_% {i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}^{(t)})))\right% \rVert_{2}= ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over˘ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT(86)
i=1np^i(𝑳^)hii(hi(𝜽(t))))2\displaystyle\neq\sum_{i=1}^{n}\left\lVert\hat{p}_{i}(\hat{{\bm{L}}})\nabla_{h% _{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}^{(t)})))\right% \rVert_{2}≠ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT

Let

K=C()A()O(βLm3/2log2(m)αn1/2ρd3/2blog(11α))superscript𝐾𝐶𝐴𝑂𝛽𝐿superscript𝑚32superscript2𝑚𝛼superscript𝑛12𝜌superscript𝑑32𝑏11𝛼K^{\prime}=C(\operatorname*{\mathcal{L}})A(\nabla\operatorname*{\mathcal{L}})% \,O\left(\frac{{\beta}Lm^{3/2}\log^{2}(m)}{\alpha n^{1/2}\rho d^{3/2}b\log% \left(\frac{1}{1-\alpha}\right)}\right)italic_K start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = italic_C ( caligraphic_L ) italic_A ( ∇ caligraphic_L ) italic_O ( divide start_ARG italic_β italic_L italic_m start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) end_ARG start_ARG italic_α italic_n start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_ρ italic_d start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT italic_b roman_log ( divide start_ARG 1 end_ARG start_ARG 1 - italic_α end_ARG ) end_ARG )(87)

Where C()>0𝐶0C(\operatorname*{\mathcal{L}})>0italic_C ( caligraphic_L ) > 0 is a constant such that \operatorname*{\mathcal{L}}caligraphic_L is C()𝐶C(\operatorname*{\mathcal{L}})italic_C ( caligraphic_L )-Lipschitz continuous, and A()>0𝐴0A(\nabla\operatorname*{\mathcal{L}})>0italic_A ( ∇ caligraphic_L ) > 0 is a constant that bound the gradient of \operatorname*{\mathcal{L}}caligraphic_L with respect to its input. C()𝐶C(\operatorname*{\mathcal{L}})italic_C ( caligraphic_L ) and A()𝐴A(\nabla\operatorname*{\mathcal{L}})italic_A ( ∇ caligraphic_L ) are guaranteed to exist under assumptions I.1.

We can prove that, with probability at least 1exp(Ω(log2(m)))1Ωsuperscript2𝑚1-\exp\left(-\Omega\left(\log^{2}(m)\right)\right)1 - roman_exp ( - roman_Ω ( roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) ) ),

  • according to lemma I.4.1

    p^(𝑳^)p^(𝑳˘)2=i=1n(p^i(𝑳^)p^i(𝑳˘))2ηαKsubscriptdelimited-∥∥^𝑝^𝑳^𝑝˘𝑳2superscriptsubscript𝑖1𝑛superscriptsubscript^𝑝𝑖^𝑳subscript^𝑝𝑖˘𝑳2𝜂𝛼superscript𝐾\left\lVert\hat{p}(\hat{{\bm{L}}})-\hat{p}(\breve{{\bm{L}}})\right\rVert_{2}=% \sqrt{\sum_{i=1}^{n}\left(\hat{p}_{i}(\hat{{\bm{L}}})-\hat{p}_{i}(\breve{{\bm{% L}}})\right)^{2}}\leq\eta\alpha K^{\prime}∥ over^ start_ARG italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) - over^ start_ARG italic_p end_ARG ( over˘ start_ARG bold_italic_L end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) - over^ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over˘ start_ARG bold_italic_L end_ARG ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ≤ italic_η italic_α italic_K start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT(88)
  • according to lemma I.4.2

    |𝜽(R𝑳h)(𝜽(t))i=1np¯i(𝑳^)𝜽(ihi)(𝜽(t))),i=1np¯i(𝑳^)𝜽(ihi)(𝜽(t)))|\displaystyle\left|\langle\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{% \theta}}^{(t)})-\sum_{i=1}^{n}\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{{\bm{\theta}}% }({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}^{(t)})),\sum_{% i=1}^{n}\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}% \mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}^{(t)}))\rangle\right|| ⟨ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) , ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ⟩ |(89)
    ηmdKi=1np¯i(𝑳^)𝜽(ihi)(𝜽(t)))22\displaystyle\leq\eta\frac{m}{d}K^{\prime}\sum_{i=1}^{n}\left\lVert\bar{p}_{i}% (\hat{{\bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}% \circ h_{i})({\bm{\theta}}^{(t)}))\right\rVert_{2}^{2}≤ italic_η divide start_ARG italic_m end_ARG start_ARG italic_d end_ARG italic_K start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
  • according to lemma I.4.3

    h(R𝑳)(h(𝜽(t)))1,2(n+ηK)i=1np¯i(𝑳^)𝜽(ihi)(𝜽(t)))22\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}^{(t)}))\right\rVert_{1,2% }\leq\left(\sqrt{n}+\eta K^{\prime}\right)\sqrt{\sum_{i=1}^{n}\left\lVert\bar{% p}_{i}(\hat{{\bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i% }}\circ h_{i})({\bm{\theta}}^{(t)}))\right\rVert_{2}^{2}}∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT ≤ ( square-root start_ARG italic_n end_ARG + italic_η italic_K start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG(90)

Combining those three inequalities with (84) we obtain

𝔼S[R(𝑳(h(𝜽(t+1)))]\displaystyle\mathbb{E}_{S}\left[R({\bm{L}}(h({\bm{\theta}}^{(t+1)}))\right]blackboard_E start_POSTSUBSCRIPT italic_S end_POSTSUBSCRIPT [ italic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) ) ]R(𝑳(h(𝜽(t)))\displaystyle-R({\bm{L}}(h({\bm{\theta}}^{(t)}))\leq- italic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ≤(91)
η𝜂\displaystyle\etaitalic_η[Ω(mδdn2)+O(nL2mω1/3log(m)d)]i=1np¯i(𝑳^)hii(hi(𝜽(t)))22delimited-[]Ω𝑚𝛿𝑑superscript𝑛2𝑂𝑛superscript𝐿2𝑚superscript𝜔13𝑚𝑑superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖superscript𝜽𝑡22\displaystyle\left[-\Omega\left(\frac{m\delta}{dn^{2}}\right)+O\left(\frac{nL^% {2}m\omega^{1/3}\sqrt{\log(m)}}{d}\right)\right]\sum_{i=1}^{n}\left\lVert\bar{% p}_{i}(\hat{{\bm{L}}})\nabla_{h_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}(% {\bm{\theta}}^{(t)}))\right\rVert_{2}^{2}[ - roman_Ω ( divide start_ARG italic_m italic_δ end_ARG start_ARG italic_d italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) + italic_O ( divide start_ARG italic_n italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_m italic_ω start_POSTSUPERSCRIPT 1 / 3 end_POSTSUPERSCRIPT square-root start_ARG roman_log ( italic_m ) end_ARG end_ARG start_ARG italic_d end_ARG ) ] ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
η2superscript𝜂2\displaystyle\eta^{2}italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPTO(K(n/α)L2m2d2+(1+md)K)i=1np¯i(𝑳^)hii(hi(𝜽(t)))22𝑂𝐾𝑛𝛼superscript𝐿2superscript𝑚2superscript𝑑21𝑚𝑑superscript𝐾superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑖^𝑳subscriptsubscript𝑖subscript𝑖subscript𝑖superscript𝜽𝑡22\displaystyle O\left(K\frac{(n/\alpha)L^{2}m^{2}}{d^{2}}+\left(1+\frac{m}{d}% \right)K^{\prime}\right)\sum_{i=1}^{n}\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})% \nabla_{h_{i}}{\color[rgb]{0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}^{(t)}))% \right\rVert_{2}^{2}italic_O ( italic_K divide start_ARG ( italic_n / italic_α ) italic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_m start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + ( 1 + divide start_ARG italic_m end_ARG start_ARG italic_d end_ARG ) italic_K start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

One can see that compared to (68), there is only the additional term (1+md)K1𝑚𝑑superscript𝐾\left(1+\frac{m}{d}\right)K^{\prime}( 1 + divide start_ARG italic_m end_ARG start_ARG italic_d end_ARG ) italic_K start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT.

Using our choice of η𝜂\etaitalic_η,

η=ηstaleO(δn2Kηexact)𝜂subscript𝜂𝑠𝑡𝑎𝑙𝑒𝑂𝛿superscript𝑛2superscript𝐾subscript𝜂𝑒𝑥𝑎𝑐𝑡\eta=\eta_{stale}\leq O\left(\frac{\delta}{n^{2}K^{\prime}}\eta_{exact}\right)italic_η = italic_η start_POSTSUBSCRIPT italic_s italic_t italic_a italic_l italic_e end_POSTSUBSCRIPT ≤ italic_O ( divide start_ARG italic_δ end_ARG start_ARG italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_K start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_ARG italic_η start_POSTSUBSCRIPT italic_e italic_x italic_a italic_c italic_t end_POSTSUBSCRIPT )(92)

where ηexactsubscript𝜂𝑒𝑥𝑎𝑐𝑡\eta_{exact}italic_η start_POSTSUBSCRIPT italic_e italic_x italic_a italic_c italic_t end_POSTSUBSCRIPT is the learning rate of Theorem 11, we have

Ω(ηmδdn2)O(η2(1+md)K)Ω𝜂𝑚𝛿𝑑superscript𝑛2𝑂superscript𝜂21𝑚𝑑superscript𝐾\Omega\left(\frac{\eta m\delta}{dn^{2}}\right)\geq O\left(\eta^{2}\left(1+% \frac{m}{d}\right)K^{\prime}\right)roman_Ω ( divide start_ARG italic_η italic_m italic_δ end_ARG start_ARG italic_d italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) ≥ italic_O ( italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 + divide start_ARG italic_m end_ARG start_ARG italic_d end_ARG ) italic_K start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT )(93)

As a result, η2(1+md)Ksuperscript𝜂21𝑚𝑑superscript𝐾\eta^{2}\left(1+\frac{m}{d}\right)K^{\prime}italic_η start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 + divide start_ARG italic_m end_ARG start_ARG italic_d end_ARG ) italic_K start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is dominated by the term Ω(ηmδdn2)Ω𝜂𝑚𝛿𝑑superscript𝑛2\Omega\left(\frac{\eta m\delta}{dn^{2}}\right)roman_Ω ( divide start_ARG italic_η italic_m italic_δ end_ARG start_ARG italic_d italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG )

In addition, since ηstaleηexactsubscript𝜂𝑠𝑡𝑎𝑙𝑒subscript𝜂𝑒𝑥𝑎𝑐𝑡\eta_{stale}\leq\eta_{exact}italic_η start_POSTSUBSCRIPT italic_s italic_t italic_a italic_l italic_e end_POSTSUBSCRIPT ≤ italic_η start_POSTSUBSCRIPT italic_e italic_x italic_a italic_c italic_t end_POSTSUBSCRIPT, Ω(ηmδdn2)Ω𝜂𝑚𝛿𝑑superscript𝑛2\Omega\left(\frac{\eta m\delta}{dn^{2}}\right)roman_Ω ( divide start_ARG italic_η italic_m italic_δ end_ARG start_ARG italic_d italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) still dominates also the ther terms as in the proof of Theorem 11.

As a consequence, we obtain as in (72) that for any iteration t>0𝑡0t>0italic_t > 0

𝔼St[R(𝑳(h(𝜽(t+1)))]R(𝑳(h(𝜽(t)))Ω(ηmδdn2)i=1np¯i(𝑳^)hii(hi(𝜽(t)))22\mathbb{E}_{S_{t}}\left[R({\bm{L}}(h({\bm{\theta}}^{(t+1)}))\right]\leq R({\bm% {L}}(h({\bm{\theta}}^{(t)}))-\Omega\left(\frac{\eta m\delta}{dn^{2}}\right)% \sum_{i=1}^{n}\left\lVert\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{h_{i}}{\color[rgb]% {0,0,0}\mathcal{L}_{i}}(h_{i}({\bm{\theta}}^{(t)}))\right\rVert_{2}^{2}blackboard_E start_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ italic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t + 1 ) end_POSTSUPERSCRIPT ) ) ] ≤ italic_R ( bold_italic_L ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) - roman_Ω ( divide start_ARG italic_η italic_m italic_δ end_ARG start_ARG italic_d italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(94)

This concludes the proof using the same arguments as in the end of the proof of Theorem 11 starting from (72). \blacksquare

I.4.1 Proof of technical lemma 2


Using Lemma 5 and Lemma 4 we obtain

𝒑^(𝑳^)𝒑^(𝑳˘)2subscriptdelimited-∥∥^𝒑^𝑳^𝒑˘𝑳2\displaystyle\left\lVert\hat{{\bm{p}}}(\hat{{\bm{L}}})-\hat{{\bm{p}}}(\breve{{% \bm{L}}})\right\rVert_{2}∥ over^ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) - over^ start_ARG bold_italic_p end_ARG ( over˘ start_ARG bold_italic_L end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT=vR(𝑳^)vR(𝑳˘)2absentsubscriptdelimited-∥∥subscript𝑣𝑅^𝑳subscript𝑣𝑅˘𝑳2\displaystyle=\left\lVert\nabla_{v}R(\hat{{\bm{L}}})-\nabla_{v}R(\breve{{\bm{L% }}})\right\rVert_{2}= ∥ ∇ start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT italic_R ( over^ start_ARG bold_italic_L end_ARG ) - ∇ start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT italic_R ( over˘ start_ARG bold_italic_L end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT(95)
βnρ𝑳^𝑳˘2absent𝛽𝑛𝜌subscriptdelimited-∥∥^𝑳˘𝑳2\displaystyle\leq\frac{{\beta}}{n\rho}\left\lVert\hat{{\bm{L}}}-\breve{{\bm{L}% }}\right\rVert_{2}≤ divide start_ARG italic_β end_ARG start_ARG italic_n italic_ρ end_ARG ∥ over^ start_ARG bold_italic_L end_ARG - over˘ start_ARG bold_italic_L end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT

Using assumptions I.2 and (Allen-Zhu et al., 2019a, Claim 11.2)

𝒑^(𝑳^)𝒑^(𝑳˘)2subscriptdelimited-∥∥^𝒑^𝑳^𝒑˘𝑳2\displaystyle\left\lVert\hat{{\bm{p}}}(\hat{{\bm{L}}})-\hat{{\bm{p}}}(\breve{{% \bm{L}}})\right\rVert_{2}∥ over^ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) - over^ start_ARG bold_italic_p end_ARG ( over˘ start_ARG bold_italic_L end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTβnρi=1n(ihi(𝜽(t))ihi(𝜽(ti(t))))2absent𝛽𝑛𝜌superscriptsubscript𝑖1𝑛superscriptsubscript𝑖subscript𝑖superscript𝜽𝑡subscript𝑖subscript𝑖superscript𝜽subscript𝑡𝑖𝑡2\displaystyle\leq\frac{{\beta}}{n\rho}\sqrt{\sum_{i=1}^{n}\left({\color[rgb]{% 0,0,0}\mathcal{L}_{i}}\circ~{}h_{i}({\bm{\theta}}^{(t)})-{\color[rgb]{0,0,0}% \mathcal{L}_{i}}\circ~{}h_{i}({\bm{\theta}}^{(t_{i}(t))})\right)^{2}}≤ divide start_ARG italic_β end_ARG start_ARG italic_n italic_ρ end_ARG square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ) end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG(96)
βnρC()C(h)i=1n𝜽(t)𝜽(ti(t))2,22absent𝛽𝑛𝜌𝐶𝐶superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥superscript𝜽𝑡superscript𝜽subscript𝑡𝑖𝑡222\displaystyle\leq\frac{{\beta}}{n\rho}C(\operatorname*{\mathcal{L}})C(h)\sqrt{% \sum_{i=1}^{n}\left\lVert{\bm{\theta}}^{(t)}-{\bm{\theta}}^{(t_{i}(t))}\right% \rVert_{2,2}^{2}}≤ divide start_ARG italic_β end_ARG start_ARG italic_n italic_ρ end_ARG italic_C ( caligraphic_L ) italic_C ( italic_h ) square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - bold_italic_θ start_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
C()O(βLm1/2nρd1/2)i=1n𝜽(t)𝜽(ti(t))2,22absent𝐶𝑂𝛽𝐿superscript𝑚12𝑛𝜌superscript𝑑12superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥superscript𝜽𝑡superscript𝜽subscript𝑡𝑖𝑡222\displaystyle\leq C(\operatorname*{\mathcal{L}})O\left(\frac{{\beta}Lm^{1/2}}{% n\rho d^{1/2}}\right)\sqrt{\sum_{i=1}^{n}\left\lVert{\bm{\theta}}^{(t)}-{\bm{% \theta}}^{(t_{i}(t))}\right\rVert_{2,2}^{2}}≤ italic_C ( caligraphic_L ) italic_O ( divide start_ARG italic_β italic_L italic_m start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ italic_d start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT end_ARG ) square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT - bold_italic_θ start_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG

Where C()𝐶C(\operatorname*{\mathcal{L}})italic_C ( caligraphic_L ) is the constant of Lipschitz continuity of the per-example loss \operatorname*{\mathcal{L}}caligraphic_L (see assumptions I.2) and C(h)𝐶C(h)italic_C ( italic_h ) is the constant of Lipschitz continuity of the deep neural network hhitalic_h with respect to its parameters 𝜽𝜽{\bm{\theta}}bold_italic_θ.

By developing the recurrence formula of 𝜽(t)superscript𝜽𝑡{\bm{\theta}}^{(t)}bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT (81), we obtain

𝒑^(𝑳^)𝒑^(𝑳˘)2subscriptdelimited-∥∥^𝒑^𝑳^𝒑˘𝑳2\displaystyle\left\lVert\hat{{\bm{p}}}(\hat{{\bm{L}}})-\hat{{\bm{p}}}(\breve{{% \bm{L}}})\right\rVert_{2}∥ over^ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) - over^ start_ARG bold_italic_p end_ARG ( over˘ start_ARG bold_italic_L end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTC()O(βLm1/2nρd1/2)i=1n𝜽(ti(t))(τ=ti(t)t1ηbjSτ𝜽(jhj)(𝜽(τ)))𝜽(ti(t))2,22absent𝐶𝑂𝛽𝐿superscript𝑚12𝑛𝜌superscript𝑑12superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥superscript𝜽subscript𝑡𝑖𝑡superscriptsubscript𝜏subscript𝑡𝑖𝑡𝑡1𝜂𝑏subscript𝑗subscript𝑆𝜏subscript𝜽subscript𝑗subscript𝑗superscript𝜽𝜏superscript𝜽subscript𝑡𝑖𝑡222\displaystyle\leq C(\operatorname*{\mathcal{L}})O\left(\frac{{\beta}Lm^{1/2}}{% n\rho d^{1/2}}\right)\sqrt{\sum_{i=1}^{n}\left\lVert{\bm{\theta}}^{(t_{i}(t))}% -\left(\sum_{\tau=t_{i}(t)}^{t-1}\frac{\eta}{b}\sum_{j\in S_{\tau}}\nabla_{{% \bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{j}}\circ~{}h_{j})({\bm{\theta}}^% {(\tau)})\right)-{\bm{\theta}}^{(t_{i}(t))}\right\rVert_{2,2}^{2}}≤ italic_C ( caligraphic_L ) italic_O ( divide start_ARG italic_β italic_L italic_m start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ italic_d start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT end_ARG ) square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ bold_italic_θ start_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ) end_POSTSUPERSCRIPT - ( ∑ start_POSTSUBSCRIPT italic_τ = italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT divide start_ARG italic_η end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_j ∈ italic_S start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) ) - bold_italic_θ start_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
ηC()O(βLm1/2nρd1/2)i=1nτ=ti(t)t11bjSτ𝜽(jhj)(𝜽(τ))2,22absent𝜂𝐶𝑂𝛽𝐿superscript𝑚12𝑛𝜌superscript𝑑12superscriptsubscript𝑖1𝑛superscriptsubscriptdelimited-∥∥superscriptsubscript𝜏subscript𝑡𝑖𝑡𝑡11𝑏subscript𝑗subscript𝑆𝜏subscript𝜽subscript𝑗subscript𝑗superscript𝜽𝜏222\displaystyle\leq\eta C(\operatorname*{\mathcal{L}})O\left(\frac{{\beta}Lm^{1/% 2}}{n\rho d^{1/2}}\right)\sqrt{\sum_{i=1}^{n}\left\lVert\sum_{\tau=t_{i}(t)}^{% t-1}\frac{1}{b}\sum_{j\in S_{\tau}}\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}% \mathcal{L}_{j}}\circ~{}h_{j})({\bm{\theta}}^{(\tau)})\right\rVert_{2,2}^{2}}≤ italic_η italic_C ( caligraphic_L ) italic_O ( divide start_ARG italic_β italic_L italic_m start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ italic_d start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT end_ARG ) square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ ∑ start_POSTSUBSCRIPT italic_τ = italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_j ∈ italic_S start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG

Let A()𝐴A(\nabla\operatorname*{\mathcal{L}})italic_A ( ∇ caligraphic_L ) a bound on the gradient of the per-example loss function. Using Theorem 14 and the chain rule

j,τ𝜽(jhj)(𝜽(τ))2,2A()O(md)for-all𝑗for-all𝜏subscriptdelimited-∥∥subscript𝜽subscript𝑗subscript𝑗superscript𝜽𝜏22𝐴𝑂𝑚𝑑\forall j,\,\forall\tau\quad\left\lVert\nabla_{{\bm{\theta}}}({\color[rgb]{% 0,0,0}\mathcal{L}_{j}}\circ h_{j})({\bm{\theta}}^{(\tau)})\right\rVert_{2,2}% \leq A(\nabla\operatorname*{\mathcal{L}})O\left(\frac{m}{d}\right)∀ italic_j , ∀ italic_τ ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT ≤ italic_A ( ∇ caligraphic_L ) italic_O ( divide start_ARG italic_m end_ARG start_ARG italic_d end_ARG )(97)

And using the triangle inequality

τ=ti(t)t11bjSτ𝜽(jhj)(𝜽(τ))2,2subscriptdelimited-∥∥superscriptsubscript𝜏subscript𝑡𝑖𝑡𝑡11𝑏subscript𝑗subscript𝑆𝜏subscript𝜽subscript𝑗subscript𝑗superscript𝜽𝜏22\displaystyle\left\lVert\sum_{\tau=t_{i}(t)}^{t-1}\frac{1}{b}\sum_{j\in S_{% \tau}}\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{j}}\circ h_{j})({% \bm{\theta}}^{(\tau)})\right\rVert_{2,2}∥ ∑ start_POSTSUBSCRIPT italic_τ = italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_j ∈ italic_S start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPTτ=ti(t)t11bjSτ𝜽(jhj)(𝜽(τ))2,2absentsuperscriptsubscript𝜏subscript𝑡𝑖𝑡𝑡11𝑏subscript𝑗subscript𝑆𝜏subscriptdelimited-∥∥subscript𝜽subscript𝑗subscript𝑗superscript𝜽𝜏22\displaystyle\leq\sum_{\tau=t_{i}(t)}^{t-1}\frac{1}{b}\sum_{j\in S_{\tau}}% \left\lVert\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{j}}\circ h_{% j})({\bm{\theta}}^{(\tau)})\right\rVert_{2,2}≤ ∑ start_POSTSUBSCRIPT italic_τ = italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG italic_b end_ARG ∑ start_POSTSUBSCRIPT italic_j ∈ italic_S start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_τ ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT(98)
τ=ti(t)t1A()O(md)absentsuperscriptsubscript𝜏subscript𝑡𝑖𝑡𝑡1𝐴𝑂𝑚𝑑\displaystyle\leq\sum_{\tau=t_{i}(t)}^{t-1}A(\nabla\operatorname*{\mathcal{L}}% )O\left(\frac{m}{d}\right)≤ ∑ start_POSTSUBSCRIPT italic_τ = italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t - 1 end_POSTSUPERSCRIPT italic_A ( ∇ caligraphic_L ) italic_O ( divide start_ARG italic_m end_ARG start_ARG italic_d end_ARG )
A()O(md)(tti(t))absent𝐴𝑂𝑚𝑑𝑡subscript𝑡𝑖𝑡\displaystyle\leq A(\nabla\operatorname*{\mathcal{L}})O\left(\frac{m}{d}\right% )(t-t_{i}(t))≤ italic_A ( ∇ caligraphic_L ) italic_O ( divide start_ARG italic_m end_ARG start_ARG italic_d end_ARG ) ( italic_t - italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) )

As a result, we obtain

𝒑^(𝑳^)𝒑^(𝑳˘)2subscriptdelimited-∥∥^𝒑^𝑳^𝒑˘𝑳2\displaystyle\left\lVert\hat{{\bm{p}}}(\hat{{\bm{L}}})-\hat{{\bm{p}}}(\breve{{% \bm{L}}})\right\rVert_{2}∥ over^ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) - over^ start_ARG bold_italic_p end_ARG ( over˘ start_ARG bold_italic_L end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTηC()A()O(βLm3/2nρd3/2)i=1n(tti(t))2absent𝜂𝐶𝐴𝑂𝛽𝐿superscript𝑚32𝑛𝜌superscript𝑑32superscriptsubscript𝑖1𝑛superscript𝑡subscript𝑡𝑖𝑡2\displaystyle\leq\eta C(\operatorname*{\mathcal{L}})A(\nabla\operatorname*{% \mathcal{L}})O\left(\frac{{\beta}Lm^{3/2}}{n\rho d^{3/2}}\right)\sqrt{\sum_{i=% 1}^{n}(t-t_{i}(t))^{2}}≤ italic_η italic_C ( caligraphic_L ) italic_A ( ∇ caligraphic_L ) italic_O ( divide start_ARG italic_β italic_L italic_m start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ italic_d start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT end_ARG ) square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( italic_t - italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG(99)

For all i𝑖iitalic_i and for any τ𝜏\tauitalic_τ the probability that the sample i𝑖iitalic_i is not in batch Sτsubscript𝑆𝜏S_{\tau}italic_S start_POSTSUBSCRIPT italic_τ end_POSTSUBSCRIPT is lesser than (1α)bsuperscript1𝛼𝑏\left(1-\alpha\right)^{b}( 1 - italic_α ) start_POSTSUPERSCRIPT italic_b end_POSTSUPERSCRIPT.

Therefore, for any k1𝑘1k\geq 1italic_k ≥ 1 and for any t𝑡titalic_t,

P(tti(t)k)(1α)kb𝑃𝑡subscript𝑡𝑖𝑡𝑘superscript1𝛼𝑘𝑏P\left(t-t_{i}(t)\geq k\right)\leq\left(1-\alpha\right)^{kb}italic_P ( italic_t - italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ≥ italic_k ) ≤ ( 1 - italic_α ) start_POSTSUPERSCRIPT italic_k italic_b end_POSTSUPERSCRIPT(100)

For k1bΩ(log2(m)log(11α))𝑘1𝑏Ωsuperscript2𝑚11𝛼k\geq\frac{1}{b}\Omega\left(\frac{\log^{2}(m)}{\log\left(\frac{1}{1-\alpha}% \right)}\right)italic_k ≥ divide start_ARG 1 end_ARG start_ARG italic_b end_ARG roman_Ω ( divide start_ARG roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) end_ARG start_ARG roman_log ( divide start_ARG 1 end_ARG start_ARG 1 - italic_α end_ARG ) end_ARG ), we have (1α)kbexp(Ω(log2(m)))superscript1𝛼𝑘𝑏Ωsuperscript2𝑚\left(1-\alpha\right)^{kb}\leq\exp\left(-\Omega\left(\log^{2}(m)\right)\right)( 1 - italic_α ) start_POSTSUPERSCRIPT italic_k italic_b end_POSTSUPERSCRIPT ≤ roman_exp ( - roman_Ω ( roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) ) ), and thus with probability at least 1exp(Ω(log2(m)))1Ωsuperscript2𝑚1-\exp\left(-\Omega\left(\log^{2}(m)\right)\right)1 - roman_exp ( - roman_Ω ( roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) ) ),

t,tti(t)O(log2(m)blog(11α))for-all𝑡𝑡subscript𝑡𝑖𝑡𝑂superscript2𝑚𝑏11𝛼\forall t,\quad t-t_{i}(t)\leq O\left(\frac{\log^{2}(m)}{b\log\left(\frac{1}{1% -\alpha}\right)}\right)∀ italic_t , italic_t - italic_t start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( italic_t ) ≤ italic_O ( divide start_ARG roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) end_ARG start_ARG italic_b roman_log ( divide start_ARG 1 end_ARG start_ARG 1 - italic_α end_ARG ) end_ARG )(101)

As a result, we finally obtain that with probability at least 1exp(Ω(log2(m)))1Ωsuperscript2𝑚1-\exp\left(-\Omega\left(\log^{2}(m)\right)\right)1 - roman_exp ( - roman_Ω ( roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) ) ),

𝒑^(𝑳^)𝒑^(𝑳˘)2subscriptdelimited-∥∥^𝒑^𝑳^𝒑˘𝑳2\displaystyle\left\lVert\hat{{\bm{p}}}(\hat{{\bm{L}}})-\hat{{\bm{p}}}(\breve{{% \bm{L}}})\right\rVert_{2}∥ over^ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) - over^ start_ARG bold_italic_p end_ARG ( over˘ start_ARG bold_italic_L end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPTηC()A()O(βLm3/2nρd3/2)nO(log2(m)blog(11α))absent𝜂𝐶𝐴𝑂𝛽𝐿superscript𝑚32𝑛𝜌superscript𝑑32𝑛𝑂superscript2𝑚𝑏11𝛼\displaystyle\leq\eta C(\operatorname*{\mathcal{L}})A(\nabla\operatorname*{% \mathcal{L}})O\left(\frac{{\beta}Lm^{3/2}}{n\rho d^{3/2}}\right)\sqrt{n}O\left% (\frac{\log^{2}(m)}{b\log\left(\frac{1}{1-\alpha}\right)}\right)≤ italic_η italic_C ( caligraphic_L ) italic_A ( ∇ caligraphic_L ) italic_O ( divide start_ARG italic_β italic_L italic_m start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_n italic_ρ italic_d start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT end_ARG ) square-root start_ARG italic_n end_ARG italic_O ( divide start_ARG roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) end_ARG start_ARG italic_b roman_log ( divide start_ARG 1 end_ARG start_ARG 1 - italic_α end_ARG ) end_ARG )(102)
ηαO(βLm3/2log2(m)αn1/2ρd3/2blog(11α))absent𝜂𝛼𝑂𝛽𝐿superscript𝑚32superscript2𝑚𝛼superscript𝑛12𝜌superscript𝑑32𝑏11𝛼\displaystyle\leq\eta\alpha O\left(\frac{{\beta}Lm^{3/2}\log^{2}(m)}{\alpha n^% {1/2}\rho d^{3/2}b\log\left(\frac{1}{1-\alpha}\right)}\right)≤ italic_η italic_α italic_O ( divide start_ARG italic_β italic_L italic_m start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT roman_log start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_m ) end_ARG start_ARG italic_α italic_n start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT italic_ρ italic_d start_POSTSUPERSCRIPT 3 / 2 end_POSTSUPERSCRIPT italic_b roman_log ( divide start_ARG 1 end_ARG start_ARG 1 - italic_α end_ARG ) end_ARG )
ηαKabsent𝜂𝛼superscript𝐾\displaystyle\leq\eta\alpha K^{\prime}≤ italic_η italic_α italic_K start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT

I.4.2 Proof of technical lemma 3


Let us first denote

A𝐴\displaystyle Aitalic_A=|𝜽(R𝑳h)(𝜽(t))i=1np¯i(𝑳^)𝜽(ihi)(𝜽(t))),i=1np¯i(𝑳^)𝜽(ihi)(𝜽(t)))|\displaystyle=\left|\langle\nabla_{{\bm{\theta}}}(R\circ{\bm{L}}\circ h)({\bm{% \theta}}^{(t)})-\sum_{i=1}^{n}\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{{\bm{\theta}}% }({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}^{(t)})),\sum_{% i=1}^{n}\bar{p}_{i}(\hat{{\bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}% \mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}^{(t)}))\rangle\right|= | ⟨ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ∘ italic_h ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) , ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ⟩ |(103)
=|i=1n(p¯i(𝑳˘)p¯i(𝑳^))𝜽(ihi)(𝜽(t))),i=1np¯i(𝑳^)𝜽(ihi)(𝜽(t)))|\displaystyle=\left|\langle\sum_{i=1}^{n}\left(\bar{p}_{i}(\breve{{\bm{L}}})-% \bar{p}_{i}(\hat{{\bm{L}}})\right)\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}% \mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}^{(t)})),\sum_{i=1}^{n}\bar{p}_{i}(% \hat{{\bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h% _{i})({\bm{\theta}}^{(t)}))\rangle\right|= | ⟨ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over˘ start_ARG bold_italic_L end_ARG ) - over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) , ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ⟩ |

Using Cauchy-Schwarz inequality

A𝐴\displaystyle Aitalic_A=|i=1n(p¯i(𝑳˘)p¯i(𝑳^))𝜽(ihi)(𝜽(t))),j=1np¯j(𝑳^)𝜽(jhj)(𝜽(t)))|\displaystyle=\left|\sum_{i=1}^{n}\left(\bar{p}_{i}(\breve{{\bm{L}}})-\bar{p}_% {i}(\hat{{\bm{L}}})\right)\langle\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}% \mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}^{(t)})),\sum_{j=1}^{n}\bar{p}_{j}(% \hat{{\bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{j}}\circ h% _{j})({\bm{\theta}}^{(t)}))\rangle\right|= | ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over˘ start_ARG bold_italic_L end_ARG ) - over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ) ⟨ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) , ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ⟩ |(104)
𝒑^(𝑳^)𝒑^(𝑳˘)2i=1n(𝜽(ihi)(𝜽(t))),j=1np¯j(𝑳^)𝜽(jhj)(𝜽(t))))2\displaystyle\leq\left\lVert\hat{{\bm{p}}}(\hat{{\bm{L}}})-\hat{{\bm{p}}}(% \breve{{\bm{L}}})\right\rVert_{2}\sqrt{\sum_{i=1}^{n}\left(\langle\nabla_{{\bm% {\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}^{(t)}% )),\sum_{j=1}^{n}\bar{p}_{j}(\hat{{\bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb% ]{0,0,0}\mathcal{L}_{j}}\circ h_{j})({\bm{\theta}}^{(t)}))\rangle\right)^{2}}≤ ∥ over^ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) - over^ start_ARG bold_italic_p end_ARG ( over˘ start_ARG bold_italic_L end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( ⟨ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) , ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ⟩ ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG

Let

B=𝜽(ihi)(𝜽(t))),j=1np¯j(𝑳^)𝜽(jhj)(𝜽(t)))B=\langle\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i}% )({\bm{\theta}}^{(t)})),\sum_{j=1}^{n}\bar{p}_{j}(\hat{{\bm{L}}})\nabla_{{\bm{% \theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{j}}\circ h_{j})({\bm{\theta}}^{(t)}))\rangleitalic_B = ⟨ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) , ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ⟩(105)

Using again Cauchy-Schwarz inequality

B𝜽(ihi)(𝜽(t)))2,2j=1np¯j(𝑳^)𝜽(jhj)(𝜽(t)))2,2B\leq\left\lVert\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}% \circ h_{i})({\bm{\theta}}^{(t)}))\right\rVert_{2,2}\left\lVert\sum_{j=1}^{n}% \bar{p}_{j}(\hat{{\bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{% L}_{j}}\circ h_{j})({\bm{\theta}}^{(t)}))\right\rVert_{2,2}italic_B ≤ ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT ∥ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT(106)

As a result, A𝐴Aitalic_A becomes

A𝐴\displaystyle Aitalic_A𝒑^(𝑳^)𝒑^(𝑳˘)2j=1np¯j(𝑳^)𝜽(jhj)(𝜽(t)))2,2i=1n𝜽(ihi)(𝜽(t)))2,22\displaystyle\leq\left\lVert\hat{{\bm{p}}}(\hat{{\bm{L}}})-\hat{{\bm{p}}}(% \breve{{\bm{L}}})\right\rVert_{2}\left\lVert\sum_{j=1}^{n}\bar{p}_{j}(\hat{{% \bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{j}}\circ h_{j}% )({\bm{\theta}}^{(t)}))\right\rVert_{2,2}\sqrt{\sum_{i=1}^{n}\left\lVert\nabla% _{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}% ^{(t)}))\right\rVert_{2,2}^{2}}≤ ∥ over^ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) - over^ start_ARG bold_italic_p end_ARG ( over˘ start_ARG bold_italic_L end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG(107)
𝒑^(𝑳^)𝒑^(𝑳˘)2j=1np¯j(𝑳^)𝜽(jhj)(𝜽(t)))2,2i=1n1α2p¯j(𝑳^)𝜽(ihi)(𝜽(t)))2,22\displaystyle\leq\left\lVert\hat{{\bm{p}}}(\hat{{\bm{L}}})-\hat{{\bm{p}}}(% \breve{{\bm{L}}})\right\rVert_{2}\left\lVert\sum_{j=1}^{n}\bar{p}_{j}(\hat{{% \bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{j}}\circ h_{j}% )({\bm{\theta}}^{(t)}))\right\rVert_{2,2}\sqrt{\sum_{i=1}^{n}\frac{1}{\alpha^{% 2}}\left\lVert\bar{p}_{j}(\hat{{\bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb]{% 0,0,0}\mathcal{L}_{i}}\circ h_{i})({\bm{\theta}}^{(t)}))\right\rVert_{2,2}^{2}}≤ ∥ over^ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) - over^ start_ARG bold_italic_p end_ARG ( over˘ start_ARG bold_italic_L end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT square-root start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG italic_α start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
1α𝒑^(𝑳^)𝒑^(𝑳˘)2j=1np¯j(𝑳^)𝜽(jhj)(𝜽(t)))2,22\displaystyle\leq\frac{1}{\alpha}\left\lVert\hat{{\bm{p}}}(\hat{{\bm{L}}})-% \hat{{\bm{p}}}(\breve{{\bm{L}}})\right\rVert_{2}\left\lVert\sum_{j=1}^{n}\bar{% p}_{j}(\hat{{\bm{L}}})\nabla_{{\bm{\theta}}}({\color[rgb]{0,0,0}\mathcal{L}_{j% }}\circ h_{j})({\bm{\theta}}^{(t)}))\right\rVert_{2,2}^{2}≤ divide start_ARG 1 end_ARG start_ARG italic_α end_ARG ∥ over^ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) - over^ start_ARG bold_italic_p end_ARG ( over˘ start_ARG bold_italic_L end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT bold_italic_θ end_POSTSUBSCRIPT ( caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∘ italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

Using the triangular inequality, Theorem 14, and Lemma I.4.1, we finally obtain

A𝐴\displaystyle Aitalic_Amαd𝒑^(𝑳^)𝒑^(𝑳˘)2j=1np¯j(𝑳^)hjj(hj(𝜽(t)))2,22absent𝑚𝛼𝑑subscriptdelimited-∥∥^𝒑^𝑳^𝒑˘𝑳2superscriptsubscript𝑗1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑗^𝑳subscriptsubscript𝑗subscript𝑗subscript𝑗superscript𝜽𝑡222\displaystyle\leq\frac{m}{\alpha d}\left\lVert\hat{{\bm{p}}}(\hat{{\bm{L}}})-% \hat{{\bm{p}}}(\breve{{\bm{L}}})\right\rVert_{2}\sum_{j=1}^{n}\left\lVert\bar{% p}_{j}(\hat{{\bm{L}}})\nabla_{h_{j}}{\color[rgb]{0,0,0}\mathcal{L}_{j}}(h_{j}(% {\bm{\theta}}^{(t)}))\right\rVert_{2,2}^{2}≤ divide start_ARG italic_m end_ARG start_ARG italic_α italic_d end_ARG ∥ over^ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) - over^ start_ARG bold_italic_p end_ARG ( over˘ start_ARG bold_italic_L end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(108)
ηmdKj=1np¯j(𝑳^)hjj(hj(𝜽(t)))2,22absent𝜂𝑚𝑑superscript𝐾superscriptsubscript𝑗1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑗^𝑳subscriptsubscript𝑗subscript𝑗subscript𝑗superscript𝜽𝑡222\displaystyle\leq\eta\frac{m}{d}K^{\prime}\sum_{j=1}^{n}\left\lVert\bar{p}_{j}% (\hat{{\bm{L}}})\nabla_{h_{j}}{\color[rgb]{0,0,0}\mathcal{L}_{j}}(h_{j}({\bm{% \theta}}^{(t)}))\right\rVert_{2,2}^{2}≤ italic_η divide start_ARG italic_m end_ARG start_ARG italic_d end_ARG italic_K start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

I.4.3 Proof of technical lemma 4


We have

h(R𝑳)(h(𝜽(t)))1,2subscriptdelimited-∥∥subscript𝑅𝑳superscript𝜽𝑡12\displaystyle\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}^{(t)}))% \right\rVert_{1,2}∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT=j=1np¯j(𝑳˘)hjj(hj(𝜽(t)))2,2absentsuperscriptsubscript𝑗1𝑛subscript¯𝑝𝑗˘𝑳subscriptdelimited-∥∥subscriptsubscript𝑗subscript𝑗subscript𝑗superscript𝜽𝑡22\displaystyle=\sum_{j=1}^{n}\bar{p}_{j}(\breve{{\bm{L}}})\left\lVert\nabla_{h_% {j}}{\color[rgb]{0,0,0}\mathcal{L}_{j}}(h_{j}({\bm{\theta}}^{(t)}))\right% \rVert_{2,2}= ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over˘ start_ARG bold_italic_L end_ARG ) ∥ ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT(109)
=j=1np¯j(𝑳^)hjj(hj(𝜽(t)))2,2absentsuperscriptsubscript𝑗1𝑛subscript¯𝑝𝑗^𝑳subscriptdelimited-∥∥subscriptsubscript𝑗subscript𝑗subscript𝑗superscript𝜽𝑡22\displaystyle=\sum_{j=1}^{n}\bar{p}_{j}(\hat{{\bm{L}}})\left\lVert\nabla_{h_{j% }}{\color[rgb]{0,0,0}\mathcal{L}_{j}}(h_{j}({\bm{\theta}}^{(t)}))\right\rVert_% {2,2}= ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∥ ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT
+j=1n(p¯j(𝑳˘)p¯j(𝑳^)p¯j(𝑳^))p¯j(𝑳^)hjj(hj(𝜽(t)))2,2superscriptsubscript𝑗1𝑛subscript¯𝑝𝑗˘𝑳subscript¯𝑝𝑗^𝑳subscript¯𝑝𝑗^𝑳subscript¯𝑝𝑗^𝑳subscriptdelimited-∥∥subscriptsubscript𝑗subscript𝑗subscript𝑗superscript𝜽𝑡22\displaystyle+\sum_{j=1}^{n}\left(\frac{\bar{p}_{j}(\breve{{\bm{L}}})-\bar{p}_% {j}(\hat{{\bm{L}}})}{\bar{p}_{j}(\hat{{\bm{L}}})}\right)\bar{p}_{j}(\hat{{\bm{% L}}})\left\lVert\nabla_{h_{j}}{\color[rgb]{0,0,0}\mathcal{L}_{j}}(h_{j}({\bm{% \theta}}^{(t)}))\right\rVert_{2,2}+ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( divide start_ARG over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over˘ start_ARG bold_italic_L end_ARG ) - over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) end_ARG start_ARG over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) end_ARG ) over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∥ ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT

Using Cauchy-Schwarz inequality

h(R𝑳)(h(𝜽(t)))1,2subscriptdelimited-∥∥subscript𝑅𝑳superscript𝜽𝑡12\displaystyle\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}^{(t)}))% \right\rVert_{1,2}∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT(n+j=1n(p¯j(𝑳˘)p¯j(𝑳^)p¯j(𝑳^))2)j=1np¯j(𝑳^)hjj(hj(𝜽(t)))2,22absent𝑛superscriptsubscript𝑗1𝑛superscriptsubscript¯𝑝𝑗˘𝑳subscript¯𝑝𝑗^𝑳subscript¯𝑝𝑗^𝑳2superscriptsubscript𝑗1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑗^𝑳subscriptsubscript𝑗subscript𝑗subscript𝑗superscript𝜽𝑡222\displaystyle\leq\left(\sqrt{n}+\sqrt{\sum_{j=1}^{n}\left(\frac{\bar{p}_{j}(% \breve{{\bm{L}}})-\bar{p}_{j}(\hat{{\bm{L}}})}{\bar{p}_{j}(\hat{{\bm{L}}})}% \right)^{2}}\right)\sqrt{\sum_{j=1}^{n}\left\lVert\bar{p}_{j}(\hat{{\bm{L}}})% \nabla_{h_{j}}{\color[rgb]{0,0,0}\mathcal{L}_{j}}(h_{j}({\bm{\theta}}^{(t)}))% \right\rVert_{2,2}^{2}}≤ ( square-root start_ARG italic_n end_ARG + square-root start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( divide start_ARG over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over˘ start_ARG bold_italic_L end_ARG ) - over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) end_ARG start_ARG over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) square-root start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG(110)

Using Lemma I.4.1

j=1n(p¯j(𝑳˘)p¯j(𝑳^)p¯j(𝑳^))2superscriptsubscript𝑗1𝑛superscriptsubscript¯𝑝𝑗˘𝑳subscript¯𝑝𝑗^𝑳subscript¯𝑝𝑗^𝑳2\displaystyle\sum_{j=1}^{n}\left(\frac{\bar{p}_{j}(\breve{{\bm{L}}})-\bar{p}_{% j}(\hat{{\bm{L}}})}{\bar{p}_{j}(\hat{{\bm{L}}})}\right)^{2}∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( divide start_ARG over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over˘ start_ARG bold_italic_L end_ARG ) - over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) end_ARG start_ARG over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT1α𝒑^(𝑳^)𝒑^(𝑳˘)2absent1𝛼subscriptdelimited-∥∥^𝒑^𝑳^𝒑˘𝑳2\displaystyle\leq\frac{1}{\alpha}\left\lVert\hat{{\bm{p}}}(\hat{{\bm{L}}})-% \hat{{\bm{p}}}(\breve{{\bm{L}}})\right\rVert_{2}≤ divide start_ARG 1 end_ARG start_ARG italic_α end_ARG ∥ over^ start_ARG bold_italic_p end_ARG ( over^ start_ARG bold_italic_L end_ARG ) - over^ start_ARG bold_italic_p end_ARG ( over˘ start_ARG bold_italic_L end_ARG ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT(111)
ηKabsent𝜂superscript𝐾\displaystyle\leq\eta K^{\prime}≤ italic_η italic_K start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT

Therefore, we finally obtain

h(R𝑳)(h(𝜽(t)))1,2subscriptdelimited-∥∥subscript𝑅𝑳superscript𝜽𝑡12\displaystyle\left\lVert\nabla_{h}(R\circ{\bm{L}})(h({\bm{\theta}}^{(t)}))% \right\rVert_{1,2}∥ ∇ start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_R ∘ bold_italic_L ) ( italic_h ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 1 , 2 end_POSTSUBSCRIPT(n+ηK)j=1np¯j(𝑳^)hjj(hj(𝜽(t)))2,22absent𝑛𝜂superscript𝐾superscriptsubscript𝑗1𝑛superscriptsubscriptdelimited-∥∥subscript¯𝑝𝑗^𝑳subscriptsubscript𝑗subscript𝑗subscript𝑗superscript𝜽𝑡222\displaystyle\leq\left(\sqrt{n}+\eta K^{\prime}\right)\sqrt{\sum_{j=1}^{n}% \left\lVert\bar{p}_{j}(\hat{{\bm{L}}})\nabla_{h_{j}}{\color[rgb]{0,0,0}% \mathcal{L}_{j}}(h_{j}({\bm{\theta}}^{(t)}))\right\rVert_{2,2}^{2}}≤ ( square-root start_ARG italic_n end_ARG + italic_η italic_K start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) square-root start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∥ over¯ start_ARG italic_p end_ARG start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( over^ start_ARG bold_italic_L end_ARG ) ∇ start_POSTSUBSCRIPT italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( italic_h start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ( bold_italic_θ start_POSTSUPERSCRIPT ( italic_t ) end_POSTSUPERSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG(112)