1 Introduction

Deep Neural Networks (DNN) have shown promising results in various medical applications, but highly depend on the amount and the diversity of training data [10]. In the context of medical imaging, this is particularly challenging since the required training data may not be available in a single institution due to the low incidence rate of some pathologies and limited numbers of patients. At the same time, it is often infeasible to collect and share patient data in a centralised data lake due to medical data privacy regulations.

One recent method that tackles this problem is Federated Learning (FL) [6, 8]: it allows collaborative and decentralised training of DNNs without sharing the patient data. Each node trains its own local model and, periodically, submits it to a parameter server. The server accumulates and aggregates the individual contributions to yield a global model, which is then shared with all nodes. It should be noted that the training data remains private to each node and is never shared during the learning process. Only the model’s trainable weights or updates are shared, thus keeping patient data private. Consequently, FL succinctly sidesteps many of the data security challenges by leaving the data where they are and enables multi-institutional collaboration.

Although FL can provide a high level of security in terms of privacy, it is still vulnerable to misuse such as reconstructions of the training examples by model inversion. One effective countermeasure is to inject noise to each node’s training process, distort the updates and limit the granularity of information shared among them [1, 9]. However, existing privacy-preserving research only focuses on general machine learning benchmarks such as MNIST, and uses vanilla stochastic gradient descent algorithms.

In this work, we implement and evaluate practical federated learning systems for brain tumour segmentation. Throughout a series of experiments on the BraTS 2018 data, we demonstrate the feasibility of privacy-preserving techniques. Our primary contributions are: (1) implement and evaluate, to the best of our knowledge, the first privacy-preserving federated learning system for medical image analysis; (2) compare and contrast various aspects of federated averaging algorithms for handling momentum-based optimisation and imbalanced training nodes; (3) empirically study the sparse vector technique for a strong differential privacy guarantee.

Fig. 1.
figure 1

Left: illustration of the federated learning system; right: distribution of the training subjects (\(N=242\)) across the participating federated clients (\(K=13\)) studied in this paper.

2 Method

We study FL systems based on a client-server architecture (illustrated in Fig. 1 (left)) implementing the federated averaging algorithm [6]. In this configuration, a centralised server maintains a global DNN model and coordinates clients’ local stochastic gradient descent (SGD) updates. This section presents the client-side model training procedure, the server-side model aggregation procedure, and the privacy-preserving module deployed on the client-side.

2.1 Client-Side Model Training

We assume each federated client has a fixed local dataset and reasonable computational resources to run mini-batch SGD updates. The clients also share the same DNN structure and loss functions. The proposed training procedure is listed in Algorithm 1. At federated round t, the local model is initialised by reading global model parameters \(W^{(t)}\) from the server, and is updated to \(W^{(l,t)}\) by running multiple iterations of SGD. After a fixed number of iterations \(N^{(local)}\), the model difference \(\varDelta W^{(t)}\) is shared with the aggregation server.

DNNs for medical image are often trained with a momentum-based SGD. Introducing the momentum terms takes the previous SGD steps into account when computing the current one. It can help accelerate the training and reduce oscillation. We explore the choices of design for handling these terms in FL. In the proposed Algorithm 1 (exemplified with Adam optimiser [4]), we re-initialise each client’s momentums at the beginning of each federated round (denoted as m. restart). Since local model parameters are initialised from the global ones, which aggregated information from other clients, the restarting operation effectively clears the clients’ local states that could interfere the training process. This is empirically compared with (a) clients maintaining a set of local momentum variables without sharing; denoted as baseline m. (b) treating the momentum variables as a part of the model, i.e., the variables are updated locally and aggregated by the server (denoted as m. aggregation). Although m. aggregation is theoretically plausible [11], it requires the momentums to be released to the server. This increases both communication overheads and data security risks.

figure a
figure b

2.2 Client-Side Privacy-Preserving Module

The client-side is designed to have full control over which data to share and local training data never leave the client’s site. Still, model inversion attacks such as [3] can potentially extract sensitive patient data from the update \(\varDelta W^{(t)}_k\) or the model \(W^{(t)}\) during federated training. We adopt a selective parameter update [9] and the sparse vector technique (SVT) [5] to provide strong protection against indirect data leakage.

Selective Parameter Sharing. The full model at the end of a client-side training process might have over-fitted and memorised local training examples. Sharing this model poses risks of revealing the training data. Selective parameter sharing methods limit the amount of information that a client shares. This is achieved by (1) only uploading a fraction of \(\varDelta W^{(t)}_k\): component \(w_i\) of \(\varDelta W^{(t)}_k\) will be shared iif \(abs(w_i)\) is greater than a threshold \(\tau ^{(t)}_k\); (2) further replacing \(\varDelta W^{(t)}_k\) by clipping the values to a fixed range \([-\gamma , \gamma ]\). Here abs(x) denotes the absolute value of x; \(\tau ^{(t)}_k\) is chosen by computing the percentile of \(abs(\varDelta W^{(t)}_k)\); \(\gamma \) is independent of specific training data and can be chosen via a small publicly available validation set before training. Gradient clipping is also applied, which is a widely-used method, acting as a model regulariser to prevent over-fitting.

Differential Privacy Module. The selective parameter sharing can be further improved by having a strong differential privacy guarantee using SVT. The procedure of selecting and sharing distorted components of \(w_i\) is described in Algorithm 2. Intuitively, instead of simply thresholding \(abs(\varDelta W^{(t)}_k)\) and sharing its components \(w_i\), every sharing \(w_i\) is controlled by the Laplacian mechanism. This is implemented by first comparing a clipped and noisy version of \(abs(w_i)\) with a noisy threshold \(\tau ^{(t)} + Lap(s/\varepsilon _2)\) (Line 8, Algorithm 2), and then only sharing a noisy answer \(clip(w_i + Lap(qs/\varepsilon _3), \gamma )\), if the thresholding condition is satisfied. Here Lap(x) denotes a random variable sampled from the Laplace distribution parameterised by x; \(clip(x, \gamma )\) denotes clipping of x to be in the range of \([-\gamma , \gamma ]\); s denotes the sensitivity of the federated gradient which is bounded by \(\gamma \) in this case [9]. The selection procedure is repeated until q fraction of \(\varDelta W^{(t)}_k\) is released. This procedure satisfies \((\varepsilon _1+\varepsilon _2+\varepsilon _3)\)-differential privacy [5].

2.3 Server-Side Model Aggregation

The server distributes a global model and receives synchronised updates from all clients at each federated round (Algorithm 3). Different clients may have different numbers of local iterations at round t, thus the contributions from the clients could be SGD updates at different training speeds. It is important to require an \(N^{(local)}\) from the clients, and weight the contributions when aggregating them (Line 8, Algorithm 3). In the case of partial model sharing, utilising the sparse property of \(\varDelta W^{(t)}_k\) to reduce the communication overheads is left for future work.

figure c

3 Experiments

This section describes the experimental setup, including the common hyper-parameters used for each FL system.

Data Preparation. The BraTS 2018 dataset [2] contains multi-parametric pre-operative MRI scans of 285 subjects with brain tumours. Each subject was scanned with four modalities, i.e. (1) T1-weighted, (2) T1-weighted with contrast enhancement, (3) T2-weighted, and (4) T2 fluid-attenuated inversion recovery (T2-FLAIR). Each subject was associated with voxel-level annotations of “whole tumour”, “tumour core”, and “enhancing tumour”. For details of the imaging and annotation protocols, we refer the readers to Bakas et al. [2]. The dataset was previously used for benchmarking machine learning algorithms and is publicly available. We use it to evaluate the FL algorithms on the multi-modal and multi-class segmentation task. For the client-side local training, we adapted the state-of-the-art training pipeline originally designed for data-centralised training [7] and implemented as a part of the NVIDIA Clara Train SDKFootnote 1.

To test the generalisation ability across the subjects, we randomly split the dataset into a model training set (\(N=242\) subjects) and a held-out test set (\(N=43\) subjects). The scans were collected from thirteen institutions with different equipment and imaging protocols, and thus heterogeneous image feature distributions. To make our federated setup realistic, we further stratified the training set into thirteen disjoint subsets, according to where the image data were originated and assigned each to a federated client. The setup is challenging for FL algorithms, because (1) each client only processes data from a single institution, which potentially suffers from more severe domain-shift and over-fitting issues compared with a data-centralised training; (2) it reflects the highly imbalanced nature of the dataset (shown in Fig. 1).

Federated Model Setup. The evaluation of the FL procedures is perpendicular to the choice of convolutional network architectures. Without loss of generality, we chose the segmentation backbone of [7] as the underlying federated model and used the same set of local training hyperparameters for all experiments: the input image window size of the network was \(224\times 224\times 128\) voxels, and spatial dropout ratio of the first convolutional layer was 0.2. Similarly to [7], we minimised a soft Dice loss using Adam [4] with a learning rate of \(10^{-4}\), batch size of 1, \(\beta _1\) of 0.9, \(\beta _2\) of 0.999, and \(\ell _2\) weight decay coefficient of \(10^{-5}\). For all federated training, we set the number of federated rounds to 300 with two local epochs per federated round. A local epoch is defined as every client “sees” its local training examples exactly once. At the beginning of each epoch, data were shuffled locally for each client. For a comparison of model convergences, we also train a data-centralised baseline for 600 epochs.

In terms of computational costs, the segmentation model has about \(1.2\times 10^6\) parameters; a training iteration with an NVIDIA Tesla V100 GPU took \(0.85\,\text {s}\).

Evaluation Metrics. We measure the segmentation performance of the models on the held-out test set using mean-class Dice score averaged over the three types of tumour regions and all testing subjects. For the FL systems, we report the performance of the global model shared among the federated clients.

Privacy-Preserving Setup. The selective parameter updates module has two system parameters: fraction of the model q and the gradient clipping value \(\gamma \). We report model performance by varying both. For differential privacy, we fixed \(\gamma \) to \(10^{-4}\), the sensitivity s to \(2\gamma \), and \(\varepsilon _2\) to \((2qs)^{\frac{2}{3}}\varepsilon _1\) according to [5]. The model performance by varying q, \(\varepsilon _1\), and \(\varepsilon _3\) are reported in the next section.

Fig. 2.
figure 2

Comparison of segmentation performance on the test set with (left): FL vs. non-FL training, and (right): partial model sharing.

4 Results

Federated vs. Data-Centralised Training. The FL systems are compared with the data-centralised training in Fig. 2 (left). The proposed FL procedure can achieve a comparable segmentation performance without sharing clients’ data. In terms of training time, the data-centralised model converged at about 300 training epochs, FL training at about 600. In our experiments, an epoch of data-centralised training (\(N=242\)) with an NVIDIA Tesla V100 GPU takes 0.85 s \(\times 242 =\) 205.70 s per epoch. The FL training time was determined by the slowest client (\(N=77\)), which takes 0.85 s \(\times \) 77 = 65.45 s plus small overheads for client-server communication.

Momentum Restarting and Weighted Averaging. Figure 2 (left) also compares variants of the FL procedure. For the treatment of momentum variables, restarting them at each federated round outperforms all the other variants. This suggests (1) each client maintaining an independent set of momentum variables slows down the convergence of the federated model; (2) averaging the momentum variables across clients improved the convergence speed over baseline m., but still gave a worse global model than the data-centralised model. On the server-side, weighted averaging of the model parameters outperforms the simple model averaging (i.e. \(W^{(t+1)} \leftarrow {\sum _k W^{(t+1)}_k}/{K}\)). This suggests that the weighted version can handle imbalanced numbers of iterations across the clients.

Partial Model Sharing. Figure 2 (right) compares partial model sharing by varying the fraction of the model to share and the gradient clipping values. The figure suggests that sharing larger proportions of models can achieve better performance. Partial model sharing does not affect the model convergence speed and the performance decrease can be almost negligible when only 40% of the full model is shared among the clients. Clipping of the gradient can, sometimes, improve the model performance. However, the value needs to be carefully tuned.

Differential Privacy Module. The model performances by varying differential privacy (DP) parameters are shown in Fig. 3. As expected, there is a trade-off between DP protection and model performance. Sharing 10% model showed better performance than sharing 40% under the same DP setup. This is due to the fact that the overall privacy costs \(\varepsilon \) are jointly defined by the amount of noise added and the number of parameters shared during training. By fixing the per-parameter DP costs, sharing fewer variables has less overall DP costs and thus better model performance.

Fig. 3.
figure 3

Comparison of segmentation models (ave. mean-class Dice score) by varying the privacy parameters: percentage of partial models, \(\varepsilon _1\), and \(\varepsilon _3\).

5 Conclusion

We propose a federated learning system for brain tumour segmentation. We studied various practical aspects of the federated model sharing with an emphasis on preserving patient data privacy. While a strong differential privacy guarantee is provided, the privacy cost allocation is conservative. In the future, we will explore differentially private SGD (e.g. [1]) for medical image analysis tasks.