Efficient computation of bit convolution loss deltas
Meta
This post describes the current Rust implementation of a particular component of binary NN training. It is likely only of interest to the intersection of Rust devs and ML devs. The implementation is likely to change in future. To follow along, see the correct commit.
The Rust source code shown here is not typical, idiomatic, Rust code. Its 9 type parameters encompass it about, its 46 lines of trait bounds blot out the sun. Most Rust code is far nicer. Please do not use it as an example of typical Rust
Abstract
Often, we desire to observe the loss deltas of all weight mutations of a convolution layer. We wish to perform this computation quickly, on commodity hardware, however it is computationally expensive. We describe and benchmark a number of different implementations, first achieving execution time quadratic with output and/or input size, then linear, and eventually barely sub linear with input size. Which implementation will have the best performance is non obvious. In attempting to improve performance, we utilize some somewhat godly optimizations, but ultimately fail to achieve acceptably good performance.
In a future post, we will describe further optimizations.
Introduction
Convolution layers are a standard component of image classification neural nets.
As previously discussed, training models with binary activations and weights has some very nice benefits. To train a binary weighed model, we need to be able to observe the loss deltas for each mutation. In other words, for each weight, if we set it to a new value, how would the loss change?
All benchmarks were carried out on a AMD Ryzen Threadripper 2950X 16 core processor with SMT disabled.
Benchmark table fields:
- version: The version of the implementation being tested
- threads: The number of worker threads simultaneously computing loss deltas. Each thread was given different examples.
- input pixel size: The number of bits per pixel of input. Patches are 3x3 times larger.
- output pixel size: The number of bits in the output pixel.
- n params: The number of params. This will be slightly larger than <input pixel size> _ 3 _ 3 * <output pixel size> because the objective head is included.
- ms per example: millisecond time, per example, per core,
- ns per pixel bit: nanosecond time, per bit of the input pixel, per core
- ns per channel: nanosecond time, per output channel, per core
- ns per parameter: nanosecond time per parameter, per core
Forward pass
The forward pass is simple.
<IS as Conv<<IPS as BitPack<bool>>::T, <OPS as BitPack<bool>>::T, PX, PY>>::conv(
input,
|patch| {
<OPS as PackedMap<<[[IPS; PY]; PX] as BitPack<W>>::T, bool>>::map(
&self.kernel,
|weights| <[[IPS; PY]; PX] as WeightArray<W>>::act(weights, &patch),
)
},
)
We use a convolution operation which takes an image of shape IS
, and returns an image the same shape.
The input image holds pixels of type:
<IPS as BitPack<bool>>::T
while the output image holds pixels of type:
<OPS as BitPack<bool>>::T
In both cases, the bool
contained in the pixels are efficiently packed using type magic such that they take only one bit each.
To the conv()
method on the Conv
trait on the IS
type, we pass input
and a function.
As previously mentioned, input
is of type
<IS as PixelPack<<IPS as BitPack<bool>>::T>>::I
The function takes a patch of type:
[[<IPS as BitPack<bool>>::T; PY]; PX]
where PY
and PX
are the dimensions of the patch, usually 3.
Inside this function, we perform a packed map operation:
<OPS as PackedMap<<[[IPS; PY]; PX] as BitPack<W>>::T, bool>>::map(
&self.kernel,
|weights| <[[IPS; PY]; PX] as WeightArray<W>>::act(weights, &patch),
)
We call the map()
method on the output pixel shape.
This function wants an array of shape OPS
, and a function that maps a:
[[IPS; PY]; PX] as BitPack<W>>::T
to a bool.
It will then pack the bool
s into a
<OPS as BitPack<bool>>::T
Implementations of loss deltas computation
We present multiple implementations of computing all non null weight mutation loss deltas. We assume that implementation 0 to be correct, and compare with its results, the results of the other implementations. For 10,000 random inputs, all implementations shown here produce the same results. This is significant evidence that they implement the same function. A formal proof of equivalence is being the scope of this page and of our current capabilities.
Implementation 0
The naive way to compute loss deltas is to mutate the weights, compute the loss, and subtract the null loss.
let null_loss = self.loss(input, class) as i64;
<Self::Weight as BitScaler>::states()
.iter()
.map(|&w| {
Self::indices()
.map(|i| {
(
i,
w,
self.mutate(i, w).loss(input, class) as i64 - null_loss,
)
})
.collect::<Vec<_>>()
})
.flatten()
.filter(|(_, _, l)| l.abs() as u64 > threshold)
.collect()
We compute the null loss and store it for later:
let null_loss = self.loss(input, class) as i64;
For each state of the weights,
<Self::Weight as BitScaler>::states().iter()
for each index of the weights,
Self::indices()
.map(|i| {
(
i,
w,
self.mutate(i, w).loss(input, class) as i64 - null_loss,
)
})
.collect::<Vec<_>>()
we set the i
th weight to the value of w
, and compute the loss for the new model.
self.mutate(i, w).loss(input, class)
Then we subtract null_loss
from it to get the loss delta.
This is very simple and easy to understand.
Perf
But let us examine the performance characteristics of this implementation. Each time we double output pixel size, example time increased by ~4, and time per parameter increases by ~2.
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
0 | 16 | 32 | 8 | 2384 | 118.432 | 3701327.623 | 14805310.492 | 49682.250 |
0 | 16 | 32 | 16 | 4768 | 551.552 | 17236356.369 | 34472712.737 | 115680.244 |
0 | 16 | 32 | 32 | 9536 | 2102.064 | 65689691.446 | 65689691.446 | 220435.206 |
0 | 16 | 32 | 64 | 19072 | 8461.376 | 264418687.694 | 132209343.847 | 443655.516 |
0 | 16 | 32 | 128 | 38144 | 36606.400 | 1143954888.040 | 285988722.010 | 959693.698 |
0 | 16 | 32 | 256 | 76288 | 132802.500 | 4150090041.562 | 518761255.195 | 1740809.581 |
Or consider how numbers change as we increase input pixel size. Here, the curve is somewhat messier, likely due to some combination of intermittent compiler optimizations and cache levels, but still, parameter time increases by, very approximately, a factor of 2 each time we double input size.
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
0 | 16 | 8 | 32 | 2624 | 464.848 | 58106625.976 | 14526656.494 | 177154.347 |
0 | 16 | 16 | 32 | 4928 | 981.216 | 61326093.455 | 30663046.727 | 199110.693 |
0 | 16 | 32 | 32 | 9536 | 2263.488 | 70734151.507 | 70734151.507 | 237362.925 |
0 | 16 | 64 | 32 | 18752 | 7282.031 | 113781979.748 | 227563959.495 | 388334.402 |
0 | 16 | 128 | 32 | 37184 | 25540.375 | 199534283.366 | 798137133.465 | 686865.003 |
0 | 16 | 256 | 32 | 74048 | 119042.500 | 465011068.965 | 3720088551.719 | 1607644.145 |
Time to observe a single mutation is linear both with output pixel size and input pixel size. Since number of weights is also linear with both output pixel size and input pixel size, total time to observe all loss deltas is quadratic with both input and output. This is undesirable. We can do significantly better.
Implementation 1
A single mutation to the weights matrix will only change one channel of the output.
Our objective head (stored in self.tail
) supports caching, we can feed it a single channel update and it will give us the loss delta more cheaply than recomputing from scratch.
First we extract the null_acts.
let null_acts = self.apply(input);
Then we use it to initialize the cache.
let cache = self.tail.cache(&null_acts, class);
Then, for each channel, we extract the one channel for each output pixel and store in null_chan_acts
,
let null_chan_acts = <IS as PixelMap<<OPS as BitPack<bool>>::T, bool>>::map(
&null_acts,
|pixel| OPS::get(&pixel, o),
);
and then subtract it from the cache.
let chan_cache = self.tail.subtract_input(&cache, o, &null_chan_acts);
We also extract a channel of the weights for later use.
let weights_channel = OPS::index_get(&self.kernel, o);
For each input, we extract the current weight bit, and use it to filter out null mutations.
Then for each weight and input index, we mutate the weights channel,
let new_weights_channel = <[[IPS; PY]; PX]>::set(*weights_channel, i, w);
and use it to compute the new channel acts.
let new_acts = <IS as Conv<
<IPS as BitPack<bool>>::T,
bool,
PX,
PY,
>>::conv(
input,
|patch| {
<[[IPS; PY]; PX] as WeightArray<W>>::act(
&new_weights_channel,
&patch,
)
},
);
Once we have the new acts, we can ask the cache for the loss delta,
let loss_delta = chan_cache.loss_delta(&new_acts);
and filter.
if loss_delta.abs() as u64 > threshold {
Some((LayerIndex::Head((o, i)), w, loss_delta))
} else {
None
}
Perf
As we double the number of channels, time per example increases by ~2, and time per parameters stays approximately constant. This is a nice improvement.
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
1 | 16 | 32 | 8 | 2384 | 6.624 | 207371.187 | 829484.748 | 2783.506 |
1 | 16 | 32 | 16 | 4768 | 12.816 | 400695.533 | 801391.066 | 2689.232 |
1 | 16 | 32 | 32 | 9536 | 25.648 | 801605.402 | 801605.402 | 2689.951 |
1 | 16 | 32 | 64 | 19072 | 51.688 | 1615237.516 | 807618.758 | 2710.130 |
1 | 16 | 32 | 128 | 38144 | 106.500 | 3329722.520 | 832430.630 | 2793.391 |
1 | 16 | 32 | 256 | 76288 | 212.062 | 6628561.961 | 828570.245 | 2780.437 |
However as we double input size, time per parameter tends to increases. It does not double smoothly, possibly due to some combination of intermittently applied compiler optimizations and/or cache levels.
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
1 | 16 | 8 | 32 | 2624 | 8.816 | 1102565.246 | 275641.312 | 3361.479 |
1 | 16 | 16 | 32 | 4928 | 20.928 | 1308267.392 | 654133.696 | 4247.621 |
1 | 16 | 32 | 32 | 9536 | 25.760 | 805499.404 | 805499.404 | 2703.018 |
1 | 16 | 64 | 32 | 18752 | 96.531 | 1508643.859 | 3017287.718 | 5148.955 |
1 | 16 | 128 | 32 | 37184 | 742.438 | 5800456.872 | 23201827.486 | 19967.149 |
1 | 16 | 256 | 32 | 74048 | 2698.438 | 10541000.500 | 84328004.002 | 36442.525 |
It is desirable that time per parameter be constant with both output and input. We can make it so.
Implementation 2
Implementation 2 is approximately the same as implementation 1 with the exception that that we are using the acts()
method from the WeightArray
trait on the patch [[IPS; PY]; PX]
.
Internally, acts()
performs powerful bitwise magic to compute all the activations were each of the weights individually set to a given state.
Consider the expression
<IS as Conv<
<IPS as BitPack<bool>>::T,
<[[IPS; PY]; PX] as BitPack<bool>>::T,
PX,
PY,
>>::conv(input, |patch| {
<[[IPS; PY]; PX] as WeightArray<W>>::acts(
weights_channel,
&patch,
w,
)
})
We perform a convolution over patches of input
.
The function which we are are mapping is
<[[IPS; PY]; PX] as WeightArray<W>>::acts(weights_channel, &patch, w)
This expression returns one bit for each weight weights_channel
, the activation of the patch if the bit were to be set to w
.
All the “multiplication” is being performed in a very efficient packed fashion, 32 bits at a time.
If the patch is dead, it will short circuit appropriately.
We are unable to skip null mutations however.
Once we have the activations for each patch, we need only re order the dimensions of the matrix. To achieve this, we map over the matrix, extracting the needed bit.
<IS as PixelMap<
<[[IPS; PY]; PX] as BitPack<bool>>::T,
bool,
>>::map(
&chan_acts,
|pixel| <[[IPS; PY]; PX]>::get(pixel, i),
)
This we can pass this image of bool pixels to the cache as before.
Let us examine the performance characteristics of this implementation.
As we double number of output channels, time per parameter stays constant, (or perhaps very slightly decreases, but this could just be an artifact of overhead).
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
2 | 16 | 32 | 8 | 2384 | 2.470 | 77294.393 | 309177.574 | 1037.509 |
2 | 16 | 32 | 16 | 4768 | 4.928 | 154059.696 | 308119.391 | 1033.958 |
2 | 16 | 32 | 32 | 9536 | 9.808 | 306596.198 | 306596.198 | 1028.846 |
2 | 16 | 32 | 64 | 19072 | 19.568 | 611546.380 | 305773.190 | 1026.085 |
2 | 16 | 32 | 128 | 38144 | 39.216 | 1225546.302 | 306386.575 | 1028.143 |
2 | 16 | 32 | 256 | 76288 | 78.032 | 2438572.927 | 304821.616 | 1022.891 |
As we double number of input channels, time per parameter drifts gradually up. We hypothesize that this is an artifact of cache.
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
2 | 16 | 8 | 32 | 2624 | 2.928 | 366376.050 | 91594.012 | 1117.000 |
2 | 16 | 16 | 32 | 4928 | 5.165 | 322948.920 | 161474.460 | 1048.535 |
2 | 16 | 32 | 32 | 9536 | 9.674 | 302382.484 | 302382.484 | 1014.706 |
2 | 16 | 64 | 32 | 18752 | 27.366 | 427646.928 | 855293.856 | 1459.546 |
2 | 16 | 128 | 32 | 37184 | 53.968 | 421643.461 | 1686573.845 | 1451.440 |
2 | 16 | 256 | 32 | 74048 | 109.933 | 429425.194 | 3435401.553 | 1484.616 |
Consider the same table but for only one thread. Here we see that time per parameter, while lower for small inputs, is still slightly super linear with input size.
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
2 | 1 | 8 | 32 | 2624 | 2.692 | 336600.884 | 84150.221 | 1026.222 |
2 | 1 | 16 | 32 | 4928 | 4.757 | 297349.003 | 148674.502 | 965.419 |
2 | 1 | 32 | 32 | 9536 | 7.756 | 242396.527 | 242396.527 | 813.411 |
2 | 1 | 64 | 32 | 18752 | 23.724 | 370694.938 | 741389.877 | 1265.170 |
2 | 1 | 128 | 32 | 37184 | 47.601 | 371887.977 | 1487551.907 | 1280.165 |
2 | 1 | 256 | 32 | 74048 | 99.474 | 388573.445 | 3108587.560 | 1343.383 |
At last, we have achieved time per parameter constant with output and almost constant with input.
But we can do better.
Implementation 3
Convolution
Consider the expression
<IS as Conv<
<IPS as BitPack<bool>>::T,
<[[IPS; PY]; PX] as BitPack<bool>>::T,
PX,
PY,
>>::conv(input, |patch| patch)
We are apply a 3x3 convolution.
This operation takes an input of image shape IS
holding pixels of type:
<IPS as BitPack<bool>>::T
that is to say, bits packed into the shape IPS
.
It returns an image of shape IS
, holding pixels of type:
<[[IPS; PY]; PX] as BitPack<bool>>::T
that is, bits packed into the shape [[IPS; PY]; PX]
.
The convolution operation also wants a function that maps a
[[<IPS as BitPack<bool>>::T; PY]; PX]
to a
<[[IPS; PY]; PX] as BitPack<bool>>::T
As we have previously explained to the compiler, these two types are equivalent.
In this case, we are using the null function as we need only to extract the patch unaltered.
In this operation, we have preserved all information, but have increased memory usage by PX * PY
.
In the case of a 3x3 convolution, the memory usage of the new representation is 9 times larger than the original representation.
Dimension reordering
Now consider another, somewhat more complex, expression:
<[[IPS; PY]; PX]>::indices()
.map(|i| {
<IS as PixelMap<<[[IPS; PY]; PX] as BitPack<bool>>::T, bool>>::map(
&chan_acts,
|pixel| <[[IPS; PY]; PX]>::get(pixel, i),
)
})
.collect::<Vec<<IS as PixelPack<bool>>::I>>()
Here, we are collecting a Vec
of
<IS as PixelPack<bool>>::I
that is, an image of shape IS
, holding pixels of type bool
.
The map function is
<IS as PixelMap<<[[IPS; PY]; PX] as BitPack<bool>>::T, bool>>::map(
&chan_acts,
|pixel| <[[IPS; PY]; PX]>::get(pixel, i),
)
This performs a pixel wise map over the pixels of chan_acts
, for each pixel, using the expression
<[[IPS; PY]; PX]>::get(pixel, i)
to extracting the i
th bit of the patch.
Using a bool to store what was previously a packed bit increases memory usage by a factor of 8.
Using these two expressions, we have increased memory usage by (assuming 3x3 patches) a factor of 9*8, or 72.
For each channel
We need to know the full sum of each patch, in other words, the hamming distance between each patch of the input and that channels weights patch.
This can be easily achieved by using another convolution operation:
<IS as Conv<<IPS as BitPack<bool>>::T, u32, PX, PY>>::conv(input, |patch| <[[IPS; PY]; PX] as BMA<W>>::bma(weights_channel, &patch))
In this expression, we map the following expression over each patch of input
:
<[[IPS; PY]; PX] as BMA<W>>::bma(weights_channel, &patch)
to produce an image of shape IS
, holding pixels of type u32
.
If we wish to obtain the activations, we map over the pixels, and compare to the threshold.
<IS as PixelMap<u32, bool>>::map(&null_chan_full_sum, |&sum| sum > <[[IPS; PY]; PX] as WeightArray<W>>::THRESHOLD)
However we do not wish to obtain the activations if no mutation is performed to the weights, but rather the activations if we were to set a single weight to a new value.
For each patch index
<IS as PixelZipMap<u32, bool, bool>>::zip_map(&null_chan_full_sum, layer, |sum, input| {
((sum - cur_weight.bma(input)) + w.bma(input)) > <[[IPS; PY]; PX] as WeightArray<W>>::THRESHOLD
})
Perf
Implementation 3 is somewhat flatter than implementation 2, while also being a constant factor faster.
As with implementation 2, time per parameter is ~constant with output size. However it is ~2/3 the time.
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
3 | 16 | 32 | 8 | 2384 | 1.699 | 53182.833 | 212731.333 | 713.864 |
3 | 16 | 32 | 16 | 4768 | 3.274 | 102343.799 | 204687.597 | 686.871 |
3 | 16 | 32 | 32 | 9536 | 6.186 | 193356.556 | 193356.556 | 648.848 |
3 | 16 | 32 | 64 | 19072 | 12.131 | 379123.450 | 189561.725 | 636.113 |
3 | 16 | 32 | 128 | 38144 | 24.528 | 766585.951 | 191646.488 | 643.109 |
3 | 16 | 32 | 256 | 76288 | 49.046 | 1532744.650 | 191593.081 | 642.930 |
Likewise, time per parameter is only very slightly super liner with input size.
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
3 | 16 | 8 | 32 | 2624 | 1.763 | 220722.256 | 55180.564 | 672.934 |
3 | 16 | 16 | 32 | 4928 | 3.258 | 203678.717 | 101839.358 | 661.295 |
3 | 16 | 32 | 32 | 9536 | 6.262 | 195711.478 | 195711.478 | 656.750 |
3 | 16 | 64 | 32 | 18752 | 12.746 | 199150.159 | 398300.318 | 679.693 |
3 | 16 | 128 | 32 | 37184 | 25.485 | 199114.908 | 796459.632 | 685.421 |
3 | 16 | 256 | 32 | 74048 | 51.866 | 202603.022 | 1620824.176 | 700.443 |
Implementation 4
Consider the distribution of a popcnt. Assuming random input bits and/or weights (which is not necessarily a valid assumption but we will pretend that it is), it is a bell curve. A single bit mutation can increase or decrease the popcnt by 2 in either direction. Therefore, if the popcnt of a patch and the current weights is outside of this range, the gradient is dead, and no mutation of any weights of that patch will result on the activation changing. As the pixel width grows, the portion of output bits which are alive shrinks. In an extreme case of very wide inputs, the vast majority of output activations will be dead. To exploit this, we need to put the patch loop inside the image spatial loop.
For each output channel, we initialize the chan cache as usual.
Then, for each weight and state, we do two things. Firstly, we fold over the patch indices of the image. source
The accumulator of this fold is a vec of length equal to the number of bits in the patch:
<[[IPS; PY]; PX]>::indices().map(|_| null_chan_acts).collect::<Vec<<IS as PixelPack<bool>>::I>>()
Each element of the vec is initialized to the null chan acts.
For each patch index, we extract the null sum,
IS::get_pixel(&null_chan_full_sum, spatial_index)
and then filter to acts within range.
(sum > (<[[IPS; PY]; PX] as WeightArray<W>>::THRESHOLD - (W::RANGE - 1))) & (sum < (<[[IPS; PY]; PX] as WeightArray<W>>::THRESHOLD + W::RANGE))
If we are within range, then for each bit of the patch, set the current patch in the channel to the new act.
<[[IPS; PY]; PX]>::indices().zip(acc.iter_mut()).for_each(|(i, target)| {
let act: bool = <[[IPS; PY]; PX]>::get(&acts, i);
<IS as PixelIndexSGet<bool>>::set_pixel_in_place(target, spatial_index, act);
});
In this way, we are avoiding mutating the majority of the pixels.
Secondly, we map over the resulting acts, and compute loss deltas as usual.
<[[IPS; PY]; PX]>::indices()
.zip(acts.iter())
.map(|(i, acts)| {
let loss_delta = chan_cache.loss_delta(acts);
(LayerIndex::<(OPS::Index, <[[IPS; PY]; PX] as Shape>::Index), H::Index>::Head((o, i)), w, loss_delta)
})
.collect::<Vec<_>>()
Perf
While the complexity with respect to output is no worse than 3, performance (for input of size 32) is only slightly better than that of implementation 2.
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
4 | 16 | 32 | 8 | 2384 | 2.797 | 87521.723 | 350086.891 | 1174.788 |
4 | 16 | 32 | 16 | 4768 | 4.469 | 139914.453 | 279828.905 | 939.023 |
4 | 16 | 32 | 32 | 9536 | 8.641 | 270129.776 | 270129.776 | 906.476 |
4 | 16 | 32 | 64 | 19072 | 17.156 | 536329.072 | 268164.536 | 899.881 |
4 | 16 | 32 | 128 | 38144 | 34.484 | 1077954.840 | 269488.710 | 904.325 |
4 | 16 | 32 | 256 | 76288 | 69.422 | 2169586.003 | 271198.250 | 910.061 |
4 | 16 | 32 | 512 | 152576 | 137.719 | 4303967.263 | 268997.954 | 902.678 |
Now, for the first time, performance is (somewhat) sub linear with input size. However at 256, time per parameter suddenly doubles. We believe this to be due to cache levels.
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
4 | 16 | 8 | 32 | 2624 | 3.953 | 494690.453 | 123672.613 | 1508.203 |
4 | 16 | 16 | 32 | 4928 | 5.453 | 341502.390 | 170751.195 | 1108.774 |
4 | 16 | 32 | 32 | 9536 | 8.906 | 278530.672 | 278530.672 | 934.667 |
4 | 16 | 64 | 32 | 18752 | 15.688 | 245297.243 | 490594.487 | 837.192 |
4 | 16 | 128 | 32 | 37184 | 29.062 | 227061.526 | 908246.106 | 781.623 |
4 | 16 | 256 | 32 | 74048 | 140.734 | 549803.949 | 4398431.596 | 1900.792 |
4 | 16 | 512 | 32 | 147776 | 552.719 | 1079535.044 | 17272560.699 | 3740.269 |
4 | 16 | 1024 | 32 | 295232 | 1338.688 | 1307321.567 | 41834290.150 | 4534.391 |
With only a single thread, the performance improvements continue until 1024. At input size of 256, we are almost competitive with implementation 3. If the trend were to continue, we could likely surpass 3 for large inputs.
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
4 | 1 | 8 | 32 | 2624 | 3.500 | 437547.905 | 109386.976 | 1333.988 |
4 | 1 | 16 | 32 | 4928 | 5.060 | 316507.318 | 158253.659 | 1027.621 |
4 | 1 | 32 | 32 | 9536 | 7.860 | 245626.743 | 245626.743 | 824.251 |
4 | 1 | 64 | 32 | 18752 | 14.040 | 219446.048 | 438892.095 | 748.963 |
4 | 1 | 128 | 32 | 37184 | 24.860 | 194250.423 | 777001.690 | 668.676 |
4 | 1 | 256 | 32 | 74048 | 47.500 | 185561.645 | 1484493.162 | 641.527 |
4 | 1 | 512 | 32 | 147776 | 106.100 | 207234.037 | 3315744.585 | 718.004 |
4 | 1 | 1024 | 32 | 295232 | 317.150 | 309724.084 | 9911170.677 | 1074.265 |
The accumulator for an input of size 1024 should be ~9.5 MB. The L2 cache on this CPU is 8 MB. We suspect this to be the cause of the discontinuity.
Conclusion
Let us inspect our progress so far.
Consider the performance of a 64 channel wide conv layer. This is a smallish model, but is within an order of magnitude of realistic sizes.
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
0 | 16 | 64 | 64 | 37504 | 28503.000 | 445365259.312 | 445365259.312 | 760008.975 |
1 | 16 | 64 | 64 | 37504 | 237.750 | 3718063.523 | 3718063.523 | 6344.818 |
2 | 16 | 64 | 64 | 37504 | 58.578 | 915286.823 | 915286.823 | 1561.923 |
3 | 16 | 64 | 64 | 37504 | 26.094 | 407774.059 | 407774.059 | 695.860 |
4 | 16 | 64 | 64 | 37504 | 32.453 | 507122.407 | 507122.407 | 865.397 |
We have reduced the time complexity of single parameter loss deltas with respect to layer width from quadratic to constant time. At 64 wide, performance is ~1,000 times better, although comparing absolute performance of algorithms of different complexities at an arbitrary point in size space is unfair.
Within the three constant time implementations, we have reduced time per parameter by a factor of ~2.2.
Our input images are 32x32 pixels.
They therefore have 30 * 30 = 900
patches.
Implementation 3 can compute a loss delta for all 900 patches in an average of ~700 nanoseconds.
If we divide by number of patches, this is ~0.77 ns per parameter per patch, or (assuming a 4 GHz CPU), ~3 cycles (on average).
26 ms per (small) example is nice, but consider a 1024 channel wide layer (now single threaded to be extra fair):
version | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
2 | 1 | 1024 | 1024 | 9447424 | 14385.600 | 14048457.464 | 14048457.464 | 1522.703 |
3 | 1 | 1024 | 1024 | 9447424 | 6035.600 | 5894239.136 | 5894239.136 | 638.873 |
4 | 1 | 1024 | 1024 | 9447424 | 10745.600 | 10493801.507 | 10493801.507 | 1137.416 |
Using implementation 3 to computing loss deltas for all 9.4 million parameters costs us over six seconds. Larger images will be yet more expensive. While we could likely parallelize across output channels fairly easily, multi core machines are more expensive, our other cores have other examples to process. Six seconds per (large) example is unacceptably poor performance.
We can do better.
Zeros
The most beautiful thing about it is that its solution is zero. It’s my feeling that because the answer is zero, nobody needs to bother with it.
- Oshino Ougi
After we apply two distinct layers of single bit activations, most mutations have a zero gradient.
While there exists a small set of mutations which have non-zero gradients, the great majority are dead. Why must we still compute them if they are only to be filtered out later? But we cannot avoid it. The objective head is the objective head; it is a black box. Unless all activations in a channel are dead, we must construct all the activations of that channel, and call the objective head cache.
Fused pooling layer
What is our objective head? Currently, it is a global avg pooling followed by a single fully connected objective layer. If we are willing to fuse the conv layer with the pooling layer, we can exploit further performance optimizations.
More generally, our pooling operation is a segmented avg pooling. We divide the image into, for example, 2x2 segments. In each segment, for each channel, we count the number of set bits and compare to n_pixels/2. This results in an output of shape [2, 2, chans] which we can feed to the FC layer.
We have two levels of sign function activations. First, the conv patch activations, then the pooling activations. For a single channel, consider a single segment of the image; all of the patches which feed into one pooling region. Just as a conv patch activation is dead if its null sum is outside the range of the threshold, so too, a segment activation is dead if its null sum is outside the range of the segment threshold.
However, correctly detecting segment liveness is more complex than correctly detecting patch liveness. Firstly, the threshold of a segment is not necessarily known at compile time. This is somewhat annoying but is ultimately not a significant problem. More importantly, the range of a segment threshold is non obvious. Naively, the range is n_pixels on either side. However, recall that, for a given chan, most of the pixels in a segment will be dead. To correctly compute segment range, we must first compute liveness for all pixels. The min range is the number of dead pixels which are set, while the max range is min range + the number of live pixels. If the segment threshold is within this range, that channel of the segment is alive and its activation must be computed. But if the threshold is outside this range, it is dead and the activation will be the null activation. Note that segment death does not require that all of its patches be dead. While patch death tends to promote segment death, a segment can be dead while many if its patches are still locally alive. And yet, segment death propagates back to patches. Why should we bother to compute a gradient that we know will inevitably die later?
If all four segments are dead, the FC layer is likewise dead, and that whole channel of the conv weights will have loss deltas of 0, and need not be computed. Just as the probability of a conv patch act being dead increases with input size, so too, the probability of a segment activation being dead increases with segment size.
Consider implementation 4. It was very nice with the exception of large memory usage and consequent cache contention. How can we adapt it to work better with fused segmented avg pool?
For a given channel, we must compute the activations and liveness for each of the patches in that segment. From this we can compute the range, and check if the segment threshold is inside this range.
For implementation and perf numbers see the next post.