Distributed training

Epistemic status: intended as persuasive argument, may somewhat overstate case and fail to steelman alternatives

Please note that most of what is described here is currently slideware. It is entirely possible that it is impracticable, or theoretically impossible.


Some time ago, I described some techniques for black box function optimization. However I was confused at the time, and also, these techniques relied on the functions being continuous, which is troublesome.

Consider the case of multilayer convolution networks of the sort often used for image classification. Traditionally, training such models requires end-to-end gradient descent via backpropagation. This requires floating-point weights and activations.

While Moore’s law is not dead, though perhaps halted in a strict sense of the term, and though there is much low-hanging fruit to be picked in novel hardware and OS design, for the foreseeable future, gradient descent via backpropagation of large, deep, models on large training sets, is computationally expensive.

The obvious solution to this is to split the work across many training nodes. Theoretically, distributed training is faster and cheaper. More workers can, in theory, perform the same work in linearly less time. Per core, small, spot/preemptible instances are significantly cheaper then large on demand instances. As of September 2020, AWS will rent us 1000 spot CPU cores for $3-$4 (2020) per hour. This price is somewhat greater if we want dedicated cores, but still, spot instances are cheap. By comparison, a singe AWS machine with 96 cores costs ~$5 (2020) per hour. If we are willing accept that our workers are individually small and may be killed with little warning, we can rent the same number of cores for well under a 1/10 of the price. In addition to total memory capacity, the total memory bandwidth of large numbers of machines, each with small local memory, is far greater then could be achieved over a single memory bus.

It is difficult to make fair comparisons between CPUs and various ASICs such as GPUs, TPUs, or FPGAs, but even if ASICs are more cost effective then commodity CPUs, more of them will, in theory, train faster. Often, a large cluster of small GPUs is cheaper then a small cluster of large GPUs of equivalent computationally capacity. Fast networking is more expensive then slow networking. NVSwitch is non trivially expensive.

For these reasons, it is highly desirable to make training efficiently distributable across large numbers of workers with slow networking.

However, distributing computation is often costly.

Each update, the parameter server must send out weights to each worker, and each worker must send back gradients. Weights and gradients are big. If we are to distribute training across many training nodes, the workers must have high-bandwidth connections to the parameter server. The number of serial updates required is, worst case, exponential with number of layers. Since updates must be serial, even if our interconnect is high-bandwidth, if it is also high-latency, we are doomed to slow training. No amount of parallelism can avoid communication latency in serial operations.

NVIDIA/Mellanox has been building hardware primarily targeted at ML training for some years now. But the A100 is less then an order of magnitude more powerful then the v100 of three years earlier.

We use f16 and we add skip connections, David Page develops a wonderful bag of tricks, but we are dissatisfied. Google TPU pods, GraphCore IPUs and the Cerebras Wafer-Scale Engine are very powerful hardware. They will scale, they will improve their interconnect, and devs will train vast, many thousand layer, models on them in record times, and yet still be bound upon the wheel of worst-case-exponential-with-depth number of serial updates.

Some day, we will have a sphere of computronium, with coolant flowing through it as fast as can be pumped, turning electricity into waste heat to produce computation, performing one update per light speed round trip across the sphere, and yet still being painfully slow to train our deepest models. When this happens, what will we do? If we want yet deeper models, we must simply wait longer.

Already, we are pushing increasingly hard against the practical limits of computation as implemented in current hardware. Long before we encounter the true limits of light speed and heat dissipation, we will encounter the limits of current packet routing and silicon based compute. As we push against reality, reality pushes back.

One may ask, “Do we really care if it costs a $1000 and few days to train a large model from scratch?”.

But consider that developing a model is not just a matter of training it, we must train all manner of poorly performing architectures first to discover the correct architecture. Many recent models could have been trained five, or even ten years earlier using merely a few weeks or months on a large compute cluster. But we would have had to know the exact hyperparameters to use.

If a developer wants to test 100 different model architectures, they must pay the training cost 100 times. Perhaps a developer can afford to pay $1,000 to train a model once, but can they afford to pay $100,000 to train 100 different architectures?

ML devs like to learn serially. Purely for the sake of developer convenience, fast training is nice, even if it costs somewhat more. Even if it is acceptable to spend a week training the model once, how are we to have discovered the architecture if we needed 100 weeks to test 100 ideas serially?

Training is computationally expensive for a number of reasons:

These problems exacerbate each other. Weights would be cheaper to transfer if they were smaller. If we could parallelize better, we would not care so much that training is computationally expensive. If single updates were faster and lower bandwidth usage, we would not mind so much that training requires many serial updates. High bandwidth usage per update would not matter so much if we had fewer updates.

Doing better

We will present here some techniques to train faster and cheaper:

None of these techniques are new. However we show that they facilitate and complement each other.

Binary/Ternary weights

Traditionally, weights are stored in 32-bit floating-point. This is needlessly high-precision. NNs are robust, we can use lower precision weights.

It is fairly standard practice to quantize weights to 8-bit ints for inference.

While 8-bit ints are nicer then 32-bit floating-point, still, they are big and expensive. Would it not be nice to be able to store 64 weights in one 64 bit word?

That would be truly wonderful.

It is possible to quantize weights to one bit.

This means that if we want to represent larger numbers, we must use unary encoding.

Or we can quantize weights to ternary, in which case we need two 64 bit words to store 64 trits:

Ternary weighted NNs are particularly suited for hard coding into FPGAs, reaping matchless performance:

The optimal precision for weights is fairly clearly less then 8-bits, and probably more then 1 bit; we suspect close to ternary. We make a hand wavy appeal to Radix economy, although we suspect that it might be inapplicable in this situation, and the optimal precision is instead dependent on properties of the dataset and/or model architecture.

Binary/ternary weights is all very nice, but it only works for inference on already trained models.

Training binary weights

Consider a simple single hidden layer NN with floating-point activations but low-precision weights. We compute the forward pass and obtain a loss, and backpropagate the loss across the first layer weights, and get some nice small updates. We apply the updates, but find that if we add 0.001 to -1, it makes no change. Unless our updates are so big as to be able to flip a -1 to a +1, they will vanish at weight update time.

There are two classes of solutions:

  1. Stochastic weight updates
  2. Evolutionary optimization

In approach 1, we compute gradients for each bit, and flip weights stochastically according to their gradient; if we can’t flip a bit by 0.01, we can flip it 1% of the time. This is the approach primarily used by current hybrid binarized training schemes.

However approach 2 is somewhat more interesting to us at this time.

Consider a black box function that takes a string of n bits and returns a loss. We can think of this function as an n-dimensional cube in which each corner has a loss. When we flip a bit, we move along one edge of the cube. If we flip all bits, we move to the opposite corner. The loss delta between any two adjacent corners represents the loss delta of a single bitflip. Since n will be many thousand, and it is somewhat expensive to observe the loss at any given corner, it is impractical to brute-force the function. However, in practice, if we chose a good class of functions such as a NN, the function is mostly convex [citation needed].

Imagine that we are on one corner of this cube. We move to the 0th adjacent corner and observe loss. If loss is lower then our own, stay; otherwise, return. Repeat for our current 1st adjacent corner, then the 2nd, etc. In this way, we spiral our way around, descending the loss surface.

But actually, all our 65535 friends’ are standing on the corresponding corner of their cubes. Each of our friends’ cubes has a slightly different loss. We need to descend the average of our cubes losses together, and it is slow to send messages to each other reporting observed loss.

Consider also, that while on a random corner of the cube, approximately half of the adjacent corners will have a loss lower; about half the gradient will be positive and the remaining half negative. However, as we descend the loss surface, our loss becomes anomalously low; the proportion of adjacent corners which are an improvement will tend to decrease. In other words, while a random bit string will have an approximately average loss, with the bit strings whose Hamming distance is one having losses greater or less in approximately equal proportions. But as we optimize the bit string, increasingly many of the n single bit mutations will cause regression to the mean.

Given that we need to descend in as few steps as possible, perhaps we can do better then testing one bit at a time.

Let us start with the current corner and observe the losses at all n adjacent corners. If we subtract from each adjacent corner the loss at the current corner, we get the loss deltas. We now have n edges, each with a loss delta. If we travel along any one of them, that is to say, we flip the corresponding bit of the n bit string, the loss will change accordingly. If we descend the loss surface one bit at a time, we are guaranteed to achieve a lower loss.

Do we need to traverse the cube of inputs one edge at a time? If we observe that two of the edges are good updates, can we move long both simultaneously, that is, flip both bits simultaneously?

If all bits are completely independent, yes; we can travel across all loss-reducing edges simultaneously and get a good result.

In practice, not all bits are completely independent for the class of functions which we wish to optimize, but still, there is significant independence. In practice, we can usually mutate more then one bit at a time.

But if we mutate too many bits, we will overshoot.

In conclusion, if the function is convenient to us, we can optimize binary weighted models. Ternary weights likely make it even easier.

Binary activations

As most every modern ML paper never tires of reiterating, in recent years, multilayer neural nets have achieved impressive results. And, as most every ML efficiency paper mentions, modern multilayer neural networks are computationally expensive to train.

Binary/ternary weights reduce storage costs, but if the activation is still floating-point, computation is only somewhat cheaper. Would it not be nice to not only store 64 weights in a 64 bit word, but perform 64 multiply-accumulate operations using only a single bitwise XOR, (in the case of trits, AND), POPCNT, and integer-add?

That too would be truly wonderful.

Tanh activation is commonly used and it is essential a smoothed sign function. It is therefore relatively simple to binarize activations for inference.

If our inputs are +/-1 and our weights are +/-1, we need only store a +1 as an unset bit and a -1 as a set bit. Now we XOR the inputs with the weights, 64 bits at a time, and count ones to sum. Instead of tanh, we use an integer compare, and bitpack into the output.

Integer compare, that is, the sign function, does not have a gradient. Therefore, while binary/ternary weighted models are relatively simple to train, models with binary activations are not.

Training across binary activations

Consider a simple single hidden layer NN with high-precision floating-point weights, but binary activations. We perform a forward pass, produce a loss, and then attempt to compute the loss gradient with respect to the first layer weights. We then find that all the weights have a null gradient because sign function activation killed the gradients.

But consider if the weights are binary. We observe loss deltas with respect to a single-bit mutation. A whole bitflip will move the popcount significantly. We have a non-zero chance of flipping the activation bit. Now, although many gradients will die, some will survive. We can train a single hidden layer model.

But we want deeper models. What if our model has two hidden layers, that is, two layers of binary activations? A single layer of binary activations killed the gradients on most of our weights. How much worse would two be? Or three or four? We want to train models with tens or hundreds of layers. Binary activations are not compatible with end-to-end training of models of any significant depth.

Previously deep models may have required exponentially many updates to train but they could be trained. Now, with binary activations, they are impossible to train. This is not an improvement. It would be truly wonderful if there was a way of reducing a many layer end-to-end training problem to a series of single hidden layer problems.

Layerwise training

We find ourselves in the middle of a vast many-dimensional loss surface. We need to find our way down. We cannot look into the distance, only observe the slope directly under our feet. If we move a short distance in the down hill direction, we will not get down very fast. If we move a long distance in the down hill direction, we may overshoot and end up higher then we started. Each update, we must move in a straight line.

To look at it differently, there exists an ideal smooth path down the loss surface. We must approximate it using a finite series of line segments. Each line segment must start in line with the gradient.

Now consider what happens when we add another layer of nonlinearity to our model. This loss surface crinkles and becomes more twisty. Our path down it require more line segments to approximate. We must perform more updates.

Very approximately, each additional layer of nonlinearity multiplies the number of updates required.

Due to ResNet, skip connections, good initialization and other techniques, it is often in practice less then truly exponential; still, it is superlinear.

Would it not be nice if we could avoid this?

That would be truly wonderful.

Greedy layerwise training

In particular, we would like to focus on Belilovsky et al. (2018). In it, the authors demonstrate that greedy optimization of a series of single hidden layer convolution models on image net produces respectably good results.

Consider the problem of training an n layer NN. If we know the values of the first n-1-layers we can pass the training set through it, cache the results, and train the final layer. This would be very nice. But, consider if instead we want to train the first layer. We need the layers after it. How else would we know if it is a good layer?

What makes a good layer? It needs to preserve and increase linear separability of information which is useful to later layers while discarding information not useful to later layers. In the case of a convolution layer, it needs to combine low level spatial features into larger and higher level spatial features.

Can we measure the goodness of a layer directly, without a large stack of layers on top of it?

According to Belilovsky et al. (2018), yes.

This is very nice. We have reduced an exponential problem to a linear one. Now, there will be some cost to accuracy. Layerwise training is strictly less powerful then end-to-end training. We suspect that it will fail utterly on pathological datasets. But empirically, on photographs of the real world, it works respectably well.

Consider another benefit of layerwise training. Imagine two models whose first n layers are identical, and which differ only in the final layer. The shared base layers need only be trained once.

Imagine we train a model, but then realize that we did the last layer wrong. It is cheap to retrain the last layer.

With layerwise training, the weights of a layer are strictly a function of the input data (which is a function of the layers before it and the data set), and its hyperparameters. It does not depend on the layers after it. We have a tree of layers. If we have skip connections, that is, pixelwise concatenation layers which take two input images, it is a DAG of layers. It is cheap to search through the space of higher layers; all lower layers can be shared.

Still, while layerwise training does reduce the number of updates, and even makes bandwidth usage per update constant with number of layers instead of linear, bandwidth usage is still large.

Sparse gradients

When Mirai compromised a million devices, each storing no doubt interesting data, it did not use its vast compute to train a big neural net. Mirai can bruteforce hashes and DDoS Dyn, but it cannot train a multilayer NN.

Consider data parallel gradient descent. We have, for example 65536 examples, one on each of the 65536 worker nodes. The parameter server sends the weights out to all workers, each worker computes a forward pass, and a backward pass, and sends gradients to the parameter server, which then averages gradients and updates the weights. This process must be repeated many thousands of times. Weights are big. Gradients are big.

If our model has 256 * 3 * 3 * 256 * 10 = 5,898,240 weights, this is 23 MB of gradients. Each of the 65536 worker must send its own version of the gradients back to the parameter server. In total, this is 1.5 terabytes which the parameter server must receive. If we need 2500 updates, this is 3.8 petabytes in total parameter server downlink traffic needed to train the model. If the model is larger, or we need more updates, this gets even larger. Quite apart from the bandwidth cost, transferring large chunks of data takes time. Now in practice, we will have a hierarchy of parameter servers summing gradients so that no one of them needs to receive parameters from all 65536 workers. Interposing a tree of summing servers will not improve per iteration latency. But still, total ingress traffic is quite impractically large for any seriously large set of workers outside of a single DC. The point of Mirai was to DDoS Dyn, not DDoS the C&C server.

Consider another vast botnet, this one run, not by a Minecraft booter service, but by Google. Google does train neural nets on their vast botnet.

Why does Googles C&C server not fall over? Become Google uses a number of optimizations to reduce bandwidth usage, primarily, gradient sparsity.

Just as we can make our weights smaller by reducing precision or by sparsifying them, so to, we make our gradients smaller, lowering precision:

or sparsifying them:

Even in a normal full floating-point NN, many of the gradients are very close to 0. They can be represented sparsely with little cost to accuracy. How much more so when binary activations are already killing so many of our gradients?

We transmit only gradients of magnitude greater then threshold. If a weight has no gradient, this is no great matter; perhaps another example has a gradient for it, or perhaps as we wander about the parameter space it will come alive again.

Consider the matter differently. Imagine that we are very lucky and that each of our 65536 examples contains perfect gradients for 1/65536th part of the model, each covering a different portion of the weights. Now, when we sparsify, the sum of the workers bandwidth is as if there was but one worker sending full gradients.

Although not nicely orchestrated, this is approximately what we get by using extremely sparse gradients from many workers. When averaged together, they will form a fairly complete set of gradients and if a gradient was not to be found in any of the examples, then it was probably not a very important gradient.

Distributed training

Traditional end-to-end gradient descent of deep NNs requires much compute and, to distribute well, high bandwidth low latency interconnect.

We have now described how to reduce weights to one bit each and how to train them. We have also described how to train single hidden layer NNs with binary activations as well as binary weights. With greedy layerwise training, we can reduce our many layer deep training task to a serial of shallow training tasks that are amenable to binary activations and make number of updates linear instead of exponential, reducing the cost of latency. Finally, gradient sparsity reduces bandwidth usage per update.

With these components, we should be able to train large, deep, models across many tens of thousands of workers connected by low-bandwidth, high-latency, connections perhaps even in geographically distant data centers.

With all this in mind, we present:

Binary Learning


(Uninterested parties should feel free to skip over this section. It is not important for understanding other parts.)

For the sake of simplicity, let us assume the CIFAR10 dataset: 50,000 32x32 pixel images. And for the sake of simplicity, let as assume 64 input channels, being turned into 64 hidden channels. This image has 32*32=1024 pixels, and each pixel has 64 one bit channels.

Consider how a single training node will compute the gradients for a 3x3 convolution layer on a single image. For the sake of simplicity, let us assume single bit weights, although the same principal will mostly apply to ternary weights.

For a given pixel channel of the layers output, we can decide if there is a grad directly below, directly above, or not within range. That is to say, is the popcount of that pixels patch XORed with the patch weights of that channel within one bit of the threshold; if we were to increase the popcount by one, or decrease it by one, would the activation bit flip?

If the current popcount is more then one bit distant from the activation, all the gradients of that pixel channel are dead; mutating any one bit of that channels’ weights will cause no change to the pixels activations.

Assuming 64 * 3 * 3 = 576 bits per patch, although the sum will be normally distributed, still only a minority will be within one bit of the threshold and have a gradient. Once we know the one trit of information that is the gradient of the pixel channel, we can compute the per patch bit gradient.

But, although we know which bits of the pixel channel patch will affect it, we do not know the loss delta which they affect. Also, we have only been computing gradients within single pixel patches. This is a convolution layer, weights are shared, and so, gradients must be correspondingly summed.

Let us digress for a moment. Consider the space of weight mutations. There are 64 * 3 * 3 * 64 = 36864 weights and so, 36864 mutations. But, any weight mutations must be bottlenecked at the hidden layer activations. For 3x3 convolution with 64 channels, there are 64 * 2^(30 * 30) = a very big number states. This does not help.

Transforming weight mutation space into image mutation space does not help unless the image is extremely small.

However, the channel pixel mutations are quite sparse. Many pixel weight mutations have no gradient, they all map to the same pixel channel. In practice, the space of pixel channels which can be reached by weight mutations is likely significantly smaller then the space of weight mutations.

Even as many bits of the weights of one pixels patch are dead or identical and yet gain personality when summed with the gradients of other patches, so to, many weight bits in an image are dead or identical, and yet gain personality when summed with the gradient of other images.

Now, for each of the 64 channels, we have a sparse set of 32x32 bit strings.

For the current hidden activations, we can compute the sum activations of the auxiliary convolution layer. For each of the 64 channels we can compute the per class sum activations of just that channel. As a sanity check, the per class sums of the 64 per channel auxiliary convolution layer activations should be equal to the per class sums of the full word auxiliary convolution layer activations.

Now, for a given channel, we can subtract that channels’ activations from the total, and then, for each member of that channels’ sparse set of pixel bitmaps, add the per channel activations back.

Now that we have the per class activations for each location in pixel channel space, we can compute loss using a standard floating-point softmax loss function.

Once we have a list of (weight_index, gradient) pairs, we filter to abs(gradient) > threshold, and send this small list up to the parameter server. The parameter server combines the sparse gradients of all the workers back into dense gradients, removes the gradients whose weights are already maxed out, and broadcasts the indices of top k mutations to all worker nodes, who apply the mutations and repeat.

As an added optimization, if k is significantly smaller then the number of channels, one could imagine caching the non-mutated channels, that is, update only the channels which were mutated last update. However this will be broken by minibatching.

Already existing alternative solutions

There are already a variety of solutions to rapidly train large many layer NNs.

Or just use a TPU pod. TPUs have brute strength. Brute strength will solve all your problems.


All interested parties should consider themselves welcome to steal our ideas and implement them themselves before we have the time to do so.

For over two years, we have been attempting to make training cheaper/faster on commodity hardware. We invite interested parties to inspect the commit history to see our numerous and diverse failures.

We intend to implement and test the techniques which we have described here. If successful, we will train vast models across tens of thousands of cheap spot CPU instances, and we will achieve first place in training time and cost.

However, based on past experience, and assuming EMH to be true, it will not work, or will be only mildly preferable to existing techniques.


Even should these techniques work as expected, they will have many inherent limitations:


I thank Noah Walton for much insightful discussion.

I thank James Crook for many helpfully comments.

I thank Egan Ford for permitting me to spend time on this work.

All mistakes are my own.


Last updated on: Mon, Sep 7, 2020