Distributional Gaussian Processes Layers for Out-of-Distribution Detection

Sebastian G. Popescu1, David J. Sharp1, James H. Cole2, Konstantinos Kamnitsas1, Ben Glocker1
1: Imperial College London, 2: University College London
IPMI 2021 special issue
Publication date: 2022/06/29
PDF · arXiv · Code · Video

Abstract

Machine learning models deployed on medical imaging tasks must be equipped with out-of-distribution detection capabilities in order to avoid erroneous predictions. It is unsure whether out-of-distribution detection models reliant on deep neural networks are suitable for detecting domain shifts in medical imaging. Gaussian Processes can reliably separate in-distribution data points from out-of-distribution data points via their mathematical construction. Hence, we propose a parameter efficient Bayesian layer for hierarchical convolutional Gaussian Processes that incorporates Gaussian Processes operating in Wasserstein-2 space to reliably propagate uncertainty. This directly replaces convolving Gaussian Processes with a distance-preserving affine operator on distributions. Our experiments on brain tissue-segmentation show that the resulting architecture approaches the performance of well-established deterministic segmentation algorithms (U-Net), which has not been achieved with previous hierarchical Gaussian Processes. Moreover, by applying the same segmentation model to out-of-distribution data (i.e., images with pathology such as brain tumors), we show that our uncertainty estimates result in out-of-distribution detection that outperforms the capabilities of previous Bayesian networks and reconstruction-based approaches that learn normative distributions.

Keywords

gaussian processes · image segmentation · out-of-distribution detection

Bibtex @article{melba:2022:009:popescu, title = "Distributional Gaussian Processes Layers for Out-of-Distribution Detection", author = "Popescu, Sebastian G. and Sharp, David J. and Cole, James H. and Kamnitsas, Konstantinos and Glocker, Ben", journal = "Machine Learning for Biomedical Imaging", volume = "1", issue = "IPMI 2021 special issue", year = "2022", pages = "1--64", issn = "2766-905X", url = "https://melba-journal.org/2022:009" }
RISTY - JOUR AU - Popescu, Sebastian G. AU - Sharp, David J. AU - Cole, James H. AU - Kamnitsas, Konstantinos AU - Glocker, Ben PY - 2022 TI - Distributional Gaussian Processes Layers for Out-of-Distribution Detection T2 - Machine Learning for Biomedical Imaging VL - 1 IS - IPMI 2021 special issue SP - 1 EP - 64 SN - 2766-905X UR - https://melba-journal.org/2022:009 ER -

2022:009 cover


1 Introduction

Deep learning methods have achieved state-of-the-art results on a plethora of medical image segmentation tasks to clinical risk assessment (Zhou et al., 2021; Tang, 2019; Imai et al., 2020). However, their application in clinical settings remains challenging due to issues pertaining to lack of reliability and miscalibration of confidence estimates. Reliably estimating uncertainty in predictions is also of vital interest in adjacent machine learning fields such as reinforcement learning, to guide exploration, or in active learning, to guide the selection of data points for the next iteration of labelling. Most research into incorporating uncertainty into medical image segmentation has gravitated around modelling inter-rater variability and the inherent aleatoric uncertainty associated to the dataset, which can be caused by noise or inter-class ambiguities, alongside modelling the uncertainty in parameters (Czolbe et al., 2021). However, less focus has been placed on how models behave when processing unexpected inputs which differ from the characteristics of the training data. Such inputs, often called anomalies, outliers or out-of-distribution samples, could possibly lead to deleterious effects in healthcare applications where predictive models may encounter data that is corrupted or from patients with diseases that the model is not trained for (Curth et al., 2019; Mårtensson et al., 2020).

Out-of-distribution (OOD) detection in medical imaging has been mostly approached through the lens of reconstruction-based techniques involving some form of encoder-decoder network trained on normative datasets (Chen et al., 2019, 2021). Conversely, we focus on enhancing task-specific models (e.g., a segmentation model) with reliable uncertainty quantification that enables outlier detection. Standard deep neural networks (DNNs), despite their high predictive performance, often exhibit unreasonably high confidence in predictions estimates when processing unseen samples that are not from the data manifold of the training set (e.g., in the presence of pathology under the hypothesis of training data being composed of normal subjects or in a more general setting the presence of motion artifacts never seen in training images). To alleviate this, Bayesian approaches that assign posteriors over weights ( MC Dropout (Gal and Ghahramani, 2016b) included ) or in function space ( Repulsive Deep Ensembles (D’Angelo and Fortuin, 2021) included ) have been proposed (Wilson and Izmailov, 2020). However, either assigning priors on weights or in function space does not necessarily lend itself to reliable OOD detection capabilities by virtue of inspecting the predictive variance of the model as was shown in Henning et al. (2021). The authors argue that both infinite-width networks, trained via the Neural Network Gaussian Process (NNGP) kernel (Lee et al., 2017), or finite-width networks trained via Hamiltonian Monte Carlo (Neal et al., 2011) are not reliable for OOD detection since they show that the associated NNGP kernel is not correlated with distances between objects in input space. This loss of distance-awareness after encoding data has catastrophic effects on OOD detection, as we will soon see. Similarly, Foong et al. (2019) describe a limitation in the expressiveness of the predictive uncertainty estimate given by mean-field variational inference (MFVI) when applied as the inference technique for Bayesian Neural Networks (BNNs). Concretely, MFVI fails in offering quality uncertainty estimates in regions between well-separated clusters of data, which the authors coin as in-between uncertainty, with potentially catastrophic consequences for active learning, Bayesian optimisation or robustness to out-of-distribution data. In this paper we follow an alternative approach, using Gaussian Processes (GP) as the building block for deep Bayesian networks.

The use of GPs for image classification has garnered interest in the past years. Hybrid approaches, whereby a DNN’s embedding mechanism is trained end-to-end with a GP as the classification layer, were the first attempts to unify the two approaches (Bradshaw et al., 2017). The first convolutional kernel was proposed in Van der Wilk et al. (2017), constructed by aggregating patch response functions. This approach was stacked on feed forward GP layers applied in a convolutional manner, with promising improvements in accuracy compared to their shallow counterpart (Blomqvist et al., 2018).

We expand on the aforementioned work, by introducing a simpler convolutional mechanism, which does not require convolving GPs at each layer and hence alleviates the computational cost of optimizing over inducing points’ locations residing in high dimensional spaces alongside the issues brought upon by multi-output GPs. We propose a plug-in Bayesian layer more amenable to CNN architectures. More concretely, we seek to replace each individual component of a standard convolutional layer in convolutional neural networks (CNNs), respectively the convolved filters and the non-linear activation function. Firstly, we impose constraints on the filter such that we have an upper bound on distances after the convolution with regards to distances between the same objects beforehand. This will ensure that objects which were close in previous layers will remain close going forward, which as we shall see later on is a fundamental property for reliable OOD detection. Moreover, directly using convolved filters as opposed to convolved GPs (Blomqvist et al., 2018) solves the issue with optimizing high-dimensional inducing points’ locations alongside introducing a simpler mechanism by which we can introduce correlations between channels (Nguyen et al., 2014). Secondly, we replace the element-wise non-linear activation functions with Distributional Gaussian Processes (DistGP) (Bachoc et al., 2017) used in one-to-one mapping manner, essentially acting as a non-parametric activation function. A variant of DistGP used in a hierarchical setting akin to Deep Gaussian Processes (DGP) (Damianou and Lawrence, 2013) was shown to be better at detecting OOD due to both kernel and architecture design (Popescu et al., 2020). In this paper we will show that our proposed module is also suited for OOD detection on both toy/image data and biomedical scans.

In the remainder of this section we provide a deeper exploration of uncertainties used in literature for biomedical image segmentation, subsequently introducing the concept of distance-awareness and imposing smoothness constraints on learned representations in a deep network as prerequisites for reliable OOD detection. These two properties will be key to motivate the imposed constraints and architecture choice of our proposed probabilistic module later on.

1.1 Uncertainty quantification for biomedical imaging segmentation

While prediction uncertainty can be computed for standard neural networks by using the softmax probability, these uncertainty estimates are often overconfident (Guo et al., 2017; McClure et al., 2019). Research into Bayesian models has focused on a separation of uncertainty into two different types, aleatoric (data intrinsic) and epistemic (model parameter uncertainty). To formalize this difference, we consider a multi-class classification problem, with classes denoted as {y1,,yC}subscript𝑦1subscript𝑦𝐶\{y_{1},\cdots,y_{C}\}{ italic_y start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , italic_y start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT } and model parameters denoted by θ𝜃\thetaitalic_θ. We have the following predictive equation at testing time:

p(yc|x*,𝐃)=p(ycx*,Θ)Aleatoric Uncertaintyp(θ|𝐃)Epistemic Uncertainty𝑑θ𝑝conditionalsubscript𝑦𝑐superscript𝑥𝐃subscript𝑝conditionalsubscript𝑦𝑐superscript𝑥ΘAleatoric Uncertaintysubscript𝑝conditional𝜃𝐃Epistemic Uncertaintydifferential-d𝜃p(y_{c}|x^{*},\mathbf{D})=\int\underbrace{p(y_{c}\mid x^{*},\Theta)}_{\text{% Aleatoric~{}Uncertainty}}\underbrace{p(\theta|\mathbf{D})}_{\text{Epistemic~{}% Uncertainty}}d\thetaitalic_p ( italic_y start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT | italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , bold_D ) = ∫ under⏟ start_ARG italic_p ( italic_y start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∣ italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , roman_Θ ) end_ARG start_POSTSUBSCRIPT Aleatoric Uncertainty end_POSTSUBSCRIPT under⏟ start_ARG italic_p ( italic_θ | bold_D ) end_ARG start_POSTSUBSCRIPT Epistemic Uncertainty end_POSTSUBSCRIPT italic_d italic_θ(1)

Aleatoric uncertainty is irreducible, given by noise in the data acquisition process and has been considered in medical image segmentation (Monteiro et al., 2020), whereas epistemic uncertainty can be reduced by providing more data during model training. This has also been studied in segmentation tasks (Nair et al., 2020). Previous work proposed to account for the uncertainty in the learned model parameters using an approximate Bayesian inference over the network weights (Kendall et al., 2015). However, it was shown that this method may produce samples that vary pixel by pixel and thus may not capture complex spatially correlated structures in the distribution of segmentations maps. The probabilistic U-Net (Kohl et al., 2018) produces samples with limited diversity due to the fact that stochasticity is introduced in the highest resolution level of the U-Net. To solve this issue, Baumgartner et al. (2019) introduce a hierachical structure between the different levels of the U-Net, hence introducing stochasticity at each level. Another improvement on the Probabilistic U-Net comes by adding variational dropout (Kingma et al., 2015) to the last layer to gain epistemic uncertainty quantification properties (Hu et al., 2019). All the models previously introduced relied on multiple annotations of the images with the intended goal of capturing this uncertainty in annotations with the aid of sampling from some form of latent variables which encode information about the whole image at varying scales of the U-Net. However, none of these previous works test how their models behave in the presence of outliers.

1.2 Distributional Uncertainty as a proxy for OOD detection

Besides the dichotomy consisting of aleatoric and epistemic uncertainty, reliably highlighting certain inputs which have undergone a domain shift (Lakshminarayanan et al., 2017) or out-of-distribution samples (Hendrycks and Gimpel, 2016) has garnered a lot of interest in the past years. Succinctly, the aim is to measure the degree to which a model knows when it does not know, or more precisely if a network trained on a specific dataset is evaluated at testing time on a completely different dataset (potentially from a different modality or another application domain), then the expectation is that the network should output high predictive uncertainty on this set of data points that are very far from the training data manifold.

A problem with introducing this new type of uncertainty is how to disentangle it from epistemic uncertainty. For example, in the Deep Ensembles paper (Lakshminarayanan et al., 2017), the authors propose to measure the disagreement between different sub-models of the deep ensemble m=1MKL[p(yx;θm)𝔼[p(yX)]]\sum\limits_{m=1}^{M}KL\left[p(y\mid x;\theta_{m})\|\mathbb{E}\left[p(y\mid X)% \right]\right]∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_K italic_L [ italic_p ( italic_y ∣ italic_x ; italic_θ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ∥ blackboard_E [ italic_p ( italic_y ∣ italic_X ) ] ] for M𝑀Mitalic_M sub-models with associated sub-model parameters θmsubscript𝜃𝑚\theta_{m}italic_θ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT and 𝔼[p(yx)]=1Mm=1Mp(yx;θm)𝔼delimited-[]𝑝conditional𝑦𝑥1𝑀superscriptsubscript𝑚1𝑀𝑝conditional𝑦𝑥subscript𝜃𝑚\mathbb{E}\left[p(y\mid x)\right]=\frac{1}{M}\sum\limits_{m=1}^{M}p(y\mid x;% \theta_{m})blackboard_E [ italic_p ( italic_y ∣ italic_x ) ] = divide start_ARG 1 end_ARG start_ARG italic_M end_ARG ∑ start_POSTSUBSCRIPT italic_m = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M end_POSTSUPERSCRIPT italic_p ( italic_y ∣ italic_x ; italic_θ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) is the prediction of the ensemble. We remind ourselves that epistemic uncertainty can be reduced by adding more data. By this logic, epistemic uncertainty cannot be reduced outside the data manifold of our dataset since we don’t add data points which do not stem from the same data generative pipeline (this is not true in the case of OOD detection models which explicitly use OOD samples during training/testing (Liang et al., 2017; Hafner et al., 2020)). Hence, epistemic uncertainty can only be reduced inside the data manifold and should be zero outside the data manifold (assuming model is distance-aware, which we will subsequently define). Conversely, our chosen measure for OOD detection should grow outside the data manifold and be close to or 0 inside the data manifold. With this in mind, the disagreement metric introduced in Lakshminarayanan et al. (2017) cannot achieve this separation, confounding the two types of uncertainty.

Malinin and Gales (2018) introduced for the first time the separation of total uncertainty into three components: epistemic, aleatoric and distributional uncertainty. To make the distinction clearer, the authors argue that aleatoric uncertainty is a "known-unknown", whereby the model confidently states that an input data point is hard to classify (class overlap). Contrary, distributional uncertainty is an "unknown-unknown" due to the fact that the model is unfamiliar with the input space region that the test data comes from, thereby not being able to make confident predictions.

Refer to caption
Figure 1: Probability simplex for a 3 class classification problem, where corners corresponds to a class; Each point represents a categorical distribution, with brighter colors indicating higher density of the underlying ensemble. Epistemic Uncertainty captures uncertainty in model parameters caused by lack of data or model non-identifiability, with the ensemble of the predictions being concentrated in a corner of the probability simplex albeit with an increased diversity; Aleatoric Uncertainty captures class overlap, with the ensemble of predictions being confidently mapped to the highest predictive entropy; Distributional Uncertainty captures domain shift, with the ensemble of predictions being centred in the middle with highest possible degree of diversity;

We briefly introduce the uncertainty decomposition mechanism introduced in Malinin and Gales (2018). Considering equation (1), by using Monte Carlo integration of above equation and computing the predictive entropy, we would not be able to discern between high predictive entropy due to aleatoric uncertainty (class overlap) or distributional uncertainty (dataset/domain shift). Hence, Malinin and Gales (2018) propose to introduce a latent variable μ𝜇\muitalic_μ over the categorical variables corresponding to each class, parametrized as a distribution over distributions on a simplex, p(μ|x*,θ)𝑝conditional𝜇superscript𝑥𝜃p(\mu|x^{*},\theta)italic_p ( italic_μ | italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_θ ). The intuition behind this Dirichlet distribution over the probability simplex is that OOD points should be scattered, whereas in-distribution points should concentrate. We can now re-write our predictive equation as:

p(yc|x*,𝐃)=p(yc|x*,μ)Aleatoric Uncertaintyp(μ|x*,θ)Distributional Uncertaintyp(θ|𝐃)Epistemic Uncertainty𝑑μ𝑑θ𝑝conditionalsubscript𝑦𝑐superscript𝑥𝐃subscript𝑝conditionalsubscript𝑦𝑐superscript𝑥𝜇Aleatoric Uncertaintysubscript𝑝conditional𝜇superscript𝑥𝜃Distributional Uncertaintysubscript𝑝conditional𝜃𝐃Epistemic Uncertaintydifferential-d𝜇differential-d𝜃p(y_{c}|x^{*},\mathbf{D})=\int\int\underbrace{p(y_{c}|x^{*},\mu)}_{\text{% Aleatoric~{}Uncertainty}}\underbrace{p(\mu|x^{*},\theta)}_{\text{% Distributional~{}Uncertainty}}\underbrace{p(\theta|\mathbf{D})}_{\text{% Epistemic~{}Uncertainty}}d\mu~{}d\thetaitalic_p ( italic_y start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT | italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , bold_D ) = ∫ ∫ under⏟ start_ARG italic_p ( italic_y start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT | italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_μ ) end_ARG start_POSTSUBSCRIPT Aleatoric Uncertainty end_POSTSUBSCRIPT under⏟ start_ARG italic_p ( italic_μ | italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_θ ) end_ARG start_POSTSUBSCRIPT Distributional Uncertainty end_POSTSUBSCRIPT under⏟ start_ARG italic_p ( italic_θ | bold_D ) end_ARG start_POSTSUBSCRIPT Epistemic Uncertainty end_POSTSUBSCRIPT italic_d italic_μ italic_d italic_θ(2)

The authors argue that using a measure of spread of the ensemble ( after sampling from p(θ𝐃)𝑝conditional𝜃𝐃p\left(\theta\mid\mathbf{D}\right)italic_p ( italic_θ ∣ bold_D )) will be more informative. We remind ourselves that Mutual Information between variable X𝑋Xitalic_X and Y𝑌Yitalic_Y can be expressed in terms of the difference between entropy and conditional entropy: I(X;Y)=H(P(X))H(P(X|Y))𝐼𝑋𝑌𝐻𝑃𝑋𝐻𝑃conditional𝑋𝑌I(X;Y)=H(P(X))-H(P(X|Y))italic_I ( italic_X ; italic_Y ) = italic_H ( italic_P ( italic_X ) ) - italic_H ( italic_P ( italic_X | italic_Y ) ). Hence we can use the Mutual Information measure between model predictions and Dirichlet parameters to obtain a better measure of uncertainty. We integrate out over θ𝜃\thetaitalic_θ in the main equation and we get:

I[y,μ|x*,𝐃]Distributional Uncertainty=H[𝔼p(μ|𝐃)p(y|x*,μ)]Total Uncertainty𝔼p(μ|𝐃)[H[P(y|x*,μ)]]Aleatoric Uncertaintysubscript𝐼𝑦conditional𝜇superscript𝑥𝐃Distributional Uncertaintysubscript𝐻delimited-[]subscript𝔼𝑝conditional𝜇𝐃𝑝conditional𝑦superscript𝑥𝜇Total Uncertaintysubscriptsubscript𝔼𝑝conditional𝜇𝐃delimited-[]𝐻delimited-[]𝑃conditional𝑦superscript𝑥𝜇Aleatoric Uncertainty\underbrace{I[y,\mu|x^{*},\mathbf{D}]}_{\text{Distributional~{}Uncertainty}}=% \underbrace{H[\mathbb{E}_{p(\mu|\mathbf{D})}p(y|x^{*},\mu)]}_{\text{Total~{}% Uncertainty}}-\underbrace{\mathbb{E}_{p(\mu|\mathbf{D})}[H[P(y|x^{*},\mu)]]}_{% \text{Aleatoric~{}Uncertainty}}under⏟ start_ARG italic_I [ italic_y , italic_μ | italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , bold_D ] end_ARG start_POSTSUBSCRIPT Distributional Uncertainty end_POSTSUBSCRIPT = under⏟ start_ARG italic_H [ blackboard_E start_POSTSUBSCRIPT italic_p ( italic_μ | bold_D ) end_POSTSUBSCRIPT italic_p ( italic_y | italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_μ ) ] end_ARG start_POSTSUBSCRIPT Total Uncertainty end_POSTSUBSCRIPT - under⏟ start_ARG blackboard_E start_POSTSUBSCRIPT italic_p ( italic_μ | bold_D ) end_POSTSUBSCRIPT [ italic_H [ italic_P ( italic_y | italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_μ ) ] ] end_ARG start_POSTSUBSCRIPT Aleatoric Uncertainty end_POSTSUBSCRIPT(3)

Connections between this uncertainty disentanglement framework specialized for DNN-parametrized Dirichlet distributions and uncertainty disentanglement in GP will be subsequently made clearer in subsection 2.3. Distributional uncertainty will be the key uncertainty score used throughout this paper to assess whether input data points are inside or outside the data manifold.

1.3 Distance awareness and smoothness for out-of-distribution detection

Perhaps the greatest inspiration and motivation for this paper resides in the theoretical framework introduced in Liu et al. (2020), by which the authors outline what are some key mathematical conditions for provably reliable OOD detection in DNNs. We commence by briefly outlining the ideas introduced in aforementioned paper.

For an abstract data generating distribution p(yx)𝑝conditional𝑦𝑥p\left(y\mid x\right)italic_p ( italic_y ∣ italic_x ), where y𝑦yitalic_y is scalar, respectively x𝕏D𝑥𝕏superscript𝐷x\in\mathbb{X}\subset\mathbb{R}^{D}italic_x ∈ blackboard_X ⊂ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT with the input data manifold being equipped with a suitable metric X\|\cdot\|_{X}∥ ⋅ ∥ start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT. We consider our training data D={(xi,yi)1:n}𝐷subscriptsubscript𝑥𝑖subscript𝑦𝑖:1𝑛D=\{(x_{i},y_{i})_{1:n}\}italic_D = { ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT 1 : italic_n end_POSTSUBSCRIPT } to be sampled from a subset of the full input space xind𝕏subscript𝑥𝑖𝑛𝑑𝕏x_{in-d}\in\mathbb{X}italic_x start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ∈ blackboard_X, where the in-d abbreviation stems from in-distribution. With this in mind, we can consider the in-distribution data generating distribution pind(yx)=p(yx,x𝕏ind)subscript𝑝𝑖𝑛𝑑conditional𝑦𝑥𝑝conditional𝑦𝑥𝑥subscript𝕏𝑖𝑛𝑑p_{in-d}\left(y\mid x\right)=p\left(y\mid x,~{}x\in\mathbb{X}_{in-d}\right)italic_p start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ( italic_y ∣ italic_x ) = italic_p ( italic_y ∣ italic_x , italic_x ∈ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ), respectively the out-of-distribution data generating distribution pood(yx)=p(yx,x𝕏ind)subscript𝑝𝑜𝑜𝑑conditional𝑦𝑥𝑝conditional𝑦𝑥𝑥subscript𝕏𝑖𝑛𝑑p_{ood}\left(y\mid x\right)=p\left(y\mid x,~{}x\notin\mathbb{X}_{in-d}\right)italic_p start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT ( italic_y ∣ italic_x ) = italic_p ( italic_y ∣ italic_x , italic_x ∉ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ). Hence, it is safe to assume the full data generating distribution p(yx)𝑝conditional𝑦𝑥p\left(y\mid x\right)italic_p ( italic_y ∣ italic_x ) is composed as a mixture of the in-distribution and OOD generating distributions:

p(yx)𝑝conditional𝑦𝑥\displaystyle p\left(y\mid x\right)italic_p ( italic_y ∣ italic_x )=p(y,x𝕏indx)+p(y,x𝕏indx)absent𝑝𝑦𝑥conditionalsubscript𝕏𝑖𝑛𝑑𝑥𝑝𝑦𝑥conditionalsubscript𝕏𝑖𝑛𝑑𝑥\displaystyle=p\left(y,x\in\mathbb{X}_{in-d}\mid x\right)+p\left(y,x\notin% \mathbb{X}_{in-d}\mid x\right)= italic_p ( italic_y , italic_x ∈ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ∣ italic_x ) + italic_p ( italic_y , italic_x ∉ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ∣ italic_x )(4)
=p(yx,x𝕏ind)p(x𝕏ind)+p(yx,x𝕏ind)p(x𝕏ind)absent𝑝conditional𝑦𝑥𝑥subscript𝕏𝑖𝑛𝑑𝑝𝑥subscript𝕏𝑖𝑛𝑑𝑝conditional𝑦𝑥𝑥subscript𝕏𝑖𝑛𝑑𝑝𝑥subscript𝕏𝑖𝑛𝑑\displaystyle=p\left(y\mid x,x\in\mathbb{X}_{in-d}\right)p\left(x\in\mathbb{X}% _{in-d}\right)+p\left(y\mid x,x\notin\mathbb{X}_{in-d}\right)p\left(x\notin% \mathbb{X}_{in-d}\right)= italic_p ( italic_y ∣ italic_x , italic_x ∈ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ) italic_p ( italic_x ∈ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ) + italic_p ( italic_y ∣ italic_x , italic_x ∉ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ) italic_p ( italic_x ∉ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT )(5)
=pind(yx)p(x𝕏ind)+pood(yx)p(x𝕏ind)absentsubscript𝑝𝑖𝑛𝑑conditional𝑦𝑥𝑝𝑥subscript𝕏𝑖𝑛𝑑subscript𝑝𝑜𝑜𝑑conditional𝑦𝑥𝑝𝑥subscript𝕏𝑖𝑛𝑑\displaystyle=p_{in-d}\left(y\mid x\right)p\left(x\in\mathbb{X}_{in-d}\right)+% p_{ood}\left(y\mid x\right)p\left(x\notin\mathbb{X}_{in-d}\right)= italic_p start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ( italic_y ∣ italic_x ) italic_p ( italic_x ∈ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ) + italic_p start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT ( italic_y ∣ italic_x ) italic_p ( italic_x ∉ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT )(6)

Evidently, during training we are only learning pind(yx)subscript𝑝𝑖𝑛𝑑conditional𝑦𝑥p_{in-d}\left(y\mid x\right)italic_p start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ( italic_y ∣ italic_x ) since we only have access to D𝕏ind𝐷subscript𝕏𝑖𝑛𝑑D\subset\mathbb{X}_{in-d}italic_D ⊂ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT. Therefore, our model is completely in the dark with regards to pood(yx)subscript𝑝𝑜𝑜𝑑conditional𝑦𝑥p_{ood}\left(y\mid x\right)italic_p start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT ( italic_y ∣ italic_x ). These two data generating distributions more often than not are fundamentally different (e.g., having trained a model on T1w MRI scans, subsequently feeding it with T2w MRI scans, an imaging modality which has an almost inverse scaling to represent varying brain tissue). With this in mind, Liu et al. (2020) argue that the optimal strategy is for pood(yx)subscript𝑝𝑜𝑜𝑑conditional𝑦𝑥p_{ood}\left(y\mid x\right)italic_p start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT ( italic_y ∣ italic_x ) to be predicted as a uniform distribution, thus signalling the lack of knowledge of the model on this different input domain. We can now recall the distinction made by Malinin and Gales (2018), between "known-unknowns" (aleatoric uncertainty, e.g., class overlap) and "unknown-unknowns" (distributional uncertainty, e.g., domain shift), both of which have an uninformative predictive distribution (maximum predictive entropy). However, to disentangle these two types of uncertainty, we need a second-order type of uncertainty that basically scatters logit samples when distributional uncertainty is high, respectively accurately samples logits to maximum predictive entropy in the case of high aleatoric uncertainty. We now formalize this desiderata by a condition called "distance awareness" in Liu et al. (2020).

Definition 1 (Definition 1 in Liu et al. (2020))

We consider the predictive distribution for unseen point p(y*x*)𝑝conditionalsuperscript𝑦superscript𝑥p\left(y^{*}\mid x^{*}\right)italic_p ( italic_y start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ∣ italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) at testing time, for model trained on 𝕏ind𝕏subscript𝕏𝑖𝑛𝑑𝕏\mathbb{X}_{in-d}\in\mathbb{X}blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ∈ blackboard_X, with the data manifold being equipped with metric X\|\cdot\|_{X}∥ ⋅ ∥ start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT. Then, we can affirm that p(y*x*)𝑝conditionalsuperscript𝑦superscript𝑥p\left(y^{*}\mid x^{*}\right)italic_p ( italic_y start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ∣ italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) is distance-aware if there exists summary statistic u(x*)𝑢superscript𝑥u\left(x^{*}\right)italic_u ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) of p(y*x*)𝑝conditionalsuperscript𝑦superscript𝑥p\left(y^{*}\mid x^{*}\right)italic_p ( italic_y start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ∣ italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) that embeds the distance between 𝕏indsubscript𝕏𝑖𝑛𝑑\mathbb{X}_{in-d}blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT and x*superscript𝑥x^{*}italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT:

u(x*)=v[𝔼x𝕏ind[x*xX2]]𝑢superscript𝑥𝑣delimited-[]subscript𝔼similar-to𝑥subscript𝕏𝑖𝑛𝑑delimited-[]superscriptsubscriptnormsuperscript𝑥𝑥𝑋2u\left(x^{*}\right)=v\left[\mathbb{E}_{x\sim\mathbb{X}_{in-d}}\left[\|x^{*}-x% \|_{X}^{2}\right]\right]italic_u ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) = italic_v [ blackboard_E start_POSTSUBSCRIPT italic_x ∼ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT - italic_x ∥ start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ](7)

, where v𝑣vitalic_v is a monotonic function that increases with distance.

Definition 1 does not make any assumptions related to the architecture of the model from which the predictive distribution stems. In practice we would have the following composition to arrive at the logits logit(x*)=fenc(x*)𝑙𝑜𝑔𝑖𝑡superscript𝑥𝑓𝑒𝑛𝑐superscript𝑥logit(x^{*})=f\circ enc\left(x^{*}\right)italic_l italic_o italic_g italic_i italic_t ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) = italic_f ∘ italic_e italic_n italic_c ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ), where enc()𝑒𝑛𝑐enc\left(\cdot\right)italic_e italic_n italic_c ( ⋅ ) represents a network that outputs the representation learning layer and f()𝑓f\left(\cdot\right)italic_f ( ⋅ ) is the output layer. In Liu et al. (2020) the authors propose the following two conditions to ensure that the composition is distance-aware:

  • f()𝑓f(\cdot)italic_f ( ⋅ ) is distance-aware

  • 𝔼x𝕏ind[x*XX2]𝔼x𝕏ind[enc(x*)enc(X)enc(X)2]subscript𝔼similar-to𝑥subscript𝕏𝑖𝑛𝑑delimited-[]superscriptsubscriptnormsuperscript𝑥𝑋𝑋2subscript𝔼similar-to𝑥subscript𝕏𝑖𝑛𝑑delimited-[]superscriptsubscriptnorm𝑒𝑛𝑐superscript𝑥𝑒𝑛𝑐𝑋𝑒𝑛𝑐𝑋2\mathbb{E}_{x\sim\mathbb{X}_{in-d}}\left[\|x^{*}-X\|_{X}^{2}\right]\approx% \mathbb{E}_{x\sim\mathbb{X}_{in-d}}\left[\|enc(x^{*})-enc(X)\|_{enc(X)}^{2}\right]blackboard_E start_POSTSUBSCRIPT italic_x ∼ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT - italic_X ∥ start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ≈ blackboard_E start_POSTSUBSCRIPT italic_x ∼ blackboard_X start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ ∥ italic_e italic_n italic_c ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) - italic_e italic_n italic_c ( italic_X ) ∥ start_POSTSUBSCRIPT italic_e italic_n italic_c ( italic_X ) end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ]

The last condition means that distances between data points in input space should be correlated with distances in learned representation, which is equipped with a enc(X)\|\cdot\|_{enc(X)}∥ ⋅ ∥ start_POSTSUBSCRIPT italic_e italic_n italic_c ( italic_X ) end_POSTSUBSCRIPT metric. In our work, we will use GPs as f𝑓fitalic_f, because as we will see in section 2.1, GPs are distance-aware functions. This enables us to build a distance-aware model that is more appropriate for OOD detection. Whereas GPs satisfy the distance-aware condition for the last layer predictor, we are still left with the question on how to maintain distances in the learned representation correlated to distances in the input layer. This will be subsequently dealt with.

Network smoothness constraints

Throughout this paper we will consider the general term of "smoothness" of a model to mean the degree to which changes in the input have an effect on the output at a particular layer. The question now shifts into how can we quantify the smoothness of a network/function? In mathematical analysis a function f:𝕏𝕐:𝑓𝕏𝕐f:\mathbb{X}\to\mathbb{Y}italic_f : blackboard_X → blackboard_Y is said to be k-smooth if the first k derivatives exist {f,f′′,,f(k)}subscript𝑓superscript𝑓′′superscript𝑓𝑘\{f_{{}^{\prime}},f^{{}^{\prime\prime}},\cdots,f^{(k)}\}{ italic_f start_POSTSUBSCRIPT start_FLOATSUPERSCRIPT ′ end_FLOATSUPERSCRIPT end_POSTSUBSCRIPT , italic_f start_POSTSUPERSCRIPT start_FLOATSUPERSCRIPT ′ ′ end_FLOATSUPERSCRIPT end_POSTSUPERSCRIPT , ⋯ , italic_f start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT } and are continuous. We denote functions which have these properties as being of class Cksuperscript𝐶𝑘C^{k}italic_C start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT. For example, Gaussian Processes using squared exponential kernels are Csuperscript𝐶C^{\infty}italic_C start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT since the squared exponential kernel is infinitely differentiable. However, such a definition and quantification of smoothness wouldn’t aid us in ensuring the second condition of distance-awareness. For this, we shall use Lipschitz continuity, which is defined as follows: considering two metric spaces 𝕏𝕏\mathbb{X}blackboard_X and 𝕐𝕐\mathbb{Y}blackboard_Y equipped with metrics X\|\cdot\|_{X}∥ ⋅ ∥ start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT and Y\|\cdot\|_{Y}∥ ⋅ ∥ start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT and f:𝕏𝕐:𝑓𝕏𝕐f:\mathbb{X}\to\mathbb{Y}italic_f : blackboard_X → blackboard_Y is Lipschitz continuous if \exists real K0𝐾0K\geq 0italic_K ≥ 0 such that x,y𝕏for-all𝑥𝑦𝕏\forall x,y\in\mathbb{X}∀ italic_x , italic_y ∈ blackboard_X we have f(x),f(y)YKx,yX\|f(x),f(y)\|_{Y}\leq K\|x,y\|_{X}∥ italic_f ( italic_x ) , italic_f ( italic_y ) ∥ start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ≤ italic_K ∥ italic_x , italic_y ∥ start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT. Intuitively, for Lipschitz functions there is an upper limit in how much outputs can change with respect to distances in input space. It is perhaps better to highlight now that Lipschitz functions represent a global property. There are also locally Lipschitz continuous functions which respect the aforementioned condition just in a neighbourhood of x𝑥xitalic_x, respectively Br(x)={y𝕏:x,yXr}B_{r}(x)=\{y\in\mathbb{X}:\|x,y\|_{X}\leq r\}italic_B start_POSTSUBSCRIPT italic_r end_POSTSUBSCRIPT ( italic_x ) = { italic_y ∈ blackboard_X : ∥ italic_x , italic_y ∥ start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT ≤ italic_r }. Lastly, bi-Lipschitz continuity is defined as 1Kx,yXf(x),f(y)YKx,yX\frac{1}{K}\|x,y\|_{X}\geq\|f(x),f(y)\|_{Y}\leq K\|x,y\|_{X}divide start_ARG 1 end_ARG start_ARG italic_K end_ARG ∥ italic_x , italic_y ∥ start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT ≥ ∥ italic_f ( italic_x ) , italic_f ( italic_y ) ∥ start_POSTSUBSCRIPT italic_Y end_POSTSUBSCRIPT ≤ italic_K ∥ italic_x , italic_y ∥ start_POSTSUBSCRIPT italic_X end_POSTSUBSCRIPT, which is a property that avoids learning trivially smooth functions and maintains useful information (Rosca et al., 2020). With this in mind, recent work (Liu et al., 2020; van Amersfoort et al., 2021) have enforced the bi-Lipschitz property on feature extractors, thereby ensuring strong correlation between distances between data points in input space, respectively in the representation learning layer.

1.4 Contributions

This work makes the following main contributions:

  • We introduce a Bayesian layer that can act as a drop-in replacement for standard convolutional layers. Operating on stochastic layers with Gaussian distributions, we upper bound the convolved affine operators in Wasserstein-2 space, thus ensuring Lipschitz continuity. To introduce non-linearities, we apply DistGP element-wise on the output of the constrained affine operator, thereby obtaining non-parametric “activation functions” which ensure adequate quantification of distributional uncertainty at each layer.

  • We derive theoretical requirements for the model to not suffer from feature collapse, with additional empirical results to support the theory.

  • We demonstrate that a GP-based convolutional architecture can achieve competitive results in segmentation tasks in comparison to a U-Net.

  • We show improved OOD detection results on both general OOD tasks and on medical images compared to previous OOD approaches such as reconstruction-based models.

2 Background

In this section we provide a brief review of the theoretical toolkit required for the remainder of the paper. We commence by laying out foundational material on GPs, followed by an introduction to attempts to sparsify GPs. Subsequently, we introduce an uncertainty disentanglement framework for sparse Gaussian Processes. We briefly define Wasserstein-2 distances and show how they can be used to define kernels operating on Gaussian distributions. Lastly, we introduce recent re-formulations of deep GPs through the lens of OOD detection.

2.1 Primer on Gaussian Processes

A Gaussian Process can be seen as a generalization of multivariate Gaussian random variables to infinite sets. We define this statement in more detail now. We consider f(x)𝑓𝑥f(x)italic_f ( italic_x ) to be a stochastic field, with xd𝑥superscript𝑑x\in\mathbb{R}^{d}italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT and we define m(x)=𝔼[f(x)]𝑚𝑥𝔼delimited-[]𝑓𝑥m(x)=\mathbb{E}\left[f(x)\right]italic_m ( italic_x ) = blackboard_E [ italic_f ( italic_x ) ] and C(xi,xj)=Cov[f(xi),f(xj)]𝐶subscript𝑥𝑖subscript𝑥𝑗𝐶𝑜𝑣𝑓subscript𝑥𝑖𝑓subscript𝑥𝑗C(x_{i},x_{j})=Cov\left[f(x_{i}),f(x_{j})\right]italic_C ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = italic_C italic_o italic_v [ italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ]. We denote a Gaussian Process (GP) f(x)𝑓𝑥f(x)italic_f ( italic_x ) as:

f(x)GP(m(x),C(xi,xj))similar-to𝑓𝑥𝐺𝑃𝑚𝑥𝐶subscript𝑥𝑖subscript𝑥𝑗f(x)\sim GP\left(m(x),C\left(x_{i},x_{j}\right)\right)italic_f ( italic_x ) ∼ italic_G italic_P ( italic_m ( italic_x ) , italic_C ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) )(8)

The covariance function have the condition to generate non-negative-definite covariance matrices, more specifically they have to satisfy: i,jaiajC(xi,xj)0subscript𝑖𝑗subscript𝑎𝑖subscript𝑎𝑗𝐶subscript𝑥𝑖subscript𝑥𝑗0\sum_{i,j}a_{i}a_{j}C\left(x_{i},x_{j}\right)\geq 0∑ start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_a start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT italic_C ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ≥ 0 for any finite set {x1,,xn}subscript𝑥1subscript𝑥𝑛\{x_{1},\cdots,x_{n}\}{ italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT } and any real valued coefficients {a1,,an}subscript𝑎1subscript𝑎𝑛\{a_{1},\cdots,a_{n}\}{ italic_a start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , italic_a start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT }. Throughout this paper we will only consider second-order stationary processes which have constant means and Cov[f(xi),f(xj)]=C(xixj)𝐶𝑜𝑣𝑓subscript𝑥𝑖𝑓subscript𝑥𝑗𝐶normsubscript𝑥𝑖subscript𝑥𝑗Cov\left[f(x_{i}),f(x_{j})\right]=C\left(\|x_{i}-x_{j}\|\right)italic_C italic_o italic_v [ italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , italic_f ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ] = italic_C ( ∥ italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ ). We can see that such covariance functions are invariant to translations.

Squared exponential/radial basis function kernel defines a general class of stationary covariance functions:

kSE(xi,xj)=σ2exp[d=1D(xi,dxj,d)2ld2]superscript𝑘𝑆𝐸subscript𝑥𝑖subscript𝑥𝑗superscript𝜎2superscriptsubscript𝑑1𝐷superscriptsubscript𝑥𝑖𝑑subscript𝑥𝑗𝑑2subscriptsuperscript𝑙2𝑑k^{SE}(x_{i},x_{j})=\sigma^{2}\exp{\left[\sum_{d=1}^{D}-\frac{\left(x_{i,d}-x_% {j,d}\right)^{2}}{l^{2}_{d}}\right]}italic_k start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_exp [ ∑ start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT - divide start_ARG ( italic_x start_POSTSUBSCRIPT italic_i , italic_d end_POSTSUBSCRIPT - italic_x start_POSTSUBSCRIPT italic_j , italic_d end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT end_ARG ](9)

, where we have written its definition in the anisotropic case. The emphasis on the domain will make more sense in subsequent subsections where we will introduce kernels operating on Gaussian measures. Intuitively, the lengthscale values {l12,,lD2}superscriptsubscript𝑙12superscriptsubscript𝑙𝐷2\{l_{1}^{2},\cdots,l_{D}^{2}\}{ italic_l start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , ⋯ , italic_l start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT } represent the strength along a particular dimension of input space by which successive values are strongly correlated with correlation invariably decreasing as the distance between points increases. Such a covariance function has the property of Automatic Relevance Determination (ARD) (Neal, 2012). Lastly, the kernel variance σ2superscript𝜎2\sigma^{2}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT controls the variance of the process, more specifically the amplitude of function samples.

A GP has the following joint distribution over finite subsets 𝕏1𝕏subscript𝕏1𝕏\mathbb{X}_{1}\in\mathbb{X}blackboard_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ blackboard_X with function values f(X1)𝕐𝑓subscript𝑋1𝕐f(X_{1})\in\mathbb{Y}italic_f ( italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∈ blackboard_Y. Analogously for 𝕏2subscript𝕏2\mathbb{X}_{2}blackboard_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, with their union being denoted as x={x1,,xn}𝑥subscript𝑥1subscript𝑥𝑛x=\{x_{1},\cdots,x_{n}\}italic_x = { italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT }.

(f(x1)f(x2))=𝒩[(m(x1)m(x2)),(k(x1,x1),k(x1,x2)k(x2,x1),k(x2,x2))]matrix𝑓subscript𝑥1𝑓subscript𝑥2𝒩matrix𝑚subscript𝑥1𝑚subscript𝑥2matrix𝑘subscript𝑥1subscript𝑥1𝑘subscript𝑥1subscript𝑥2𝑘subscript𝑥2subscript𝑥1𝑘subscript𝑥2subscript𝑥2\begin{pmatrix}f(x_{1})\\ f(x_{2})\end{pmatrix}=\mathcal{N}\left[\begin{pmatrix}m(x_{1})\\ m(x_{2})\end{pmatrix},\begin{pmatrix}k(x_{1},x_{1}),k(x_{1},x_{2})\\ k(x_{2},x_{1}),k(x_{2},x_{2})\end{pmatrix}\right]( start_ARG start_ROW start_CELL italic_f ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_f ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG ) = caligraphic_N [ ( start_ARG start_ROW start_CELL italic_m ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_m ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG ) , ( start_ARG start_ROW start_CELL italic_k ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_k ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_k ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_k ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG ) ](10)

The following observation model is used:

p(y|f,x)=i=1Np(yn|f(xn))𝑝conditional𝑦𝑓𝑥superscriptsubscriptproduct𝑖1𝑁𝑝conditionalsubscript𝑦𝑛𝑓subscript𝑥𝑛p(y|f,x)=\prod_{i=1}^{N}p(y_{n}|f(x_{n}))italic_p ( italic_y | italic_f , italic_x ) = ∏ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_p ( italic_y start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT | italic_f ( italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) )(11)

, where given a supervised learning scenario, the dataset D={xi,yi}i=1,,n𝐷subscriptsubscript𝑥𝑖subscript𝑦𝑖𝑖1𝑛D=\{x_{i},y_{i}\}_{i=1,\cdots,n}italic_D = { italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 , ⋯ , italic_n end_POSTSUBSCRIPT can be shorthand denoted as D={x,y}𝐷𝑥𝑦D=\{x,y\}italic_D = { italic_x , italic_y }. In the case of probabilistic regression, we make the assumption that the noise is additive, independent and Gaussian, such that the latent function f(x)𝑓𝑥f(x)italic_f ( italic_x ) and the observed noisy outputs y𝑦yitalic_y are defined by the following equation:

yi=f(xi)+ϵi, where ϵi𝒩(0,σnoise2)subscript𝑦𝑖𝑓subscript𝑥𝑖subscriptitalic-ϵ𝑖, where subscriptitalic-ϵ𝑖similar-to𝒩0subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒y_{i}=f(x_{i})+\epsilon_{i}\text{, where }~{}\epsilon_{i}\sim\mathcal{N}\left(% 0,\sigma^{2}_{noise}\right)italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_f ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) + italic_ϵ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , where italic_ϵ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ caligraphic_N ( 0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT )(12)

To train a GP for regression tasks, one performs Marginal Likelihood Maximization of Type 2 over the following equation:

p(y)𝑝𝑦\displaystyle p(y)italic_p ( italic_y )=𝒩(ym,Kff+σnoise2𝕀n)absent𝒩conditional𝑦𝑚subscript𝐾𝑓𝑓subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛\displaystyle=\mathcal{N}\left(y\mid m,K_{ff}+\sigma^{2}_{noise}\mathbb{I}_{n}\right)= caligraphic_N ( italic_y ∣ italic_m , italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT )(13)
(ym)(Kff+σnoise2𝕀n)1(ym)logKff+σnoise2𝕀nproportional-toabsentsuperscript𝑦𝑚topsuperscriptsubscript𝐾𝑓𝑓subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛1𝑦𝑚delimited-∣∣subscript𝐾𝑓𝑓subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛\displaystyle\propto-\left(y-m\right)^{\top}\left(K_{ff}+\sigma^{2}_{noise}% \mathbb{I}_{n}\right)^{-1}\left(y-m\right)-\log{\mid K_{ff}+\sigma^{2}_{noise}% \mathbb{I}_{n}\mid}∝ - ( italic_y - italic_m ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_y - italic_m ) - roman_log ∣ italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∣(14)

by treating the kernel hyperparameters as point-mass.

We are interested in finding the posterior p(f(x*)y)𝑝conditional𝑓superscript𝑥𝑦p\left(f(x^{*})\mid y\right)italic_p ( italic_f ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ italic_y ) since the goal is to predict for unseen data points x*superscript𝑥x^{*}italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT which are different than the training set. We know that the joint prior over training set observations and testing set latent functions is given by:

(yf(x*))=𝒩[(m(x)m(x*)),(k(x,x)+σnoise2𝕀nk(x,x*)k(x*,x)k(x*,x*))]matrix𝑦𝑓superscript𝑥𝒩matrix𝑚𝑥𝑚superscript𝑥matrix𝑘𝑥𝑥subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛𝑘𝑥superscript𝑥𝑘superscript𝑥𝑥𝑘superscript𝑥superscript𝑥\begin{pmatrix}y\\ f(x^{*})\end{pmatrix}=\mathcal{N}\left[\begin{pmatrix}m(x)\\ m(x^{*})\end{pmatrix},\begin{pmatrix}k(x,x)+\sigma^{2}_{noise}\mathbb{I}_{n}&k% (x,x^{*})\\ k(x^{*},x)&k(x^{*},x^{*})\end{pmatrix}\right]( start_ARG start_ROW start_CELL italic_y end_CELL end_ROW start_ROW start_CELL italic_f ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARG ) = caligraphic_N [ ( start_ARG start_ROW start_CELL italic_m ( italic_x ) end_CELL end_ROW start_ROW start_CELL italic_m ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARG ) , ( start_ARG start_ROW start_CELL italic_k ( italic_x , italic_x ) + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_CELL start_CELL italic_k ( italic_x , italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_k ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_x ) end_CELL start_CELL italic_k ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARG ) ](15)

Now we can simply apply the conditional rule for multivariate Gaussians to obtain:

p(f(x*)y)𝑝conditional𝑓superscript𝑥𝑦\displaystyle p\left(f(x^{*})\mid y\right)italic_p ( italic_f ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ italic_y )=𝒩(f(x*m(x*)+Kf*f[Kff+σnoise2𝕀n]1[ym(x)],\displaystyle=\mathcal{N}(f(x^{*}\mid m(x^{*})+K_{f^{*}f}\left[K_{ff}+\sigma^{% 2}_{noise}\mathbb{I}_{n}\right]^{-1}\left[y-m(x)\right],= caligraphic_N ( italic_f ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ∣ italic_m ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) + italic_K start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT italic_f end_POSTSUBSCRIPT [ italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT [ italic_y - italic_m ( italic_x ) ] ,(16)
Kf*f*Kf*f[Kff+σnoise2𝕀n]1Kff*)\displaystyle\hskip 28.45274ptK_{f^{*}f^{*}}-K_{f^{*}f}\left[K_{ff}+\sigma^{2}% _{noise}\mathbb{I}_{n}\right]^{-1}K_{ff^{*}})italic_K start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - italic_K start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT italic_f end_POSTSUBSCRIPT [ italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_f italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT )

An illustration of this predictive distribution is given in Figure 2.

GP predictive variance as distributional uncertainty.

GPs are clearly distance-aware provided we use a translation-invariant kernel. The summary statistics (see definition 1) for an unseen point is given by u(x*)=Kf*f*Kf*fKff1Kff*𝑢superscript𝑥subscript𝐾superscript𝑓superscript𝑓subscript𝐾superscript𝑓𝑓superscriptsubscript𝐾𝑓𝑓1subscript𝐾𝑓superscript𝑓u(x^{*})=K_{f^{*}f^{*}}-K_{f^{*}f}K_{ff}^{-1}K_{ff^{*}}italic_u ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) = italic_K start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT - italic_K start_POSTSUBSCRIPT italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT italic_f end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_f italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, which is monotonically increasing as a function of distance (see Figure 2). Throughout this paper, we will use the non-parametric variance of sparse variants of GPs as a proxy for distributional uncertainty, which will be used to assess if inputs are in or outside the distribution.

Refer to caption
Figure 2: Left : GP prior samples using radial basis function kernel, which ensures a smooth function space hypothesis space; Right : GP samples conditioned on observations using radial basis function kernel. Predictive variance increases as input is further away from observations.

The usage of GP in real-world datasets is hindered by matrix inversion operations which have 𝕆(n3)𝕆superscript𝑛3\mathbb{O}(n^{3})blackboard_O ( italic_n start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ) time, 𝕆(n2)𝕆superscript𝑛2\mathbb{O}(n^{2})blackboard_O ( italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) memory for training, where n𝑛nitalic_n is the number of data points in the training set. In the next subsection we will see how to avert having to incur these expensive computational budgets.

2.2 Sparse Variational Gaussian Processes

In this subsection we succinctly review commonly used probabilistic sparse approximations for Gaussian process regression. Quinonero-Candela and Rasmussen (2005) provides a unifying view of sparse approximations by placing each method into a common framework of analyzing their posterior and their effective prior, which will be shortly defined.

One way to avert the computationally expensive operators associated to the Kffsubscript𝐾𝑓𝑓K_{ff}italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT matrix is to modify the joint prior over p(f,f*)𝑝𝑓superscript𝑓p(f,f^{*})italic_p ( italic_f , italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) so that the respective terms depend on a matrix of lower rank, where the joint prior is defined as:

(f(x)f(x*))=𝒩[(m(x)m(x*)),(k(x,x),k(x,x*)k(x*,x),k(x*,x*))]matrix𝑓𝑥𝑓superscript𝑥𝒩matrix𝑚𝑥𝑚superscript𝑥matrix𝑘𝑥𝑥𝑘𝑥superscript𝑥𝑘superscript𝑥𝑥𝑘superscript𝑥superscript𝑥\begin{pmatrix}f(x)\\ f(x^{*})\end{pmatrix}=\mathcal{N}\left[\begin{pmatrix}m(x)\\ m(x^{*})\end{pmatrix},\begin{pmatrix}k(x,x),k(x,x^{*})\\ k(x^{*},x),k(x^{*},x^{*})\end{pmatrix}\right]( start_ARG start_ROW start_CELL italic_f ( italic_x ) end_CELL end_ROW start_ROW start_CELL italic_f ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARG ) = caligraphic_N [ ( start_ARG start_ROW start_CELL italic_m ( italic_x ) end_CELL end_ROW start_ROW start_CELL italic_m ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARG ) , ( start_ARG start_ROW start_CELL italic_k ( italic_x , italic_x ) , italic_k ( italic_x , italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_k ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_x ) , italic_k ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARG ) ](17)

We introduce an additional set of M latent variables U=[U1,,Um]𝕐𝑈subscript𝑈1subscript𝑈𝑚𝕐U=\left[U_{1},\cdots,U_{m}\right]\in\mathbb{Y}italic_U = [ italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , italic_U start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ] ∈ blackboard_Y with associated input locations Z=[Z1,,Zm]𝕏𝑍subscript𝑍1subscript𝑍𝑚𝕏Z=\left[Z_{1},\cdots,Z_{m}\right]\in\mathbb{X}italic_Z = [ italic_Z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , italic_Z start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ] ∈ blackboard_X. Throughout this paper, the former will be entitled inducing point values, respectively the latter inducing point locations.

Due to the consistency property of Gaussian Processes (i.e., for probabilistic model as defined in equation (10) we have p(f(x1))=p(f(x))𝑑f(x2)𝑝𝑓subscript𝑥1𝑝𝑓𝑥differential-d𝑓subscript𝑥2p\left(f(x_{1})\right)=\int p\left(f(x)\right)~{}df(x_{2})italic_p ( italic_f ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ) = ∫ italic_p ( italic_f ( italic_x ) ) italic_d italic_f ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ), which ensures that if we marginalize a subset of elements, the remainder will remain unchanged.), one can marginalize out U𝑈Uitalic_U to recover the initial joint prior over p(f,f*)𝑝𝑓superscript𝑓p(f,f^{*})italic_p ( italic_f , italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ):

p(f,f*)=p(f,f*,U)𝑑U=p(f,f*U)p(U)𝑑U𝑝𝑓superscript𝑓𝑝𝑓superscript𝑓𝑈differential-d𝑈𝑝𝑓conditionalsuperscript𝑓𝑈𝑝𝑈differential-d𝑈p(f,f^{*})=\int p(f,f^{*},U)dU=\int p(f,f^{*}\mid U)p(U)~{}dUitalic_p ( italic_f , italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) = ∫ italic_p ( italic_f , italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_U ) italic_d italic_U = ∫ italic_p ( italic_f , italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ∣ italic_U ) italic_p ( italic_U ) italic_d italic_U(18)

, where p(U)=𝒩(U0,Kuu)𝑝𝑈𝒩conditional𝑈0subscript𝐾𝑢𝑢p(U)=\mathcal{N}\left(U\mid 0,K_{uu}\right)italic_p ( italic_U ) = caligraphic_N ( italic_U ∣ 0 , italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT ) and Kuusubscript𝐾𝑢𝑢K_{uu}italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT is the kernel covariance matrix evaluated at Z𝑍Zitalic_Z.

All sparse approximations to GPs originate from the following approximation:

p(f,f*)q(f,f*)=q(f*U)q(fU)p(U)𝑑U𝑝𝑓superscript𝑓𝑞𝑓superscript𝑓𝑞conditionalsuperscript𝑓𝑈𝑞conditional𝑓𝑈𝑝𝑈differential-d𝑈p(f,f^{*})\approx q(f,f^{*})=\int q(f^{*}\mid U)q(f\mid U)p(U)dUitalic_p ( italic_f , italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ≈ italic_q ( italic_f , italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) = ∫ italic_q ( italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ∣ italic_U ) italic_q ( italic_f ∣ italic_U ) italic_p ( italic_U ) italic_d italic_U(19)

which translates into a conditional independency between training and testing latent variables given U𝑈Uitalic_U. Intuitively, the name "inducing points" for {Z,U}𝑍𝑈\{Z,U\}{ italic_Z , italic_U } was given for this precise property, that U𝑈Uitalic_U induces the values for the training and testing set.

Titsias (2009) introduced the first variational lower bound comprising a probabilistic regressiom model over inducing points. More specifically, the authors applied variational inference in an augmented probability space that comprises training set latent function values F𝐹Fitalic_F alongside inducing point latent function values U𝑈Uitalic_U, more specifically using the following generative process in the case of a regression task:

p(U)𝑝𝑈\displaystyle p\left(U\right)italic_p ( italic_U )=𝒩(U0,Kuu;Z)absent𝒩conditional𝑈0subscript𝐾𝑢𝑢𝑍\displaystyle=\mathcal{N}\left(U\mid 0,K_{uu};Z\right)= caligraphic_N ( italic_U ∣ 0 , italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT ; italic_Z )(20)
p(FU)𝑝conditional𝐹𝑈\displaystyle p\left(F\mid U\right)italic_p ( italic_F ∣ italic_U )=𝒩(FKfuKuu1U,KffQff;Z,X)absent𝒩conditional𝐹subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1𝑈subscript𝐾𝑓𝑓subscript𝑄𝑓𝑓𝑍𝑋\displaystyle=\mathcal{N}\left(F\mid K_{fu}K_{uu}^{-1}U,K_{ff}-Q_{ff};Z,X\right)= caligraphic_N ( italic_F ∣ italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_U , italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT ; italic_Z , italic_X )(21)
p(yF)𝑝conditional𝑦𝐹\displaystyle p\left(y\mid F\right)italic_p ( italic_y ∣ italic_F )=𝒩(yF,σnoise2)absent𝒩conditional𝑦𝐹subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒\displaystyle=\mathcal{N}\left(y\mid F,\sigma^{2}_{noise}\right)= caligraphic_N ( italic_y ∣ italic_F , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT )(22)

, where Qff=KfuKuu1Kufsubscript𝑄𝑓𝑓subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1subscript𝐾𝑢𝑓Q_{ff}=K_{fu}K_{uu}^{-1}K_{uf}italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT = italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT. We explicitly denoted the dependence of either Z𝑍Zitalic_Z or X𝑋Xitalic_X, however for decluttering reasons these notations will be dropped unless its not evident on what certain distributions depend on.

In terms of doing exact inference in this new model, respectively computing the posterior p(f|y)𝑝conditional𝑓𝑦p(f|y)italic_p ( italic_f | italic_y ) and the marginal likelihood p(y)𝑝𝑦p(y)italic_p ( italic_y ), it remains unchanged even with the augmentation of the probability space by U𝑈Uitalic_U as we can marginalize p(F)=p(F,U)𝑑U𝑝𝐹𝑝𝐹𝑈differential-d𝑈p(F)=\int p(F,U)dUitalic_p ( italic_F ) = ∫ italic_p ( italic_F , italic_U ) italic_d italic_U due to the marginalization properties of Gaussian processes. Succintely, p(F)𝑝𝐹p(F)italic_p ( italic_F ) is not changed by modifying the values of U𝑈Uitalic_U, even though p(F|U)𝑝conditional𝐹𝑈p(F|U)italic_p ( italic_F | italic_U ) and p(U)𝑝𝑈p(U)italic_p ( italic_U ) do indeed change. This translates into the fundamental difference between variational parameters U𝑈Uitalic_U and hyperparameters of the model {σnoise2,σ2,l12,,lD2}subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒superscript𝜎2subscriptsuperscript𝑙21subscriptsuperscript𝑙2𝐷\{\sigma^{2}_{noise},\sigma^{2},l^{2}_{1},\cdots,l^{2}_{D}\}{ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT }, whereby the introduction of more variational parameters does not change the fundamental definition of the model before probability space augmentation.

Stochastic Variational Inference (SVI) (Hoffman et al., 2013) enables the application of VI for extremely large datasets, by virtue of performing inference over a set of global variables, which induce a factorisation in the observations and latent variables, such as in the Bayesian formulation of Neural Networks with distributions (implicit or explicit) over matrix weights. GP do no exhibit these properties, but by virtue of the approximate prior over testing and training latent functions for SGP approximations with inducing points U𝑈Uitalic_U, which we remind here:

p(f,f*)q(f,f*)=p(fU)p(f*U)p(U)𝑑U𝑝𝑓superscript𝑓𝑞𝑓superscript𝑓𝑝conditional𝑓𝑈𝑝conditionalsuperscript𝑓𝑈𝑝𝑈differential-d𝑈p(f,f^{*})\approx q(f,f^{*})=\int p(f\mid U)p(f^{*}\mid U)p(U)~{}dUitalic_p ( italic_f , italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ≈ italic_q ( italic_f , italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) = ∫ italic_p ( italic_f ∣ italic_U ) italic_p ( italic_f start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ∣ italic_U ) italic_p ( italic_U ) italic_d italic_U(23)

this translates into a fully factorized model with respect to observations at training and testing time, conditioned on the global variables U𝑈Uitalic_U.

Our goal is to approximate the true posterior distribution p(F,Uy)=p(FU,Y)p(UY)𝑝𝐹conditional𝑈𝑦𝑝conditional𝐹𝑈𝑌𝑝conditional𝑈𝑌p(F,U\mid y)=p(F\mid U,Y)p(U\mid Y)italic_p ( italic_F , italic_U ∣ italic_y ) = italic_p ( italic_F ∣ italic_U , italic_Y ) italic_p ( italic_U ∣ italic_Y ) by introducing a variational distribution q(F,U)𝑞𝐹𝑈q(F,U)italic_q ( italic_F , italic_U ) and minimizing the Kullback-Leibler divergence:

KL[q(F,U)p(F,Uy)]=q(F,U)logq(F,U)p(F,Uy)dFdUKL\left[q(F,U)\|p(F,U\mid y)\right]=\int q(F,U)\log\frac{q(F,U)}{p(F,U\mid y)}% ~{}dF~{}dUitalic_K italic_L [ italic_q ( italic_F , italic_U ) ∥ italic_p ( italic_F , italic_U ∣ italic_y ) ] = ∫ italic_q ( italic_F , italic_U ) roman_log divide start_ARG italic_q ( italic_F , italic_U ) end_ARG start_ARG italic_p ( italic_F , italic_U ∣ italic_y ) end_ARG italic_d italic_F italic_d italic_U(24)

, where the approximate posterior factorized as q(F,U)=p(FU)q(U)𝑞𝐹𝑈𝑝conditional𝐹𝑈𝑞𝑈q(F,U)=p(F\mid U)q(U)italic_q ( italic_F , italic_U ) = italic_p ( italic_F ∣ italic_U ) italic_q ( italic_U ) and q(U)𝑞𝑈q(U)italic_q ( italic_U ) is an unconstrained variational distribution over U𝑈Uitalic_U. Following the standard VI framework we need to maximize the following variational lower bound on the log marginal likelihood:

logp(y)𝑝𝑦\displaystyle\log p(y)roman_log italic_p ( italic_y )p(FU)q(U)logp(yF)p(FU)p(U)p(FU)p(U)dFdUabsent𝑝conditional𝐹𝑈𝑞𝑈𝑝conditional𝑦𝐹𝑝conditional𝐹𝑈𝑝𝑈𝑝conditional𝐹𝑈𝑝𝑈𝑑𝐹𝑑𝑈\displaystyle\geq\int p(F\mid U)q(U)\log\frac{p(y\mid F)p(F\mid U)p(U)}{p(F% \mid U)p(U)}~{}dF~{}dU≥ ∫ italic_p ( italic_F ∣ italic_U ) italic_q ( italic_U ) roman_log divide start_ARG italic_p ( italic_y ∣ italic_F ) italic_p ( italic_F ∣ italic_U ) italic_p ( italic_U ) end_ARG start_ARG italic_p ( italic_F ∣ italic_U ) italic_p ( italic_U ) end_ARG italic_d italic_F italic_d italic_U(25)
q(U)[logp(YF)p(FU)𝑑F+logp(U)q(U)]𝑑Uabsent𝑞𝑈delimited-[]𝑝conditional𝑌𝐹𝑝conditional𝐹𝑈differential-d𝐹𝑝𝑈𝑞𝑈differential-d𝑈\displaystyle\geq\int q(U)\left[\int\log p(Y\mid F)p(F\mid U)~{}dF+\log\frac{p% (U)}{q(U)}\right]~{}dU≥ ∫ italic_q ( italic_U ) [ ∫ roman_log italic_p ( italic_Y ∣ italic_F ) italic_p ( italic_F ∣ italic_U ) italic_d italic_F + roman_log divide start_ARG italic_p ( italic_U ) end_ARG start_ARG italic_q ( italic_U ) end_ARG ] italic_d italic_U(26)

We can now solve for the integral over F𝐹Fitalic_F:

logp(y|F)p(F|U)𝑑F𝑝conditional𝑦𝐹𝑝conditional𝐹𝑈differential-d𝐹\displaystyle\int\log p(y|F)p(F|U)~{}dF∫ roman_log italic_p ( italic_y | italic_F ) italic_p ( italic_F | italic_U ) italic_d italic_F=𝔼p(F|U)[n2log(2πσnoise2)12σnoise2Tr[yy2yF+FF]]absentsubscript𝔼𝑝conditional𝐹𝑈delimited-[]𝑛22𝜋subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒12subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒𝑇𝑟delimited-[]𝑦superscript𝑦top2𝑦superscript𝐹top𝐹superscript𝐹top\displaystyle=\mathbb{E}_{p(F|U)}\left[-\frac{n}{2}\log(2\pi\sigma^{2}_{noise}% )-\frac{1}{2\sigma^{2}_{noise}}Tr\left[yy^{\top}-2yF^{\top}+FF^{\top}\right]\right]= blackboard_E start_POSTSUBSCRIPT italic_p ( italic_F | italic_U ) end_POSTSUBSCRIPT [ - divide start_ARG italic_n end_ARG start_ARG 2 end_ARG roman_log ( 2 italic_π italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT ) - divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG italic_T italic_r [ italic_y italic_y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - 2 italic_y italic_F start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_F italic_F start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] ](27)
=n2log(2πσnoise2)12σnoise2Tr[yy2y(KfuKuu1U)+\displaystyle=-\frac{n}{2}\log(2\pi\sigma^{2}_{noise})-\frac{1}{2\sigma^{2}_{% noise}}Tr[yy^{\top}-2y\left(K_{fu}K_{uu}^{-1}U\right)^{\top}+= - divide start_ARG italic_n end_ARG start_ARG 2 end_ARG roman_log ( 2 italic_π italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT ) - divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG italic_T italic_r [ italic_y italic_y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT - 2 italic_y ( italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_U ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT +(28)
(KfuKuu1U)(KfuKuu1U)+KffQff]\displaystyle\hskip 28.45274pt\left(K_{fu}K_{uu}^{-1}U\right)\left(K_{fu}K_{uu% }^{-1}U\right)^{\top}+K_{ff}-Q_{ff}]( italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_U ) ( italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_U ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT ]
=log𝒩(y|KfuKuu1U,σnoise2𝕀n)12σnoise2Tr[KffQff]absent𝒩conditional𝑦subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1𝑈subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛12subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒𝑇𝑟delimited-[]subscript𝐾𝑓𝑓subscript𝑄𝑓𝑓\displaystyle=\log\mathcal{N}\left(y|K_{fu}K_{uu}^{-1}U,\sigma^{2}_{noise}% \mathbb{I}_{n}\right)-\frac{1}{2\sigma^{2}_{noise}}Tr\left[K_{ff}-Q_{ff}\right]= roman_log caligraphic_N ( italic_y | italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_U , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) - divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG italic_T italic_r [ italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT ](29)

We can now rewrite our variational lower bound as follows:

logp(y)q(U)log𝒩(yKfuKuu1U,σnoise2𝕀n)p(U)q(U)dU12σnoise2Tr[KffQff]𝑝𝑦𝑞𝑈𝒩conditional𝑦subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1𝑈subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛𝑝𝑈𝑞𝑈𝑑𝑈12subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒𝑇𝑟delimited-[]subscript𝐾𝑓𝑓subscript𝑄𝑓𝑓\log p(y)\geq\int q(U)\log\frac{\mathcal{N}\left(y\mid K_{fu}K_{uu}^{-1}U,% \sigma^{2}_{noise}\mathbb{I}_{n}\right)p(U)}{q(U)}~{}dU-\frac{1}{2\sigma^{2}_{% noise}}Tr\left[K_{ff}-Q_{ff}\right]roman_log italic_p ( italic_y ) ≥ ∫ italic_q ( italic_U ) roman_log divide start_ARG caligraphic_N ( italic_y ∣ italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_U , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) italic_p ( italic_U ) end_ARG start_ARG italic_q ( italic_U ) end_ARG italic_d italic_U - divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG italic_T italic_r [ italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT ](30)

The variational posterior is explicit in this case, respectively q(F,U)=p(FU;X,Z)q(U)𝑞𝐹𝑈𝑝conditional𝐹𝑈𝑋𝑍𝑞𝑈q(F,U)=p(F\mid U;X,Z)q(U)italic_q ( italic_F , italic_U ) = italic_p ( italic_F ∣ italic_U ; italic_X , italic_Z ) italic_q ( italic_U ), where q(U)=𝒩(UmU,SU)𝑞𝑈𝒩conditional𝑈subscript𝑚𝑈subscript𝑆𝑈q(U)=\mathcal{N}(U\mid m_{U},S_{U})italic_q ( italic_U ) = caligraphic_N ( italic_U ∣ italic_m start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ). Here, mUsubscript𝑚𝑈m_{U}italic_m start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT and SUsubscript𝑆𝑈S_{U}italic_S start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT are free variational parameters. Due to the Gaussian nature of both terms we can marginalize U𝑈Uitalic_U to arrive at q(F)=p(FU)q(U)=𝒩(FU~(x),Σ~(x))𝑞𝐹𝑝conditional𝐹𝑈𝑞𝑈𝒩conditional𝐹~𝑈𝑥~Σ𝑥q(F)=\int p(F\mid U)q(U)=\mathcal{N}(F\mid\tilde{U}(x),\tilde{\Sigma}(x))italic_q ( italic_F ) = ∫ italic_p ( italic_F ∣ italic_U ) italic_q ( italic_U ) = caligraphic_N ( italic_F ∣ over~ start_ARG italic_U end_ARG ( italic_x ) , over~ start_ARG roman_Σ end_ARG ( italic_x ) ), where:

U~(x)~𝑈𝑥\displaystyle\tilde{U}(x)over~ start_ARG italic_U end_ARG ( italic_x )=KfuKuu1mUabsentsubscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1subscript𝑚𝑈\displaystyle=K_{fu}K_{uu}^{-1}m_{U}= italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT(31)
Σ~(x)~Σ𝑥\displaystyle\tilde{\Sigma}(x)over~ start_ARG roman_Σ end_ARG ( italic_x )=KffKfuKuu1[KuuSU]Kuu1Kufabsentsubscript𝐾𝑓𝑓subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1delimited-[]subscript𝐾𝑢𝑢subscript𝑆𝑈superscriptsubscript𝐾𝑢𝑢1subscript𝐾𝑢𝑓\displaystyle=K_{ff}-K_{fu}K_{uu}^{-1}\left[K_{uu}-S_{U}\right]K_{uu}^{-1}K_{uf}= italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT [ italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - italic_S start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT ] italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT(32)

The lower bound can be re-expressed as follows:

logp(y)q(U)log𝒩y(KfuKuu1U,σnoise2𝕀n)𝑑UKL[q(U)p(U)]12σnoise2Tr[KffQff]𝑝𝑦𝑞𝑈subscript𝒩𝑦subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1𝑈subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛differential-d𝑈𝐾𝐿delimited-[]conditional𝑞𝑈𝑝𝑈12subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒𝑇𝑟delimited-[]subscript𝐾𝑓𝑓subscript𝑄𝑓𝑓\log p(y)\geq\int q(U)\log\mathcal{N}_{y}\left(K_{fu}K_{uu}^{-1}U,\sigma^{2}_{% noise}\mathbb{I}_{n}\right)~{}dU-KL\left[q(U)\|p(U)\right]-\frac{1}{2\sigma^{2% }_{noise}}Tr\left[K_{ff}-Q_{ff}\right]roman_log italic_p ( italic_y ) ≥ ∫ italic_q ( italic_U ) roman_log caligraphic_N start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_U , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) italic_d italic_U - italic_K italic_L [ italic_q ( italic_U ) ∥ italic_p ( italic_U ) ] - divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG italic_T italic_r [ italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT ](33)

We proceed to integrate out U𝑈Uitalic_U, arriving at the following lower bound:

SVGPsubscript𝑆𝑉𝐺𝑃\displaystyle\mathcal{L}_{SVGP}caligraphic_L start_POSTSUBSCRIPT italic_S italic_V italic_G italic_P end_POSTSUBSCRIPT=𝒩(yKfuKuu1mU,σnoise2𝕀n)12σnoise2Tr[KfuKuu1SUKuu1Kuf]absent𝒩conditional𝑦subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1subscript𝑚𝑈subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛12subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒𝑇𝑟delimited-[]subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1subscript𝑆𝑈superscriptsubscript𝐾𝑢𝑢1subscript𝐾𝑢𝑓\displaystyle=\mathcal{N}\left(y\mid K_{fu}K_{uu}^{-1}m_{U},\sigma^{2}_{noise}% \mathbb{I}_{n}\right)-\frac{1}{2\sigma^{2}_{noise}}Tr\left[K_{fu}K_{uu}^{-1}S_% {U}K_{uu}^{-1}K_{uf}\right]= caligraphic_N ( italic_y ∣ italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) - divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG italic_T italic_r [ italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT ](34)
12σnoise2Tr[KffQff]KL[q(U)p(U)]12subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒𝑇𝑟delimited-[]subscript𝐾𝑓𝑓subscript𝑄𝑓𝑓𝐾𝐿delimited-[]conditional𝑞𝑈𝑝𝑈\displaystyle-\frac{1}{2\sigma^{2}_{noise}}Tr\left[K_{ff}-Q_{ff}\right]-KL% \left[q(U)\|p(U)\right]- divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG italic_T italic_r [ italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT ] - italic_K italic_L [ italic_q ( italic_U ) ∥ italic_p ( italic_U ) ]

, where we can easily see that the last equation is factorized with respect to individual observations. This lower variational bound will be denoted as the sparse variational GP (SVGP). This bound is maximized with respect to variational parameters U𝑈Uitalic_U and hyperparameters of the model {Z,σnoise2,σ2,l12,,lD2}𝑍subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒superscript𝜎2subscriptsuperscript𝑙21subscriptsuperscript𝑙2𝐷\{Z,\sigma^{2}_{noise},\sigma^{2},l^{2}_{1},\cdots,l^{2}_{D}\}{ italic_Z , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT }. An illustration of SVGP trained on the “banana” dataset is given in Figure 3, showing similar behaviour to a GP only using a fraction of training set to obtain similar predictive distribution at testing time.

Refer to caption
(a) SVGP
Refer to caption
(b) GP
Figure 3: Left: Predictive mean and variance of SVGP. Inducing points (teal stars) are tasked to compress the information present in the entire training set such that predictive equations conditioned on them are similar to ones conditioned on entire training set; Right: Predictive mean and variance of GP. Not all training points are crucial in devising the decision boundary.

2.3 Uncertainty decomposition in SVGP through evidential learning lens

In subsection 1.2 we have introduced the rationale behind the uncertainty decomposition framework introduced in Malinin and Gales (2018). We now expand on this topic on how to separate uncertainties in deep evidential learning models (Amini et al., 2019) and make an analogy to how uncertainties are decomposed in SVGP.

For multi-class classification tasks, evidential learning directly parametrizes predictive distributions over the probability simplex. Hence, in comparison to Bayesian Deep Learning or Deep Ensembles it averts parametrizing the logit space, subsequently feeding it through a softmax function. Dirichlet distributions provide an obvious choice for defining a distribution over the K1𝐾1K-1italic_K - 1 dimensional probability simplex, having the following p.d.f.: Dir(μ,α)=1β(α)c=1Kμcαc1𝐷𝑖𝑟𝜇𝛼1𝛽𝛼superscriptsubscriptproduct𝑐1𝐾superscriptsubscript𝜇𝑐subscript𝛼𝑐1Dir\left(\mu,\alpha\right)=\frac{1}{\beta(\alpha)}\prod\limits_{c=1}^{K}\mu_{c% }^{\alpha_{c}-1}italic_D italic_i italic_r ( italic_μ , italic_α ) = divide start_ARG 1 end_ARG start_ARG italic_β ( italic_α ) end_ARG ∏ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT - 1 end_POSTSUPERSCRIPT, where β(α)=c=1KΓ(αc)Γ(α0)𝛽𝛼superscriptsubscriptproduct𝑐1𝐾Γsubscript𝛼𝑐Γsubscript𝛼0\beta(\alpha)=\frac{\prod\limits_{c=1}^{K}\Gamma(\alpha_{c})}{\Gamma(\alpha_{0% })}italic_β ( italic_α ) = divide start_ARG ∏ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_Γ ( italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) end_ARG start_ARG roman_Γ ( italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) end_ARG and α0=c=1Kαcsubscript𝛼0superscriptsubscript𝑐1𝐾subscript𝛼𝑐\alpha_{0}=\sum\limits_{c=1}^{K}\alpha_{c}italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT with αc0subscript𝛼𝑐0\alpha_{c}\geq 0italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ≥ 0. α0subscript𝛼0\alpha_{0}italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is called the precision, being similar to the precision of a Gaussian distribution, where larger α0subscript𝛼0\alpha_{0}italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT will indicate a sharper distribution.

Dirichlet networks involve having a NN predict the concentration parameters of the Dirichlet distribution α=fθ(x)𝛼subscript𝑓𝜃𝑥\alpha=f_{\theta}(x)italic_α = italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_x ), where predictions are made as follows: y~=argmax𝑐{αcα0}c=1K~𝑦𝑐superscriptsubscriptsubscript𝛼𝑐subscript𝛼0𝑐1𝐾\tilde{y}=\underset{c}{\arg\max}\{\frac{\alpha_{c}}{\alpha_{0}}\}_{c=1}^{K}over~ start_ARG italic_y end_ARG = underitalic_c start_ARG roman_arg roman_max end_ARG { divide start_ARG italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG } start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT.

We remind ourselves the uncertainty decomposition framework laid out in subsection 1.2:

I[y,μ|x*,𝐃]Distributional Uncertainty=H[𝔼p(μ|𝐃)p(y|x*,μ)]Total Uncertainty𝔼p(μ|𝐃)[H[P(y|x*,μ)]]Aleatoric Uncertaintysubscript𝐼𝑦conditional𝜇superscript𝑥𝐃Distributional Uncertaintysubscript𝐻delimited-[]subscript𝔼𝑝conditional𝜇𝐃𝑝conditional𝑦superscript𝑥𝜇Total Uncertaintysubscriptsubscript𝔼𝑝conditional𝜇𝐃delimited-[]𝐻delimited-[]𝑃conditional𝑦superscript𝑥𝜇Aleatoric Uncertainty\underbrace{I[y,\mu|x^{*},\mathbf{D}]}_{\text{Distributional~{}Uncertainty}}=% \underbrace{H[\mathbb{E}_{p(\mu|\mathbf{D})}p(y|x^{*},\mu)]}_{\text{Total~{}% Uncertainty}}-\underbrace{\mathbb{E}_{p(\mu|\mathbf{D})}[H[P(y|x^{*},\mu)]]}_{% \text{Aleatoric~{}Uncertainty}}under⏟ start_ARG italic_I [ italic_y , italic_μ | italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , bold_D ] end_ARG start_POSTSUBSCRIPT Distributional Uncertainty end_POSTSUBSCRIPT = under⏟ start_ARG italic_H [ blackboard_E start_POSTSUBSCRIPT italic_p ( italic_μ | bold_D ) end_POSTSUBSCRIPT italic_p ( italic_y | italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_μ ) ] end_ARG start_POSTSUBSCRIPT Total Uncertainty end_POSTSUBSCRIPT - under⏟ start_ARG blackboard_E start_POSTSUBSCRIPT italic_p ( italic_μ | bold_D ) end_POSTSUBSCRIPT [ italic_H [ italic_P ( italic_y | italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_μ ) ] ] end_ARG start_POSTSUBSCRIPT Aleatoric Uncertainty end_POSTSUBSCRIPT(35)

In the case of Dirichlet networks these uncertainty measures have analytic formulas:

𝔼p(μ|𝐃)[H[P(y|x*,μ)]]subscript𝔼𝑝conditional𝜇𝐃delimited-[]𝐻delimited-[]𝑃conditional𝑦superscript𝑥𝜇\displaystyle\mathbb{E}_{p(\mu|\mathbf{D})}[H[P(y|x^{*},\mu)]]blackboard_E start_POSTSUBSCRIPT italic_p ( italic_μ | bold_D ) end_POSTSUBSCRIPT [ italic_H [ italic_P ( italic_y | italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_μ ) ] ]=c=1Kαcα0[ψ(αc+1)ψ(α0+1)]absentsuperscriptsubscript𝑐1𝐾subscript𝛼𝑐subscript𝛼0delimited-[]𝜓subscript𝛼𝑐1𝜓subscript𝛼01\displaystyle=-\sum\limits_{c=1}^{K}\frac{\alpha_{c}}{\alpha_{0}}\left[\psi(% \alpha_{c}+1)-\psi(\alpha_{0}+1)\right]= - ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG [ italic_ψ ( italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT + 1 ) - italic_ψ ( italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + 1 ) ](36)
I[y,μ|x*,𝐃]𝐼𝑦conditional𝜇superscript𝑥𝐃\displaystyle I[y,\mu|x^{*},\mathbf{D}]italic_I [ italic_y , italic_μ | italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , bold_D ]=c=1Kαcα0[logαcα0ψ(αc+1)+ψ(α0+1)]absentsuperscriptsubscript𝑐1𝐾subscript𝛼𝑐subscript𝛼0delimited-[]subscript𝛼𝑐subscript𝛼0𝜓subscript𝛼𝑐1𝜓subscript𝛼01\displaystyle=-\sum\limits_{c=1}^{K}\frac{\alpha_{c}}{\alpha_{0}}\left[\log{% \frac{\alpha_{c}}{\alpha_{0}}}-\psi(\alpha_{c}+1)+\psi(\alpha_{0}+1)\right]= - ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG [ roman_log divide start_ARG italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG start_ARG italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_ARG - italic_ψ ( italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT + 1 ) + italic_ψ ( italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + 1 ) ](37)

, where ψ𝜓\psiitalic_ψ is the digamma function. Epistemic uncertainty quantifies the spread in the Dirichlet distribution, hence α0subscript𝛼0\alpha_{0}italic_α start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT can be used to measure it (Charpentier et al., 2021).

Exact inference is not tractable in GP on classification tasks due to the non-conjugacy between the GP prior and the non-Gaussian likelihood (Categorical or Bernoulli). Therefore, approximation are required such as the Laplace approximation (Williams and Barber, 1998), Expectation Propagation (Minka, 2013) or VI (Hensman et al., 2015). Milios et al. (2018) have proposed a method that circumvents these approximate inference techniques by re-branding the classification problem into a regression one, for which exact inference is possible. We commence to briefly lay out the Dirichlet-based GP Classification algorithm.

We consider the probability simplex π=[π1,,πK]Dir(α)𝜋subscript𝜋1subscript𝜋𝐾similar-to𝐷𝑖𝑟𝛼\pi=[\pi_{1},\cdots,\pi_{K}]\sim Dir(\alpha)italic_π = [ italic_π start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , italic_π start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ] ∼ italic_D italic_i italic_r ( italic_α ). We can transform a multi-class classification task into a multi regression scenario where if yc=1subscript𝑦𝑐1y_{c}=1italic_y start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = 1 in a one-hot-encoding, then we can assign αc=1+αϵsubscript𝛼𝑐1subscript𝛼italic-ϵ\alpha_{c}=1+\alpha_{\epsilon}italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = 1 + italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT, respectively αc=αϵsubscript𝛼𝑐subscript𝛼italic-ϵ\alpha_{c}=\alpha_{\epsilon}italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT for 0αϵ<<10subscript𝛼italic-ϵmuch-less-than10\leq\alpha_{\epsilon}<<10 ≤ italic_α start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT < < 1. The model has the following generative process:

π𝜋\displaystyle\piitalic_πDir(α)similar-toabsent𝐷𝑖𝑟𝛼\displaystyle\sim Dir(\alpha)∼ italic_D italic_i italic_r ( italic_α )(38)
p(yα)𝑝conditional𝑦𝛼\displaystyle p(y\mid\alpha)italic_p ( italic_y ∣ italic_α )=Cat(π)absent𝐶𝑎𝑡𝜋\displaystyle=Cat(\pi)= italic_C italic_a italic_t ( italic_π )(39)

To sample from the Dirichlet distribution we use the following routine: πc=xck=1Kxksubscript𝜋𝑐subscript𝑥𝑐superscriptsubscript𝑘1𝐾subscript𝑥𝑘\pi_{c}=\frac{x_{c}}{\sum\limits_{k=1}^{K}x_{k}}italic_π start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = divide start_ARG italic_x start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG with xcΓ(αc,1)similar-tosubscript𝑥𝑐Γsubscript𝛼𝑐1x_{c}\sim\Gamma(\alpha_{c},1)italic_x start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∼ roman_Γ ( italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , 1 ) following the Gamma distribution. From this sampling procedure, we can see that the generative process translates to independent Gamma likelihoods for each class. Intuitively, at this point in the derivation we need a GP to produce xc0subscript𝑥𝑐0x_{c}\geq 0italic_x start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ≥ 0, since Gamma distributions are only defined on +superscript\mathbb{R}^{+}blackboard_R start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT. Since the marginal GP over a subset of data is governed by a multivariate normal it will not satisfy this constraint. To obtain GP sampled functions that respect this constraints, we can use an exp\exproman_exp function to transform it. With this in mind, we know that xlog-normal(xμ,σ2)=𝑑exp(𝒩(xμ,σ2))similar-to𝑥log-normalconditional𝑥𝜇superscript𝜎2𝑑𝑒𝑥𝑝𝒩conditional𝑥𝜇superscript𝜎2x\sim\text{log-normal}\left(x\mid\mu,\sigma^{2}\right)\overset{d}{=}exp{\left(% \mathcal{N}\left(x\mid\mu,\sigma^{2}\right)\right)}italic_x ∼ log-normal ( italic_x ∣ italic_μ , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) overitalic_d start_ARG = end_ARG italic_e italic_x italic_p ( caligraphic_N ( italic_x ∣ italic_μ , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ), with 𝔼[x]=exp(μ+σ22)𝔼delimited-[]𝑥𝜇superscript𝜎22\mathbb{E}\left[x\right]=\exp{\left(\mu+\frac{\sigma^{2}}{2}\right)}blackboard_E [ italic_x ] = roman_exp ( italic_μ + divide start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG ) and V[x]=[expσ21]exp(2μ+σ2)𝑉delimited-[]𝑥delimited-[]superscript𝜎212𝜇superscript𝜎2V\left[x\right]=\left[\exp{\sigma^{2}}-1\right]\exp{\left(2\mu+\sigma^{2}% \right)}italic_V [ italic_x ] = [ roman_exp italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 1 ] roman_exp ( 2 italic_μ + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ). Hence, we can approximate xcγ(αc,1)similar-tosubscript𝑥𝑐𝛾subscript𝛼𝑐1x_{c}\sim\gamma(\alpha_{c},1)italic_x start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∼ italic_γ ( italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT , 1 ) with xc~log-normal(μc~,σc2~)similar-to~subscript𝑥𝑐log-normal~subscript𝜇𝑐~subscriptsuperscript𝜎2𝑐\tilde{x_{c}}\sim\text{log-normal}\left(\tilde{\mu_{c}},\tilde{\sigma^{2}_{c}}\right)over~ start_ARG italic_x start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG ∼ log-normal ( over~ start_ARG italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG ). To ensure a good approximation, the authors in Milios et al. (2018) propose using moment matching:

𝔼[xc]=αc𝔼delimited-[]subscript𝑥𝑐subscript𝛼𝑐\displaystyle\mathbb{E}\left[x_{c}\right]=\alpha_{c}blackboard_E [ italic_x start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ] = italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT=exp(μc~,σc2~2)=𝔼[xc~]absent~subscript𝜇𝑐~subscriptsuperscript𝜎2𝑐2𝔼delimited-[]~subscript𝑥𝑐\displaystyle=\exp{\left(\tilde{\mu_{c}},\frac{\tilde{\sigma^{2}_{c}}}{2}% \right)}=\mathbb{E}\left[\tilde{x_{c}}\right]= roman_exp ( over~ start_ARG italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG , divide start_ARG over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG end_ARG start_ARG 2 end_ARG ) = blackboard_E [ over~ start_ARG italic_x start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG ](40)
V[xc]=αc𝑉delimited-[]subscript𝑥𝑐subscript𝛼𝑐\displaystyle V\left[x_{c}\right]=\alpha_{c}italic_V [ italic_x start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ] = italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT=[expσc2~1]exp[2μc~+σc2~]=V[xc~]absentdelimited-[]~subscriptsuperscript𝜎2𝑐12~subscript𝜇𝑐~subscriptsuperscript𝜎2𝑐𝑉delimited-[]~subscript𝑥𝑐\displaystyle=\left[\exp{\tilde{\sigma^{2}_{c}}}-1\right]\exp{\left[2\tilde{% \mu_{c}}+\tilde{\sigma^{2}_{c}}\right]}=V\left[\tilde{x_{c}}\right]= [ roman_exp over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG - 1 ] roman_exp [ 2 over~ start_ARG italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG + over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG ] = italic_V [ over~ start_ARG italic_x start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG ](41)

with equality if μc~=logαcσc2~2~subscript𝜇𝑐subscript𝛼𝑐~subscriptsuperscript𝜎2𝑐2\tilde{\mu_{c}}=\log{\alpha_{c}}-\frac{\tilde{\sigma^{2}_{c}}}{2}over~ start_ARG italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG = roman_log italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT - divide start_ARG over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG end_ARG start_ARG 2 end_ARG and σc2~=log(1αc+1)~subscriptsuperscript𝜎2𝑐1subscript𝛼𝑐1\tilde{\sigma^{2}_{c}}=\log{\left(\frac{1}{\alpha_{c}}+1\right)}over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG = roman_log ( divide start_ARG 1 end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG + 1 ). We can re-express this approximation by taking a natural logarithm, obtaining logxc~𝒩(μc~,σc2~)similar-to~subscript𝑥𝑐𝒩~subscript𝜇𝑐~subscriptsuperscript𝜎2𝑐\log{\tilde{x_{c}}}\sim\mathcal{N}\left(\tilde{\mu_{c}},\tilde{\sigma^{2}_{c}}\right)roman_log over~ start_ARG italic_x start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG ∼ caligraphic_N ( over~ start_ARG italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG ). This translates into a heteroskedastic regression model μc~=fc+𝒩(0,σc2~)~subscript𝜇𝑐subscript𝑓𝑐𝒩0~subscriptsuperscript𝜎2𝑐\tilde{\mu_{c}}=f_{c}+\mathcal{N}\left(0,\tilde{\sigma^{2}_{c}}\right)over~ start_ARG italic_μ start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG = italic_f start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT + caligraphic_N ( 0 , over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT end_ARG ), where fcGP(0,Kff)similar-tosubscript𝑓𝑐𝐺𝑃0subscript𝐾𝑓𝑓f_{c}\sim GP\left(0,K_{ff}\right)italic_f start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ∼ italic_G italic_P ( 0 , italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT ). Hence, one can now apply the standard inference scheme for full GP or we can sparsify the model and apply the SVGP framework. At testing time, the expectation of class probabilities will be:

𝔼[πi,c]=expfi,ck=1Cexpfi,kq(fi,c)𝑑fi,c𝔼delimited-[]subscript𝜋𝑖𝑐subscript𝑓𝑖𝑐superscriptsubscript𝑘1𝐶subscript𝑓𝑖𝑘𝑞subscript𝑓𝑖𝑐differential-dsubscript𝑓𝑖𝑐\mathbb{E}\left[\pi_{i,c}\right]=\int\frac{\exp{f_{i,c}}}{\sum\limits_{k=1}^{C% }\exp{f_{i,k}}}q(f_{i,c})~{}df_{i,c}blackboard_E [ italic_π start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT ] = ∫ divide start_ARG roman_exp italic_f start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT roman_exp italic_f start_POSTSUBSCRIPT italic_i , italic_k end_POSTSUBSCRIPT end_ARG italic_q ( italic_f start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT ) italic_d italic_f start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT(42)

which can be approximated via Monte Carlo integration. In the sparse scenario, q(fi,c)𝒩(U~(xi),Σ~(xi))similar-to𝑞subscript𝑓𝑖𝑐𝒩~𝑈subscript𝑥𝑖~Σsubscript𝑥𝑖q(f_{i,c})\sim\mathcal{N}\left(\tilde{U}(x_{i}),\tilde{\Sigma}(x_{i})\right)italic_q ( italic_f start_POSTSUBSCRIPT italic_i , italic_c end_POSTSUBSCRIPT ) ∼ caligraphic_N ( over~ start_ARG italic_U end_ARG ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , over~ start_ARG roman_Σ end_ARG ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) similar to the predictive equations introduced in subsection 2.2. In conclusion, if using Dirichlet-based GP for Classification one can obtain similar estimates of aleatoric and distributional uncertainty in the space of the probability simplex as in equations (36) and (37) specific to Dirichlet Networks. However, for the purposes of this paper we intend to measure distributional uncertainty in the space of logits, as the formulas are simpler to compute and more intuitive from the viewpoint of distance-awareness.

As we have previously stated, GPs are distance-aware. Thus, they can reliably notice departures from the training set manifold. For SVGP we decompose the model uncertainty into two components:

h()\displaystyle h(\cdot)italic_h ( ⋅ )=𝒩(h0,KffKfuKuu1Kuf)absent𝒩conditional0subscript𝐾𝑓𝑓subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1subscript𝐾𝑢𝑓\displaystyle=\mathcal{N}(h\mid 0,K_{ff}-K_{fu}K_{uu}^{-1}K_{uf})= caligraphic_N ( italic_h ∣ 0 , italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT )(43)
g()𝑔\displaystyle g(\cdot)italic_g ( ⋅ )=𝒩(gKfuKuu1mU,KfuKuu1SUKuu1Kuf)absent𝒩conditional𝑔subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1subscript𝑚𝑈subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1subscript𝑆𝑈superscriptsubscript𝐾𝑢𝑢1subscript𝐾𝑢𝑓\displaystyle=\mathcal{N}(g\mid K_{fu}K_{uu}^{-1}m_{U},K_{fu}K_{uu}^{-1}S_{U}K% _{uu}^{-1}K_{uf})= caligraphic_N ( italic_g ∣ italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT , italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_U end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT )(44)

The h()h(\cdot)italic_h ( ⋅ ) variance captures the shift from within to outside the data manifold and will be denoted as distributional uncertainty. The variance g()𝑔g(\cdot)italic_g ( ⋅ ) is termed here as within-data uncertainty and encapsulates uncertainty present inside the data manifold. A visual depiction of the two is provided in Figure 14 (bottom). To capture the overall uncertainty in h()h(\cdot)italic_h ( ⋅ ), thereby also capturing the spread of samples from it, we can calculate it’s differential entropy as:

h(h)=n2log2π+12logKffKfuKuu1Kuf+12𝑛22𝜋12delimited-∣∣subscript𝐾𝑓𝑓subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1subscript𝐾𝑢𝑓12h(h)=\frac{n}{2}\log{2\pi}+\frac{1}{2}\log{\mid K_{ff}-K_{fu}K_{uu}^{-1}K_{uf}% \mid}+\frac{1}{2}italic_h ( italic_h ) = divide start_ARG italic_n end_ARG start_ARG 2 end_ARG roman_log 2 italic_π + divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log ∣ italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT ∣ + divide start_ARG 1 end_ARG start_ARG 2 end_ARG(45)

In practice we only use the diagonal terms of the Schur complement, hence the log determinant term will considerably simplify. Intuitively, if terms on the diagonal of the Schur complement have higher values, so will the distributional differential entropy. This OOD measure in logit space will be used throughout the rest of the paper.

2.4 Deep Gaussian Processes fail in propagating distributional uncertainty

Deep Gaussian Processes (DGP) were first introduced in Damianou and Lawrence (2013), as a multi-layered hierarchical formulation of GPs. Composition of processes has retains theoretical properties of underlying stochastic process (such as Kolmogorov extension theorem) while also ensuring a more diverse hypothesis space of process priors, or at least in theory as we shall later see.

We can view the DGP as a composition of functions, keeping in mind that this is only one way of defining this class of probabilistic models (Dunlop et al., 2018):

fL(x)=fLf1(x)subscript𝑓𝐿𝑥subscript𝑓𝐿subscript𝑓1𝑥f_{L}(x)=f_{L}\circ...\circ f_{1}(x)italic_f start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ( italic_x ) = italic_f start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ∘ … ∘ italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x )(46)

with fl=𝒢𝒫(ml,kl(,))subscript𝑓𝑙𝒢𝒫subscript𝑚𝑙subscript𝑘𝑙f_{l}=\mathcal{GP}\left(m_{l},k_{l}\left(\cdot,\cdot\right)\right)italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = caligraphic_G caligraphic_P ( italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( ⋅ , ⋅ ) ). Assuming a likelihood function we can write the joint prior as:

p(y,{fl}l=1L;X)=p(yfL)likelihoodl=1Lp(flfl1)prior𝑝𝑦superscriptsubscriptsubscript𝑓𝑙𝑙1𝐿𝑋subscript𝑝conditional𝑦subscript𝑓𝐿likelihoodsubscriptsuperscriptsubscriptproduct𝑙1𝐿𝑝conditionalsubscript𝑓𝑙subscript𝑓𝑙1priorp\left(y,\{f_{l}\}_{l=1}^{L};X\right)=\underbrace{p(y\mid f_{L})}_{\text{% likelihood}}\underbrace{\prod_{l=1}^{L}p(f_{l}\mid f_{l-1})}_{\text{prior}}italic_p ( italic_y , { italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ; italic_X ) = under⏟ start_ARG italic_p ( italic_y ∣ italic_f start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT likelihood end_POSTSUBSCRIPT under⏟ start_ARG ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_p ( italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT prior end_POSTSUBSCRIPT(47)

with p(flfl1)𝒢𝒫(ml(fl1),kl(fl1,fl1))similar-to𝑝conditionalsubscript𝑓𝑙subscript𝑓𝑙1𝒢𝒫subscript𝑚𝑙subscript𝑓𝑙1subscript𝑘𝑙subscript𝑓𝑙1subscript𝑓𝑙1p\left(f_{l}\mid f_{l-1}\right)\sim\mathcal{GP}\left(m_{l}(f_{l-1}),k_{l}\left% (f_{l-1},f_{l-1}\right)\right)italic_p ( italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) ∼ caligraphic_G caligraphic_P ( italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) , italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT , italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) ), where in the case we choose squared exponential kernels we have the following formula for the l-th layer:

kSE(fl,i,fl,j)l=σl2exp[d=1Dl(fl,i,dfl,j,d)2ll,d2]superscript𝑘𝑆𝐸subscriptsubscript𝑓𝑙𝑖subscript𝑓𝑙𝑗𝑙subscriptsuperscript𝜎2𝑙superscriptsubscript𝑑1subscript𝐷𝑙superscriptsubscript𝑓𝑙𝑖𝑑subscript𝑓𝑙𝑗𝑑2subscriptsuperscript𝑙2𝑙𝑑k^{SE}(f_{l,i},f_{l,j})_{l}=\sigma^{2}_{l}\exp{\left[\sum_{d=1}^{D_{l}}-\frac{% \left(f_{l,i,d}-f_{l,j,d}\right)^{2}}{l^{2}_{l,d}}\right]}italic_k start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT ( italic_f start_POSTSUBSCRIPT italic_l , italic_i end_POSTSUBSCRIPT , italic_f start_POSTSUBSCRIPT italic_l , italic_j end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT roman_exp [ ∑ start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - divide start_ARG ( italic_f start_POSTSUBSCRIPT italic_l , italic_i , italic_d end_POSTSUBSCRIPT - italic_f start_POSTSUBSCRIPT italic_l , italic_j , italic_d end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT end_ARG ]

where Dlsubscript𝐷𝑙D_{l}italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT represents the number of dimensions of Flsubscript𝐹𝑙F_{l}italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT and we introduce layer specific kernel hyperparameters {σl2,ll,12,,ll,Dl2}subscriptsuperscript𝜎2𝑙subscriptsuperscript𝑙2𝑙1subscriptsuperscript𝑙2𝑙subscript𝐷𝑙\{\sigma^{2}_{l},l^{2}_{l,1},\cdots,l^{2}_{l,D_{l}}\}{ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , 1 end_POSTSUBSCRIPT , ⋯ , italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l , italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT }.

Analytically integrating this Bayesian hierarchical model is intractable as it requires integrating Gaussians which are present in a non-linear way. Moreover, to enable faster inference over our model we can augment each layer l𝑙litalic_l with Mlsubscript𝑀𝑙M_{l}italic_M start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT inducing points’ locations Zl1subscript𝑍𝑙1Z_{l-1}italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT, respectively inducing points’ values Ulsubscript𝑈𝑙U_{l}italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT resulting in the following augmented joint prior:

p(y,{fl}l=1L,{Ul}l=1L;X,{Zl}l=0L1)=p(y|fL)likelihoodl=1Lp(fl|fl1,Ul;Zl1)p(Ul)prior𝑝𝑦superscriptsubscriptsubscript𝑓𝑙𝑙1𝐿superscriptsubscriptsubscript𝑈𝑙𝑙1𝐿𝑋superscriptsubscriptsubscript𝑍𝑙𝑙0𝐿1subscript𝑝conditional𝑦subscript𝑓𝐿likelihoodsubscriptsuperscriptsubscriptproduct𝑙1𝐿𝑝conditionalsubscript𝑓𝑙subscript𝑓𝑙1subscript𝑈𝑙subscript𝑍𝑙1𝑝subscript𝑈𝑙priorp\left(y,\{f_{l}\}_{l=1}^{L},\{U_{l}\}_{l=1}^{L};X,\{Z_{l}\}_{l=0}^{L-1}\right% )=\underbrace{p(y|f_{L})}_{\text{likelihood}}\underbrace{\prod_{l=1}^{L}p(f_{l% }|f_{l-1},U_{l};Z_{l-1})p(U_{l})}_{\text{prior}}italic_p ( italic_y , { italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT , { italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ; italic_X , { italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT ) = under⏟ start_ARG italic_p ( italic_y | italic_f start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT likelihood end_POSTSUBSCRIPT under⏟ start_ARG ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_p ( italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT | italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ; italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) italic_p ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT prior end_POSTSUBSCRIPT(48)

, where p(flfl1,Ul;Zl1)=𝒩(flml(fl1)+KfuKuu1(Ulml(Zl1),KffKfuKuu1Kuf))𝑝conditionalsubscript𝑓𝑙subscript𝑓𝑙1subscript𝑈𝑙subscript𝑍𝑙1𝒩conditionalsubscript𝑓𝑙subscript𝑚𝑙subscript𝑓𝑙1subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1subscript𝑈𝑙subscript𝑚𝑙subscript𝑍𝑙1subscript𝐾𝑓𝑓subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1subscript𝐾𝑢𝑓p(f_{l}\mid f_{l-1},U_{l};Z_{l-1})=\mathcal{N}\left(f_{l}\mid m_{l}(f_{l-1})+K% _{fu}K_{uu}^{-1}\left(U_{l}-m_{l}(Z_{l-1}),K_{ff}-K_{fu}K_{uu}^{-1}K_{uf}% \right)\right)italic_p ( italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ; italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) = caligraphic_N ( italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) + italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT - italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) , italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT ) ). To perform SVI we introduce a factorised variational approximate posterior q({Ul}l=1L)=l=1L𝒩(UlmUl,SUl)𝑞superscriptsubscriptsubscript𝑈𝑙𝑙1𝐿superscriptsubscriptproduct𝑙1𝐿𝒩conditionalsubscript𝑈𝑙subscript𝑚subscript𝑈𝑙subscript𝑆subscript𝑈𝑙q\left(\{U_{l}\}_{l=1}^{L}\right)=\prod\limits_{l=1}^{L}\mathcal{N}\left(U_{l}% \mid m_{U_{l}},S_{U_{l}}\right)italic_q ( { italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ) = ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT caligraphic_N ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ italic_m start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ). Using a similar derivation as in the uncollapsed evidence lower bound for SVGPs, we can arrive at our ELBO for DGPs:

DGP=𝔼q({fl}l=1L)[logp(yfL)]l=1LKL[q(Ul)p(Ul)]subscript𝐷𝐺𝑃subscript𝔼𝑞superscriptsubscriptsubscript𝑓𝑙𝑙1𝐿delimited-[]𝑝conditional𝑦subscript𝑓𝐿superscriptsubscript𝑙1𝐿𝐾𝐿delimited-[]conditional𝑞subscript𝑈𝑙𝑝subscript𝑈𝑙\mathcal{L}_{DGP}=\mathbb{E}_{q(\{f_{l}\}_{l=1}^{L})}\left[\log{p\left(y\mid f% _{L}\right)}\right]-\sum\limits_{l=1}^{L}KL\left[q(U_{l})\|p(U_{l})\right]caligraphic_L start_POSTSUBSCRIPT italic_D italic_G italic_P end_POSTSUBSCRIPT = blackboard_E start_POSTSUBSCRIPT italic_q ( { italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT [ roman_log italic_p ( italic_y ∣ italic_f start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) ] - ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_K italic_L [ italic_q ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ∥ italic_p ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ](49)

where q({fl}l=1L)=l=1Lq(flfl1)𝑞superscriptsubscriptsubscript𝑓𝑙𝑙1𝐿superscriptsubscriptproduct𝑙1𝐿𝑞conditionalsubscript𝑓𝑙subscript𝑓𝑙1q\left(\{f_{l}\}_{l=1}^{L}\right)=\prod\limits_{l=1}^{L}q\left(f_{l}\mid f_{l-% 1}\right)italic_q ( { italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ) = ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_q ( italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) and q(flfl1)=𝒩(flUl~(fl1),Σl~(fl1))𝑞conditionalsubscript𝑓𝑙subscript𝑓𝑙1𝒩conditionalsubscript𝑓𝑙~subscript𝑈𝑙subscript𝑓𝑙1~subscriptΣ𝑙subscript𝑓𝑙1q\left(f_{l}\mid f_{l-1}\right)=\mathcal{N}\left(f_{l}\mid\tilde{U_{l}}(f_{l-1% }),\tilde{\Sigma_{l}}(f_{l-1})\right)italic_q ( italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) = caligraphic_N ( italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ over~ start_ARG italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ( italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) , over~ start_ARG roman_Σ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ( italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) ), respectively:

Ul~(fl1)~subscript𝑈𝑙subscript𝑓𝑙1\displaystyle\tilde{U_{l}}(f_{l-1})over~ start_ARG italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ( italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT )=ml(fl1)+KfuKuu1[mUlml(Zl1)]absentsubscript𝑚𝑙subscript𝑓𝑙1subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1delimited-[]subscript𝑚subscript𝑈𝑙subscript𝑚𝑙subscript𝑍𝑙1\displaystyle=m_{l}(f_{l-1})+K_{fu}K_{uu}^{-1}\left[m_{U_{l}}-m_{l}(Z_{l-1})\right]= italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) + italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT [ italic_m start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) ](50)
Σl~(fl1)~subscriptΣ𝑙subscript𝑓𝑙1\displaystyle\tilde{\Sigma_{l}}(f_{l-1})over~ start_ARG roman_Σ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ( italic_f start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT )=KffKfuKuu1[KuuSUl]Kuu1Kufabsentsubscript𝐾𝑓𝑓subscript𝐾𝑓𝑢superscriptsubscript𝐾𝑢𝑢1delimited-[]subscript𝐾𝑢𝑢subscript𝑆subscript𝑈𝑙superscriptsubscript𝐾𝑢𝑢1subscript𝐾𝑢𝑓\displaystyle=K_{ff}-K_{fu}K_{uu}^{-1}\left[K_{uu}-S_{U_{l}}\right]K_{uu}^{-1}% K_{uf}= italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT [ italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - italic_S start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT(51)

This composition of functions is approximated via Monte Carlo integration as introduced in the doubly stochastic variational inference framework for training DGPs (Salimbeni and Deisenroth, 2017).

In Popescu et al. (2020) the authors argued that total uncertainty in the hidden layers of a DGP will be higher for OOD data points in comparison to in-distribution data points only under a set of conditions. We briefly lay out the details here.

Without loss of generality for deeper architectures, we can consider the case of a DGP with two layers and zero mean functions which has the following posterior predictive equation:

q(F2)(x)=p(F2|F1)q(F1(x))𝑑F1𝑞subscript𝐹2𝑥𝑝conditionalsubscript𝐹2subscript𝐹1𝑞subscript𝐹1𝑥differential-dsubscript𝐹1q(F_{2})(x)=\int p(F_{2}|F_{1})q(F_{1}(x))dF_{1}italic_q ( italic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ( italic_x ) = ∫ italic_p ( italic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) italic_q ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) ) italic_d italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT(52)

, where q(F1(x))=𝒩f1(U1~(x),Σ1~(x))𝑞subscript𝐹1𝑥subscript𝒩subscript𝑓1~subscript𝑈1𝑥~subscriptΣ1𝑥q(F_{1}(x))=\mathcal{N}_{f_{1}}\left(\tilde{U_{1}}(x),\tilde{\Sigma_{1}}(x)\right)italic_q ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) ) = caligraphic_N start_POSTSUBSCRIPT italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( over~ start_ARG italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ( italic_x ) , over~ start_ARG roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ( italic_x ) ). This is similar to the case of approximating GPs with uncertain inputs, in this case Multivariate Normals. In Girard (2004) they lay out a framework for obtaining Gaussian approxiations of GPs with uncertain inputs (in our case the uncertainty stems from the previous layer of the DGP), which when adapted to our case we obtain the following approximate moments for q(F2)(x)𝑞subscript𝐹2𝑥q(F_{2})(x)italic_q ( italic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ( italic_x ):

m(F2)𝑚subscript𝐹2\displaystyle m(F_{2})italic_m ( italic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )=U~2(U~1(x))absentsubscript~𝑈2subscript~𝑈1𝑥\displaystyle=\tilde{U}_{2}(\tilde{U}_{1}(x))= over~ start_ARG italic_U end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( over~ start_ARG italic_U end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) )(53)
v(F2)𝑣subscript𝐹2\displaystyle v(F_{2})italic_v ( italic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )=Σ~2(U~1(x))+Σ~1(x)[122Σ~2(F1)2F1|F1=U~1(x)+(U~2(F1)F1)2|F1=U~1(x)]absentsubscript~Σ2subscript~𝑈1𝑥subscript~Σ1𝑥delimited-[]evaluated-at12superscript2subscript~Σ2subscript𝐹1superscript2subscript𝐹1subscript𝐹1subscript~𝑈1𝑥evaluated-atsuperscriptsubscript~𝑈2subscript𝐹1subscript𝐹12subscript𝐹1subscript~𝑈1𝑥\displaystyle=\tilde{\Sigma}_{2}(\tilde{U}_{1}(x))+\tilde{\Sigma}_{1}(x)\Bigg{% [}\frac{1}{2}\frac{\partial^{2}\tilde{\Sigma}_{2}(F_{1})}{\partial^{2}F_{1}}% \Bigr{|}_{\begin{subarray}{c}F_{1}=\tilde{U}_{1}(x)\end{subarray}}+\left(\frac% {\partial\tilde{U}_{2}(F_{1})}{\partial F_{1}}\right)^{2}\Bigr{|}_{\begin{% subarray}{c}F_{1}=\tilde{U}_{1}(x)\end{subarray}}\Bigg{]}= over~ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( over~ start_ARG italic_U end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) ) + over~ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) [ divide start_ARG 1 end_ARG start_ARG 2 end_ARG divide start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over~ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG | start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = over~ start_ARG italic_U end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) end_CELL end_ROW end_ARG end_POSTSUBSCRIPT + ( divide start_ARG ∂ over~ start_ARG italic_U end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = over~ start_ARG italic_U end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) end_CELL end_ROW end_ARG end_POSTSUBSCRIPT ](56)

In Popescu et al. (2020) they propose a realistic scenario which occurs frequently in practice, whereby the inducing points Zlsubscript𝑍𝑙Z_{l}italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT of particular layer are spread out such as to cover the entire spectrum of possible samples from the previous layer Fl1subscript𝐹𝑙1F_{l-1}italic_F start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT. More precisely, we can consider an OOD data point xoodsubscript𝑥𝑜𝑜𝑑x_{ood}italic_x start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT in input space such that Σ1~(xood)=σ2~subscriptΣ1subscript𝑥𝑜𝑜𝑑superscript𝜎2\tilde{\Sigma_{1}}(x_{ood})=\sigma^{2}over~ start_ARG roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ( italic_x start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT ) = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and U1~(xood)=0~subscript𝑈1subscript𝑥𝑜𝑜𝑑0\tilde{U_{1}}(x_{ood})=0over~ start_ARG italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ( italic_x start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT ) = 0, respectively an in-distribution point xindsubscript𝑥𝑖𝑛𝑑x_{in-d}italic_x start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT such that Σ~1(xind)=Vinσ2subscript~Σ1subscript𝑥𝑖𝑛𝑑subscript𝑉𝑖𝑛superscript𝜎2\tilde{\Sigma}_{1}(x_{in-d})=V_{in}\leq\sigma^{2}over~ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ) = italic_V start_POSTSUBSCRIPT italic_i italic_n end_POSTSUBSCRIPT ≤ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and U~1(xind)=Minsubscript~𝑈1subscript𝑥𝑖𝑛𝑑subscript𝑀𝑖𝑛\tilde{U}_{1}(x_{in-d})=M_{in}over~ start_ARG italic_U end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ) = italic_M start_POSTSUBSCRIPT italic_i italic_n end_POSTSUBSCRIPT. We also assume that Z2subscript𝑍2Z_{2}italic_Z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT are equidistantly placed between [3σ,3σ]3𝜎3𝜎[-3\sigma,3\sigma][ - 3 italic_σ , 3 italic_σ ]. The authors go on to show that the total variance in the second layer of xoodsubscript𝑥𝑜𝑜𝑑x_{ood}italic_x start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT will be higher than xindsubscript𝑥𝑖𝑛𝑑x_{in-d}italic_x start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT if the following holds (U~2(F1)F1)2|F1=Min(U~2(F1)F1)2|F1=0σ2Vinevaluated-atsuperscriptsubscript~𝑈2subscript𝐹1subscript𝐹12subscript𝐹1subscript𝑀𝑖𝑛evaluated-atsuperscriptsubscript~𝑈2subscript𝐹1subscript𝐹12subscript𝐹10superscript𝜎2subscript𝑉𝑖𝑛\frac{\left(\frac{\partial\tilde{U}_{2}(F_{1})}{\partial F_{1}}\right)^{2}% \Bigr{|}_{\begin{subarray}{c}F_{1}=M_{in}\end{subarray}}}{\left(\frac{\partial% \tilde{U}_{2}(F_{1})}{\partial F_{1}}\right)^{2}\Bigr{|}_{\begin{subarray}{c}F% _{1}=0\end{subarray}}}\leq\frac{\sigma^{2}}{V_{in}}divide start_ARG ( divide start_ARG ∂ over~ start_ARG italic_U end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_M start_POSTSUBSCRIPT italic_i italic_n end_POSTSUBSCRIPT end_CELL end_ROW end_ARG end_POSTSUBSCRIPT end_ARG start_ARG ( divide start_ARG ∂ over~ start_ARG italic_U end_ARG start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG start_ARG ∂ italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = 0 end_CELL end_ROW end_ARG end_POSTSUBSCRIPT end_ARG ≤ divide start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_V start_POSTSUBSCRIPT italic_i italic_n end_POSTSUBSCRIPT end_ARG. One can rapidly infer that this inequality holds if the absolute first order derivative of the parametric component of the SVGP around 0 is higher compared to any other value which might be evaluated at. This observation is to be made in conjunction with the fact that σ2Vin1.0superscript𝜎2subscript𝑉𝑖𝑛1.0\frac{\sigma^{2}}{V_{in}}\geq 1.0divide start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_V start_POSTSUBSCRIPT italic_i italic_n end_POSTSUBSCRIPT end_ARG ≥ 1.0, since the total variance of in-distribution points will be reduced compared to the prior variance.

Refer to caption
Figure 4: Layer-wise decomposition of uncertainty into parametric/epistemic and non-parametric/distributional for a zero mean function DGP, alongside first order derivatives. OOD points in input space xoutsubscript𝑥outx_{\text{out}}italic_x start_POSTSUBSCRIPT out end_POSTSUBSCRIPT get mapped on average to 00 in f1(xout)subscript𝑓1subscript𝑥outf_{1}(x_{\text{out}})italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT out end_POSTSUBSCRIPT ), which has a high absolute first order derivative causing the parametric uncertainty in f2subscript𝑓2f_{2}italic_f start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT to be high for xoutsubscript𝑥outx_{\text{out}}italic_x start_POSTSUBSCRIPT out end_POSTSUBSCRIPT.

To gain some intuition as to what occurs in practice, we can consider a 2 layer DGP trained on a toy regression task, where we decompose the resulting posterior SVGP predictive equation into its parametric and non-parametric components for each layer with respect to input space (first two rows of Figure 4). To investigate whether our trained DGP respects the above inequality for propagating higher total uncertainty for OOD data points in comparison to in-distribution data points, we also need to predict what are the first-order derivatives with respect to the input stemming from the previous layer (last row of Figure 4). We encourage the reader to inspect McHutchon (2013) for an in-depth introduction to first order derivative of GPs. We can notice that the total variance is indeed higher for OOD data points in the final layer, as this was brought upon by the high absolute value of the first order derivative around 0 in the second layer (OOD data points in the first hidden layer will have an expected value of 0). Intuitively, OOD data points in input space will have higher total uncertainty in output space due to the higher diversity of function values in the second layer. The diversity is caused by the high non-parametric uncertainty in the first hidden layer. Conversely, we can see that for in-distribution points the total variance in the first hidden layer is relatively small, hence the sampling will be close to deterministic, implicitly meaning that it will access only a very restricted set of function values in the second layer thus causing a relatively small total variance. Lastly, we remind ourselves that for GPs we can consider the non-parametric/distributional uncertainty as a proxy for OOD detection. From Figure 4 we can see that distributional uncertainty collapses in the second layer for any value in input space. This implies that DGPs are not distance-aware.

Refer to caption
Figure 5: Layer-wise decomposition of uncertainty into parametric/epistemic and non-parametric/distributional for a zero mean function DGP. Outlier points are sampled close to inliers points in f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, thereby causing their collapse of non-parametric variance since inducing points in f1subscript𝑓1f_{1}italic_f start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT are close to both outlier and inlier samples.

To understand what is causing this pathology, we take a simple case study of a DGP (zero mean function) with two hidden layers trained on a toy regression dataset (Figure 5). Taking a clear outlier in input space, say the data point situated at -7.5, it gets correctly identified as an outlier in the mapping from input space to hidden layer space as given by its distributional variance. However, its outlier property dissipates in the next layer after sampling, as it gets mapped to regions where the next GP assigns inducing point locations. This is due to points inside the data manifold getting confidently mapped between -2.0 and 1.0 in hidden layer space. Consequently, what was initially correctly identified as an outlier will now have its final distributional uncertainty close to zero. Adding further layers, will only compound this pathology.

2.5 Wasserstein-2 kernels for probability measures

As we have seen in the previous subsection, analytically integrating out the prior of a DGP is intractable in the case of using kernels operating in Euclidean space. However, the hidden layers of a DGP are intrinsically defined over probability measures (Gaussian in this case). This leads us to ponder whether we can obtain an analytically tractable formulation of DGPs by using kernels operating on probability measures, thereby we need a metric on probability measures which we subsequently introduce.

The Wasserstein space on \mathbb{R}blackboard_R can be defined as the set W2()subscript𝑊2W_{2}(\mathbb{R})italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( blackboard_R ) of probability measures on \mathbb{R}blackboard_R with a finite moment of order two. We denote by Π(μ,ν)Π𝜇𝜈\Pi(\mu,\nu)roman_Π ( italic_μ , italic_ν ) the set of all probability measures ΠΠ\Piroman_Π over the product set ×\mathbb{R}\times\mathbb{R}blackboard_R × blackboard_R with marginals μ𝜇\muitalic_μ and ν𝜈\nuitalic_ν, which are probability measures in W2()subscript𝑊2W_{2}(\mathbb{R})italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( blackboard_R ). The transportation cost between two measures μ𝜇\muitalic_μ and ν𝜈\nuitalic_ν is defined as:

T2(μ,ν)=infπΠ(μ,ν)[xy]2𝑑π(x,y)subscript𝑇2𝜇𝜈subscriptinfimum𝜋Π𝜇𝜈superscriptdelimited-[]𝑥𝑦2differential-d𝜋𝑥𝑦T_{2}(\mu,\nu)=\inf_{\pi\in\Pi(\mu,\nu)}\int[x-y]^{2}d\pi(x,y)italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_μ , italic_ν ) = roman_inf start_POSTSUBSCRIPT italic_π ∈ roman_Π ( italic_μ , italic_ν ) end_POSTSUBSCRIPT ∫ [ italic_x - italic_y ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_d italic_π ( italic_x , italic_y )(57)

This transportation cost allows us to endow the set W2()subscript𝑊2W_{2}(\mathbb{R})italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( blackboard_R ) with a metric by defining the quadratic Wasserstein distance between μ𝜇\muitalic_μ and ν𝜈\nuitalic_ν as:

W2(μ,ν)=T2(μ,ν)1/2subscript𝑊2𝜇𝜈subscript𝑇2superscript𝜇𝜈12W_{2}(\mu,\nu)=T_{2}(\mu,\nu)^{1/2}italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_μ , italic_ν ) = italic_T start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_μ , italic_ν ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT(58)
Theorem 2 (Theorem IV.1. in Bachoc et al. (2017))

Let kW:W2()×W2()normal-:subscript𝑘𝑊normal-→subscript𝑊2subscript𝑊2k_{W}:W_{2}(\mathbb{R})\times W_{2}(\mathbb{R})\rightarrow\mathbb{R}italic_k start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT : italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( blackboard_R ) × italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( blackboard_R ) → blackboard_R be the Wasserstein-2 RBF kernel defined as following:

kW2(μ,ν)=σ2expW22(μ,ν)l2superscript𝑘subscript𝑊2𝜇𝜈superscript𝜎2superscriptsubscript𝑊22𝜇𝜈superscript𝑙2k^{W_{2}}(\mu,\nu)=\sigma^{2}\exp\frac{-W_{2}^{2}(\mu,\nu)}{l^{2}}italic_k start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_μ , italic_ν ) = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_exp divide start_ARG - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_μ , italic_ν ) end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG(59)

then kW2(μ,ν)superscript𝑘subscript𝑊2𝜇𝜈k^{W_{2}}(\mu,\nu)italic_k start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_μ , italic_ν ) is a positive definite kernel for any μ,νW2()𝜇𝜈subscript𝑊2\mu,\nu\in W_{2}(\mathbb{R})italic_μ , italic_ν ∈ italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( blackboard_R ), respectively σ2superscript𝜎2\sigma^{2}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT is the kernel variance, l2superscript𝑙2l^{2}italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT being the lengthscale.

A detailed proof of this theorem can be found in Bachoc et al. (2017).

Multiplication of positive definite kernels results again in a positive definite kernel, hence we arrive at the automatic relevance determination kernel based on Wasserstein-2 distances:

kW2([μd]d=1D,[νd]d=1D)=σ2expd=1DW22(μd,νd)ld2superscript𝑘subscript𝑊2superscriptsubscriptdelimited-[]subscript𝜇𝑑𝑑1𝐷superscriptsubscriptdelimited-[]subscript𝜈𝑑𝑑1𝐷superscript𝜎2superscriptsubscript𝑑1𝐷superscriptsubscript𝑊22subscript𝜇𝑑subscript𝜈𝑑superscriptsubscript𝑙𝑑2k^{W_{2}}([\mu_{d}]_{d=1}^{D},[\nu_{d}]_{d=1}^{D})=\sigma^{2}\exp\sum_{d=1}^{D% }\frac{-W_{2}^{2}(\mu_{d},\nu_{d})}{l_{d}^{2}}italic_k start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( [ italic_μ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT , [ italic_ν start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ] start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ) = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_exp ∑ start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT divide start_ARG - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT , italic_ν start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ) end_ARG start_ARG italic_l start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG(60)
Wasserstein-2 Distance between Gaussian distributions.

Gaussian measures fulfill the condition of finite second order moment, thereby being a clear example of probability measures for which we can compute Wasserstein metrics. The Wasserstein-2 distance between two multivariate Gaussian distributions 𝒩(m1,Σ1)𝒩subscript𝑚1subscriptΣ1\mathcal{N}(m_{1},\Sigma_{1})caligraphic_N ( italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) and 𝒩(m2,Σ2)𝒩subscript𝑚2subscriptΣ2\mathcal{N}(m_{2},\Sigma_{2})caligraphic_N ( italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ), which have associated Gaussian measures and implicitly the Wasserstein metric is well defined for them, has been shown to have the following form m1m222+Tr[Σ1+Σ22(Σ11/2Σ2Σ11/2)1/2]superscriptsubscriptnormsubscript𝑚1subscript𝑚222𝑇𝑟delimited-[]subscriptΣ1subscriptΣ22superscriptsuperscriptsubscriptΣ112subscriptΣ2superscriptsubscriptΣ11212\parallel m_{1}\scalebox{0.75}[1.0]{$-$}m_{2}\parallel_{2}^{2}+Tr\Big{[}\Sigma% _{1}+\Sigma_{2}\scalebox{0.75}[1.0]{$-$}2\Big{(}\Sigma_{1}^{1/2}\Sigma_{2}% \Sigma_{1}^{1/2}\Big{)}^{1/2}\Big{]}∥ italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_T italic_r [ roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + roman_Σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - 2 ( roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT roman_Σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ] (Dowson and Landau, 1982), which in the case of univariate Gaussians simplifies to |m1m2|2+|Σ1Σ2|2superscriptsubscript𝑚1subscript𝑚22superscriptsubscriptΣ1subscriptΣ22|m_{1}-m_{2}|^{2}+|\sqrt{\Sigma_{1}}\scalebox{0.75}[1.0]{$-$}\sqrt{\Sigma_{2}}% |^{2}| italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + | square-root start_ARG roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG - square-root start_ARG roman_Σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. This last formulation will be used throughout this paper.

2.6 Distributional Deep Gaussian Processes & OOD detection

In the previous subsection we have seen that DGPs can easily fail in propagating distributional uncertainty forward. We now focus on the variant of DGPs introduced in Popescu et al. (2020) that was proven both theoretically and empirically to propagate distributional uncertainty throughout the hierarchy, thus ensuring distance-awareness properties. The insights gained from this subsection will constitute the departure point for our proposed model in the next section.

Distributional Gaussian Processes (DistGP) were first introduced in (Bachoc et al., 2017) to describe a shallow GP that operates on probability measures using a Wasserstein-2 based kernel as defined in equation (60).

We introduce the generative process of Distributional Deep Gaussian Processes (DDGP) for 2 layers:

p(F1)𝑝subscript𝐹1\displaystyle p(F_{1})italic_p ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )𝒩(0,Kff)similar-toabsent𝒩0subscript𝐾𝑓𝑓\displaystyle\sim\mathcal{N}\left(0,K_{ff}\right)∼ caligraphic_N ( 0 , italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT )(61)
F1sthsuperscriptsubscript𝐹1𝑠𝑡\displaystyle F_{1}^{sth}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s italic_t italic_h end_POSTSUPERSCRIPT=m(F1)+v(F1)ϵ,ϵ𝒩(0,𝕀n)formulae-sequenceabsent𝑚subscript𝐹1𝑣subscript𝐹1italic-ϵsimilar-toitalic-ϵ𝒩0subscript𝕀𝑛\displaystyle=m(F_{1})+\sqrt{v(F_{1})}\epsilon,~{}\epsilon\sim\mathcal{N}\left% (0,\mathbb{I}_{n}\right)= italic_m ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + square-root start_ARG italic_v ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_ARG italic_ϵ , italic_ϵ ∼ caligraphic_N ( 0 , blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT )(62)
F1detsuperscriptsubscript𝐹1𝑑𝑒𝑡\displaystyle F_{1}^{det}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_e italic_t end_POSTSUPERSCRIPT=𝒩(m(F1),diag[v(F1)])absent𝒩𝑚subscript𝐹1𝑑𝑖𝑎𝑔delimited-[]𝑣subscript𝐹1\displaystyle=\mathcal{N}\left(m(F_{1}),diag\left[v(F_{1})\right]\right)= caligraphic_N ( italic_m ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_d italic_i italic_a italic_g [ italic_v ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ] )(63)
p(F2)𝑝subscript𝐹2\displaystyle p(F_{2})italic_p ( italic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )𝒩(0,khybrid({F1sth,F1det},{F1sth,F1det}))similar-toabsent𝒩0subscript𝑘𝑦𝑏𝑟𝑖𝑑superscriptsubscript𝐹1𝑠𝑡superscriptsubscript𝐹1𝑑𝑒𝑡superscriptsubscript𝐹1𝑠𝑡superscriptsubscript𝐹1𝑑𝑒𝑡\displaystyle\sim\mathcal{N}\left(0,k_{hybrid}\left(\{F_{1}^{sth},F_{1}^{det}% \},\{F_{1}^{sth},F_{1}^{det}\}\right)\right)∼ caligraphic_N ( 0 , italic_k start_POSTSUBSCRIPT italic_h italic_y italic_b italic_r italic_i italic_d end_POSTSUBSCRIPT ( { italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s italic_t italic_h end_POSTSUPERSCRIPT , italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_e italic_t end_POSTSUPERSCRIPT } , { italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s italic_t italic_h end_POSTSUPERSCRIPT , italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_e italic_t end_POSTSUPERSCRIPT } ) )(64)

, where the hybrid kernel is defined as follows:

khybrid(μi,μj)=kE(xi,xj)expd=1DW22(μi,d,μj,d)ld2superscript𝑘𝑦𝑏𝑟𝑖𝑑subscript𝜇𝑖subscript𝜇𝑗superscript𝑘𝐸subscript𝑥𝑖subscript𝑥𝑗superscriptsubscript𝑑1𝐷superscriptsubscript𝑊22subscript𝜇𝑖𝑑subscript𝜇𝑗𝑑superscriptsubscript𝑙𝑑2k^{hybrid}\left(\mu_{i},\mu_{j}\right)=k^{E}(x_{i},x_{j})\exp\sum_{d=1}^{D}% \frac{-W_{2}^{2}(\mu_{i,d},\mu_{j,d})}{l_{d}^{2}}italic_k start_POSTSUPERSCRIPT italic_h italic_y italic_b italic_r italic_i italic_d end_POSTSUPERSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = italic_k start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) roman_exp ∑ start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT divide start_ARG - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_i , italic_d end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT italic_j , italic_d end_POSTSUBSCRIPT ) end_ARG start_ARG italic_l start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG(65)

, where we denoted μi=𝒩(m(F1(xi)),σ12)subscript𝜇𝑖𝒩𝑚subscript𝐹1subscript𝑥𝑖superscriptsubscript𝜎12\mu_{i}=\mathcal{N}(m(F_{1}(x_{i})),\sigma_{1}^{2})italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = caligraphic_N ( italic_m ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) , italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ); μj=𝒩(m(F1(xj)),σ12)subscript𝜇𝑗𝒩𝑚subscript𝐹1subscript𝑥𝑗superscriptsubscript𝜎12\mu_{j}=\mathcal{N}(m(F_{1}(x_{j})),\sigma_{1}^{2})italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = caligraphic_N ( italic_m ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) ) , italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) as the first two moments which are obtained through the F1detsuperscriptsubscript𝐹1𝑑𝑒𝑡F_{1}^{det}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d italic_e italic_t end_POSTSUPERSCRIPT operation in the generative process. Intuitively this generative process implies keeping track of a stochastic, respectively deterministic component of the same SVGP at any given hidden layer, while the first layered is governed by a standard SVGP operating on Euclidean data. It is worthy to point out that for this probabilistic construction, the inducing points {Zl}l=1Lsuperscriptsubscriptsubscript𝑍𝑙𝑙1𝐿\{Z_{l}\}_{l=1}^{L}{ italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT have to reside in the space of multivariate Gaussians, hence Zl𝒩(ZlμZl,ΣZl)similar-tosubscript𝑍𝑙𝒩conditionalsubscript𝑍𝑙subscript𝜇subscript𝑍𝑙subscriptΣsubscript𝑍𝑙Z_{l}\sim\mathcal{N}\left(Z_{l}\mid\mu_{Z_{l}},\Sigma_{Z_{l}}\right)italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∼ caligraphic_N ( italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ italic_μ start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ). The first two moments are treated as hyperparameters that are optimized during training.

Refer to caption
Figure 6: Conceptual difference between euclidean and hybrid kernel.

We can consider an OOD data point xoodsubscript𝑥𝑜𝑜𝑑x_{ood}italic_x start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT in input space such that Σ1~(xood)=σ2~subscriptΣ1subscript𝑥𝑜𝑜𝑑superscript𝜎2\tilde{\Sigma_{1}}(x_{ood})=\sigma^{2}over~ start_ARG roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ( italic_x start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT ) = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and U1~(xood)=0~subscript𝑈1subscript𝑥𝑜𝑜𝑑0\tilde{U_{1}}(x_{ood})=0over~ start_ARG italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ( italic_x start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT ) = 0, respectively an in-distribution point xindsubscript𝑥𝑖𝑛𝑑x_{in-d}italic_x start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT such that Σ~1(xind)=Vinσ2subscript~Σ1subscript𝑥𝑖𝑛𝑑subscript𝑉𝑖𝑛superscript𝜎2\tilde{\Sigma}_{1}(x_{in-d})=V_{in}\leq\sigma^{2}over~ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ) = italic_V start_POSTSUBSCRIPT italic_i italic_n end_POSTSUBSCRIPT ≤ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT and U~1(xind)=Minsubscript~𝑈1subscript𝑥𝑖𝑛𝑑subscript𝑀𝑖𝑛\tilde{U}_{1}(x_{in-d})=M_{in}over~ start_ARG italic_U end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ) = italic_M start_POSTSUBSCRIPT italic_i italic_n end_POSTSUBSCRIPT. We also assume that Z2subscript𝑍2Z_{2}italic_Z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT are equidistantly placed between [3σ,3σ]3𝜎3𝜎[-3\sigma,3\sigma][ - 3 italic_σ , 3 italic_σ ]. The authors go on to show that the total variance in the second layer of xoodsubscript𝑥𝑜𝑜𝑑x_{ood}italic_x start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT will be higher than xindsubscript𝑥𝑖𝑛𝑑x_{in-d}italic_x start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT if the following holds σ2>>Z2,varmuch-greater-thansuperscript𝜎2subscript𝑍2𝑣𝑎𝑟\sigma^{2}>>Z_{2,var}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT > > italic_Z start_POSTSUBSCRIPT 2 , italic_v italic_a italic_r end_POSTSUBSCRIPT and VinZ2,varsubscript𝑉𝑖𝑛subscript𝑍2𝑣𝑎𝑟V_{in}\approx Z_{2,var}italic_V start_POSTSUBSCRIPT italic_i italic_n end_POSTSUBSCRIPT ≈ italic_Z start_POSTSUBSCRIPT 2 , italic_v italic_a italic_r end_POSTSUBSCRIPT. To better understand this behaviour, we can consider a two-layered DGPs and DDGPs, we assume an in-distribution point to have low total variance in hidden layer F1subscript𝐹1F_{1}italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT, respectively an OOD point to have high total variance. In the DGP case, upon sampling from q(F1(xind))𝑞subscript𝐹1subscript𝑥𝑖𝑛𝑑q(F_{1}(x_{in-d}))italic_q ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT ) ) and q(F1(xood))𝑞subscript𝐹1subscript𝑥𝑜𝑜𝑑q(F_{1}(x_{ood}))italic_q ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT ) ) we can end up with samples which are equally distance with respect to inducing points’ location Z1subscript𝑍1Z_{1}italic_Z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT. If this occurs, then non-parametric variance (proxy for distributional variance) will be equal for xindsubscript𝑥𝑖𝑛𝑑x_{in-d}italic_x start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT and xoodsubscript𝑥𝑜𝑜𝑑x_{ood}italic_x start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT in F2subscript𝐹2F_{2}italic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Hence, what was initially flagged as OOD in the first hidden layer will be considered as in-distribution by the second hidden layer. onversely, in the DDGP case and under the assumption that the variance of distributional inducing points’ locations Z2subscript𝑍2Z_{2}italic_Z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is almost equal in distribution to the total variance of in-distribution points in F2subscript𝐹2F_{2}italic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, the Wasserstein-2 component of the hybrid kernel will notice that there is a higher distance between the now distributional inducing point location and xoodsubscript𝑥𝑜𝑜𝑑x_{ood}italic_x start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT, as opposed of former with xindsubscript𝑥𝑖𝑛𝑑x_{in-d}italic_x start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT. Then, the non-parametric variance of xoodsubscript𝑥𝑜𝑜𝑑x_{ood}italic_x start_POSTSUBSCRIPT italic_o italic_o italic_d end_POSTSUBSCRIPT will be higher than that of xindsubscript𝑥𝑖𝑛𝑑x_{in-d}italic_x start_POSTSUBSCRIPT italic_i italic_n - italic_d end_POSTSUBSCRIPT. A visual depiction of this case study is illustrated in Figure 6.

3 Distributional GP Layers

Shift towards single-pass uncertainty quantification

Early methods for uncertainty quantification in Bayesian deep learning (BDL) have focused on estimating the variance of sample from difference sub-models, such as in using dropout (Gal and Ghahramani, 2016b), deep ensembles (Lakshminarayanan et al., 2017) or in sampling posterior network weights from a hypernetwork (Pawlowski et al., 2017). This results in slow uncertainty estimation at testing time, which can be critical in high-risk domains where speed is of essence (e.g., self-driving cars). Recent work in OOD detection has focused on estimating proxies for distributional uncertainty in a single-pass, such as in bi-Lipschitz regularized feature extractors for GP (van Amersfoort et al., 2021; Liu et al., 2020) or in parametrizing second-order uncertainty via neural networks within the framework of evidential learning (Charpentier et al., 2020; Amini et al., 2019). With this shift towards single-pass uncertainty quantification, DDGPs and the hybrid kernel introduced in subsection 2.6 are no longer appropriate since they involved sampling the features at each hidden layer. In next subsection we detail a deterministic variant which still preserves correlations between data points in the hidden layers.

Integrating GPs in convolutional architectures

GP for image classification has garnered interest in the past years, with hybrid approaches, whereby a deep neural network embedding mechanism is trained end-to-end with GPs as the classification layer, being the first attempt to unify the two approaches (Wilson et al., 2016; Bradshaw et al., 2017). Garriga-Alonso et al. (2018) provided a conceptual framework by which classic CNN architectures are translated into the kernel of a shallow GP by exploiting the mathematical properties of the variance of weights matrices. Van der Wilk et al. (2017) proposed the first convolutional kernel, constructed by aggregating patch response functions. Dutordoir et al. (2019) have attempted to solve the issue with complete spatial invariance of the convolutional kernel by adding an additional squared exponential kernel between the locations of two patches to account for spatial location, obtaining improvements in accuracy. To extend this shallow GP model to accommodate deeper architectures, Blomqvist et al. (2018) have proposed to use the convolutional GP on top of a succession of feed-forward GP layers which process data in a convolutional manner akin to standard convolutional layers. However, scaling this framework to modern convolutional architectures with large number of channels in each hidden layer is problematic for two reasons. Firstly, this would imply training high-dimensional multi-output GPs which still represents a research avenue on how to make it more efficient (Bruinsma et al., 2020). Ignoring correlations between channels would severely diminish the expressivity of the model. Secondly, the hidden layer GP which process data in a convolutional manner implies taking inducing points, with a dimensionality which scales linearly with the number of channels. This would imply optimization over high-dimensional spaces for each hidden layer, potentially leading to local minima. We will see later on an alternative to this framework for integrating GP in a convolutional architecture, one that is more amenable to modern convolutional architectures.

3.1 Deep Wasserstein Kernel Learning

3.1.1 Generative Process

We now write the generative process of this new probabilistic framework coined Deep Wasserstein Kernel Learning (DWKL) for 2 layers:

p(F1)𝑝subscript𝐹1\displaystyle p(F_{1})italic_p ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )=𝒩(PCA(F0),diag[Kff])absent𝒩𝑃𝐶𝐴subscript𝐹0𝑑𝑖𝑎𝑔delimited-[]subscript𝐾𝑓𝑓\displaystyle=\mathcal{N}\left(PCA(F_{0}),diag\left[K_{ff}\right]\right)= caligraphic_N ( italic_P italic_C italic_A ( italic_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) , italic_d italic_i italic_a italic_g [ italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT ] )(66)
p(F2)𝑝subscript𝐹2\displaystyle p(F_{2})italic_p ( italic_F start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT )=𝒩[0,kW2(p(F1),p(F1))]absent𝒩0superscript𝑘subscript𝑊2𝑝subscript𝐹1𝑝subscript𝐹1\displaystyle=\mathcal{N}\left[0,k^{W_{2}}\left(p(F_{1}),p(F_{1})\right)\right]= caligraphic_N [ 0 , italic_k start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_p ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , italic_p ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ) ](67)

Due to the introduction of PCA mean functions, data points in the hidden layer are now correlated. To make this clear, we can explicitly calculate it:

p(F2,iF2,j)𝑝matrixsubscript𝐹2𝑖subscript𝐹2𝑗\displaystyle p\begin{pmatrix}F_{2,i}\\ F_{2,j}\end{pmatrix}italic_p ( start_ARG start_ROW start_CELL italic_F start_POSTSUBSCRIPT 2 , italic_i end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL italic_F start_POSTSUBSCRIPT 2 , italic_j end_POSTSUBSCRIPT end_CELL end_ROW end_ARG )𝒩[(00),(σ22expW22(μi,μi)l2σ22expW22(μi,μj)l2σ22expW22(μj,μi)l2σ22expW22(μj,μj)l2)]similar-toabsent𝒩matrix00matrixsuperscriptsubscript𝜎22superscriptsubscript𝑊22subscript𝜇𝑖subscript𝜇𝑖superscript𝑙2superscriptsubscript𝜎22superscriptsubscript𝑊22subscript𝜇𝑖subscript𝜇𝑗superscript𝑙2superscriptsubscript𝜎22superscriptsubscript𝑊22subscript𝜇𝑗subscript𝜇𝑖superscript𝑙2superscriptsubscript𝜎22superscriptsubscript𝑊22subscript𝜇𝑗subscript𝜇𝑗superscript𝑙2\displaystyle\sim\mathcal{N}\left[\begin{pmatrix}0\\ 0\end{pmatrix},\begin{pmatrix}\sigma_{2}^{2}\exp-\frac{-W_{2}^{2}(\mu_{i},\mu_% {i})}{l^{2}}&\sigma_{2}^{2}\exp-\frac{-W_{2}^{2}(\mu_{i},\mu_{j})}{l^{2}}\\ \sigma_{2}^{2}\exp-\frac{-W_{2}^{2}(\mu_{j},\mu_{i})}{l^{2}}&\sigma_{2}^{2}% \exp-\frac{-W_{2}^{2}(\mu_{j},\mu_{j})}{l^{2}}\end{pmatrix}\right]∼ caligraphic_N [ ( start_ARG start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ) , ( start_ARG start_ROW start_CELL italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_exp - divide start_ARG - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_CELL start_CELL italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_exp - divide start_ARG - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_CELL end_ROW start_ROW start_CELL italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_exp - divide start_ARG - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_CELL start_CELL italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_exp - divide start_ARG - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_CELL end_ROW end_ARG ) ](74)
𝒩[(00),(σ22Ki,jW2Kj,iW2σ22)]similar-toabsent𝒩matrix00matrixsuperscriptsubscript𝜎22superscriptsubscript𝐾𝑖𝑗subscript𝑊2superscriptsubscript𝐾𝑗𝑖subscript𝑊2superscriptsubscript𝜎22\displaystyle\sim\mathcal{N}\left[\begin{pmatrix}0\\ 0\end{pmatrix},\begin{pmatrix}\sigma_{2}^{2}&K_{i,j}^{W_{2}}\\ K_{j,i}^{W_{2}}&\sigma_{2}^{2}\end{pmatrix}\right]∼ caligraphic_N [ ( start_ARG start_ROW start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL end_ROW end_ARG ) , ( start_ARG start_ROW start_CELL italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_CELL start_CELL italic_K start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL italic_K start_POSTSUBSCRIPT italic_j , italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_CELL start_CELL italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ) ](79)

where μi=𝒩(PCA(F0,i),σ2)subscript𝜇𝑖𝒩𝑃𝐶𝐴subscript𝐹0𝑖superscript𝜎2\mu_{i}=\mathcal{N}\left(PCA(F_{0,i}),\sigma^{2}\right)italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = caligraphic_N ( italic_P italic_C italic_A ( italic_F start_POSTSUBSCRIPT 0 , italic_i end_POSTSUBSCRIPT ) , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) and μj=𝒩(PCA(F0,j),σ2)subscript𝜇𝑗𝒩𝑃𝐶𝐴subscript𝐹0𝑗superscript𝜎2\mu_{j}=\mathcal{N}\left(PCA(F_{0,j}),\sigma^{2}\right)italic_μ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = caligraphic_N ( italic_P italic_C italic_A ( italic_F start_POSTSUBSCRIPT 0 , italic_j end_POSTSUBSCRIPT ) , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ). If the PCA embeddings of xisubscript𝑥𝑖x_{i}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and xjsubscript𝑥𝑗x_{j}italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT are different, then the Wasserstein-2 distance will be different than zero, hence introducing correlations.

3.1.2 Evidence lower bound

Deep Kernel Learning (DKL) (Wilson et al., 2016) is defined as a shallow GP with the input encoded by a neural network:

p(Y,FL,UL)=p(YFL)likelihoodp(FLUL;ZL1,Enc(X))p(UL)prior𝑝𝑌subscript𝐹𝐿subscript𝑈𝐿subscript𝑝conditional𝑌subscript𝐹𝐿likelihoodsubscript𝑝conditionalsubscript𝐹𝐿subscript𝑈𝐿subscript𝑍𝐿1Enc𝑋𝑝subscript𝑈𝐿priorp(Y,F_{L},U_{L})=\underbrace{p(Y\mid F_{L})}_{\text{likelihood}}\underbrace{p(% F_{L}\mid U_{L};Z_{L-1},\emph{Enc}(X))p(U_{L})}_{\text{prior}}italic_p ( italic_Y , italic_F start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT , italic_U start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) = under⏟ start_ARG italic_p ( italic_Y ∣ italic_F start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT likelihood end_POSTSUBSCRIPT under⏟ start_ARG italic_p ( italic_F start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ∣ italic_U start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ; italic_Z start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT , Enc ( italic_X ) ) italic_p ( italic_U start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT prior end_POSTSUBSCRIPT(80)

,where Enc(X)Enc𝑋\emph{Enc}(X)Enc ( italic_X ) represents the input passed through a neural network encoder, providing a deterministic transformation of the data which is then fed into a SVGP operating on Euclidean data (using equation (9)).

We diverge from this approach by utilising stacked DistGP with Wasserstein-2 kernels as the encoder network, hence our transformed input given by a Gaussian distribution q(FL1)𝑞subscript𝐹𝐿1q(F_{L-1})italic_q ( italic_F start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT ). Using the first two moments of the penultimate layer, we introduce a DistGP so as to obtain the final predictions. The conditional equation for DistGP at arbitrary layer l2𝑙2l\geq 2italic_l ≥ 2 is written as:

p(FlUl;Zl1,Fl1)=𝒩(FlKfuW2KuuW21U,KffW2QffW2)𝑝conditionalsubscript𝐹𝑙subscript𝑈𝑙subscript𝑍𝑙1subscript𝐹𝑙1𝒩conditionalsubscript𝐹𝑙superscriptsubscript𝐾𝑓𝑢subscript𝑊2subscriptsuperscript𝐾subscript𝑊2𝑢𝑢1𝑈superscriptsubscript𝐾𝑓𝑓subscript𝑊2superscriptsubscript𝑄𝑓𝑓subscript𝑊2p(F_{l}\mid U_{l};Z_{l-1},F_{l-1})=\mathcal{N}(F_{l}\mid K_{fu}^{W_{2}}K^{W_{2% }}_{uu}\raisebox{4.95134pt}{$\scriptscriptstyle-\!1$}U,K_{ff}^{W_{2}}-Q_{ff}^{% W_{2}})italic_p ( italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ; italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT , italic_F start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) = caligraphic_N ( italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_K start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - 1 italic_U , italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT )(81)

, where we have inducing points Zl𝒩(ZlμZl,ΣZl)similar-tosubscript𝑍𝑙𝒩conditionalsubscript𝑍𝑙subscript𝜇subscript𝑍𝑙subscriptΣsubscript𝑍𝑙Z_{l}\sim\mathcal{N}(Z_{l}\mid\mu_{Z_{l}},\Sigma_{Z_{l}})italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∼ caligraphic_N ( italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ italic_μ start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) and uncertain input Fl1𝒩(Fl1Ul1~(Fl2),Σl1~(Fl2))similar-tosubscript𝐹𝑙1𝒩conditionalsubscript𝐹𝑙1~subscript𝑈𝑙1subscript𝐹𝑙2~subscriptΣ𝑙1subscript𝐹𝑙2F_{l-1}\sim\mathcal{N}\left(F_{l-1}\mid\tilde{U_{l-1}}(F_{l-2}),\tilde{\Sigma_% {l-1}}(F_{l-2})\right)italic_F start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ∼ caligraphic_N ( italic_F start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ∣ over~ start_ARG italic_U start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT end_ARG ( italic_F start_POSTSUBSCRIPT italic_l - 2 end_POSTSUBSCRIPT ) , over~ start_ARG roman_Σ start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT end_ARG ( italic_F start_POSTSUBSCRIPT italic_l - 2 end_POSTSUBSCRIPT ) ). For computational reasons we take both Σl1~(Fl2)~subscriptΣ𝑙1subscript𝐹𝑙2\tilde{\Sigma_{l-1}}(F_{l-2})over~ start_ARG roman_Σ start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT end_ARG ( italic_F start_POSTSUBSCRIPT italic_l - 2 end_POSTSUBSCRIPT ) and ΣZlsubscriptΣsubscript𝑍𝑙\Sigma_{Z_{l}}roman_Σ start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT to be diagonal matrices. For l=1𝑙1l=1italic_l = 1 we have q(F1)𝒩(F1U1~(x),Σ1~(x))similar-to𝑞subscript𝐹1𝒩conditionalsubscript𝐹1~subscript𝑈1𝑥~subscriptΣ1𝑥q(F_{1})\sim\mathcal{N}\left(F_{1}\mid\tilde{U_{1}}(x),\tilde{\Sigma_{1}}(x)\right)italic_q ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ∼ caligraphic_N ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∣ over~ start_ARG italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ( italic_x ) , over~ start_ARG roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ( italic_x ) ) which are the standard predictive equations for SVGP as given in equations (31) and (32) since the first layer is governed by a standard SVGP operating on Euclidean data.

The joint density prior of Deep Wasserstein Kernel Learning (DWKL) is given as:

p(Y|F)likelihoodp(FLUL;ZL1,Enc(X))l=1Lp(Ul)priorsubscript𝑝conditional𝑌𝐹likelihoodsubscript𝑝conditionalsubscript𝐹𝐿subscript𝑈𝐿subscript𝑍𝐿1𝐸𝑛𝑐𝑋superscriptsubscriptproduct𝑙1𝐿𝑝subscript𝑈𝑙prior\underbrace{p(Y|F)}_{\text{likelihood}}\underbrace{p(F_{L}\mid U_{L};Z_{L-1},% Enc(X))\prod_{l=1}^{L}p(U_{l})}_{\text{prior}}under⏟ start_ARG italic_p ( italic_Y | italic_F ) end_ARG start_POSTSUBSCRIPT likelihood end_POSTSUBSCRIPT under⏟ start_ARG italic_p ( italic_F start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ∣ italic_U start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ; italic_Z start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT , italic_E italic_n italic_c ( italic_X ) ) ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_p ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) end_ARG start_POSTSUBSCRIPT prior end_POSTSUBSCRIPT(82)

, where in our case Enc(X)𝒩[U~L1(FL2),Σ~L1(FL2)]similar-to𝐸𝑛𝑐𝑋𝒩subscript~𝑈𝐿1subscript𝐹𝐿2subscript~Σ𝐿1subscript𝐹𝐿2Enc(X)\sim\mathcal{N}\left[\tilde{U}_{L-1}(F_{L-2}),\tilde{\Sigma}_{L-1}(F_{L-% 2})\right]italic_E italic_n italic_c ( italic_X ) ∼ caligraphic_N [ over~ start_ARG italic_U end_ARG start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT ( italic_F start_POSTSUBSCRIPT italic_L - 2 end_POSTSUBSCRIPT ) , over~ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT ( italic_F start_POSTSUBSCRIPT italic_L - 2 end_POSTSUBSCRIPT ) ] that acts as the uncertain input for the final distributional GP. We introduce a factorized posterior between layers and dimensions q(FL,{Ul}l=1L)=p(FL|UL;ZL1)l=1Lq(Ul)𝑞subscript𝐹𝐿superscriptsubscriptsubscript𝑈𝑙𝑙1𝐿𝑝conditionalsubscript𝐹𝐿subscript𝑈𝐿subscript𝑍𝐿1superscriptsubscriptproduct𝑙1𝐿𝑞subscript𝑈𝑙q(F_{L},\{U_{l}\}_{l=1}^{L})=p(F_{L}|U_{L};Z_{L-1})\prod_{l=1}^{L}q(U_{l})italic_q ( italic_F start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT , { italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ) = italic_p ( italic_F start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT | italic_U start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ; italic_Z start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT ) ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_q ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ), where q(Ul)𝑞subscript𝑈𝑙q(U_{l})italic_q ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) is taken to be a multivariate Gaussian with mean mUlsubscript𝑚subscript𝑈𝑙m_{U_{l}}italic_m start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT and variance SUlsubscript𝑆subscript𝑈𝑙S_{U_{l}}italic_S start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT. This gives the DWKL variational lower bound:

DKWL=𝐄q(FL,{Ul}l=1L)p(YFL)l=1LKL[q(Ul)p(Ul)]subscript𝐷𝐾𝑊𝐿subscript𝐄𝑞subscript𝐹𝐿superscriptsubscriptsubscript𝑈𝑙𝑙1𝐿𝑝conditional𝑌subscript𝐹𝐿superscriptsubscript𝑙1𝐿𝐾𝐿delimited-[]conditional𝑞subscript𝑈𝑙𝑝subscript𝑈𝑙\mathcal{L}_{DKWL}=\textbf{E}_{q(F_{L},\{U_{l}\}_{l=1}^{L})}p(Y\mid F_{L})-% \sum_{l=1}^{L}KL\left[q(U_{l})\|p(U_{l})\right]caligraphic_L start_POSTSUBSCRIPT italic_D italic_K italic_W italic_L end_POSTSUBSCRIPT = E start_POSTSUBSCRIPT italic_q ( italic_F start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT , { italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT italic_p ( italic_Y ∣ italic_F start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) - ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_K italic_L [ italic_q ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ∥ italic_p ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ](83)

, where q(FL)=𝒩(U~L(Enc(X)),Σ~L(Enc(X)))𝑞subscript𝐹𝐿𝒩subscript~𝑈𝐿𝐸𝑛𝑐𝑋subscript~Σ𝐿𝐸𝑛𝑐𝑋q(F_{L})=\mathcal{N}(\tilde{U}_{L}(Enc(X)),\tilde{\Sigma}_{L}(Enc(X)))italic_q ( italic_F start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) = caligraphic_N ( over~ start_ARG italic_U end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ( italic_E italic_n italic_c ( italic_X ) ) , over~ start_ARG roman_Σ end_ARG start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ( italic_E italic_n italic_c ( italic_X ) ) ). For 1lL11𝑙𝐿11\leq l\leq L-11 ≤ italic_l ≤ italic_L - 1, Flsubscript𝐹𝑙F_{l}italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT act as features for the next kernel as opposed to random variables that need to be integrated out. We provide pseudo-code of the previously mentioned operations (see Algorithm 1).

  Input: Euclidean data X=F0𝑋subscript𝐹0X=F_{0}italic_X = italic_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
  First layer is standard sparse variational GP
  Variational Parameters: U1𝒩(mU1,ΣU1)similar-tosubscript𝑈1𝒩subscript𝑚subscript𝑈1subscriptΣsubscript𝑈1U_{1}\sim\mathcal{N}(m_{U_{1}},~{}\Sigma_{U_{1}})italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∼ caligraphic_N ( italic_m start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT )
  Inducing Points: Euclidean space Z0subscript𝑍0Z_{0}italic_Z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT
  q(F1)=𝒩(F1KfuSEKuuSE1mU1,KffSEKfuSEKuuSE1(KuuSESU1)KuuSE1KufSEq(F_{1})=\mathcal{N}(F_{1}\mid K_{fu}^{SE}K^{SE}_{uu}\raisebox{4.95134pt}{$% \scriptscriptstyle-\!1$}m_{U_{1}},K_{ff}^{SE}-K_{fu}^{SE}K^{SE}_{uu}\raisebox{% 4.95134pt}{$\scriptscriptstyle-\!1$}(K^{SE}_{uu}-S_{U_{1}})K^{SE}_{uu}% \raisebox{4.95134pt}{$\scriptscriptstyle-\!1$}K_{uf}^{SE}italic_q ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = caligraphic_N ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∣ italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT italic_K start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - 1 italic_m start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT - italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT italic_K start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - 1 ( italic_K start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - italic_S start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) italic_K start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - 1 italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT
  for l=2𝑙2l=2italic_l = 2 to L𝐿Litalic_L do
     Hidden layers are distributional sparse variational GP
     Variational Parameters: Ul𝒩(mUl,SUl)similar-tosubscript𝑈𝑙𝒩subscript𝑚subscript𝑈𝑙subscript𝑆subscript𝑈𝑙U_{l}\sim\mathcal{N}(m_{U_{l}},~{}S_{U_{l}})italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∼ caligraphic_N ( italic_m start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT )
     Inducing Points: Zl1𝒩(μZl1,ΣZl1)similar-tosubscript𝑍𝑙1𝒩subscript𝜇subscript𝑍𝑙1subscriptΣsubscript𝑍𝑙1Z_{l-1}\sim\mathcal{N}(\mu_{Z_{l-1}},~{}\Sigma_{Z_{l-1}})italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ∼ caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT )
     Compute KfuW2superscriptsubscript𝐾𝑓𝑢subscript𝑊2K_{fu}^{W_{2}}italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT: σl2expd=1DlW22(q(Fl1[:,d]),Zl1[:,d])ll,d2subscriptsuperscript𝜎2𝑙superscriptsubscript𝑑1subscript𝐷𝑙superscriptsubscript𝑊22𝑞subscript𝐹𝑙1:𝑑subscript𝑍𝑙1:𝑑superscriptsubscript𝑙𝑙𝑑2\sigma^{2}_{l}\exp\sum_{d=1}^{D_{l}}\frac{-W_{2}^{2}(q(F_{l-1}[:,d]),Z_{l-1}[:% ,d])}{l_{l,d}^{2}}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT roman_exp ∑ start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_q ( italic_F start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT [ : , italic_d ] ) , italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT [ : , italic_d ] ) end_ARG start_ARG italic_l start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
     Compute KuuW2superscriptsubscript𝐾𝑢𝑢subscript𝑊2K_{uu}^{W_{2}}italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT: σl2expd=1DlW22(Zl1[:,d],Zl1[:,d])ll,d2subscriptsuperscript𝜎2𝑙superscriptsubscript𝑑1subscript𝐷𝑙superscriptsubscript𝑊22subscript𝑍𝑙1:𝑑subscript𝑍𝑙1:𝑑superscriptsubscript𝑙𝑙𝑑2\sigma^{2}_{l}\exp\sum_{d=1}^{D_{l}}\frac{-W_{2}^{2}(Z_{l-1}[:,d],Z_{l-1}[:,d]% )}{l_{l,d}^{2}}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT roman_exp ∑ start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT [ : , italic_d ] , italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT [ : , italic_d ] ) end_ARG start_ARG italic_l start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
     q(Fl)=𝒩(FlKfuW2KuuW21mUl,KffW2KfuW2KuuW21[KuuW2SUl]KuuW21KufW2q(F_{l})=\mathcal{N}(F_{l}\mid K_{fu}^{W_{2}}K^{W_{2}}_{uu}\raisebox{4.95134pt% }{$\scriptscriptstyle-\!1$}m_{U_{l}},K_{ff}^{W_{2}}-K_{fu}^{W_{2}}K^{W_{2}}_{% uu}\raisebox{4.95134pt}{$\scriptscriptstyle-\!1$}\left[K^{W_{2}}_{uu}-S_{U_{l}% }\right]K^{W_{2}}_{uu}\raisebox{4.95134pt}{$\scriptscriptstyle-\!1$}K_{uf}^{W_% {2}}italic_q ( italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) = caligraphic_N ( italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_K start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - 1 italic_m start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_K start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - 1 [ italic_K start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - italic_S start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] italic_K start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - 1 italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
  end for
  Maximize ELBO: 𝔼q(FL),{q(Ul}l=1L)p(YFL)l=1LKL[q(Ul)p(Ul)]\mathbb{E}_{q(F_{L}),\{q(U_{l}\}_{l=1}^{L})}p(Y\mid F_{L})-\sum_{l=1}^{L}KL% \left[q(U_{l})\|p(U_{l})\right]blackboard_E start_POSTSUBSCRIPT italic_q ( italic_F start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) , { italic_q ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT italic_p ( italic_Y ∣ italic_F start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) - ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_K italic_L [ italic_q ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ∥ italic_p ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ]
Algorithm 1 Deep Wasserstein Kernel Learning

3.2 Module Architecture

Refer to caption
Figure 7: Schematic of measure-preserving DistGP layer. Sparse variational GP is convolved on input data to obtain first hidden layer. Affine operator is convolved on stochastic layer, propagating both mean and variance to obtain the pre-activation of the second hidden layer. Distributional GP is applied element-wise to introduce non-linearities and to propagate distributional uncertainty in the post-activation of the second hidden layer.

For ease of notation and graphical representation we describe the case of the input being a 2D image, with no loss of generality. We denote the image’s representation FlHl,Wl,Clsubscript𝐹𝑙superscriptsubscript𝐻𝑙subscript𝑊𝑙subscript𝐶𝑙F_{l}\in\mathbb{R}^{H_{l},W_{l},C_{l}}italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT with width Wlsubscript𝑊𝑙W_{l}italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT, height Hlsubscript𝐻𝑙H_{l}italic_H start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT and Clsubscript𝐶𝑙C_{l}italic_C start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT channels at the l-th layer of a multi-layer model. F0subscript𝐹0F_{0}italic_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT is the image. Consider a square kernel of size kl×klsubscript𝑘𝑙subscript𝑘𝑙k_{l}\!\times\!k_{l}italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT × italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT. We denote with Fl[p,kl]kl,kl,Clsubscriptsuperscript𝐹𝑝subscript𝑘𝑙𝑙superscriptsubscript𝑘𝑙subscript𝑘𝑙subscript𝐶𝑙F^{[p,k_{l}]}_{l}\in\mathbb{R}^{k_{l},k_{l},C_{l}}italic_F start_POSTSUPERSCRIPT [ italic_p , italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT the p𝑝pitalic_p-th patch of Flsubscript𝐹𝑙F_{l}italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT, which is the area of Flsubscript𝐹𝑙F_{l}italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT that the kernel covers when overlaid at position p𝑝pitalic_p during convolution (e.g., orange square for a 3×3333\!\times\!33 × 3 kernel in Figure 7). We introduce the convolved GP0:F0[p,k0]𝒩(m,k):𝐺subscript𝑃0subscriptsuperscript𝐹𝑝subscript𝑘00𝒩𝑚𝑘GP_{0}:F^{[p,k_{0}]}_{0}\rightarrow\mathcal{N}(m,k)italic_G italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT : italic_F start_POSTSUPERSCRIPT [ italic_p , italic_k start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT → caligraphic_N ( italic_m , italic_k ) with Z0k0,k0,C0subscript𝑍0superscriptsubscript𝑘0subscript𝑘0subscript𝐶0Z_{0}\in\mathbb{R}^{k_{0},k_{0},C_{0}}italic_Z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_k start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT to be the SGP operating on the Euclidean space of patches of the input image in a similar fashion to the layers introduced in Blomqvist et al. (2018). For 1lL1𝑙𝐿1\leq l\leq L1 ≤ italic_l ≤ italic_L we introduce affine operators Alkl,kl,Cl1,Cl,presubscript𝐴𝑙superscriptsubscript𝑘𝑙subscript𝑘𝑙subscript𝐶𝑙1subscript𝐶𝑙𝑝𝑟𝑒A_{l}\in\mathbb{R}^{k_{l},k_{l},C_{l-1},C_{l,pre}}italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_l , italic_p italic_r italic_e end_POSTSUBSCRIPT end_POSTSUPERSCRIPT which are convolved on the previous stochastic layer in the following manner:

m(Flpre)𝑚subscriptsuperscript𝐹𝑝𝑟𝑒𝑙\displaystyle m(F^{pre}_{l})italic_m ( italic_F start_POSTSUPERSCRIPT italic_p italic_r italic_e end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT )=Conv2D(m(Fl1),Al)absentsubscriptConv2D𝑚subscript𝐹𝑙1subscript𝐴𝑙\displaystyle=\text{Conv}_{\text{2D}}(m(F_{l-1}),A_{l})= Conv start_POSTSUBSCRIPT 2D end_POSTSUBSCRIPT ( italic_m ( italic_F start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) , italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT )(84)
var(Flpre)𝑣𝑎𝑟subscriptsuperscript𝐹𝑝𝑟𝑒𝑙\displaystyle var(F^{pre}_{l})italic_v italic_a italic_r ( italic_F start_POSTSUPERSCRIPT italic_p italic_r italic_e end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT )=Conv2D(var(Fl1),AlAl)absentsubscriptConv2D𝑣𝑎𝑟subscript𝐹𝑙1direct-productsubscript𝐴𝑙subscript𝐴𝑙\displaystyle=\text{Conv}_{\text{2D}}(var(F_{l-1}),A_{l}\odot A_{l})= Conv start_POSTSUBSCRIPT 2D end_POSTSUBSCRIPT ( italic_v italic_a italic_r ( italic_F start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) , italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ⊙ italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT )(85)

, where direct-product\odot represents the Hadamard product. The affine operator is sequentially applied on the mean, respectively variance components of the previous layer Fl1subscript𝐹𝑙1F_{l-1}italic_F start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT so as to propagate the Gaussian distribution to the next pre-activation layer Flpresubscriptsuperscript𝐹𝑝𝑟𝑒𝑙F^{pre}_{l}italic_F start_POSTSUPERSCRIPT italic_p italic_r italic_e end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT. To obtain the post-activation layer, we apply a DistGPl:Flpre,[p,1]𝒩(m,k):𝐷𝑖𝑠𝑡𝐺subscript𝑃𝑙subscriptsuperscript𝐹𝑝𝑟𝑒𝑝1𝑙𝒩𝑚𝑘DistGP_{l}:F^{pre,[p,1]}_{l}\rightarrow\mathcal{N}(m,k)italic_D italic_i italic_s italic_t italic_G italic_P start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT : italic_F start_POSTSUPERSCRIPT italic_p italic_r italic_e , [ italic_p , 1 ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT → caligraphic_N ( italic_m , italic_k ) in a many-to-one manner on the pre-activation patches to arrive at Flpostsubscriptsuperscript𝐹𝑝𝑜𝑠𝑡𝑙F^{post}_{l}italic_F start_POSTSUPERSCRIPT italic_p italic_o italic_s italic_t end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT. Figure 7 depicts this new module, entitled “Measure preserving DistGP” layer with pseudo-code offered in Algorithm 2. In Blomqvist et al. (2018) the convolved GP is used across the entire hierarchy, thereby inducing points are in high-dimensional space (kl2*Clsuperscriptsubscript𝑘𝑙2subscript𝐶𝑙k_{l}^{2}*C_{l}italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT * italic_C start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT). In our case, the convolutional process is replaced by an inducing points free affine operator, with inducing points in low-dimensional space (Cl,presubscript𝐶𝑙𝑝𝑟𝑒C_{l,pre}italic_C start_POSTSUBSCRIPT italic_l , italic_p italic_r italic_e end_POSTSUBSCRIPT) for the DistGP activation functions. The affine operator outputs Cl,presubscript𝐶𝑙𝑝𝑟𝑒C_{l,pre}italic_C start_POSTSUBSCRIPT italic_l , italic_p italic_r italic_e end_POSTSUBSCRIPT, which is taken to be higher than the associated output space of DistGP activation functions Clsubscript𝐶𝑙C_{l}italic_C start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT. Hence, the affine operator can cheaply expand the channels, in constrast to the layers in Blomqvist et al. (2018) which would require high-dimensional multi-output GP. We motivate the preservation of distance in Wasserstein-2 space in the following section. Previous research has highlighted the importance of having an upper bound on h(x1)h(x2)hLupperx1x2xsubscriptnormsubscript𝑥1subscript𝑥2subscript𝐿𝑢𝑝𝑝𝑒𝑟subscriptnormsubscript𝑥1subscript𝑥2𝑥||h(x_{1})-h(x_{2})||_{h}\leq L_{upper}||x_{1}-x_{2}||_{x}| | italic_h ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) - italic_h ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) | | start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ≤ italic_L start_POSTSUBSCRIPT italic_u italic_p italic_p italic_e italic_r end_POSTSUBSCRIPT | | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT, as it ensures a certain degree of robustness towards adversarial examples, since it prevents the hidden forward mappings from being overly sensitive to the conceptually meaningless perturbations in input space (Jacobsen et al., 2018; Sokolić et al., 2017; Weng et al., 2018). Conversely, the lower bound h(x1)h(x2)hLlowerx1x2xsubscriptnormsubscript𝑥1subscript𝑥2subscript𝐿𝑙𝑜𝑤𝑒𝑟subscriptnormsubscript𝑥1subscript𝑥2𝑥||h(x_{1})-h(x_{2})||_{h}\geq L_{lower}||x_{1}-x_{2}||_{x}| | italic_h ( italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) - italic_h ( italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) | | start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ≥ italic_L start_POSTSUBSCRIPT italic_l italic_o italic_w italic_e italic_r end_POSTSUBSCRIPT | | italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | | start_POSTSUBSCRIPT italic_x end_POSTSUBSCRIPT ensures that the forward mappings do not become invariant to semantically meaningful changes in the input van Amersfoort et al. (2020).

  Input: Euclidean data X=F0H0,W0,C0𝑋subscript𝐹0superscriptsubscript𝐻0subscript𝑊0subscript𝐶0X=F_{0}\in\mathbb{R}^{H_{0},W_{0},C_{0}}italic_X = italic_F start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_H start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_W start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
  First layer is convolved sparse variational GP0:F0[p,k0]F1[p]:𝐺subscript𝑃0subscriptsuperscript𝐹𝑝subscript𝑘00subscriptsuperscript𝐹delimited-[]𝑝1GP_{0}:F^{[p,k_{0}]}_{0}\rightarrow F^{[p]}_{1}italic_G italic_P start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT : italic_F start_POSTSUPERSCRIPT [ italic_p , italic_k start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT → italic_F start_POSTSUPERSCRIPT [ italic_p ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT
  Variational Parameters: U1𝒩(mU1,SU1)similar-tosubscript𝑈1𝒩subscript𝑚subscript𝑈1subscript𝑆subscript𝑈1U_{1}\sim\mathcal{N}(m_{U_{1}},~{}S_{U_{1}})italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∼ caligraphic_N ( italic_m start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT )
  Inducing Points: Euclidean space Z0k0,k0,C0subscript𝑍0superscriptsubscript𝑘0subscript𝑘0subscript𝐶0Z_{0}\in\mathbb{R}^{k_{0},k_{0},C_{0}}italic_Z start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_k start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
  q(F1)=𝒩(F1KfuSEKuuSE1mU1,KffSEKfuSEKuuSE1(KuuSESU1)KuuSE1KufSEq(F_{1})=\mathcal{N}(F_{1}\mid K_{fu}^{SE}K^{SE}_{uu}\raisebox{4.95134pt}{$% \scriptscriptstyle-\!1$}m_{U_{1}},K_{ff}^{SE}-K_{fu}^{SE}K^{SE}_{uu}\raisebox{% 4.95134pt}{$\scriptscriptstyle-\!1$}(K^{SE}_{uu}-S_{U_{1}})K^{SE}_{uu}% \raisebox{4.95134pt}{$\scriptscriptstyle-\!1$}K_{uf}^{SE}italic_q ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = caligraphic_N ( italic_F start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∣ italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT italic_K start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - 1 italic_m start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT - italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT italic_K start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - 1 ( italic_K start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - italic_S start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) italic_K start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - 1 italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT
  for l=2𝑙2l=2italic_l = 2 to L𝐿Litalic_L do
     affine operators: Alkl,kl,Cl1,Cl,presubscript𝐴𝑙superscriptsubscript𝑘𝑙subscript𝑘𝑙subscript𝐶𝑙1subscript𝐶𝑙𝑝𝑟𝑒A_{l}\in\mathbb{R}^{k_{l},k_{l},C_{l-1},C_{l,pre}}italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_k start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT italic_l , italic_p italic_r italic_e end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
     m(Flpre)=Conv2D(m(Fl1),Al)𝑚subscriptsuperscript𝐹𝑝𝑟𝑒𝑙𝐶𝑜𝑛subscript𝑣2𝐷𝑚subscript𝐹𝑙1subscript𝐴𝑙m(F^{pre}_{l})=Conv_{2D}(m(F_{l-1}),A_{l})italic_m ( italic_F start_POSTSUPERSCRIPT italic_p italic_r italic_e end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) = italic_C italic_o italic_n italic_v start_POSTSUBSCRIPT 2 italic_D end_POSTSUBSCRIPT ( italic_m ( italic_F start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) , italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT )
     v(Flpre)=Conv2D(var(Fl1),Al2)𝑣subscriptsuperscript𝐹𝑝𝑟𝑒𝑙𝐶𝑜𝑛subscript𝑣2𝐷𝑣𝑎𝑟subscript𝐹𝑙1superscriptsubscript𝐴𝑙2v(F^{pre}_{l})=Conv_{2D}(var(F_{l-1}),A_{l}^{2})italic_v ( italic_F start_POSTSUPERSCRIPT italic_p italic_r italic_e end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) = italic_C italic_o italic_n italic_v start_POSTSUBSCRIPT 2 italic_D end_POSTSUBSCRIPT ( italic_v italic_a italic_r ( italic_F start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) , italic_A start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
     Hidden layer activation functions are sparse variational GP DistGPl:Flpre,[p,1]Flpost,[p,1]:𝐷𝑖𝑠𝑡𝐺subscript𝑃𝑙subscriptsuperscript𝐹𝑝𝑟𝑒𝑝1𝑙subscriptsuperscript𝐹𝑝𝑜𝑠𝑡𝑝1𝑙DistGP_{l}:F^{pre,[p,1]}_{l}\rightarrow F^{post,[p,1]}_{l}italic_D italic_i italic_s italic_t italic_G italic_P start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT : italic_F start_POSTSUPERSCRIPT italic_p italic_r italic_e , [ italic_p , 1 ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT → italic_F start_POSTSUPERSCRIPT italic_p italic_o italic_s italic_t , [ italic_p , 1 ] end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT
     Variational Parameters: Ul𝒩(mUl,SUl)similar-tosubscript𝑈𝑙𝒩subscript𝑚subscript𝑈𝑙subscript𝑆subscript𝑈𝑙U_{l}\sim\mathcal{N}(m_{U_{l}},~{}S_{U_{l}})italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∼ caligraphic_N ( italic_m start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT )
     Inducing Points: Zl1𝒩(μZl1,ΣZl1)similar-tosubscript𝑍𝑙1𝒩subscript𝜇subscript𝑍𝑙1subscriptΣsubscript𝑍𝑙1Z_{l-1}\sim\mathcal{N}(\mu_{Z_{l-1}},~{}\Sigma_{Z_{l-1}})italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ∼ caligraphic_N ( italic_μ start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT )
     Compute KfuW2superscriptsubscript𝐾𝑓𝑢subscript𝑊2K_{fu}^{W_{2}}italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT: σl2expd=1DlW22(q(Fl1[:,d]),Zl1[:,d])ll,d2subscriptsuperscript𝜎2𝑙superscriptsubscript𝑑1subscript𝐷𝑙superscriptsubscript𝑊22𝑞subscript𝐹𝑙1:𝑑subscript𝑍𝑙1:𝑑superscriptsubscript𝑙𝑙𝑑2\sigma^{2}_{l}\exp\sum_{d=1}^{D_{l}}\frac{-W_{2}^{2}(q(F_{l-1}[:,d]),Z_{l-1}[:% ,d])}{l_{l,d}^{2}}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT roman_exp ∑ start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_q ( italic_F start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT [ : , italic_d ] ) , italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT [ : , italic_d ] ) end_ARG start_ARG italic_l start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
     Compute KuuW2superscriptsubscript𝐾𝑢𝑢subscript𝑊2K_{uu}^{W_{2}}italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT: σl2expd=1DlW22(Zl1[:,d],Zl1[:,d])ll,d2subscriptsuperscript𝜎2𝑙superscriptsubscript𝑑1subscript𝐷𝑙superscriptsubscript𝑊22subscript𝑍𝑙1:𝑑subscript𝑍𝑙1:𝑑superscriptsubscript𝑙𝑙𝑑2\sigma^{2}_{l}\exp\sum_{d=1}^{D_{l}}\frac{-W_{2}^{2}(Z_{l-1}[:,d],Z_{l-1}[:,d]% )}{l_{l,d}^{2}}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT roman_exp ∑ start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT divide start_ARG - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT [ : , italic_d ] , italic_Z start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT [ : , italic_d ] ) end_ARG start_ARG italic_l start_POSTSUBSCRIPT italic_l , italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG
     q(Flpost)=𝒩(FlKfuW2KuuW21mUl,KffW2KfuW2KuuW21[KuuW2SUl]KuuW21KufW2q(F_{l}^{post})=\mathcal{N}(F_{l}\mid K_{fu}^{W_{2}}K^{W_{2}}_{uu}\raisebox{4.% 95134pt}{$\scriptscriptstyle-\!1$}m_{U_{l}},K_{ff}^{W_{2}}-K_{fu}^{W_{2}}K^{W_% {2}}_{uu}\raisebox{4.95134pt}{$\scriptscriptstyle-\!1$}\left[K^{W_{2}}_{uu}-S_% {U_{l}}\right]K^{W_{2}}_{uu}\raisebox{4.95134pt}{$\scriptscriptstyle-\!1$}K_{% uf}^{W_{2}}italic_q ( italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_p italic_o italic_s italic_t end_POSTSUPERSCRIPT ) = caligraphic_N ( italic_F start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∣ italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_K start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - 1 italic_m start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT , italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT - italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_K start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - 1 [ italic_K start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - italic_S start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] italic_K start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - 1 italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT
  end for
  Maximize ELBO: 𝔼q(FL),{q(Ul}l=1L)p(YFL)l=1LKL[q(Ul)p(Ul)]\mathbb{E}_{q(F_{L}),\{q(U_{l}\}_{l=1}^{L})}p(Y\mid F_{L})-\sum_{l=1}^{L}KL% \left[q(U_{l})\|p(U_{l})\right]blackboard_E start_POSTSUBSCRIPT italic_q ( italic_F start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) , { italic_q ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT italic_p ( italic_Y ∣ italic_F start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ) - ∑ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_K italic_L [ italic_q ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ∥ italic_p ( italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ]
Algorithm 2 Distributional Gaussian Processes Layers

3.3 Imposing Lipschitz Conditions in Convolutionally Warped DistGP

If a sample is identified as an outlier at certain layer, respectively being flagged with high variance, in an ideal scenario we would like to preserve that status throughout the remainder of the network. As the kernels operate in Wasserstein-2 space, the distance of a data point’s first two moments with respect to inducing points is vital. Hence, we would like our network to vary smoothly between layers, so that similar objects in previous layers get mapped into similar spaces in the Wasserstein-2 domain. In this section, we accomplish this by quantifying the "Lipschitzness" of our "Measure preserving DistGP" layer and by imposing constraints on the affine operators so that they preserve distances in Wasserstein-2 space.

Proposition 3

For a given DistGP F𝐹Fitalic_F and a Gaussian distribution μ𝒩(m1,Σ1)similar-to𝜇𝒩subscript𝑚1subscriptnormal-Σ1\mu\sim\mathcal{N}(m_{1},\Sigma_{1})italic_μ ∼ caligraphic_N ( italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) to be the centre of an annulus B(x)={ν𝒩(m2,Σ2)0.125W2(μ,ν)l21.0B(x)=\{\nu\sim\mathcal{N}\left(m_{2},\Sigma_{2}\right)\mid 0.125\leq\frac{W_{2% }(\mu,\nu)}{l^{2}}\leq 1.0italic_B ( italic_x ) = { italic_ν ∼ caligraphic_N ( italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ∣ 0.125 ≤ divide start_ARG italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_μ , italic_ν ) end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ≤ 1.0 and choosing any ν𝜈\nuitalic_ν inside the ball we have the following Lipschitz bounds: W2(F(μ),F(ν))LW2(μ,ν)subscript𝑊2𝐹𝜇𝐹𝜈𝐿subscript𝑊2𝜇𝜈W_{2}(F(\mu),F(\nu))\leq LW_{2}(\mu,\nu)italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_F ( italic_μ ) , italic_F ( italic_ν ) ) ≤ italic_L italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_μ , italic_ν ), where L=(4σ2l)2[Kuu1m22+Kuu1(KuuS)Kuu12]𝐿superscript4superscript𝜎2𝑙2delimited-[]superscriptsubscriptnormsuperscriptsubscript𝐾𝑢𝑢1𝑚22subscriptnormsuperscriptsubscript𝐾𝑢𝑢1subscript𝐾𝑢𝑢𝑆superscriptsubscript𝐾𝑢𝑢12L=(\frac{4\sigma^{2}}{l})^{2}\left[\|K_{uu}^{-1}m\|_{2}^{2}+\|K_{uu}^{-1}\left% (K_{uu}-S\right)K_{uu}^{-1}\|_{2}\right]italic_L = ( divide start_ARG 4 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_l end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT [ ∥ italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_m ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∥ italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT - italic_S ) italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] and l,σ2𝑙superscript𝜎2l,\sigma^{2}italic_l , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT are the lengthscales and variance of the kernel.

Proof is given in Appendix A.

Remark 4

This theoretical result shows that DistGP "activation functions" have Lipschitz constants with respect to the Wasserstein-2 metric in both output and input domain. This will ensure that the distance between previously identified outliers and inliers will stay constant. However, it is worthy to highlight that we can only obtain locally Lipschitz continuous functions, given that we can only obtain Lipschitz constants for any Gaussian distribution ν𝜈\nuitalic_ν inside the annulus B(x)={ν𝒩(m2,Σ2)0.125W2(μ,ν)l21B(x)=\{\nu\sim\mathcal{N}\left(m_{2},\Sigma_{2}\right)\mid 0.125\leq\frac{W_{2% }(\mu,\nu)}{l^{2}}\leq 1italic_B ( italic_x ) = { italic_ν ∼ caligraphic_N ( italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , roman_Σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) ∣ 0.125 ≤ divide start_ARG italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_μ , italic_ν ) end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ≤ 1. with respect to the centre of the annulus, μ𝜇\muitalic_μ.

We are now interested in finding Lipschitz constants for the affine operator A𝐴Aitalic_A that gets convolved to arrive at the pre-activation stochastic layer.

Proposition 5

We consider the affine operator AC,1𝐴superscript𝐶1A\in\mathbb{R}^{C,1}italic_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_C , 1 end_POSTSUPERSCRIPT operating in the space of multivariate Gaussian distributions of size C. Consider two distributions μ𝒩(m1,σ12)similar-to𝜇𝒩subscript𝑚1subscriptsuperscript𝜎21\mu\sim\mathcal{N}(m_{1},\sigma^{2}_{1})italic_μ ∼ caligraphic_N ( italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) and ν𝒩(m2,σ22)similar-to𝜈𝒩subscript𝑚2subscriptsuperscript𝜎22\nu\sim\mathcal{N}(m_{2},\sigma^{2}_{2})italic_ν ∼ caligraphic_N ( italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ), which can be thought of as elements of a hidden layer patch, then for the affine operator function f(μ)=(m1A,σ2A2)𝑓𝜇subscript𝑚1𝐴superscript𝜎2superscript𝐴2f(\mu)=\mathbb{N}(m_{1}A,\sigma^{2}A^{2})italic_f ( italic_μ ) = blackboard_N ( italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_A , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) we have the following Lipschitz bound: W2(f(μ),f(ν))LW2(μ,ν)subscript𝑊2𝑓𝜇𝑓𝜈𝐿subscript𝑊2𝜇𝜈W_{2}\left(f(\mu),f(\nu)\right)\leq LW_{2}\left(\mu,\nu\right)italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_f ( italic_μ ) , italic_f ( italic_ν ) ) ≤ italic_L italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_μ , italic_ν ), where L=CW22𝐿𝐶superscriptsubscriptnorm𝑊22L=\sqrt{C}\|W\|_{2}^{2}italic_L = square-root start_ARG italic_C end_ARG ∥ italic_W ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

Proof is given in Appendix A.

Remark 6

We denote the l-th layer weight matrix, computing the c-th channel by column matrix Al,csubscript𝐴𝑙𝑐A_{l,c}italic_A start_POSTSUBSCRIPT italic_l , italic_c end_POSTSUBSCRIPT. We can impose the Lipschitz condition to Eq. 8485 by having constrained weight matrices with elements of the form Al,c=Al,1C12c=1CWl,c2subscript𝐴𝑙𝑐subscript𝐴𝑙1superscript𝐶12superscriptsubscript𝑐1𝐶superscriptsubscript𝑊𝑙𝑐2A_{l,c}=\frac{A_{l,1}}{C^{\frac{1}{2}}\sqrt{\sum_{c=1}^{C}W_{l,c}^{2}}}italic_A start_POSTSUBSCRIPT italic_l , italic_c end_POSTSUBSCRIPT = divide start_ARG italic_A start_POSTSUBSCRIPT italic_l , 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_C start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT square-root start_ARG ∑ start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT italic_l , italic_c end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG.

3.4 Feature-collapse in DistGP layers

In this subsection we delve deeper into the properties of DistGP layers from a function-space view. In light of recent interest into feature collapse van Amersfoort et al. (2020), which is the pathological phenomenon of having the representation layer collapse to a small finite set of values, with catastrophic consequences for OOD detection, we investigate what are the necessary conditions for our proposed network to collapse in feature space. Subsequently, we investigate if feature collapse is inherently encouraged by our loss function.

We commence by introducing notation conventions. We consider {ulDl}l=0:Lsubscriptsubscript𝑢𝑙superscriptsubscript𝐷𝑙:𝑙0𝐿\{u_{l}\in\mathbb{R}^{D_{l}}\}_{l=0:L}{ italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_l = 0 : italic_L end_POSTSUBSCRIPT where Dlsubscript𝐷𝑙D_{l}italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT is the number of dimensions in the l-th layer of the hierarchy. We consider the following two functions Ψl:ul1ml:subscriptΨ𝑙subscript𝑢𝑙1superscriptsubscript𝑚𝑙\Psi_{l}:u_{l-1}\to\mathcal{R}^{m_{l}}roman_Ψ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT : italic_u start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT → caligraphic_R start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and fl:mlul:subscript𝑓𝑙superscriptsubscript𝑚𝑙subscript𝑢𝑙f_{l}:\mathcal{R}^{m_{l}}\to u_{l}italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT : caligraphic_R start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT → italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT. To relate this notation to our construction of a DistGP layer introduced in section 3.2, mlsubscript𝑚𝑙m_{l}italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT represents the number of dimensions of the warped GP (warping performed by affine deterministic layer; see dark green arrows in Figure 7). We denote by flsubscript𝑓𝑙f_{l}italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT to be the DistGP (mean function included) taking values in the space of continuous functions C(ul;ml)𝐶subscript𝑢𝑙superscriptsubscript𝑚𝑙C(u_{l};\mathcal{R}^{m_{l}})italic_C ( italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ; caligraphic_R start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ), which relates to the "activation function" construction from Figure 7. Then we have the following composition for a given DistGP layer:

ul(x)=fl(Ψl(ul1)(x))subscript𝑢𝑙𝑥subscript𝑓𝑙subscriptΨ𝑙subscript𝑢𝑙1𝑥u_{l}(x)=f_{l}\left(\Psi_{l}(u_{l-1})(x)\right)italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) = italic_f start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( roman_Ψ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_u start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ) ( italic_x ) )(86)

One can easily see that DWKL can be recovered by taking Ψl=idsubscriptΨ𝑙𝑖𝑑\Psi_{l}=idroman_Ψ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = italic_i italic_d, instead of the affine embedding. The first layer prior p(u1(x)u1(x*))𝑝matrixsubscript𝑢1𝑥subscript𝑢1superscript𝑥p\begin{pmatrix}u_{1}(x)\\ u_{1}(x^{*})\end{pmatrix}italic_p ( start_ARG start_ROW start_CELL italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) end_CELL end_ROW start_ROW start_CELL italic_u start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARG ) is defined as follows:

𝒩[(m1(x)m1(x*)),(σ12kE(x,x*)kE(x*,x)σ12)]𝒩matrixsubscript𝑚1𝑥subscript𝑚1superscript𝑥matrixsubscriptsuperscript𝜎21superscript𝑘𝐸𝑥superscript𝑥superscript𝑘𝐸superscript𝑥𝑥subscriptsuperscript𝜎21\mathcal{N}\left[\begin{pmatrix}m_{1}(x)\\ m_{1}(x^{*})\end{pmatrix},\begin{pmatrix}\sigma^{2}_{1}&k^{E}\left(x,x^{*}% \right)\\ k^{E}\left(x^{*},x\right)&\sigma^{2}_{1}\end{pmatrix}\right]caligraphic_N [ ( start_ARG start_ROW start_CELL italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) end_CELL end_ROW start_ROW start_CELL italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARG ) , ( start_ARG start_ROW start_CELL italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL italic_k start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT ( italic_x , italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) end_CELL end_ROW start_ROW start_CELL italic_k start_POSTSUPERSCRIPT italic_E end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT , italic_x ) end_CELL start_CELL italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ) ](87)

We now define the prior post-activation layers p(ul(x)ul(x*))𝑝matrixsubscript𝑢𝑙𝑥subscript𝑢𝑙superscript𝑥p\begin{pmatrix}u_{l}(x)\\ u_{l}(x^{*})\end{pmatrix}italic_p ( start_ARG start_ROW start_CELL italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) end_CELL end_ROW start_ROW start_CELL italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARG ) for l2𝑙2l\geq 2italic_l ≥ 2 in the following recursive manner:

𝒩[(ml(x)ml(x*)),(σl2kW2(μl1(x),μl1(x*))kW2(μl1(x*),μl1(x))σl2)]𝒩matrixsubscript𝑚𝑙𝑥subscript𝑚𝑙superscript𝑥matrixsubscriptsuperscript𝜎2𝑙superscript𝑘subscript𝑊2subscript𝜇𝑙1𝑥subscript𝜇𝑙1superscript𝑥superscript𝑘subscript𝑊2subscript𝜇𝑙1superscript𝑥subscript𝜇𝑙1𝑥subscriptsuperscript𝜎2𝑙\mathcal{N}\left[\begin{pmatrix}m_{l}(x)\\ m_{l}(x^{*})\end{pmatrix},\begin{pmatrix}\sigma^{2}_{l}&k^{W_{2}}\left(\mu_{l-% 1}(x),\mu_{l-1}(x^{*})\right)\\ k^{W_{2}}\left(\mu_{l-1}(x^{*}),\mu_{l-1}(x)\right)&\sigma^{2}_{l}\end{pmatrix% }\right]caligraphic_N [ ( start_ARG start_ROW start_CELL italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) end_CELL end_ROW start_ROW start_CELL italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARG ) , ( start_ARG start_ROW start_CELL italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_CELL start_CELL italic_k start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x ) , italic_μ start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL italic_k start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT ( italic_μ start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) , italic_μ start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x ) ) end_CELL start_CELL italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ) ](88)

, where μl(x)=𝒩(ml1(x)Wl,σl12Wl2)subscript𝜇𝑙𝑥𝒩subscript𝑚𝑙1𝑥subscript𝑊𝑙subscriptsuperscript𝜎2𝑙1superscriptsubscript𝑊𝑙2\mu_{l}\left(x\right)=\mathcal{N}\left(m_{l-1}(x)W_{l},\sigma^{2}_{l-1}W_{l}^{% 2}\right)italic_μ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) = caligraphic_N ( italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x ) italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ), where ml+1()=m1()W1¯Wl¯subscript𝑚𝑙1¯¯subscript𝑚1subscript𝑊1subscript𝑊𝑙m_{l+1}(\cdot)=\overline{\overline{m_{1}(\cdot)W_{1}}\cdots W_{l}}italic_m start_POSTSUBSCRIPT italic_l + 1 end_POSTSUBSCRIPT ( ⋅ ) = over¯ start_ARG over¯ start_ARG italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( ⋅ ) italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG ⋯ italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG and m1(x)W1¯¯subscript𝑚1𝑥subscript𝑊1\overline{m_{1}(x)W_{1}}over¯ start_ARG italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG signifies having the Principal Component Analysis (PCA) mean function of the first layer multiplied by W1subscript𝑊1W_{1}italic_W start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and averaged across its dimensions.

Proposition 7

We assume μ0subscript𝜇0\mu_{0}italic_μ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT to be bounded on bounded sets almost-surely. If at each layer we have satisfied the following inequality Dl2Wl~,Wl~1superscriptsubscript𝐷𝑙2normal-~subscript𝑊𝑙normal-~subscript𝑊𝑙1D_{l}^{2}\langle\tilde{W_{l}},\tilde{W_{l}}\rangle\leq 1italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⟨ over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ⟩ ≤ 1, respectively [DL*WL~,WL~+σL22lL2]1delimited-[]subscript𝐷𝐿normal-~subscript𝑊𝐿normal-~subscript𝑊𝐿superscriptsubscript𝜎𝐿22superscriptsubscript𝑙𝐿21\left[D_{L}*\langle\tilde{W_{L}},\tilde{W_{L}}\rangle+\frac{\sigma_{L}^{2}}{2l% _{L}^{2}}\right]\leq 1[ italic_D start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT * ⟨ over~ start_ARG italic_W start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_W start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG ⟩ + divide start_ARG italic_σ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_l start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ] ≤ 1, where Dlsubscript𝐷𝑙D_{l}italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT is the size of the l-th layer and Wl~normal-~subscript𝑊𝑙\tilde{W_{l}}over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG represents a normalized version of the affine embedding Wlsubscript𝑊𝑙W_{l}italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT, we have the following result:

P(un(x)un(x*)20)=1𝑃subscriptnormsubscript𝑢𝑛𝑥subscript𝑢𝑛superscript𝑥201P\left(\|u_{n}(x)-u_{n}(x^{*})\|_{2}\to 0\right)=1italic_P ( ∥ italic_u start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT → 0 ) = 1(89)

The proof of Proposition 3 can be found in Appendix B.

Remark 8

As we have previously outlined in the above derivation, if at each layer we have satisfied the following inequality Dlml1Dl1<Wl~,Wl~>1D_{l}m_{l-1}D_{l-1}<\tilde{W_{l}},\tilde{W_{l}}>\leq 1italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT < over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG > ≤ 1, respectively [mlDL1*WL~,WL~+σL22lL2]1delimited-[]subscript𝑚𝑙subscript𝐷𝐿1normal-~subscript𝑊𝐿normal-~subscript𝑊𝐿superscriptsubscript𝜎𝐿22superscriptsubscript𝑙𝐿21\left[m_{l}D_{L-1}*\langle\tilde{W_{L}},\tilde{W_{L}}\rangle+\frac{\sigma_{L}^% {2}}{2l_{L}^{2}}\right]\leq 1[ italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_L - 1 end_POSTSUBSCRIPT * ⟨ over~ start_ARG italic_W start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_W start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_ARG ⟩ + divide start_ARG italic_σ start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_l start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ] ≤ 1 then the network collapses to constant values. Intuitively, if the norm of Wlsubscript𝑊𝑙W_{l}italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT is not large enough, then it won’t change the Gaussian random field too much. Furthermore, if σl2superscriptsubscript𝜎𝑙2\sigma_{l}^{2}italic_σ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT is larger, which translates in increased amplitude of the samples from the Gaussian random field, then the values will not collapse. As opposed to the hypothetical requirements for DGP Dunlop et al. (2018), we can immediately notice that for DistGP layers there is no requirement for the kernel variance and lengthscales from intermediate layers, relying solely on the last layer hyperparameters. Lastly, we can notice that as the width of the stochastic layers is increased, alongside warped layers through affine embedding, the conditions are less likely to be satisfied.

3.5 Over-correlation in latent space

Ober et al. (2021) has highlighted a certain pathology in DKL applied to regression problems in the non-sparse scenario. The authors provide empirical examples of this pathology, whereby features in the representation learning layer are almost perfectly correlated, which would correspond to the feature collapse phenomenon as coined in van Amersfoort et al. (2020). We commence by briefly introducing the main results from that paper and then adapt them to the sparse scenario, which bears more resemblance to what occurs in practice.

Full GPs are trained via type-II maximum likelihood:

logp(y)𝑝𝑦\displaystyle\log{p(y)}roman_log italic_p ( italic_y )=log𝒩(y0,Kff+σnoise2𝕀n)absent𝒩conditional𝑦0subscript𝐾𝑓𝑓subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛\displaystyle=\log{\mathcal{N}\left(y\mid 0,K_{ff}+\sigma^{2}_{noise}\mathbb{I% }_{n}\right)}= roman_log caligraphic_N ( italic_y ∣ 0 , italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT )(90)
12logKff+σnoise2𝕀ncomplexity penalty12y(Kff+σnoise2𝕀n)1ydata fitproportional-toabsentsubscript12delimited-∣∣subscript𝐾𝑓𝑓subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛complexity penaltysubscript12superscript𝑦topsuperscriptsubscript𝐾𝑓𝑓subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛1𝑦data fit\displaystyle\propto-\underbrace{\frac{1}{2}\log\mid K_{ff}+\sigma^{2}_{noise}% \mathbb{I}_{n}\mid}_{\textit{complexity penalty}}-\underbrace{\frac{1}{2}y^{% \top}\left(K_{ff}+\sigma^{2}_{noise}\mathbb{I}_{n}\right)^{-1}y}_{\textit{data% fit}}∝ - under⏟ start_ARG divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log ∣ italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∣ end_ARG start_POSTSUBSCRIPT complexity penalty end_POSTSUBSCRIPT - under⏟ start_ARG divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_y end_ARG start_POSTSUBSCRIPT data fit end_POSTSUBSCRIPT(91)

, where we define the squared exponential kernel kSE(xi,xj)=σ2exp[d=1D(xi,dxj,d)22ld2]superscript𝑘𝑆𝐸subscript𝑥𝑖subscript𝑥𝑗superscript𝜎2superscriptsubscript𝑑1𝐷superscriptsubscript𝑥𝑖𝑑subscript𝑥𝑗𝑑22superscriptsubscript𝑙𝑑2k^{SE}\left(x_{i},x_{j}\right)=\sigma^{2}\exp\left[\sum\limits_{d=1}^{D}-\frac% {\left(x_{i,d}-x_{j,d}\right)^{2}}{2l_{d}^{2}}\right]italic_k start_POSTSUPERSCRIPT italic_S italic_E end_POSTSUPERSCRIPT ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ) = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_exp [ ∑ start_POSTSUBSCRIPT italic_d = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT - divide start_ARG ( italic_x start_POSTSUBSCRIPT italic_i , italic_d end_POSTSUBSCRIPT - italic_x start_POSTSUBSCRIPT italic_j , italic_d end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_l start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ] for xi,xjDsubscript𝑥𝑖subscript𝑥𝑗superscript𝐷x_{i},x_{j}\in\mathbb{R}^{D}italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT.

The authors in Ober et al. (2021) go on to show that at optimal values, the data fit term will converge towards N2𝑁2\frac{N}{2}divide start_ARG italic_N end_ARG start_ARG 2 end_ARG, where N𝑁Nitalic_N is the number of training points. Hence, once the model has reached convergence, it can only increase its log-likelihood score by modifications to the complexity penalty term, which can be broken up as follows:

12logKff+σnoise2𝕀n=N2logσf2+12logKff~+σnoise2~𝕀n12delimited-∣∣subscript𝐾𝑓𝑓subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛𝑁2superscriptsubscript𝜎𝑓212delimited-∣∣~subscript𝐾𝑓𝑓~superscriptsubscript𝜎𝑛𝑜𝑖𝑠𝑒2subscript𝕀𝑛\frac{1}{2}\log\mid K_{ff}+\sigma^{2}_{noise}\mathbb{I}_{n}\mid=\frac{N}{2}% \log\sigma_{f}^{2}+\frac{1}{2}\log\mid\tilde{K_{ff}}+\tilde{\sigma_{noise}^{2}% }\mathbb{I}_{n}\middivide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log ∣ italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∣ = divide start_ARG italic_N end_ARG start_ARG 2 end_ARG roman_log italic_σ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log ∣ over~ start_ARG italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT end_ARG + over~ start_ARG italic_σ start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∣(92)

, where we introduced the reparametrizations Kff=σ2Kff~subscript𝐾𝑓𝑓superscript𝜎2~subscript𝐾𝑓𝑓K_{ff}=\sigma^{2}\tilde{K_{ff}}italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over~ start_ARG italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT end_ARG and σnoise2=σ2σnoise2~subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒superscript𝜎2~subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒\sigma^{2}_{noise}=\sigma^{2}\tilde{\sigma^{2}_{noise}}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG. We can easily see that if this term is to be minimized, one could decrease σfsubscript𝜎𝑓\sigma_{f}italic_σ start_POSTSUBSCRIPT italic_f end_POSTSUBSCRIPT with the caveat that this would decrease model fit. Hence, the only solution is to have high correlations values in Kffsubscript𝐾𝑓𝑓K_{ff}italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT so as to get a determinant close to 0.

In the remainder of this subsection, we derive similar results to Ober et al. (2021) but in the sparse scenario. We introduce the collapsed bound introduced in Titsias (2009):

Titsiassubscript𝑇𝑖𝑡𝑠𝑖𝑎𝑠\displaystyle\mathcal{L}_{Titsias}caligraphic_L start_POSTSUBSCRIPT italic_T italic_i italic_t italic_s italic_i italic_a italic_s end_POSTSUBSCRIPT=log𝒩(yQff+σnoise2𝕀n)12σnoise2Tr[KffQff]absent𝒩conditional𝑦subscript𝑄𝑓𝑓subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛12subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒𝑇𝑟delimited-[]subscript𝐾𝑓𝑓subscript𝑄𝑓𝑓\displaystyle=\log{\mathcal{N}\left(y\mid Q_{ff}+\sigma^{2}_{noise}\mathbb{I}_% {n}\right)}-\frac{1}{2\sigma^{2}_{noise}}Tr\left[K_{ff}-Q_{ff}\right]= roman_log caligraphic_N ( italic_y ∣ italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) - divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG italic_T italic_r [ italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT ](93)
12logQff+σnoise2𝕀ncomplexity penalty12y(Qff+σnoise2𝕀n)1ydata fit12σnoise2Tr[KffQff]trace termproportional-toabsentsubscript12delimited-∣∣subscript𝑄𝑓𝑓subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛complexity penaltysubscript12superscript𝑦topsuperscriptsubscript𝑄𝑓𝑓subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛1𝑦data fitsubscript12subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒𝑇𝑟delimited-[]subscript𝐾𝑓𝑓subscript𝑄𝑓𝑓trace term\displaystyle\propto-\underbrace{\frac{1}{2}\log\mid Q_{ff}+\sigma^{2}_{noise}% \mathbb{I}_{n}\mid}_{\textit{complexity penalty}}-\underbrace{\frac{1}{2}y^{% \top}\left(Q_{ff}+\sigma^{2}_{noise}\mathbb{I}_{n}\right)^{-1}y}_{\textit{data% fit}}-\underbrace{\frac{1}{2\sigma^{2}_{noise}}Tr\left[K_{ff}-Q_{ff}\right]}_% {\textit{trace term}}∝ - under⏟ start_ARG divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log ∣ italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∣ end_ARG start_POSTSUBSCRIPT complexity penalty end_POSTSUBSCRIPT - under⏟ start_ARG divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_y end_ARG start_POSTSUBSCRIPT data fit end_POSTSUBSCRIPT - under⏟ start_ARG divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG italic_T italic_r [ italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT ] end_ARG start_POSTSUBSCRIPT trace term end_POSTSUBSCRIPT(94)
N2σ212logQff~+σnoise2~𝕀n12σ2y[Qff~+σnoise2~𝕀n]yproportional-toabsent𝑁2superscript𝜎212delimited-∣∣~subscript𝑄𝑓𝑓~subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛12superscript𝜎2superscript𝑦topdelimited-[]~subscript𝑄𝑓𝑓~subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛𝑦\displaystyle\propto-\frac{N}{2}\sigma^{2}-\frac{1}{2}\log\mid\tilde{Q_{ff}}+% \tilde{\sigma^{2}_{noise}}\mathbb{I}_{n}\mid-\frac{1}{2\sigma^{2}}y^{\top}% \left[\tilde{Q_{ff}}+\tilde{\sigma^{2}_{noise}}\mathbb{I}_{n}\right]y∝ - divide start_ARG italic_N end_ARG start_ARG 2 end_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log ∣ over~ start_ARG italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT end_ARG + over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∣ - divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG italic_y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT [ over~ start_ARG italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT end_ARG + over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] italic_y(95)
12σnoise2~Tr[Kff~Qff~]12~subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒𝑇𝑟delimited-[]~subscript𝐾𝑓𝑓~subscript𝑄𝑓𝑓\displaystyle\hskip 28.45274pt-\frac{1}{2\tilde{\sigma^{2}_{noise}}}Tr\left[% \tilde{K_{ff}}-\tilde{Q_{ff}}\right]- divide start_ARG 1 end_ARG start_ARG 2 over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG end_ARG italic_T italic_r [ over~ start_ARG italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT end_ARG - over~ start_ARG italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT end_ARG ]

, where we have used again the following notation for kernel terms k(,)=σ2k(,)~𝑘superscript𝜎2~𝑘k\left(\cdot,\cdot\right)=\sigma^{2}\tilde{k\left(\cdot,\cdot\right)}italic_k ( ⋅ , ⋅ ) = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over~ start_ARG italic_k ( ⋅ , ⋅ ) end_ARG and σnoise2=σ2σnoise2~subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒superscript𝜎2~subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒\sigma^{2}_{noise}=\sigma^{2}\tilde{\sigma^{2}_{noise}}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT = italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG. To obtain predictions at testing time under this framework we can make use of the optimal q(U)𝑞𝑈q(U)italic_q ( italic_U ) being given by the following first two moments:

m(U*)𝑚superscript𝑈\displaystyle m(U^{*})italic_m ( italic_U start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT )=σnoise2Kuu[Kuu+σnoise2KufKfu]1Kufyabsentsubscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝐾𝑢𝑢superscriptdelimited-[]subscript𝐾𝑢𝑢subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝐾𝑢𝑓subscript𝐾𝑓𝑢1subscript𝐾𝑢𝑓𝑦\displaystyle=\sigma^{-2}_{noise}K_{uu}\left[K_{uu}+\sigma^{-2}_{noise}K_{uf}K% _{fu}\right]^{-1}K_{uf}y= italic_σ start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT [ italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT italic_y(96)
v(U*)𝑣superscript𝑈\displaystyle v(U^{*})italic_v ( italic_U start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT )=Kuu[Kuu+σnoise2KufKfu]1Kuuabsentsubscript𝐾𝑢𝑢superscriptdelimited-[]subscript𝐾𝑢𝑢subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝐾𝑢𝑓subscript𝐾𝑓𝑢1subscript𝐾𝑢𝑢\displaystyle=K_{uu}\left[K_{uu}+\sigma^{-2}_{noise}K_{uf}K_{fu}\right]^{-1}K_% {uu}= italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT [ italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT + italic_σ start_POSTSUPERSCRIPT - 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_f end_POSTSUBSCRIPT italic_K start_POSTSUBSCRIPT italic_f italic_u end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT(97)

, which we can plug in to standard SVGP predictive equations (equations (31) and (32)).

We adapt the derivation in Ober et al. (2021) to our framework at hand:

Titsiasσ2subscript𝑇𝑖𝑡𝑠𝑖𝑎𝑠superscript𝜎2\displaystyle\frac{\partial\mathcal{L}_{Titsias}}{\partial\sigma^{2}}divide start_ARG ∂ caligraphic_L start_POSTSUBSCRIPT italic_T italic_i italic_t italic_s italic_i italic_a italic_s end_POSTSUBSCRIPT end_ARG start_ARG ∂ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG=N2σ212logQff~+σnoise2~𝕀n12σ2y[Qff~+σnoise2~𝕀n]1y12σnoise2~[Kff~Qff~]σ2absent𝑁2superscript𝜎212delimited-∣∣~subscript𝑄𝑓𝑓~subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛12superscript𝜎2superscript𝑦topsuperscriptdelimited-[]~subscript𝑄𝑓𝑓~subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛1𝑦12~subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒delimited-[]~subscript𝐾𝑓𝑓~subscript𝑄𝑓𝑓superscript𝜎2\displaystyle=\frac{\partial-\frac{N}{2}\sigma^{2}-\frac{1}{2}\log\mid\tilde{Q% _{ff}}+\tilde{\sigma^{2}_{noise}}\mathbb{I}_{n}\mid-\frac{1}{2\sigma^{2}}y^{% \top}\left[\tilde{Q_{ff}}+\tilde{\sigma^{2}_{noise}}\mathbb{I}_{n}\right]^{-1}% y-\frac{1}{2\tilde{\sigma^{2}_{noise}}}\left[\tilde{K_{ff}}-\tilde{Q_{ff}}% \right]}{\partial\sigma^{2}}= divide start_ARG ∂ - divide start_ARG italic_N end_ARG start_ARG 2 end_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log ∣ over~ start_ARG italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT end_ARG + over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∣ - divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG italic_y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT [ over~ start_ARG italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT end_ARG + over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_y - divide start_ARG 1 end_ARG start_ARG 2 over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG end_ARG [ over~ start_ARG italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT end_ARG - over~ start_ARG italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT end_ARG ] end_ARG start_ARG ∂ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG(98)
=N2σ2+12σ4y[Qff~+σnoise2~𝕀n]1yabsent𝑁2superscript𝜎212superscript𝜎4superscript𝑦topsuperscriptdelimited-[]~subscript𝑄𝑓𝑓~subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒subscript𝕀𝑛1𝑦\displaystyle=-\frac{N}{2\sigma^{2}}+\frac{1}{2\sigma^{4}}y^{\top}\left[\tilde% {Q_{ff}}+\tilde{\sigma^{2}_{noise}}\mathbb{I}_{n}\right]^{-1}y= - divide start_ARG italic_N end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + divide start_ARG 1 end_ARG start_ARG 2 italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT end_ARG italic_y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT [ over~ start_ARG italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT end_ARG + over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG blackboard_I start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_y(99)

Hence, if we set the derivative to 0, then we obtain that σ2=1Ny[Qff~+σnoise2~𝕀]1ysuperscript𝜎21𝑁superscript𝑦topsuperscriptdelimited-[]~subscript𝑄𝑓𝑓~subscriptsuperscript𝜎2𝑛𝑜𝑖𝑠𝑒𝕀1𝑦\sigma^{2}=\frac{1}{N}y^{\top}\left[\tilde{Q_{ff}}+\tilde{\sigma^{2}_{noise}}% \mathbb{I}\right]^{-1}yitalic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG italic_y start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT [ over~ start_ARG italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT end_ARG + over~ start_ARG italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_n italic_o italic_i italic_s italic_e end_POSTSUBSCRIPT end_ARG blackboard_I ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_y, which if we input it into the data fit term it results in N2𝑁2\frac{N}{2}divide start_ARG italic_N end_ARG start_ARG 2 end_ARG, similar to the non-sparse scenario analyzed in Ober et al. (2021). The difference between the sparse and non-sparse framework is that after convergence in the data fit term, the model now has to achieve over-correlation in Qffsubscript𝑄𝑓𝑓Q_{ff}italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT, while still minimizing KffQffsubscript𝐾𝑓𝑓subscript𝑄𝑓𝑓K_{ff}-Q_{ff}italic_K start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT - italic_Q start_POSTSUBSCRIPT italic_f italic_f end_POSTSUBSCRIPT.

3.6 Pooling operations on stochastic layers

Previous work that dealt with combining GP with convolutional architectures Dutordoir et al. (2020); Kumar et al. (2018); Blomqvist et al. (2018) have used in their experiments simple architectures involving a couple of stacked layers. In this paper, we propose to experiment with more modern architectures such as DenseNet Huang et al. (2017) or ResNet He et al. (2016). However, both these architectures include pooling layers such as average pooling, which for Euclidean data is a straightforward operation since we have a naturally induced metric. Since we are using stochastic layers that operate in the space of Gaussian distributions, this introduces some complications as it is not desirable to sample from the stochastic layers, subsequently applying the Euclidean space average pooling operation. Nevertheless, in the remainder of this subsection we show a simple method for replicating average pooling in Wasserstein space by using Wasserstein barycentres (Agueh and Carlier, 2011).

We consider probability measures μ1,,μksubscript𝜇1subscript𝜇𝑘\mu_{1},...,\mu_{k}italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and fixed weights θ1,,θksubscript𝜃1subscript𝜃𝑘\theta_{1},...,\theta_{k}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT that are positive real numbers such that k=1Kθk=1superscriptsubscript𝑘1𝐾subscript𝜃𝑘1\sum\limits_{k=1}^{K}\theta_{k}=1∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = 1. For ν2(𝐑d)𝜈subscript2superscript𝐑𝑑\nu\in\mathbb{P}_{2}(\mathbf{R}^{d})italic_ν ∈ blackboard_P start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( bold_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ), where 2subscript2\mathbb{P}_{2}blackboard_P start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is the set of Borel probabilities on 𝕕superscript𝕕\mathbb{R^{d}}blackboard_R start_POSTSUPERSCRIPT blackboard_d end_POSTSUPERSCRIPT with finite second moment and absolutely continuous with respect to Lebesque measures, we consider the following functionals:

𝐕(ν)𝐕𝜈\displaystyle\mathbf{V}(\nu)bold_V ( italic_ν )=k=1Kθk𝐖22(μ,μk)absentsuperscriptsubscript𝑘1𝐾subscript𝜃𝑘superscriptsubscript𝐖22𝜇subscript𝜇𝑘\displaystyle=\sum_{k=1}^{K}\theta_{k}\mathbf{W}_{2}^{2}(\mu,\mu_{k})= ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_μ , italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT )(100)
𝕍(μ~)𝕍~𝜇\displaystyle\mathbb{V}(\tilde{\mu})blackboard_V ( over~ start_ARG italic_μ end_ARG )=minμ2𝕍(μ)absentsubscript𝜇subscript2𝕍𝜇\displaystyle=\min_{\mu\in\mathbb{P}_{2}}\mathbb{V}(\mu)= roman_min start_POSTSUBSCRIPT italic_μ ∈ blackboard_P start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUBSCRIPT blackboard_V ( italic_μ )(101)

, where 𝕍(μ~)𝕍~𝜇\mathbb{V}(\tilde{\mu})blackboard_V ( over~ start_ARG italic_μ end_ARG ) is defined as the barycentre with respect to the Wasserstein-2 distance of the set of probabilities {μ1,,μk}subscript𝜇1subscript𝜇𝑘\{\mu_{1},...,\mu_{k}\}{ italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }. Intuitively, barycentres can be seen as the equivalent of averaging in Euclidean space, while still maintaining the geometric properties of the distributions at hand.

Theorem 9 (Theorem 4.2. in Álvarez-Esteban et al. (2016))

Assume Σ1,,ΣKsubscriptnormal-Σ1normal-…subscriptnormal-Σ𝐾\Sigma_{1},...,\Sigma_{K}roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , roman_Σ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT are symmetric positive semidefinite matrices, with at least one of them positive definite. We take S0𝕄d×d+subscript𝑆0superscriptsubscript𝕄𝑑𝑑S_{0}\in\mathbb{M}_{d\times d}^{+}italic_S start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∈ blackboard_M start_POSTSUBSCRIPT italic_d × italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT and define:

Sn+1=Sn1/2(k=1Kθk(Sn1/2ΣkSn1/2)1/2)2Sn1/2subscript𝑆𝑛1superscriptsubscript𝑆𝑛12superscriptsuperscriptsubscript𝑘1𝐾subscript𝜃𝑘superscriptsuperscriptsubscript𝑆𝑛12subscriptΣ𝑘superscriptsubscript𝑆𝑛12122superscriptsubscript𝑆𝑛12S_{n+1}=S_{n}^{-1/2}(\sum_{k=1}^{K}\theta_{k}(S_{n}^{1/2}\Sigma_{k}S_{n}^{1/2}% )^{1/2})^{2}S_{n}^{-1/2}italic_S start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT = italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT roman_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 / 2 end_POSTSUPERSCRIPT(102)

If (0,Σ0)0subscriptnormal-Σ0\mathbb{N}(0,\Sigma_{0})blackboard_N ( 0 , roman_Σ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) is the barycenter of (0,Σ1),,(0,ΣK)0subscriptnormal-Σ1normal-…0subscriptnormal-Σ𝐾\mathbb{N}(0,\Sigma_{1}),...,\mathbb{N}(0,\Sigma_{K})blackboard_N ( 0 , roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , blackboard_N ( 0 , roman_Σ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) , then W22((0,Sn),(0,Σ0)0W_{2}^{2}(\mathbb{N}(0,S_{n}),\mathbb{N}(0,\Sigma_{0})\to 0italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( blackboard_N ( 0 , italic_S start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) , blackboard_N ( 0 , roman_Σ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) → 0 as nnormal-→𝑛n\to\inftyitalic_n → ∞.

Remark 10

In the case of computing the barycentre of univariate Gaussian measures, the iterative algorithm converges in one iteration to Σ0=(k=1KθkΣk12)2subscriptnormal-Σ0superscriptsuperscriptsubscript𝑘1𝐾subscript𝜃𝑘superscriptsubscriptnormal-Σ𝑘122\Sigma_{0}=\left(\sum_{k=1}^{K}\theta_{k}\Sigma_{k}^{\frac{1}{2}}\right)^{2}roman_Σ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT = ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT roman_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. This provides us with a deterministic and single step equation to downsample stochastic layers, where we can additionally calculate the mean of the barycentre by k=1Kθkmksuperscriptsubscript𝑘1𝐾subscript𝜃𝑘subscript𝑚𝑘\sum\limits_{k=1}^{K}\theta_{k}m_{k}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, where {m1,,mK}subscript𝑚1normal-⋯subscript𝑚𝐾\{m_{1},\cdots,m_{K}\}{ italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋯ , italic_m start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT } represent the first moments of the respective distributions.

4 DistGP Layer Networks & OOD detection

An outlier can be defined in various ways (Ruff et al., 2021). In this paper we follow the most basic one, namely "An anomaly is an observation that deviates considerably from some concept of normality." More concretely, it can be formalised as follows: our data resides in XD𝑋superscript𝐷X\in\mathbb{R}^{D}italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, an anomaly/outlier is a data point xX𝑥𝑋x\in Xitalic_x ∈ italic_X that lies in a low probability region under 𝒫𝒫\mathcal{P}caligraphic_P such that the set of anomalies/outliers is defined as A={xX|p(x)ξ},ξ0formulae-sequence𝐴conditional-set𝑥𝑋𝑝𝑥𝜉𝜉0A=\{x\in X|p(x)\leq\xi\},~{}\xi\geq 0italic_A = { italic_x ∈ italic_X | italic_p ( italic_x ) ≤ italic_ξ } , italic_ξ ≥ 0, with ξ𝜉\xiitalic_ξ is a threshold under which we consider data points to deviate sufficiently from what normality constitutes.

Influence of enforced Lipschitz condition.

We aim to visually assess if the Lipschitz condition imposed via Proposition 5 negatively influences the predictive capabilities. We use a standard neural network architecture with two hidden layers with 5 dimensions each, with the affine embeddings operations described in equations (84) and (85) being replaced by a non-convolutional dense layer. From Figure 8 we can notice that imposing a unitary Lipschitz constant does not result in the over-regularization of the predictive mean. A slight smoothing effect on the predictive mean can be noticed in output space. Moreover, for the Lipschitz constrained version we can discern a better fit of the data manifold in terms of distributional variance, with a noticeable difference in the second hidden layer.

Refer to caption
(a) DistGP-NN without Lipschitz constraint
Refer to caption
(b) DistGP-NN with Lipschitz constraint
Figure 8: Layer-wise predictive moments of DistGP-NN models (with or without unitary Lipschitz constraints) trained on toy binary classification dataset.
Over-correlation in latent space.

We aim to understand whether the over-correlation phenomenon occurs for our model. We consider a standard neural network architecture with two hidden layers with 50 hidden units per layer. From Figure 9 we can notice that for DKL, the sparse framework does remove any unwanted over-correlations in the final hidden layer latent space. In the unconstrained model, there is a notion of locality in the final hidden layer latent space, albeit of a lower degree compared to the DKL model. With regards to OOD detection, of utmost importance is the fact that regions outside the training set manifold have a correlation value of 0. Perhaps unsurprisingly, introducing a unitary Lipschitz constraint resulted in an increased correlation in the latent space, alongside a smoother predictive mean.

Refer to caption
(a) Collapsed SGPR
Refer to caption
(b) DKL
Refer to caption
(c) DistGP-NN Not Constrained
Refer to caption
(d) DistGP-NN Lipschitz Constraint
Figure 9: Top row: Predictive mean and variance of parametric part of SGP; Middle row: Predictive mean and variance of non-parametric part of SGP; Bottom row: Kernel evaluated across whole input span with respect to -2.0 (blue) and 2.0 (orange).

4.1 Reliability of in-between uncertainty estimates

We are interested to test our newly introduced module in scenarios where in-between uncertainty can fail. For this we use the “snelson” dataset, with the training set taken to comprise the intervals between 0.0 and 2.0, respectively 4.0 and 6.5. Thereby, in an ideal scenario we would expect our model to offer high distributional uncertainty estimates between 2.0 and 4.0, which constitutes our in-between region. To benchmark our approach, we compare it to a collapsed SGPR as defined in Titsias (2009).

Refer to caption
(a) Collapsed SGPR
Refer to caption
(b) DistGP-NN
Figure 10: Reliability of in-between uncertainty. Top row: Predictive mean and variance of parametric part of SGP; Bottom row: Predictive mean and variance of non-parametric part of SGP.

From Figure 10 we can observe that the behaviour is strikingly similar between a collapsed SGPR and a three layer DistGP-Layers network.

Reliability of within-data uncertainty estimates.

Within-data uncertainty or more conveniently epistemic uncertainty is responsible to detect regions of the input space where the variance in the model parameters, in this case of U𝑈Uitalic_U, can be further reduced if we add more data points in said input regions. To test if our newly introduced module can provide reliable within-data uncertainty estimates one can proceed to subsample a dataset (as done in Figure 11 subsampling in the interval [0,2.5]02.5[0,2.5][ 0 , 2.5 ]), with the intended effect being of an increase in within-data uncertainty across the input region where we subsampled.

From Figure 11 we can see that despite the low number of training points, it did not result in over-fitting, with our model exhibiting a relatively smooth predictive mean. Moreover, in comparison to Figure 9 we can notice that the within-data uncertainty has substantially increased in the [0,2.5]02.5[0,2.5][ 0 , 2.5 ] interval.

Refer to caption
(a) Collapsed SGPR
Refer to caption
(b) DistGP-NN
Figure 11: Reliability of within-data uncertainty. Top row: Predictive mean and variance of parametric part of SGP; Bottom row: Predictive mean and variance of non-parametric part of SGP.
MNIST and CIFAR10.

We compare our approach on the standard image classification benchmarks of MNIST Lecun et al. (1998) and CIFAR-10 Krizhevsky (2009), which have standard training and test folds to facilitate direct performance comparisons. MNIST contains 60,000 training examples of 28×28282828\times 2828 × 28 sized grayscale images of 10 hand-drawn digits, with a separate 10,000 validation set. CIFAR-10 contains 50,000 training examples of RGB colour images of size 32×32323232\times 3232 × 32 from 10 classes, with 5,000 images per class. We preprocess the images such that the input is normalized to be between 0 and 1. We compare our model primarily against the original shallow Convolutional Gaussian process Van der Wilk et al. (2017) and Deep Convolutional Gaussian Process (DeepConvGP) Blomqvist et al. (2018). In terms of model architectures, we have used a standard stacked convolutional approach, with the model entitled “DistGP-DeepConv” consisting of 64 hidden units for the “Convolutionally Warped DistGP” part of the module, respectively, 5 hidden units for the “DistGP activation-function” part. For the DeepConvGP, we used 64 hidden units at each hidden layer. All models use a stride of 2 at the first layer. In all experiments we use 250 inducing points at each layer. Lastly, we also devised 18 hidden layers size versions of the ResNet (He et al., 2016) and DenseNet (Huang et al., 2017) architectures.

Convolutional GP modelsHidden LayersMNISTCIFAR-10
ConvGP098.8398.8398.8398.8364.664.664.664.6
DeepConvGP198.3898.3898.3898.3858.6558.6558.6558.65
DeepConvGP299.2499.2499.2499.2473.8573.8573.8573.85
DeepConvGP399.4499.4499.4499.4475.8975.8975.8975.89
DistGP-DeepConv199.0199.0199.0199.0170.1270.1270.1270.12
DistGP-DeepConv299.4399.4399.4399.4376.5476.5476.5476.54
DistGP-DeepConv399.6799.6799.6799.6778.4978.4978.4978.49
DistGP-ResNet-181899.5274.56
DistGP-DenseNet1899.7575.29
Hybrid NN-GP modelsHidden LayersMNISTCIFAR-10
Deep Kernel Learning599.299.299.299.277.077.077.077.0
GPDNN4099.9599.9599.9599.9593.093.093.093.0
Table 1: Performance on MNIST and CIFAR-10. Deep Kernel Learning are the set of models from Wilson et al. (2016), whereas GPDNN are the set of models published in Bradshaw et al. (2017). Other results than our method are taken from the respective publications

Table 1 shows the classification accuracy on MNIST and CIFAR-10 for different Convolutional GP models. Compared to other convolutional GP approaches, our method achieves superior classification accuracy compared to DeepConvGP (Blomqvist et al., 2018). We find that for our method, adding more layers increases the performance significantly. This observation is only available for a couple of stacked layers, as the results from our ResNet and DenseNet variants do not support this assertion. The GPDNN models introduced in Bradshaw et al. (2017) are nonetheless close to state of the art on CIFAR10 but also using a variant of DenseNet (Huang et al., 2017) as the building blocks for their GP classifier.

Outlier detection on different fonts of digits.

We test if DistGP-DeepConv models outperform OOD detection models from literature such as DUQ (van Amersfoort et al., 2020), OVA-DM (Padhy et al., 2020) and OVVNI (Franchi et al., 2020). In these experiments we assess the capacity of our model to detect domain shift by training it on MNIST and looking at the uncertainty measures computed on the testing set of MNIST and the entire NotMNIST dataset (Bulatov, 2011), respectively SVHN (Netzer et al., 2011). The hypothesis is that we ought to see both higher predictive entropy and differential entropy for distributional uncertainty (respectively higher OOD measures specific to each of the baseline models) for the digits stemming from a wide array of fonts present in NotMNST as none of the fonts are handwritten, respectively the digit fonts in SVHN exhibit different backgrounds, orientations besides not being handwritten.

Model MNIST vs. NotMNISTAUC MNIST vs. SVHNAUC
AUCPred. EntropyOOD measurePred. EntropyOOD measure
DistGP-DeepConv0.920.820.950.98
0VA-DM0.731.00.701.0
OVNNI0.680.550.560.81
DUQ0.820.810.650.74
Table 2: OOD detection results. Performance of OOD detection based on predictive entropy and distributional differential entropy (for baseline OOD models each has a different OOD measure). Models are trained on MNIST (normative data).

From Table 2 we can observe that generally all models exhibit a shift in their uncertainty measure between MNIST and notMNIST, with the notable exception of OVNNI which barely manages to better separate the two datasets compared to a random guess. Moreover, OVA-DM manages to completely separate the two datasets with the caveat that it obtains lower predictive entropy for MNIST vs. notMNIST compared to DistGP-DeepConv. The latter achieves similar results to DUQ, with the added benefit of a higher degree of separation using predictive entropy. In the case of SVHN we can observe similar patterns to notMNIST, with OVA-DM and DistGP-DeepConv managing to almost separate the two datasets (MNIST vs. SVHN) by inspecting their uncertainty measure, again with the caveat for OVA-DM that it exhibits lower predictive entropy for SVHN in comparison to MNIST.

Sensitivity to input perturbations.

MorphoMNIST (Castro et al., 2018) enables the systematic deformation of MNIST digits using morphological operations. We use MorphoMNIST to better understand the outlier detection capabilities of each method by exposing them to increasingly deformed samples. We use the first 500 MNIST digits in the testing set to generate new images with controlled morphological deformations. We use the swelling deformation with a strength of 3 and increasing radius from 3 to 14. Our hypothesis is that the predictive entropy should increase as the deformation is increased, alongside with the distributional differential entropy, which is a measure of the overall uncertainty in the logit space. This is motivated by the fact that the newly obtained images from MorphoMNIST are outside of the data manifold, which is different from the concept of having high uncertainty as expressed by entropy upon seeing a difficult digit to classify. In this case we would expect high entropy but low differential entropy.

All models are able to pick up on the shift in the data manifold as swelling is applied to the original digits, with the model-specific uncertainty measure steadily increasing (for OVNNI, a decrease in the measure translates to higher uncertainty) as increasing deformation is applied. However, for OVA-DM and DUQ the predictive entropy is stable or actually decreases as more deformation is applied, which is in contrast to what one would expect (Figure 12).

Refer to caption
(a) DistGP-DeepConv
Refer to caption
(b) OVA-DM
Refer to caption
(c) OVNNI
Refer to caption
(d) DUQ
Figure 12: Predictive entropy and model-specific uncertainty measure for varying models as swelling of increasing radius is applied on MNIST digits. Higher values of uncertainty measure indicate outlier status, expect for OVNNI where the inverse is true. Results are shown for 3 hidden layers with DistGP-DeepConv dimensionality being set to 5, whereas the capacity of the convolutionally warped DistGPs was set to 12, whereas for OOD models we use 128 hidden units at each layer.

To further assess the sensitivity to input perturbations of our methods, we employ the experiments introduced in Gal and Ghahramani (2016a) by successively rotating digits from MNIST. We expect to see an increase in both predictive entropy and distributional differential entropy as digits are rotated. For our experiment we rotate digit 6. When the digit is rotated by around 180 degrees the entropy and differential entropy should revert back closer to initial levels, as it will resembles digit 9.

From Figure 13 we can notice that all models exhibit an increase (decrease for OVNNI translates into higher uncertainty) in their specific uncertainty measures for rotation angles between 40 and 160, respectively between 240 and 320 degrees. In terms of predictive entropy, we can discern relatively stable and highly overlapping values for OVA-DM and DUQ, whereas for DistGP-DeepConv and OVNNI we can observe a clear pattern of increases and decreases as what was originally a 6 becomes a 9.

Refer to caption
(a) DistGP-DeepConv
Refer to caption
(b) OVA-DM
Refer to caption
(c) OVNNI
Refer to caption
(d) DUQ
Figure 13: Predictive entropy and model-specific uncertainty measure for varying models as varying degrees of rotation is applied to digit 6. Higher values of uncertainty measure indicate outlier status, expect for OVNNI where the inverse is true. Results are shown for 3 hidden layers with DistGP-DeepConv dimensionality being set to 5, whereas the capacity of the convolutionally warped DistGPs was set to 12, whereas for OOD models we use 128 hidden units at each layer.

5 DistGP-based Segmentation Network & OOD Detection in Medical Imaging

Refer to caption
Figure 14: Top: Schematic of proposed DistGP activated segmentation net. Above and below each layer we show the number of channels and their dimension respectively. Bottom: Visual depiction of the two uncertainties in DistGP after fitting a toy regression dataset. Hyperparameters and variational approximate posteriors are optimized. Distributional uncertainty increases outside the manifold of training data and is therefore useful for OOD detection.

The above introduced modules in Sec. 3.2 can be used to construct a convolutional network that benefits from properties of DistGP. Specifically, we construct a 3D network for segmenting volumetric medical images, which is depicted in Figure 14 (top). It consists of a convolved GP layer, followed by two measure-preserving DistGP layers. Each hidden layer uses filters of size 5×5×55555\!\times\!5\!\times\!55 × 5 × 5. To increase the model’s receptive field, in the second layer we use convolution dilated by 2. We use 250 inducing points and 2 channels for the DistGP “activation functions”. The affine operators project the stochastic patches into a 12 dimensional space. The size of the network is limited by computational requirements for GP-based layers, which is an active research area. Like regular convolutional nets, this model can process input of arbitrary size but GPU memory requirement increases with input size. We here provide input of size 323superscript32332^{3}32 start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT to the model, which then segments the central 163superscript16316^{3}16 start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT voxels. To segment a whole scan we divide it into tiles and stitch together the segmentations.

5.1 Evaluation on Brain MRI

In this section we evaluate our method alongside recent OOD models (van Amersfoort et al., 2020; Franchi et al., 2020; Padhy et al., 2020), assessing their capabilities to reach segmentation performance comparable to well-established deterministic models and whether they can accurately detect outliers.

5.1.1 Data and pre-processing

For evaluation we use publicly available datasets:

1) Brain MRI scans from the UKBB study (Alfaro-Almagro et al., 2018), which contains scans from nearly 15,000 subjects. We selected for training and evaluation the bottom 10%percent\%% percentile in terms of white matter hypointensities with an equal split between training and testing. All subjects have been confirmed to be normal by radiological assessment. Segmentation of brain tissue (CSF,GM,WM) has been obtained with SPM12.

2) MRI scans of 285 patients with gliomas from BraTS 2017 (Bakas et al., 2017). All classes are fused into a tumor class, which we will use to quantify OOD detection performance.

In what follows, we use only the FLAIR sequence to perform the brain tissue segmentation task and OOD detection of tumors, as this MRI sequence is available for both UKBB and BraTS. All FLAIR images are pre-processed with skull-stripping, N4 bias correction, rigid registration to MNI152 space and histogram matching between UKBB and BraTS. Finally, we normalize intensities of each scan via linear scaling of its minimum and maximum intensities to the [-1,1] range.

5.1.2 Brain tissue segmentation on normal MRI scans

ModelHidden LayersDICE CSFDICE GMDICE WM
OVA-DM (Padhy et al., 2020)30.720.790.77
OVNNI (Franchi et al., 2020)30.660.770.73
DUQ (van Amersfoort et al., 2020)30.7450.8250.781
DistGP-Seg (ours)30.8290.8230.867
U-Net3 scales0.850.890.86
Table 3: Performance on UK Biobank in terms of Dice scores per tissue.
Task:

We train and test our model on the task of segmenting brain tissue of healthy UKBB subjects. This corresponds to the within-data manifold in our setup.

Baselines:

We compare our model with recent Bayesian approaches for enabling task-specific models (such as image segmentation) to perform uncertainty-based OOD detection (van Amersfoort et al., 2020; Franchi et al., 2020; Padhy et al., 2020). For fair comparison, we use these methods in an architecture similar to ours (Figure  14), except that each layer is replaced by standard convolutional layer, each with 256 channels, LeakyRelu activations, and dilation rates as in ours. We also compare these Bayesian methods with a well-established deterministic baseline, a U-Net with 3 scales (down/up-sampling) and 2 convolution layers per scale in encoder and 2 in decoder (total 12 layers).

Results:

Table 3 shows that DistGP-Seg surpasses other Bayesian methods with respect to Dice score for all tissue classes. Our method approaches the performance of the deterministic U-Net, which has a much larger architecture and receptive field. We emphasize this has not been previously achieved with GP-based architectures, as their size (e.g., number of layers) is limited due to computational requirements. This supports the potential of DistGP, which is bound to be further unlocked by advances in scaling GP-based models.

5.1.3 Outlier detection in MRI scans with tumors

Model

DICE

DICE

DICE

DICE

FPR=5.0
OVA-DM (Padhy et al., 2020)0.3820.4280.4570.410
OVNNI (Franchi et al., 2020)0.001absent0.001\leq 0.001≤ 0.0010.001absent0.001\leq 0.001≤ 0.0010.001absent0.001\leq 0.001≤ 0.0010.001absent0.001\leq 0.001≤ 0.001
DUQ (van Amersfoort et al., 2020)0.0680.1210.1690.182
DistGP-Seg (ours)0.5120.5710.5320.489
VAE-LG (Chen et al., 2019)0.2590.4070.4480.303
AAE-LG (Chen et al., 2019)0.2200.3950.4180.302
Table 4: Performance comparison of Dice for detecting outliers on BraTS for different thresholds obtained from UKBB.
Task:

The previous task of brain tissue segmentation on UKBB serves as a proxy task for learning normative patterns with our network. Here, we apply this pre-trained network on BRATS scans with tumors. We expect the region surrounding the tumor and other related pathologies, such as squeezed brain parts or shifted ventricles, to be highlighted with higher distributional uncertainty, which is the OOD measure for the Bayesian deep learning models. To evaluate quality of OOD detection at a pixel level, we follow the procedure in Chen et al. (2019), for example to get the 5.0%percent\%% False Positive Ratio threshold value we compute the 95%percent\%% percentile of distributional variance on the testing set of UKBB, taking into consideration that there is no outlier tissue there. Subsequently, using this value we threshold the distributional variance heatmaps on BraTS, with tissue having a value above the threshold being flagged as an outlier. We then quantify the overlap of the pixels detected as outliers (over the threshold) with the ground-truth tumor labels by computing the Dice score between them.


Results:

Table 5.1.3 shows the results from our experiments with DistGP and compared Bayesian deep learning baselines. We also provide performance of reconstruction-based OOD detection models as reported in Chen et al. (2019) for similar experimental setup. DistGP-Seg surpasses its Bayesian deep learning counterparts, as well as reconstructed-based models. In Figure 15 we provide representative results from the methods we implemented for qualitative assessment. Moreover, although BRATS does not provide labels for WM/GM/CSF tissues hence we cannot quantify how well these tissues are segmented, visual assessment shows our method compares favorably to compared counterparts.

Refer to caption
Figure 15: Comparison between models in terms of voxel-level outlier detection of tumors on BRATS scans. Mean segmentation represents the hard segmentation of brain tissues. OOD measure is the quantification of uncertainty for each model, using their own procedure. Higher values translate to appartenance to outlier status, whereas for OVNNI it is the converse. OOD measures have been normalized to be between 0 and 1 for each model in part.

In Figure 16 we plotted the different differential entropy measures based on BRATS scans by overlying their tumor labels on the obtained uncertainties from our model. We can notice that tumor tissue is highlighted with higher inside and outside of the data manifold uncertainty compared to healthy tissue. More detailed plots are available in Appendix C.

Refer to caption
Refer to caption
Figure 16: Comparison in terms of voxel-level epistemic and distributional differential entropy between non-tumor tissues and different tumor gradations from subjects in the BRATS dataset.

6 Discussion & Conclusion

We have introduced a novel Bayesian convolutional layer with Lipschitz continuity that is capable of reliably propagating uncertainty. We have shown on a wide array of general OOD detection tasks that it surpasses other OOD models from literature, while also offering an increase in accuracy compared to counterpart architectures based solely on Euclidean space SVGPs (Blomqvist et al., 2018). General criticism surrounding deep and convolutional GP involves the issue of under-performance compared to other Bayesian deep learning techniques, and especially compared to deterministic networks. Our experiments demonstrate that our 3-layers model, size limited due to computational cost, is capable of approaching the performance of a U-Net, an architecture with a much larger receptive field. Further advances in computational efficient GP-based models, an active area of research, will enable our model to scale further and unlock its full potential. Importantly, we showed that our DistGP-Seg network offers better uncertainty estimates for OOD detection than state-of-the-art OOD detection models, and also surpasses some recent unsupervised reconstruction-based deep learning models for identifying outliers corresponding to pathology on brain scans.



This framework can also be used for regression and classification tasks within a medical imaging context, facilitating the adoption of deep learning in clinical settings thanks to enhanced accountability in predictions. For example, parts of scans flagged with high distributional uncertainty can be sent back for inspection and quality control. To support our claim, we have included additional results on flagging white matter hyperintensities as outliers (see Appendix D), respectively retina pathologies (see Appendix E).


Our results indicate that OOD methods that do not take into account distances in latent space, such as OVNNI, tend to fail in detecting outliers, whereas OVA-DM and DUQ that make predictions based on distances in the last layer perform better. Our model utilises distances at every hidden layer, thus allowing the notion of outlier to evolve gradually through the depth of our network. This difference can be noticed in the smoothness of OOD measure for our model in comparison to other methods in Figure 15. Furthermore, the issue of feature collapse (van Amersfoort et al., 2020) in deep networks can be precisely controlled due to the mathematical underpinnings of our proposed network, enabling us to assess the scenarios when this happens by simple equations. Additionally, we have shown that despite the possibility of achieving over-correlation in the latent space via the loss function, that this does not happen in practice.


A drawback of our study resides in the small architecture used on medical imaging scans. Extending our “measure preserving DistGP” module to larger architectures such as U-Net for segmentation or modern CNNs for whole-image prediction tasks remains a prospective research avenue fuelled by advances in scalability of SGP. Moreover, our experiments involving more complicated architectures, such as ResNet or DenseNet for standard multi-class classification, have not managed to surpass in accuracy a far less complex model with only 3 hidden layers. A plausible reason behind this under-fitting resides in the factorized approximate posterior formulation, which was shown to negatively affect predictive performance compared to MCMC inference schemes (Havasi et al., 2018). We posit that using alternative inference frameworks (Ustyuzhaninov et al., 2019) whereby we impose correlations between layers might alleviate this issue. Moreover, the lack of added representational capacity upon adding new layers raises some further questions regarding what are optimal architectures for hierarchical GPs, what inductive biases do they need or how to properly initialize them to facilitate adequate training. Additionally, our comparison with respect to reconstruction based approaches towards OOD detection was not complete as it did not include a comprehensive list of recent models (Dey and Hong, 2021; Pinaya et al., 2021; Schlegl et al., 2019; Baur et al., 2018). However, comparing our proposed model with reconstruction based approaches was not our intended goal for this paper, the main aim being to compare with models which can provide accurate predictive results alongside OOD detection capabilities at the same time. Another limitation of our work is the training speed for our proposed module, with matrix inversion operations and log determinants being required at each layer. Future work should consider matrix inversion free inference techniques for GPs (van der Wilk et al., 2020).

In conclusion, our work shows that incorporating DistGP in convolutional architectures provides both competitive performance and reliable uncertainty quantification in medical image analysis alongside general OOD tasks, opening up a new direction of research.


Acknowledgments

SGP is funded by an EPSRC Centre for Doctoral Training studentship award to Imperial College London. KK is funded by the UKRI London Medical Imaging & Artificial Intelligence Centre for Value Based Healthcare. BG received funding from the European Research Council (ERC) under the European Union’s Horizon 2020 research and innovation programme (grant agreement No 757173, project MIRA, ERC-2017-STG). DJS is supported by the NIHR Biomedical Research Centre at Imperial College Healthcare NHS Trust and the UK Dementia Research Institute (DRI) Care Research and Technology Centre. JHC acknowledges funding from UKRI/MRC Innovation Fellowship (MR/R024790/2).

Ethical Standards

The work follows appropriate ethical standards in conducting research and writing the manuscript, following all applicable laws and regulations regarding treatment of human subjects.

Conflicts of Interest

BG has received grants from European Commission and UK Research and Innovation Engineering and Physical Sciences Research Council, during the conduct of this study; and is Scientific Advisor for Kheiron Medical Technologies and Advisor and Scientific Lead of the HeartFlow-Imperial Research Team. JHC is a shareholder in and Scientific Advisor to BrainKey and Claritas Healthcare, both medical image analysis software companies.


References

  • Agueh and Carlier (2011) Martial Agueh and Guillaume Carlier. Barycenters in the wasserstein space. SIAM Journal on Mathematical Analysis, 43(2):904–924, 2011.
  • Alfaro-Almagro et al. (2018) Fidel Alfaro-Almagro, Mark Jenkinson, Neal K Bangerter, Jesper LR Andersson, Ludovica Griffanti, Gwenaëlle Douaud, Stamatios N Sotiropoulos, Saad Jbabdi, Moises Hernandez-Fernandez, Emmanuel Vallee, et al. Image processing and quality control for the first 10,000 brain imaging datasets from uk biobank. Neuroimage, 166:400–424, 2018.
  • Álvarez-Esteban et al. (2016) Pedro C Álvarez-Esteban, E Del Barrio, JA Cuesta-Albertos, and C Matrán. A fixed-point approach to barycenters in wasserstein space. Journal of Mathematical Analysis and Applications, 441(2):744–762, 2016.
  • Amini et al. (2019) Alexander Amini, Wilko Schwarting, Ava Soleimany, and Daniela Rus. Deep evidential regression. arXiv preprint arXiv:1910.02600, 2019.
  • Bachoc et al. (2017) François Bachoc, Fabrice Gamboa, Jean-Michel Loubes, and Nil Venet. A gaussian process regression model for distribution inputs. IEEE Transactions on Information Theory, 64(10):6620–6637, 2017.
  • Bakas et al. (2017) Spyridon Bakas, Hamed Akbari, Aristeidis Sotiras, Michel Bilello, Martin Rozycki, Justin S Kirby, John B Freymann, Keyvan Farahani, and Christos Davatzikos. Advancing the cancer genome atlas glioma mri collections with expert segmentation labels and radiomic features. Scientific data, 4:170117, 2017.
  • Baumgartner et al. (2019) Christian F Baumgartner, Kerem C Tezcan, Krishna Chaitanya, Andreas M Hötker, Urs J Muehlematter, Khoschy Schawkat, Anton S Becker, Olivio Donati, and Ender Konukoglu. Phiseg: Capturing uncertainty in medical image segmentation. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 119–127. Springer, 2019.
  • Baur et al. (2018) Christoph Baur, Benedikt Wiestler, Shadi Albarqouni, and Nassir Navab. Deep autoencoding models for unsupervised anomaly segmentation in brain mr images. In International MICCAI Brainlesion Workshop, pages 161–169. Springer, 2018.
  • Blomqvist et al. (2018) Kenneth Blomqvist, Samuel Kaski, and Markus Heinonen. Deep convolutional gaussian processes. arXiv preprint arXiv:1810.03052, 2018.
  • Bradshaw et al. (2017) John Bradshaw, Alexander G de G Matthews, and Zoubin Ghahramani. Adversarial examples, uncertainty, and transfer testing robustness in gaussian process hybrid deep networks. arXiv preprint arXiv:1707.02476, 2017.
  • Bruinsma et al. (2020) Wessel Bruinsma, Eric Perim, William Tebbutt, Scott Hosking, Arno Solin, and Richard Turner. Scalable exact inference in multi-output gaussian processes. In International Conference on Machine Learning, pages 1190–1201. PMLR, 2020.
  • Bulatov (2011) Yaroslav Bulatov. Notmnist dataset. Google (Books/OCR), Tech. Rep.[Online]. Available: http://yaroslavvb. blogspot. it/2011/09/notmnist-dataset. html, 2, 2011.
  • Castro et al. (2018) Daniel C Castro, Jeremy Tan, Bernhard Kainz, Ender Konukoglu, and Ben Glocker. Morpho-mnist: Quantitative assessment and diagnostics for representation learning. arXiv preprint arXiv:1809.10780, 2018.
  • Charpentier et al. (2020) Bertrand Charpentier, Daniel Zügner, and Stephan Günnemann. Posterior network: Uncertainty estimation without ood samples via density-based pseudo-counts. arXiv preprint arXiv:2006.09239, 2020.
  • Charpentier et al. (2021) Bertrand Charpentier, Oliver Borchert, Daniel Zügner, Simon Geisler, and Stephan Günnemann. Natural posterior network: Deep bayesian predictive uncertainty for exponential family distributions. arXiv preprint arXiv:2105.04471, 2021.
  • Chen et al. (2019) Xiaoran Chen, Nick Pawlowski, Ben Glocker, and Ender Konukoglu. Unsupervised lesion detection with locally gaussian approximation. In International Workshop on Machine Learning in Medical Imaging, pages 355–363. Springer, 2019.
  • Chen et al. (2021) Xiaoran Chen, Nick Pawlowski, Ben Glocker, and Ender Konukoglu. Normative ascent with local gaussians for unsupervised lesion detection. Medical Image Analysis, page 102208, 2021.
  • Curth et al. (2019) Alicia Curth, Patrick Thoral, Wilco van den Wildenberg, Peter Bijlstra, Daan de Bruin, Paul WG Elbers, and Mattia Fornasa. Transferring clinical prediction models across hospitals and electronic health record systems. In PKDD/ECML Workshops (1), pages 605–621, 2019.
  • Czolbe et al. (2021) Steffen Czolbe, Kasra Arnavaz, Oswin Krause, and Aasa Feragen. Is segmentation uncertainty useful? In International Conference on Information Processing in Medical Imaging, pages 715–726. Springer, 2021.
  • Damianou and Lawrence (2013) Andreas Damianou and Neil Lawrence. Deep gaussian processes. In Artificial Intelligence and Statistics, pages 207–215, 2013.
  • D’Angelo and Fortuin (2021) Francesco D’Angelo and Vincent Fortuin. Repulsive deep ensembles are bayesian. arXiv preprint arXiv:2106.11642, 2021.
  • Dey and Hong (2021) Raunak Dey and Yi Hong. Asc-net: Adversarial-based selective network for unsupervised anomaly segmentation. In Marleen de Bruijne, Philippe C. Cattin, Stéphane Cotin, Nicolas Padoy, Stefanie Speidel, Yefeng Zheng, and Caroline Essert, editors, Medical Image Computing and Computer Assisted Intervention – MICCAI 2021, pages 236–247, Cham, 2021. Springer International Publishing. ISBN 978-3-030-87240-3.
  • Dowson and Landau (1982) D. Dowson and B. Landau. The fréchet distance between multivariate normal distributions. Journal of Multivariate Analysis, 12:450–455, 1982.
  • Dunlop et al. (2018) Matthew M Dunlop, Mark A Girolami, Andrew M Stuart, and Aretha L Teckentrup. How deep are deep gaussian processes? Journal of Machine Learning Research, 19(54):1–46, 2018.
  • Dutordoir et al. (2019) Vincent Dutordoir, Mark van der Wilk, Artem Artemev, Marcin Tomczak, and James Hensman. Translation insensitivity for deep convolutional gaussian processes. arXiv preprint arXiv:1902.05888, 2019.
  • Dutordoir et al. (2020) Vincent Dutordoir, Mark van der Wilk, Artem Artemev, and James Hensman. Bayesian image classification with deep convolutional gaussian processes. In Silvia Chiappa and Roberto Calandra, editors, Proceedings of the Twenty Third International Conference on Artificial Intelligence and Statistics, volume 108 of Proceedings of Machine Learning Research, pages 1529–1539. PMLR, 26–28 Aug 2020. URL http://proceedings.mlr.press/v108/dutordoir20a.html.
  • Foong et al. (2019) Andrew YK Foong, Yingzhen Li, José Miguel Hernández-Lobato, and Richard E Turner. ’in-between’uncertainty in bayesian neural networks. arXiv preprint arXiv:1906.11537, 2019.
  • Franchi et al. (2020) Gianni Franchi, Andrei Bursuc, Emanuel Aldea, Severine Dubuisson, and Isabelle Bloch. One versus all for deep neural network incertitude (ovnni) quantification. arXiv preprint arXiv:2006.00954, 2020.
  • Gal and Ghahramani (2016a) Yarin Gal and Zoubin Ghahramani. Dropout as a bayesian approximation: Representing model uncertainty in deep learning. In ICML, 2016a.
  • Gal and Ghahramani (2016b) Yarin Gal and Zoubin Ghahramani. Dropout as a bayesian approximation: Representing model uncertainty in deep learning. In international conference on machine learning, pages 1050–1059. PMLR, 2016b.
  • Garriga-Alonso et al. (2018) Adrià Garriga-Alonso, Laurence Aitchison, and Carl Edward Rasmussen. Deep convolutional networks as shallow gaussian processes. arXiv preprint arXiv:1808.05587, 2018.
  • Girard (2004) Agathe Girard. Approximate methods for propagation of uncertainty with Gaussian process models. University of Glasgow (United Kingdom), 2004.
  • Guo et al. (2017) Chuan Guo, Geoff Pleiss, Yu Sun, and Kilian Q Weinberger. On calibration of modern neural networks. arXiv preprint arXiv:1706.04599, 2017.
  • Hafner et al. (2020) Danijar Hafner, Dustin Tran, Timothy Lillicrap, Alex Irpan, and James Davidson. Noise contrastive priors for functional uncertainty. In Uncertainty in Artificial Intelligence, pages 905–914. PMLR, 2020.
  • Havasi et al. (2018) Marton Havasi, José Miguel Hernández Lobato, and Juan José Murillo Fuentes. Inference in deep gaussian processes using stochastic gradient hamiltonian monte carlo. arXiv preprint arXiv:1806.05490, 2018.
  • He et al. (2016) Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. Identity mappings in deep residual networks. In European conference on computer vision, pages 630–645. Springer, 2016.
  • Hendrycks and Gimpel (2016) Dan Hendrycks and Kevin Gimpel. A baseline for detecting misclassified and out-of-distribution examples in neural networks. arXiv preprint arXiv:1610.02136, 2016.
  • Henning et al. (2021) Christian Henning, Francesco D’Angelo, and Benjamin F Grewe. Are bayesian neural networks intrinsically good at out-of-distribution detection? arXiv preprint arXiv:2107.12248, 2021.
  • Hensman et al. (2015) James Hensman, Alexander Matthews, and Zoubin Ghahramani. Scalable variational gaussian process classification. In Artificial Intelligence and Statistics, pages 351–360. PMLR, 2015.
  • Hoffman et al. (2013) Matthew D Hoffman, David M Blei, Chong Wang, and John Paisley. Stochastic variational inference. The Journal of Machine Learning Research, 14(1):1303–1347, 2013.
  • Hoover et al. (2000) AD Hoover, Valentina Kouznetsova, and Michael Goldbaum. Locating blood vessels in retinal images by piecewise threshold probing of a matched filter response. IEEE Transactions on Medical imaging, 19(3):203–210, 2000.
  • Hu et al. (2019) Shi Hu, Daniel Worrall, Stefan Knegt, Bas Veeling, Henkjan Huisman, and Max Welling. Supervised uncertainty quantification for segmentation with multiple annotations. In International Conference on Medical Image Computing and Computer-Assisted Intervention, pages 137–145. Springer, 2019.
  • Huang et al. (2017) Gao Huang, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q Weinberger. Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 4700–4708, 2017.
  • Imai et al. (2020) Shungo Imai, Yoh Takekuma, Hitoshi Kashiwagi, Takayuki Miyai, Masaki Kobayashi, Ken Iseki, and Mitsuru Sugawara. Validation of the usefulness of artificial neural networks for risk prediction of adverse drug reactions used for individual patients in clinical practice. Plos one, 15(7):e0236789, 2020.
  • Jacobsen et al. (2018) Jörn-Henrik Jacobsen, Jens Behrmann, Richard Zemel, and Matthias Bethge. Excessive invariance causes adversarial vulnerability. arXiv preprint arXiv:1811.00401, 2018.
  • Kendall et al. (2015) Alex Kendall, Vijay Badrinarayanan, and Roberto Cipolla. Bayesian segnet: Model uncertainty in deep convolutional encoder-decoder architectures for scene understanding. arXiv preprint arXiv:1511.02680, 2015.
  • Kingma et al. (2015) Durk P Kingma, Tim Salimans, and Max Welling. Variational dropout and the local reparameterization trick. In Advances in neural information processing systems, pages 2575–2583, 2015.
  • Kohl et al. (2018) Simon Kohl, Bernardino Romera-Paredes, Clemens Meyer, Jeffrey De Fauw, Joseph R Ledsam, Klaus Maier-Hein, SM Ali Eslami, Danilo Jimenez Rezende, and Olaf Ronneberger. A probabilistic u-net for segmentation of ambiguous images. In Advances in Neural Information Processing Systems, pages 6965–6975, 2018.
  • Krizhevsky (2009) Alex Krizhevsky. Learning multiple layers of features from tiny images. Technical report, 2009.
  • Kumar et al. (2018) Vinayak Kumar, Vaibhav Singh, PK Srijith, and Andreas Damianou. Deep gaussian processes with convolutional kernels. arXiv preprint arXiv:1806.01655, 2018.
  • Lakshminarayanan et al. (2017) Balaji Lakshminarayanan, Alexander Pritzel, and Charles Blundell. Simple and scalable predictive uncertainty estimation using deep ensembles. Advances in neural information processing systems, 30, 2017.
  • Lecun et al. (1998) Yann Lecun, Léon Bottou, Yoshua Bengio, and Patrick Haffner. Gradient-based learning applied to document recognition. In Proceedings of the IEEE, pages 2278–2324, 1998.
  • Lee et al. (2017) Jaehoon Lee, Yasaman Bahri, Roman Novak, Samuel S Schoenholz, Jeffrey Pennington, and Jascha Sohl-Dickstein. Deep neural networks as gaussian processes. arXiv preprint arXiv:1711.00165, 2017.
  • Liang et al. (2017) Shiyu Liang, Yixuan Li, and Rayadurgam Srikant. Enhancing the reliability of out-of-distribution image detection in neural networks. arXiv preprint arXiv:1706.02690, 2017.
  • Liu et al. (2020) Jeremiah Zhe Liu, Zi Lin, Shreyas Padhy, Dustin Tran, Tania Bedrax-Weiss, and Balaji Lakshminarayanan. Simple and principled uncertainty estimation with deterministic deep learning via distance awareness. arXiv preprint arXiv:2006.10108, 2020.
  • Malinin and Gales (2018) Andrey Malinin and Mark Gales. Predictive uncertainty estimation via prior networks. In Advances in Neural Information Processing Systems, pages 7047–7058, 2018.
  • Mårtensson et al. (2020) Gustav Mårtensson, Daniel Ferreira, Tobias Granberg, Lena Cavallin, Ketil Oppedal, Alessandro Padovani, Irena Rektorova, Laura Bonanni, Matteo Pardini, Milica G Kramberger, et al. The reliability of a deep learning model in clinical out-of-distribution mri data: a multicohort study. Medical Image Analysis, 66:101714, 2020.
  • McClure et al. (2019) Patrick McClure, Nao Rho, John A Lee, Jakub R Kaczmarzyk, Charles Y Zheng, Satrajit S Ghosh, Dylan M Nielson, Adam G Thomas, Peter Bandettini, and Francisco Pereira. Knowing what you know in brain segmentation using bayesian deep neural networks. Frontiers in neuroinformatics, 13:67, 2019.
  • McHutchon (2013) Andrew McHutchon. Differentiating gaussian processes. Cambridge (ed.), 2013.
  • Milios et al. (2018) Dimitrios Milios, Raffaello Camoriano, Pietro Michiardi, Lorenzo Rosasco, and Maurizio Filippone. Dirichlet-based gaussian processes for large-scale calibrated classification. arXiv preprint arXiv:1805.10915, 2018.
  • Minka (2013) Thomas P Minka. Expectation propagation for approximate bayesian inference. arXiv preprint arXiv:1301.2294, 2013.
  • Monteiro et al. (2020) Miguel Monteiro, Loïc Le Folgoc, Daniel Coelho de Castro, Nick Pawlowski, Bernardo Marques, Konstantinos Kamnitsas, Mark van der Wilk, and Ben Glocker. Stochastic segmentation networks: Modelling spatially correlated aleatoric uncertainty. arXiv preprint arXiv:2006.06015, 2020.
  • Nair et al. (2020) Tanya Nair, Doina Precup, Douglas L Arnold, and Tal Arbel. Exploring uncertainty measures in deep networks for multiple sclerosis lesion detection and segmentation. Medical image analysis, 59:101557, 2020.
  • Neal (2012) Radford M Neal. Bayesian learning for neural networks, volume 118. Springer Science & Business Media, 2012.
  • Neal et al. (2011) Radford M Neal et al. Mcmc using hamiltonian dynamics. Handbook of markov chain monte carlo, 2(11):2, 2011.
  • Netzer et al. (2011) Yuval Netzer, Tao Wang, Adam Coates, Alessandro Bissacco, Bo Wu, and Andrew Y Ng. Reading digits in natural images with unsupervised feature learning. 2011.
  • Nguyen et al. (2014) Trung V Nguyen, Edwin V Bonilla, et al. Collaborative multi-output gaussian processes. In UAI, pages 643–652. Citeseer, 2014.
  • Ober et al. (2021) Sebastian W Ober, Carl E Rasmussen, and Mark van der Wilk. The promises and pitfalls of deep kernel learning. arXiv preprint arXiv:2102.12108, 2021.
  • Padhy et al. (2020) Shreyas Padhy, Zachary Nado, Jie Ren, Jeremiah Liu, Jasper Snoek, and Balaji Lakshminarayanan. Revisiting one-vs-all classifiers for predictive uncertainty and out-of-distribution detection in neural networks. arXiv preprint arXiv:2007.05134, 2020.
  • Pawlowski et al. (2017) Nick Pawlowski, Andrew Brock, Matthew CH Lee, Martin Rajchl, and Ben Glocker. Implicit weight uncertainty in neural networks. arXiv preprint arXiv:1711.01297, 2017.
  • Pinaya et al. (2021) Walter Hugo Lopez Pinaya, Petru-Daniel Tudosiu, Robert Gray, Geraint Rees, Parashkev Nachev, Sébastien Ourselin, and M. Jorge Cardoso. Unsupervised brain anomaly detection and segmentation with transformers. In Mattias Heinrich, Qi Dou, Marleen de Bruijne, Jan Lellmann, Alexander Schläfer, and Floris Ernst, editors, Proceedings of the Fourth Conference on Medical Imaging with Deep Learning, volume 143 of Proceedings of Machine Learning Research, pages 596–617. PMLR, 07–09 Jul 2021. URL https://proceedings.mlr.press/v143/pinaya21a.html.
  • Popescu et al. (2020) Sebastian Popescu, David Sharp, James Cole, and Ben Glocker. Hierarchical gaussian processes with wasserstein-2 kernels. arXiv preprint arXiv:2010.14877, 2020.
  • Porwal et al. (2018) Prasanna Porwal, Samiksha Pachade, Ravi Kamble, Manesh Kokare, Girish Deshmukh, Vivek Sahasrabuddhe, and Fabrice Meriaudeau. Indian diabetic retinopathy image dataset (idrid): a database for diabetic retinopathy screening research. Data, 3(3):25, 2018.
  • Quinonero-Candela and Rasmussen (2005) Joaquin Quinonero-Candela and Carl Edward Rasmussen. A unifying view of sparse approximate gaussian process regression. The Journal of Machine Learning Research, 6:1939–1959, 2005.
  • Rosca et al. (2020) Mihaela Rosca, Theophane Weber, Arthur Gretton, and Shakir Mohamed. A case for new neural network smoothness constraints. 2020.
  • Ruff et al. (2021) Lukas Ruff, Jacob R Kauffmann, Robert A Vandermeulen, Grégoire Montavon, Wojciech Samek, Marius Kloft, Thomas G Dietterich, and Klaus-Robert Müller. A unifying review of deep and shallow anomaly detection. Proceedings of the IEEE, 2021.
  • Salimbeni and Deisenroth (2017) Hugh Salimbeni and Marc Deisenroth. Doubly stochastic variational inference for deep gaussian processes. In Advances in Neural Information Processing Systems, pages 4588–4599, 2017.
  • Schlegl et al. (2019) Thomas Schlegl, Philipp Seeböck, Sebastian M Waldstein, Georg Langs, and Ursula Schmidt-Erfurth. f-anogan: Fast unsupervised anomaly detection with generative adversarial networks. Medical image analysis, 54:30–44, 2019.
  • Sokolić et al. (2017) Jure Sokolić, Raja Giryes, Guillermo Sapiro, and Miguel RD Rodrigues. Robust large margin deep neural networks. IEEE Transactions on Signal Processing, 65(16):4265–4280, 2017.
  • Staal et al. (2004) Joes Staal, Michael D Abràmoff, Meindert Niemeijer, Max A Viergever, and Bram Van Ginneken. Ridge-based vessel segmentation in color images of the retina. IEEE transactions on medical imaging, 23(4):501–509, 2004.
  • Sweeney et al. (2013) EM Sweeney, RT Shinohara, CD Shea, DS Reich, and Ciprian M Crainiceanu. Automatic lesion incidence estimation and detection in multiple sclerosis using multisequence longitudinal mri. American Journal of Neuroradiology, 34(1):68–73, 2013.
  • Tang (2019) Xiaoli Tang. The role of artificial intelligence in medical imaging research. BJR| Open, 2(1):20190031, 2019.
  • Titsias (2009) Michalis Titsias. Variational learning of inducing variables in sparse gaussian processes. In Artificial Intelligence and Statistics, pages 567–574, 2009.
  • Ustyuzhaninov et al. (2019) Ivan Ustyuzhaninov, Ieva Kazlauskaite, Markus Kaiser, Erik Bodin, Neill DF Campbell, and Carl Henrik Ek. Compositional uncertainty in deep gaussian processes. arXiv preprint arXiv:1909.07698, 2019.
  • van Amersfoort et al. (2020) Joost van Amersfoort, Lewis Smith, Yee Whye Teh, and Yarin Gal. Simple and scalable epistemic uncertainty estimation using a single deep deterministic neural network. arXiv preprint arXiv:2003.02037, 2020.
  • van Amersfoort et al. (2021) Joost van Amersfoort, Lewis Smith, Andrew Jesson, Oscar Key, and Yarin Gal. On feature collapse and deep kernel learning for single forward pass uncertainty. arXiv preprint arXiv:2102.11409, 2021.
  • Van der Wilk et al. (2017) Mark Van der Wilk, Carl Edward Rasmussen, and James Hensman. Convolutional gaussian processes. In Advances in Neural Information Processing Systems, pages 2849–2858, 2017.
  • van der Wilk et al. (2020) Mark van der Wilk, ST John, Artem Artemev, and James Hensman. Variational gaussian process models without matrix inverses. In Symposium on Advances in Approximate Bayesian Inference, pages 1–9. PMLR, 2020.
  • Weng et al. (2018) Tsui-Wei Weng, Huan Zhang, Pin-Yu Chen, Jinfeng Yi, Dong Su, Yupeng Gao, Cho-Jui Hsieh, and Luca Daniel. Evaluating the robustness of neural networks: An extreme value theory approach. arXiv preprint arXiv:1801.10578, 2018.
  • Williams and Barber (1998) Christopher KI Williams and David Barber. Bayesian classification with gaussian processes. IEEE Transactions on Pattern Analysis and Machine Intelligence, 20(12):1342–1351, 1998.
  • Wilson and Izmailov (2020) Andrew Gordon Wilson and Pavel Izmailov. Bayesian deep learning and a probabilistic perspective of generalization. arXiv preprint arXiv:2002.08791, 2020.
  • Wilson et al. (2016) Andrew Gordon Wilson, Zhiting Hu, Ruslan Salakhutdinov, and Eric P Xing. Deep kernel learning. In Artificial Intelligence and Statistics, pages 370–378, 2016.
  • Zhou et al. (2021) S Kevin Zhou, Hayit Greenspan, Christos Davatzikos, James S Duncan, Bram Van Ginneken, Anant Madabhushi, Jerry L Prince, Daniel Rueckert, and Ronald M Summers. A review of deep learning in medical imaging: Imaging traits, technology trends, case studies with progress highlights, and future promises. Proceedings of the IEEE, 2021.

A Proving Lipschitz bounds in a DistGP layer

We here prove Propositions 3 and 5 of Sec. 3.3.


Lemmas on p-norms/

We have the following relations between norms : x2x1subscriptnorm𝑥2subscriptnorm𝑥1\|x\|_{2}\leq\|x\|_{1}∥ italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ ∥ italic_x ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and x1Dx2subscriptnorm𝑥1𝐷subscriptnorm𝑥2\|x\|_{1}\leq\sqrt{D}\|x\|_{2}∥ italic_x ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ≤ square-root start_ARG italic_D end_ARG ∥ italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. Will be used for the proof of Proposition 2.



Proof of Proposition 3.

Throughout this proof we shall refer to the first two moments of a Gaussian distribution by m()𝑚m(\cdot)italic_m ( ⋅ ), v()𝑣v(\cdot)italic_v ( ⋅ ). Explicitly writing the Wasserstein-2 distances of the inequality we get:

|m(F(μ))m(F(ν))|2+|v(F(μ))v(F(ν)|2L|m1m2|2+|Σ1Σ2|2|m(F(\mu))-m(F(\nu))|^{2}+|v(F(\mu))-v(F(\nu)|^{2}\leq L|m_{1}-m_{2}|^{2}+|% \Sigma_{1}-\Sigma_{2}|^{2}| italic_m ( italic_F ( italic_μ ) ) - italic_m ( italic_F ( italic_ν ) ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + | italic_v ( italic_F ( italic_μ ) ) - italic_v ( italic_F ( italic_ν ) | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_L | italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + | roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - roman_Σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(103)

We focus on the mean part and applying Cauchy–Schwarz we get the following inequality:

|[KμuKνu]Kuu1m|2KμuKνu22Kuu1m22superscriptdelimited-[]subscript𝐾𝜇𝑢subscript𝐾𝜈𝑢superscriptsubscript𝐾𝑢𝑢1𝑚2superscriptsubscriptnormsubscript𝐾𝜇𝑢subscript𝐾𝜈𝑢22superscriptsubscriptnormsuperscriptsubscript𝐾𝑢𝑢1𝑚22|\left[K_{\mu u}-K_{\nu u}\right]K_{uu}^{-1}m|^{2}\leq\|K_{\mu u}-K_{\nu u}\|_% {2}^{2}\|K_{uu}^{-1}m\|_{2}^{2}| [ italic_K start_POSTSUBSCRIPT italic_μ italic_u end_POSTSUBSCRIPT - italic_K start_POSTSUBSCRIPT italic_ν italic_u end_POSTSUBSCRIPT ] italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_m | start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ ∥ italic_K start_POSTSUBSCRIPT italic_μ italic_u end_POSTSUBSCRIPT - italic_K start_POSTSUBSCRIPT italic_ν italic_u end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ italic_K start_POSTSUBSCRIPT italic_u italic_u end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_m ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(104)

To simplify the problem and without loss of generality we consider Uzsubscript𝑈𝑧U_{z}italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT to be a sufficient statistic for the set of inducing points Z𝑍Zitalic_Z. Expanding the first term of the r.h.s. we get:

[σ2expW2(μ,Uz)l2σ2expW2(ν,Uz)l2]2superscriptdelimited-[]superscript𝜎2subscript𝑊2𝜇subscript𝑈𝑧superscript𝑙2superscript𝜎2subscript𝑊2𝜈subscript𝑈𝑧superscript𝑙22\left[\sigma^{2}\exp{\frac{-W_{2}(\mu,U_{z})}{l^{2}}}-\sigma^{2}\exp{\frac{-W_% {2}(\nu,U_{z})}{l^{2}}}\right]^{2}[ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_exp divide start_ARG - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_μ , italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG - italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_exp divide start_ARG - italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_ν , italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(105)

We assume ν=μ+h𝜈𝜇\nu=\mu+hitalic_ν = italic_μ + italic_h, where h𝒩(|m1m2|,|Σ1Σ2|)similar-to𝒩subscript𝑚1subscript𝑚2subscriptΣ1subscriptΣ2h\sim\mathcal{N}(|m_{1}-m_{2}|,|\Sigma_{1}-\Sigma_{2}|)italic_h ∼ caligraphic_N ( | italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | , | roman_Σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - roman_Σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT | ) and μ𝜇\muitalic_μ is a high density point in the data manifold, hence W2(μUz)=0subscript𝑊2𝜇subscript𝑈𝑧0W_{2}(\mu-U_{z})=0italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_μ - italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT ) = 0. We denote m(h)2+var(h)2=λ𝑚superscript2𝑣𝑎𝑟superscript2𝜆m(h)^{2}+var(h)^{2}=\lambdaitalic_m ( italic_h ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_v italic_a italic_r ( italic_h ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_λ. Considering the general equality log(xy)=log(x)+log(y)+log(1y1x)𝑥𝑦𝑥𝑦1𝑦1𝑥\log(x-y)=\log(x)+\log(y)+\log(\frac{1}{y}-\frac{1}{x})roman_log ( italic_x - italic_y ) = roman_log ( italic_x ) + roman_log ( italic_y ) + roman_log ( divide start_ARG 1 end_ARG start_ARG italic_y end_ARG - divide start_ARG 1 end_ARG start_ARG italic_x end_ARG ) and applying it to our case we get that:

logm(F(μ))m(F(ν))2superscriptdelimited-∣∣𝑚𝐹𝜇𝑚𝐹𝜈2\displaystyle\log\mid m(F(\mu))-m(F(\nu))\mid^{2}roman_log ∣ italic_m ( italic_F ( italic_μ ) ) - italic_m ( italic_F ( italic_ν ) ) ∣ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPTlog[σ2σ2expλl2]2\displaystyle\leq\log\left[\sigma^{2}-\sigma^{2}\exp\frac{{-\lambda}}{l^{2}}% \right]^{2}≤ roman_log [ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_exp divide start_ARG - italic_λ end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(106)
2logσ22λl2+2log[expλl21]absent2superscript𝜎22𝜆superscript𝑙22𝜆superscript𝑙21\displaystyle\leq 2\log\sigma^{2}-2\frac{\lambda}{l^{2}}+2\log\left[\exp{\frac% {\lambda}{l^{2}}}-1\right]≤ 2 roman_log italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 divide start_ARG italic_λ end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + 2 roman_log [ roman_exp divide start_ARG italic_λ end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG - 1 ](107)
2log[σ2expλl2]absent2superscript𝜎2𝜆superscript𝑙2\displaystyle\leq 2\log\left[\sigma^{2}\exp{\frac{\lambda}{l^{2}}}\right]≤ 2 roman_log [ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_exp divide start_ARG italic_λ end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ](108)

We have the general inequality expx1+x+x2𝑥1𝑥superscript𝑥2\exp{x}\leq 1+x+x^{2}roman_exp italic_x ≤ 1 + italic_x + italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT for x1.79𝑥1.79x\leq 1.79italic_x ≤ 1.79, which for 0x10𝑥10\leq x\leq 10 ≤ italic_x ≤ 1 can be modified as expx1+2x𝑥12𝑥\exp{x}\leq 1+2xroman_exp italic_x ≤ 1 + 2 italic_x. Applying this new inequality and taking the exponential we now obtain:

m(F(μ))m(F(ν))2superscriptdelimited-∣∣𝑚𝐹𝜇𝑚𝐹𝜈2\displaystyle\mid m(F(\mu))-m(F(\nu))\mid^{2}∣ italic_m ( italic_F ( italic_μ ) ) - italic_m ( italic_F ( italic_ν ) ) ∣ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT[σ2+2σ2λl2]2absentsuperscriptdelimited-[]superscript𝜎22superscript𝜎2𝜆superscript𝑙22\displaystyle\leq\left[\sigma^{2}+2\sigma^{2}\frac{\lambda}{l^{2}}\right]^{2}≤ [ italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT divide start_ARG italic_λ end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(109)
σ4+σ4λl2+4σ4(λ)2l4absentsuperscript𝜎4superscript𝜎4𝜆superscript𝑙24superscript𝜎4superscript𝜆2superscript𝑙4\displaystyle\leq\sigma^{4}+\sigma^{4}\frac{\lambda}{l^{2}}+4\sigma^{4}\frac{(% \lambda)^{2}}{l^{4}}≤ italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT + italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT divide start_ARG italic_λ end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG + 4 italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT divide start_ARG ( italic_λ ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT end_ARG(110)
16σ4λl2absent16superscript𝜎4𝜆superscript𝑙2\displaystyle\leq 16\sigma^{4}\frac{\lambda}{l^{2}}≤ 16 italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT divide start_ARG italic_λ end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG(111)

where the last inequality follows from the ball constraints made in the definition. We now move to the variance components of the Lipschitz bound, we notice that

v(F(μ))12v(F(ν))122superscriptdelimited-∣∣𝑣superscript𝐹𝜇12𝑣superscript𝐹𝜈122\displaystyle\mid v(F(\mu))^{\frac{1}{2}}-v(F(\nu))^{\frac{1}{2}}\mid^{2}∣ italic_v ( italic_F ( italic_μ ) ) start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT - italic_v ( italic_F ( italic_ν ) ) start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT ∣ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPTv(F(μ))12v(F(ν))12v(F(μ))12+v(F(ν))12absentdelimited-∣∣𝑣superscript𝐹𝜇12𝑣superscript𝐹𝜈12delimited-∣∣𝑣superscript𝐹𝜇12𝑣superscript𝐹𝜈12\displaystyle\leq\mid v(F(\mu))^{\frac{1}{2}}-v(F(\nu))^{\frac{1}{2}}\mid\mid v% (F(\mu))^{\frac{1}{2}}+v(F(\nu))^{\frac{1}{2}}\mid≤ ∣ italic_v ( italic_F ( italic_μ ) ) start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT - italic_v ( italic_F ( italic_ν ) ) start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT ∣ ∣ italic_v ( italic_F ( italic_μ ) ) start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT + italic_v ( italic_F ( italic_ν ) ) start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT ∣(112)
v(F(μ))v(F(ν))absentdelimited-∣∣𝑣𝐹𝜇𝑣𝐹𝜈\displaystyle\leq\mid v(F(\mu))-v(F(\nu))\mid≤ ∣ italic_v ( italic_F ( italic_μ ) ) - italic_v ( italic_F ( italic_ν ) ) ∣(113)

which after applying Cauchy–Schwarz results in an upper bound of the form:

Kμ,UzKν,Uz22KUz1(KUzS)KUz12superscriptsubscriptnormsubscript𝐾𝜇subscript𝑈𝑧subscript𝐾𝜈subscript𝑈𝑧22subscriptnormsuperscriptsubscript𝐾subscript𝑈𝑧1subscript𝐾subscript𝑈𝑧𝑆superscriptsubscript𝐾subscript𝑈𝑧12\|K_{\mu,U_{z}}-K_{\nu,U_{z}}\|_{2}^{2}\|K_{U_{z}}^{-1}(K_{U_{z}-S})K_{U_{z}}^% {-1}\|_{2}∥ italic_K start_POSTSUBSCRIPT italic_μ , italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_K start_POSTSUBSCRIPT italic_ν , italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ italic_K start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_K start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT - italic_S end_POSTSUBSCRIPT ) italic_K start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT(114)

Using that Kμ,UzKν,Uz2216σ4λl2superscriptsubscriptnormsubscript𝐾𝜇subscript𝑈𝑧subscript𝐾𝜈subscript𝑈𝑧2216superscript𝜎4𝜆superscript𝑙2\|K_{\mu,U_{z}}-K_{\nu,U_{z}}\|_{2}^{2}\leq\frac{16\sigma^{4}\lambda}{l^{2}}∥ italic_K start_POSTSUBSCRIPT italic_μ , italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_K start_POSTSUBSCRIPT italic_ν , italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ divide start_ARG 16 italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT italic_λ end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG we obtain that:

|v(F(μ))v(F(ν))|16σ4λl2KUz1(KUzS)KUz12𝑣𝐹𝜇𝑣𝐹𝜈16superscript𝜎4𝜆superscript𝑙2subscriptnormsuperscriptsubscript𝐾subscript𝑈𝑧1subscript𝐾subscript𝑈𝑧𝑆superscriptsubscript𝐾subscript𝑈𝑧12|v(F(\mu))-v(F(\nu))|\leq\frac{16\sigma^{4}\lambda}{l^{2}}\|K_{U_{z}}^{-1}(K_{% U_{z}}-S)K_{U_{z}}^{-1}\|_{2}| italic_v ( italic_F ( italic_μ ) ) - italic_v ( italic_F ( italic_ν ) ) | ≤ divide start_ARG 16 italic_σ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT italic_λ end_ARG start_ARG italic_l start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∥ italic_K start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( italic_K start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUBSCRIPT - italic_S ) italic_K start_POSTSUBSCRIPT italic_U start_POSTSUBSCRIPT italic_z end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT(115)

Now taking into consideration both the upper bounds on the mean and variance components we arrive at the desired Lipschitz constant.





Proof of Proposition 5.

Using the definition for Wasserstein-2 distances for the l.h.s of the inequality, we can re-express as follows:

W2(f(μ),f(ν))m1Am2A22+(σ12A2)1/2(σ22A2)1/2F2subscript𝑊2𝑓𝜇𝑓𝜈superscriptsubscriptnormsubscript𝑚1𝐴subscript𝑚2𝐴22superscriptsubscriptnormsuperscriptsuperscriptsubscript𝜎12superscript𝐴212superscriptsuperscriptsubscript𝜎22superscript𝐴212𝐹2W_{2}\left(f(\mu),f(\nu)\right)\leq\|m_{1}A-m_{2}A\|_{2}^{2}+\|(\sigma_{1}^{2}% A^{2})^{1/2}-(\sigma_{2}^{2}A^{2})^{1/2}\|_{F}^{2}italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_f ( italic_μ ) , italic_f ( italic_ν ) ) ≤ ∥ italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_A - italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∥ ( italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT - ( italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(116)

which after rearranging terms and noticing that inside the Frobenius norm we have scalars, becomes:

W2(f(μ),f(ν))(m1m2)A22+[σ12A2)1/2(σ22A2)1/2]2W_{2}\left(f(\mu),f(\nu)\right)\leq\|(m_{1}-m_{2})A\|_{2}^{2}+[\sigma_{1}^{2}A% ^{2})^{1/2}-(\sigma_{2}^{2}A^{2})^{1/2}]^{2}italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( italic_f ( italic_μ ) , italic_f ( italic_ν ) ) ≤ ∥ ( italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + [ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT - ( italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(117)

We can now apply the Cauchy–Schwarz inequality for the part involving means and multiplying the right hand side with C𝐶\sqrt{C}square-root start_ARG italic_C end_ARG, which represents the number of channels, we get:

(m1m2)A22+[σ12A2)1/2(σ22A2)1/2]2m1m222CA22+C[σ12A2)1/2(σ22A2)1/2]2\|(m_{1}-m_{2})A\|_{2}^{2}+[\sigma_{1}^{2}A^{2})^{1/2}-(\sigma_{2}^{2}A^{2})^{% 1/2}]^{2}\leq\|m_{1}-m_{2}\|_{2}^{2}\sqrt{C}\|A\|_{2}^{2}+\sqrt{C}[\sigma_{1}^% {2}A^{2})^{1/2}-(\sigma_{2}^{2}A^{2})^{1/2}]^{2}∥ ( italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + [ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT - ( italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ ∥ italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_m start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT square-root start_ARG italic_C end_ARG ∥ italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + square-root start_ARG italic_C end_ARG [ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT - ( italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(118)

We can notice that the Lipschitz constant for the component involving mean terms is CA22𝐶superscriptsubscriptnorm𝐴22\sqrt{C}\|A\|_{2}^{2}square-root start_ARG italic_C end_ARG ∥ italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. Hence, we try to prove that the same L is also available for the variance terms component. Hence, we can affirm that:

L=CA22C[σ12A2)1/2(σ22A2)1/2]2[σ1σ2]2CA22L=\sqrt{C}\|A\|_{2}^{2}\leftrightarrow\sqrt{C}[\sigma_{1}^{2}A^{2})^{1/2}-(% \sigma_{2}^{2}A^{2})^{1/2}]^{2}\leq[\sigma_{1}-\sigma_{2}]^{2}\sqrt{C}\|A\|_{2% }^{2}italic_L = square-root start_ARG italic_C end_ARG ∥ italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ↔ square-root start_ARG italic_C end_ARG [ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT - ( italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ [ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT square-root start_ARG italic_C end_ARG ∥ italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(119)

By virtue of Cauchy–Schwarz we have the following inequality C[σ1Aσ2A]2[σ1σ2]2CA22𝐶superscriptdelimited-[]subscript𝜎1𝐴subscript𝜎2𝐴2superscriptdelimited-[]subscript𝜎1subscript𝜎22𝐶superscriptsubscriptnorm𝐴22\sqrt{C}[\sigma_{1}A-\sigma_{2}A]^{2}\leq[\sigma_{1}-\sigma_{2}]^{2}\sqrt{C}\|% A\|_{2}^{2}square-root start_ARG italic_C end_ARG [ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_A - italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_A ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ [ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT square-root start_ARG italic_C end_ARG ∥ italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. Hence the aforementioned if and only if statement will hold if we prove that

C[(σ12A2)12(σ22A2)12]2C[σ1Aσ2A]2𝐶superscriptdelimited-[]superscriptsuperscriptsubscript𝜎12superscript𝐴212superscriptsuperscriptsubscript𝜎22superscript𝐴2122𝐶superscriptdelimited-[]subscript𝜎1𝐴subscript𝜎2𝐴2\sqrt{C}\left[(\sigma_{1}^{2}A^{2})^{\frac{1}{2}}-(\sigma_{2}^{2}A^{2})^{\frac% {1}{2}}\right]^{2}\leq\sqrt{C}\left[\sigma_{1}A-\sigma_{2}A\right]^{2}square-root start_ARG italic_C end_ARG [ ( italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT - ( italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_A start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ square-root start_ARG italic_C end_ARG [ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_A - italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_A ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(120)

which after expressing in terms of norms becomes:

C[σ1A2σ2A2]2C[σ1A1σ2A1]2𝐶superscriptdelimited-[]subscriptnormsubscript𝜎1𝐴2subscriptnormsubscript𝜎2𝐴22𝐶superscriptdelimited-[]subscriptnormsubscript𝜎1𝐴1subscriptnormsubscript𝜎2𝐴12\displaystyle\sqrt{C}\left[\|\sigma_{1}A\|_{2}-\|\sigma_{2}A\|_{2}\right]^{2}% \leq\sqrt{C}\left[\|\sigma_{1}A\|_{1}-\|\sigma_{2}A\|_{1}\right]^{2}square-root start_ARG italic_C end_ARG [ ∥ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT - ∥ italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ square-root start_ARG italic_C end_ARG [ ∥ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_A ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - ∥ italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_A ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(121)

Expanding the square brackets gives:

C[σ1A22+σ2A222σ1A2σ2A2]C[σ1A12+σ2A122σ1A1σ2A1]𝐶delimited-[]superscriptsubscriptnormsubscript𝜎1𝐴22superscriptsubscriptnormsubscript𝜎2𝐴222subscriptnormsubscript𝜎1𝐴2subscriptnormsubscript𝜎2𝐴2𝐶delimited-[]superscriptsubscriptnormsubscript𝜎1𝐴12superscriptsubscriptnormsubscript𝜎2𝐴122subscriptnormsubscript𝜎1𝐴1subscriptnormsubscript𝜎2𝐴1\displaystyle\sqrt{C}\left[\|\sigma_{1}A\|_{2}^{2}+\|\sigma_{2}A\|_{2}^{2}-2\|% \sigma_{1}A\|_{2}\|\sigma_{2}A\|_{2}\right]\leq\sqrt{C}\left[\|\sigma_{1}A\|_{% 1}^{2}+\|\sigma_{2}A\|_{1}^{2}-2\|\sigma_{1}A\|_{1}\|\sigma_{2}A\|_{1}\right]square-root start_ARG italic_C end_ARG [ ∥ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∥ italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 ∥ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∥ italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] ≤ square-root start_ARG italic_C end_ARG [ ∥ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_A ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∥ italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_A ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 ∥ italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT italic_A ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∥ italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_A ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ](122)

This inequality holds by applying the p-norm lemma, thereby the if and only if statement is satisfied. Consequently, the Lipschitz constant is CA22𝐶superscriptsubscriptnorm𝐴22\sqrt{C}\|A\|_{2}^{2}square-root start_ARG italic_C end_ARG ∥ italic_A ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.



B Deriving function contraction requirements in DistGP Layers

We here prove Proposition 7 of Sec. 3.4.


Proof of Proposition 7.

We are interested in determining the specific scenarios in which the function space collapses to constant values. Hence we explicitly write 𝔼[ul(x)ul(x*)22ul1]𝔼delimited-[]conditionalsuperscriptsubscriptnormsubscript𝑢𝑙𝑥subscript𝑢𝑙superscript𝑥22subscript𝑢𝑙1\mathbb{E}\left[\|u_{l}(x)-u_{l}(x^{*})\|_{2}^{2}\mid u_{l-1}\right]blackboard_E [ ∥ italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∣ italic_u start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ] as:

=j=1Dl𝔼[ulj(x)ulj(x*)22ul1]absentsuperscriptsubscript𝑗1subscript𝐷𝑙𝔼delimited-[]conditionalsuperscriptsubscriptnormsuperscriptsubscript𝑢𝑙𝑗𝑥superscriptsubscript𝑢𝑙𝑗superscript𝑥22subscript𝑢𝑙1\displaystyle=\sum\limits_{j=1}^{D_{l}}\mathbb{E}\left[\|u_{l}^{j}(x)-u_{l}^{j% }(x^{*})\|_{2}^{2}\mid u_{l-1}\right]= ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT blackboard_E [ ∥ italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∣ italic_u start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ](123)
=j=1Dl𝔼[(ulj(x))2ul1]2𝔼[ulj(x)ulj(x*)ul1]+𝔼[(ulj(x*))2ul1]absentsuperscriptsubscript𝑗1subscript𝐷𝑙𝔼delimited-[]conditionalsuperscriptsuperscriptsubscript𝑢𝑙𝑗𝑥2subscript𝑢𝑙12𝔼delimited-[]conditionalsuperscriptsubscript𝑢𝑙𝑗𝑥superscriptsubscript𝑢𝑙𝑗superscript𝑥subscript𝑢𝑙1𝔼delimited-[]conditionalsuperscriptsuperscriptsubscript𝑢𝑙𝑗superscript𝑥2subscript𝑢𝑙1\displaystyle=\sum\limits_{j=1}^{D_{l}}\mathbb{E}\left[\left(u_{l}^{j}(x)% \right)^{2}\mid u_{l-1}\right]-2\mathbb{E}\left[u_{l}^{j}(x)u_{l}^{j}(x^{*})% \mid u_{l-1}\right]+\mathbb{E}\left[\left(u_{l}^{j}(x^{*})\right)^{2}\mid u_{l% -1}\right]= ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT blackboard_E [ ( italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∣ italic_u start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ] - 2 blackboard_E [ italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x ) italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ italic_u start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ] + blackboard_E [ ( italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∣ italic_u start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ](124)
=j=1Dlσl2+ml2(x)2ml(x)ml(x*)2kW2[μl(x),μl(x*)]+σl2+ml2(x*)absentsuperscriptsubscript𝑗1subscript𝐷𝑙subscriptsuperscript𝜎2𝑙superscriptsubscript𝑚𝑙2𝑥2subscript𝑚𝑙𝑥subscript𝑚𝑙superscript𝑥2superscript𝑘subscript𝑊2subscript𝜇𝑙𝑥subscript𝜇𝑙superscript𝑥superscriptsubscript𝜎𝑙2superscriptsubscript𝑚𝑙2superscript𝑥\displaystyle=\sum\limits_{j=1}^{D_{l}}\sigma^{2}_{l}+m_{l}^{2}(x)-2m_{l}(x)m_% {l}(x^{*})-2k^{W_{2}}\left[\mu_{l}(x),\mu_{l}(x^{*})\right]+\sigma_{l}^{2}+m_{% l}^{2}(x^{*})= ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT + italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_x ) - 2 italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) - 2 italic_k start_POSTSUPERSCRIPT italic_W start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT [ italic_μ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) , italic_μ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ] + italic_σ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT )(125)
=j=1Dl[mlj(x)mlj(x*)]2+2σl22σl2exp[ml1j(x)ml1j(x*)22ll2]absentsuperscriptsubscript𝑗1subscript𝐷𝑙superscriptdelimited-[]superscriptsubscript𝑚𝑙𝑗𝑥superscriptsubscript𝑚𝑙𝑗superscript𝑥22superscriptsubscript𝜎𝑙22superscriptsubscript𝜎𝑙2delimited-[]superscriptdelimited-∣∣superscriptsubscript𝑚𝑙1𝑗𝑥superscriptsubscript𝑚𝑙1𝑗superscript𝑥22superscriptsubscript𝑙𝑙2\displaystyle=\sum\limits_{j=1}^{D_{l}}\left[m_{l}^{j}(x)-m_{l}^{j}(x^{*})% \right]^{2}+2\sigma_{l}^{2}-2\sigma_{l}^{2}\exp{-\left[\frac{\mid m_{l-1}^{j}(% x)-m_{l-1}^{j}(x^{*})\mid^{2}}{2l_{l}^{2}}\right]}= ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT [ italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x ) - italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 italic_σ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 italic_σ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_exp - [ divide start_ARG ∣ italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x ) - italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_l start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ](126)

, where in the last equation we have ignored the variance part of the Wasserstein-2 kernel since the two variance terms are equal. We make use of the following inequality 1expxx1𝑥𝑥1-\exp{-x}\leq x1 - roman_exp - italic_x ≤ italic_x for x0𝑥0x\geq 0italic_x ≥ 0 and equality only in the case that x=0𝑥0x=0italic_x = 0, resulting in the following upper bound:

𝔼[ul(x)ul(x*)22|ul1]j=1Dl[mlj(x)mlj(x*)]2+σl2ml1j(x)ml1j(x*)2ll2𝔼delimited-[]conditionalsuperscriptsubscriptnormsubscript𝑢𝑙𝑥subscript𝑢𝑙superscript𝑥22subscript𝑢𝑙1superscriptsubscript𝑗1subscript𝐷𝑙superscriptdelimited-[]superscriptsubscript𝑚𝑙𝑗𝑥superscriptsubscript𝑚𝑙𝑗superscript𝑥2superscriptsubscript𝜎𝑙2superscriptdelimited-∣∣superscriptsubscript𝑚𝑙1𝑗𝑥superscriptsubscript𝑚𝑙1𝑗superscript𝑥2superscriptsubscript𝑙𝑙2\mathbb{E}\left[\|u_{l}(x)-u_{l}(x^{*})\|_{2}^{2}|u_{l-1}\right]\leq\sum% \limits_{j=1}^{D_{l}}\left[m_{l}^{j}(x)-m_{l}^{j}(x^{*})\right]^{2}+\sigma_{l}% ^{2}\frac{\mid m_{l-1}^{j}(x)-m_{l-1}^{j}(x^{*})\mid^{2}}{l_{l}^{2}}blackboard_E [ ∥ italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | italic_u start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ] ≤ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT [ italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x ) - italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT divide start_ARG ∣ italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x ) - italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_l start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG(127)

We can now view the previously defined operator ml(x)Wl¯¯subscript𝑚𝑙𝑥subscript𝑊𝑙\overline{m_{l}(x)W_{l}}over¯ start_ARG italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG as an inner product in vector space between a tiled version of ml(x)subscript𝑚𝑙𝑥m_{l}(x)italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) and a normalised version of Wlsubscript𝑊𝑙W_{l}italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT, more specifically:

[ml(x),,ml(x)],[Wl,1ml,,Wl,mlDl1ml]=ml(x)Wl¯subscript𝑚𝑙𝑥subscript𝑚𝑙𝑥subscript𝑊𝑙1subscript𝑚𝑙subscript𝑊𝑙subscript𝑚𝑙subscript𝐷𝑙1subscript𝑚𝑙¯subscript𝑚𝑙𝑥subscript𝑊𝑙\langle\left[m_{l}(x),\cdots,m_{l}(x)\right],\left[\frac{W_{l,1}}{m_{l}},% \cdots,\frac{W_{l,m_{l}D_{l-1}}}{m_{l}}\right]\rangle=\overline{m_{l}(x)W_{l}}⟨ [ italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) , ⋯ , italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) ] , [ divide start_ARG italic_W start_POSTSUBSCRIPT italic_l , 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , ⋯ , divide start_ARG italic_W start_POSTSUBSCRIPT italic_l , italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ] ⟩ = over¯ start_ARG italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG(128)

where mlsubscript𝑚𝑙m_{l}italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT is the number of dimensions caused by the affine embedding function ΨlsubscriptΨ𝑙\Psi_{l}roman_Ψ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT in the l-th layer of the hierarchy.

Our current goal is to relate mlj()superscriptsubscript𝑚𝑙𝑗m_{l}^{j}(\cdot)italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( ⋅ ) to ml1j()superscriptsubscript𝑚𝑙1𝑗m_{l-1}^{j}(\cdot)italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( ⋅ ). We can now apply Cauchy-Schwarz to:

mlj(x)mlj(x*)2=ml1(x)Wl¯ml1(x*)Wl¯2superscriptdelimited-∣∣superscriptsubscript𝑚𝑙𝑗𝑥superscriptsubscript𝑚𝑙𝑗superscript𝑥2superscriptdelimited-∣∣¯subscript𝑚𝑙1𝑥subscript𝑊𝑙¯subscript𝑚𝑙1superscript𝑥subscript𝑊𝑙2\displaystyle\mid m_{l}^{j}(x)-m_{l}^{j}(x^{*})\mid^{2}=\mid\overline{m_{l-1}(% x)W_{l}}-\overline{m_{l-1}(x^{*})W_{l}}\mid^{2}∣ italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x ) - italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = ∣ over¯ start_ARG italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x ) italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG - over¯ start_ARG italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ∣ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(129)
=[ml1(x)ml1(x*),,ml1(x)ml1(x*)],[Wl,1ml,,Wl,mlDl1ml]2absentsuperscriptdelimited-∣∣subscript𝑚𝑙1𝑥subscript𝑚𝑙1superscript𝑥subscript𝑚𝑙1𝑥subscript𝑚𝑙1superscript𝑥subscript𝑊𝑙1subscript𝑚𝑙subscript𝑊𝑙subscript𝑚𝑙subscript𝐷𝑙1subscript𝑚𝑙2\displaystyle=\mid\langle\left[m_{l-1}(x)-m_{l-1}(x^{*}),\cdots,m_{l-1}(x)-m_{% l-1}(x^{*})\right],\left[\frac{W_{l,1}}{m_{l}},\cdots,\frac{W_{l,m_{l}D_{l-1}}% }{m_{l}}\right]\rangle\mid^{2}= ∣ ⟨ [ italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x ) - italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) , ⋯ , italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x ) - italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ] , [ divide start_ARG italic_W start_POSTSUBSCRIPT italic_l , 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , ⋯ , divide start_ARG italic_W start_POSTSUBSCRIPT italic_l , italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ] ⟩ ∣ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(130)
Dl1ml[ml1(x)ml1(x*)]2*Wl~,Wl~absentsubscript𝐷𝑙1subscript𝑚𝑙superscriptdelimited-[]subscript𝑚𝑙1𝑥subscript𝑚𝑙1superscript𝑥2~subscript𝑊𝑙~subscript𝑊𝑙\displaystyle\leq D_{l-1}m_{l}\left[m_{l-1}(x)-m_{l-1}(x^{*})\right]^{2}*% \langle\tilde{W_{l}},\tilde{W_{l}}\rangle≤ italic_D start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT [ italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x ) - italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT * ⟨ over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ⟩(131)

where in the last line we denoted Wl~=[Wl,1Dl,,Wl,Dl1mlml]~subscript𝑊𝑙subscript𝑊𝑙1subscript𝐷𝑙subscript𝑊𝑙subscript𝐷𝑙1subscript𝑚𝑙subscript𝑚𝑙\tilde{W_{l}}=[\frac{W_{l,1}}{D_{l}},\cdots,\frac{W_{l,D_{l-1}m_{l}}}{m_{l}}]over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG = [ divide start_ARG italic_W start_POSTSUBSCRIPT italic_l , 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , ⋯ , divide start_ARG italic_W start_POSTSUBSCRIPT italic_l , italic_D start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ] to avoid cluttering.

We can now apply the previous result to equation (127):

𝔼[ul(x)ul(x*)22|ul1]\displaystyle\mathbb{E}\left[\|u_{l}(x)-u_{l}(x^{*})\|_{2}^{2}|\mid u_{l-1}\right]blackboard_E [ ∥ italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | ∣ italic_u start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ]j=1DlmlDl1(ml1(x)ml1(x*))2*Wl~,Wl~absentsuperscriptsubscript𝑗1subscript𝐷𝑙subscript𝑚𝑙subscript𝐷𝑙1superscriptsubscript𝑚𝑙1𝑥subscript𝑚𝑙1superscript𝑥2~subscript𝑊𝑙~subscript𝑊𝑙\displaystyle\leq\sum\limits_{j=1}^{D_{l}}m_{l}D_{l-1}\left(m_{l-1}(x)-m_{l-1}% (x^{*})\right)^{2}*\langle\tilde{W_{l}},\tilde{W_{l}}\rangle≤ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x ) - italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT * ⟨ over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ⟩(132)
+σl2ll2ml1j(x)ml1j(x*)2superscriptsubscript𝜎𝑙2superscriptsubscript𝑙𝑙2superscriptdelimited-∣∣superscriptsubscript𝑚𝑙1𝑗𝑥superscriptsubscript𝑚𝑙1𝑗superscript𝑥2\displaystyle\hskip 28.45274pt+\frac{\sigma_{l}^{2}}{l_{l}^{2}}\mid m_{l-1}^{j% }(x)-m_{l-1}^{j}(x^{*})\mid^{2}+ divide start_ARG italic_σ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_l start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∣ italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x ) - italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT
j=1Dl[mlDl1*Wl~,Wl~+σl2ll2]ml1j(x)ml1j(x*)2absentsuperscriptsubscript𝑗1subscript𝐷𝑙delimited-[]subscript𝑚𝑙subscript𝐷𝑙1~subscript𝑊𝑙~subscript𝑊𝑙superscriptsubscript𝜎𝑙2superscriptsubscript𝑙𝑙2superscriptdelimited-∣∣superscriptsubscript𝑚𝑙1𝑗𝑥superscriptsubscript𝑚𝑙1𝑗superscript𝑥2\displaystyle\leq\sum\limits_{j=1}^{D_{l}}\left[m_{l}D_{l-1}*\langle\tilde{W_{% l}},\tilde{W_{l}}\rangle+\frac{\sigma_{l}^{2}}{l_{l}^{2}}\right]\mid m_{l-1}^{% j}(x)-m_{l-1}^{j}(x^{*})\mid^{2}≤ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_POSTSUPERSCRIPT [ italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT * ⟨ over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ⟩ + divide start_ARG italic_σ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_l start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ] ∣ italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x ) - italic_m start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_j end_POSTSUPERSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT(133)

We can now recursively apply the previously derived Cauchy-Schwarz based inequality to obtain:

𝔼[ul(x)ul(x*)22|{ul}l=1l1]𝔼delimited-[]conditionalsuperscriptsubscriptdelimited-∣∣delimited-∣∣subscript𝑢𝑙𝑥subscript𝑢𝑙superscript𝑥22superscriptsubscriptsubscript𝑢𝑙𝑙1𝑙1\displaystyle\mathbb{E}\left[\mid\mid u_{l}(x)-u_{l}(x^{*})\mid\mid_{2}^{2}|\{% u_{l}\}_{l=1}^{l-1}\right]blackboard_E [ ∣ ∣ italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ ∣ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT | { italic_u start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT ][mlDl1*Wl~,Wl~+σl22ll2]l=1l1DlmlDl1Wl~,Wl~absentdelimited-[]subscript𝑚𝑙subscript𝐷𝑙1~subscript𝑊𝑙~subscript𝑊𝑙superscriptsubscript𝜎𝑙22superscriptsubscript𝑙𝑙2superscriptsubscriptproduct𝑙1𝑙1subscript𝐷𝑙subscript𝑚𝑙subscript𝐷𝑙1~subscript𝑊𝑙~subscript𝑊𝑙\displaystyle\leq\left[m_{l}D_{l-1}*\langle\tilde{W_{l}},\tilde{W_{l}}\rangle+% \frac{\sigma_{l}^{2}}{2l_{l}^{2}}\right]\prod\limits_{l=1}^{l-1}D_{l}m_{l}D_{l% -1}\langle\tilde{W_{l}},\tilde{W_{l}}\rangle≤ [ italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT * ⟨ over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ⟩ + divide start_ARG italic_σ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_l start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ] ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ⟨ over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ⟩(134)
[m1(x)m1(x*)]2superscriptdelimited-[]subscript𝑚1𝑥subscript𝑚1superscript𝑥2\displaystyle\hskip 28.45274pt\left[m_{1}(x)-m_{1}(x^{*})\right]^{2}[ italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) - italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

By Markov’s inequality, for any ϵ>0italic-ϵ0\epsilon>0italic_ϵ > 0 we have that:

P(ul+1(x)ul+1(x*)2ϵ)𝑃subscriptdelimited-∣∣delimited-∣∣subscript𝑢𝑙1𝑥subscript𝑢𝑙1superscript𝑥2italic-ϵ\displaystyle P\left(\mid\mid u_{l+1}(x)-u_{l+1}(x^{*})\mid\mid_{2}\geq% \epsilon\right)italic_P ( ∣ ∣ italic_u start_POSTSUBSCRIPT italic_l + 1 end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_l + 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ ∣ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≥ italic_ϵ )1ϵ2[mlDl1*Wl~,Wl~+σl22ll2]l=1l1DlmlDl1Wl~,Wl~absent1superscriptitalic-ϵ2delimited-[]subscript𝑚𝑙subscript𝐷𝑙1~subscript𝑊𝑙~subscript𝑊𝑙superscriptsubscript𝜎𝑙22superscriptsubscript𝑙𝑙2superscriptsubscriptproduct𝑙1𝑙1subscript𝐷𝑙subscript𝑚𝑙subscript𝐷𝑙1~subscript𝑊𝑙~subscript𝑊𝑙\displaystyle\leq\frac{1}{\epsilon^{2}}\left[m_{l}D_{l-1}*\langle\tilde{W_{l}}% ,\tilde{W_{l}}\rangle+\frac{\sigma_{l}^{2}}{2l_{l}^{2}}\right]\prod\limits_{l=% 1}^{l-1}D_{l}m_{l}D_{l-1}\langle\tilde{W_{l}},\tilde{W_{l}}\rangle≤ divide start_ARG 1 end_ARG start_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG [ italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT * ⟨ over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ⟩ + divide start_ARG italic_σ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_l start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ] ∏ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_l - 1 end_POSTSUPERSCRIPT italic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ⟨ over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ⟩(135)
[m1(x)m1(x*)]2superscriptdelimited-[]subscript𝑚1𝑥subscript𝑚1superscript𝑥2\displaystyle\hskip 28.45274pt\left[m_{1}(x)-m_{1}(x^{*})\right]^{2}[ italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x ) - italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ] start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT

Then, only in the case that [mlDl1*Wl~,Wl~+σl22ll2]1delimited-[]subscript𝑚𝑙subscript𝐷𝑙1~subscript𝑊𝑙~subscript𝑊𝑙superscriptsubscript𝜎𝑙22superscriptsubscript𝑙𝑙21\left[m_{l}D_{l-1}*\langle\tilde{W_{l}},\tilde{W_{l}}\rangle+\frac{\sigma_{l}^% {2}}{2l_{l}^{2}}\right]\leq 1[ italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT * ⟨ over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ⟩ + divide start_ARG italic_σ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 italic_l start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ] ≤ 1 and DlmlDl1Wl~,Wl~subscript𝐷𝑙subscript𝑚𝑙subscript𝐷𝑙1~subscript𝑊𝑙~subscript𝑊𝑙absentD_{l}m_{l}D_{l-1}\langle\tilde{W_{l}},\tilde{W_{l}}\rangle\leqitalic_D start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_m start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_D start_POSTSUBSCRIPT italic_l - 1 end_POSTSUBSCRIPT ⟨ over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG , over~ start_ARG italic_W start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG ⟩ ≤ is satisfied for intermediate layers, we can apply the first Borel-Cantelli lemma to obtain:

P(l=1m=lum(x)um(x*)2ϵ)=0P\left(\cap_{l=1}^{\infty}\cup_{m=l}^{\infty}\mid\mid u_{m}(x)-u_{m}(x^{*})% \mid\mid_{2}\geq\epsilon\right)=0italic_P ( ∩ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ∪ start_POSTSUBSCRIPT italic_m = italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ∣ ∣ italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ ∣ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≥ italic_ϵ ) = 0(136)

Lastly, we can express the following:

P(un(x)un(x*)20)𝑃subscriptdelimited-∣∣delimited-∣∣subscript𝑢𝑛𝑥subscript𝑢𝑛superscript𝑥20\displaystyle P\left(\mid\mid u_{n}(x)-u_{n}(x^{*})\mid\mid_{2}\to 0\right)italic_P ( ∣ ∣ italic_u start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ ∣ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT → 0 )=P(k=1l=1m=lum(x)um(x*)21k)\displaystyle=P\left(\cap_{k=1}^{\infty}\cup_{l=1}^{\infty}\cap_{m=l}^{\infty}% \mid\mid u_{m}(x)-u_{m}(x^{*})\mid\mid_{2}\leq\frac{1}{k}\right)= italic_P ( ∩ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ∪ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ∩ start_POSTSUBSCRIPT italic_m = italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ∣ ∣ italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ ∣ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ divide start_ARG 1 end_ARG start_ARG italic_k end_ARG )(137)
=1P(k=1l=1m=lum(x)um(x*)21k)\displaystyle=1-P\left(\cup_{k=1}^{\infty}\cap_{l=1}^{\infty}\cup_{m=l}^{% \infty}\mid\mid u_{m}(x)-u_{m}(x^{*})\mid\mid_{2}\geq\frac{1}{k}\right)= 1 - italic_P ( ∪ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ∩ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ∪ start_POSTSUBSCRIPT italic_m = italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ∣ ∣ italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ ∣ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≥ divide start_ARG 1 end_ARG start_ARG italic_k end_ARG )(138)
1k=1P(l=1m=lum(x)um(x*)21k)=1\displaystyle\geq 1-\sum\limits_{k=1}^{\infty}P\left(\cap_{l=1}^{\infty}\cup_{% m=l}^{\infty}\mid\mid u_{m}(x)-u_{m}(x^{*})\mid\mid_{2}\geq\frac{1}{k}\right)=1≥ 1 - ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT italic_P ( ∩ start_POSTSUBSCRIPT italic_l = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ∪ start_POSTSUBSCRIPT italic_m = italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT ∣ ∣ italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ ∣ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≥ divide start_ARG 1 end_ARG start_ARG italic_k end_ARG ) = 1(139)

From which we obtain the proof of our proposition, respectively P(un(x)un(x*)20)=1𝑃subscriptdelimited-∣∣delimited-∣∣subscript𝑢𝑛𝑥subscript𝑢𝑛superscript𝑥201P\left(\mid\mid u_{n}(x)-u_{n}(x^{*})\mid\mid_{2}\to 0\right)=1italic_P ( ∣ ∣ italic_u start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_x ) - italic_u start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ( italic_x start_POSTSUPERSCRIPT * end_POSTSUPERSCRIPT ) ∣ ∣ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT → 0 ) = 1




C Outlier detection in MRI scans with Tumors.

Remarks.

We provide additional plots for the task investigated in sec. 5.1.3 for DistGP-Seg and OVA-DM as they were the only models to provide decent outlier detection capabilities. We refer the reader to Figures 17 and 18. From case study A, we can see that OVA-DM is over-segmenting across all FPR levels almost randomly from outside the tumor area, whereas DistGP-Seg is over-segmenting at FPR={1.0,5.0}𝐹𝑃𝑅1.05.0FPR=\{1.0,5.0\}italic_F italic_P italic_R = { 1.0 , 5.0 } areas around the margins of the ventricles. For case study B, at FPR={0.5,1.0,5.0}𝐹𝑃𝑅0.51.05.0FPR=\{0.5,1.0,5.0\}italic_F italic_P italic_R = { 0.5 , 1.0 , 5.0 } OVA-DM seems to be under-segmenting the tumor in comparison to DistGP-Seg. The same observation can be made again for case study C. Lastly, for case study D DistGP-Seg seems to be under-segmenting for FPR={0.1,0.5}𝐹𝑃𝑅0.10.5FPR=\{0.1,0.5\}italic_F italic_P italic_R = { 0.1 , 0.5 }.


Refer to caption
Figure 17: Detailed segmentation output for DistGP-Seg on BRATS. Mean segmentation represents the hard segmentation of brain tissues. OOD measure is the quantification of uncertainty for each model, using their own procedure. Higher values translate to appartenance to outlier status.
Refer to caption
Figure 18: Detailed segmentation output for OVA-DM on BRATS. Mean segmentation represents the hard segmentation of brain tissues. OOD measure is the quantification of uncertainty for each model, using their own procedure. Higher values translate to appartenance to outlier status.

D Outlier detection in MRI scans with WMH.

Data and pre-processing.

Brain MRI scans from the 2015 Longitudinal Multiple Sclerosis Lesion Segmentation Challenge Sweeney et al. (2013) which comprises of FLAIR, PD, T2-weighted, and T1- weighted volumes from a total of 110 MR imaging studies (11 longitudinal studies each of 10 subjects). All participants gave written consent and were scanned as part of an institutional review board approved natural history protocol. For the purposes of the task at hand, we only use the baseline FLAIR scans. All FLAIR images are pre-processed with skull-stripping, N4 bias correction, rigid registration to MNI152 space and histogram matching between UKBB and BraTS. Finally, we normalize intensities of each scan via linear scaling of its minimum and maximum intensities to the [-1,1] range.


Remarks.

The task of detecting white matter hyperintensities (WMH) is considerably more difficult than detecting tumors, the latter usually presenting itself as a large blob, whereas the former constitutes of multiple non-contiguous areas of varying shapes. From Figure 19 we can notice that the large connected WMH regions are reliably detected as outliers, with smaller disconnected WMH regions being only in some cases outlined. Another issue is over-segmentation, as seen in case study C.

Refer to caption
Figure 19: Detailed segmentation output for DistGP-Seg on WMH dataset. Mean segmentation represents the hard segmentation of brain tissues. OOD measure is the quantification of uncertainty for each model, using their own procedure. Higher values translate to appartenance to outlier status.

E Evaluation on Retina scans.

Data and pre-processing.

DRIVE: The Digital Retinal Images for Vessel Extraction (DRIVE) dataset Staal et al. (2004) is a dataset for retinal vessel segmentation. It consists of a total of JPEG 40 color fundus images; including 7 abnormal pathology cases. The images were obtained from a diabetic retinopathy screening program in the Netherlands. The images were acquired using Canon CR5 non-mydriatic 3CCD camera with FOV equals to 45 degrees. Each image resolution is 584*565 pixels with eight bits per color channel (3 channels).

The set of 40 images was equally divided into 20 images for the training set and 20 images for the testing set. Inside both sets, for each image, there is circular field of view (FOV) mask of diameter that is approximately 540 pixels. Inside training set, for each image, one manual segmentation by an ophthalmological expert has been applied. Inside testing set, for each image, two manual segmentations have been applied by two different observers, where the first observer segmentation is accepted as the ground-truth for performance evaluation.

STARE: STructured Analysis of the Retina (STARE) database Hoover et al. (2000) was created by scanning and digitizing the retinal image photographs. Hence, the image quality of this database is less than the other public databases. The STARE dataset comprises 97 images (59 AMD and 38 normal) taken using a fundus camera (TOPCON TRV-50; Topcon Corp., Tokyo, Japan) at a 35°35°35\degree35 ° field and with a resolution of 605×700605700605\times 700605 × 700 pixels. Its retina scans are from subjects suffering from the following retina pathologies: Hollenhorst Emboli Branch Retinal Artery Occlusion, Cilio-Retinal Artery Occlusion, Branch Retinal Vein Occlusion, Central Retinal Vein Occlusion, Hemi-Central Retinal Vein Occlusion, Background Diabetic Retinopathy, Proliferative Diabetic Retinopathy, Arteriosclerotic Retinopathy, Hypertensive Retinopathy, Coat’s, Macroaneurism, Choroidal Neovascularization.

IDRID: The Indian Diabetic Retinopathy Image Dataset (IDRID) dataset Porwal et al. (2018), is a publicly available retinal fundus image database consisting of 516 images categorised into two parts: retina images with the signs of Diabetic Retinopathy and/or Diabetic Macular Edema; normal retinal images. Images were acquired using a Kowa VX-10a digital fundus camera with 50°50°50\degree50 ° field of view (FOV). The images have resolution of 4288×2848428828484288\times 28484288 × 2848 pixels and are stored in jpg file format. We have pre-processed these images to match the FOV and resolution of DRIVE.


Task.

We train a similar DistGP-Seg architecture (see sec. 5) adapted for 2D data on DRIVE (normative data) to segment blood vessels, subsequently at testing time we use STARE and IDRID (OOD data) to segment blood vessels in the presence of previouslt unseen pathologies on Retina scans.


Blood vessel segmentation on normal Retina scans.

From Figure 20 we can observe that DistGP-Seg manages to correctly segment the blood vessels, while both distributional and within-data uncertainty are relatively low, which is to be expected as these testing examples represent in-distribution data.

Refer to caption
Figure 20: Detailed segmentation output for DistGP-Seg DRIVE dataset. Mean segmentation represents the hard segmentation of brain tissues. OOD measure is the quantification of uncertainty for each model, using their own procedure. Higher values translate to appartenance to outlier status.
Outlier detection of Retina pathologies.

From Figure 21 we can observe that DistGP-Seg manages to correctly identity the vast majority of soft and hard exudates as outliers.

Refer to caption
Figure 21: Detailed segmentation output for DistGP-Seg on STARE/IDRID datasets. Mean segmentation represents the hard segmentation of brain tissues. OOD measure is the quantification of uncertainty for each model, using their own procedure. Higher values translate to appartenance to outlier status. Case studies A-E originate from STARE, whereas the remainder from IDRID.