Fused Convolution Segmented Pooling Loss Deltas
Introduction
In the previous post, we described a number of implementations of bit conv loss deltas. In it, we used a global avg pooling layer. However, often we want avg pooling layers with multiple segments, usually 2x2 or perhaps 3x3.
The numbers presented here are not directly comparable to the numbers present in the previous post. We have performed a number of miscellaneous improvements, in addition to the switch from global avg pooling to segmented avg pooling.
As before, this post is like only of interest the intersection of Rust and ML devs. To follow along, use the correct commit.
Segmented pooling
Traditionally when we perform avg pooling, we average across all pixels of the image. However this destroys all spatial information. Alternatively, we could use a fully connected layer. But this is incompatible with dynamically sized images.
One solution is to cut the image into x by y segments, where x and y are usually 2 or perhaps 3. Then we can apply a fully connected objective layer to the segments.
Consider the implementation of segmented pooling.
let mut target = <[[(u32, <P as BitPack<bool>>::T); SY]; SX]>::default();
for sx in 0..SX {
for sy in 0..SY {
let (n, counts) = <I as SegmentedPixelFold<(usize, <P as Pack<u32>>::T), <P as BitPack<bool>>::T, SX, SY, PX, PY>>::seg_fold(
input,
sx,
sy,
<(usize, <P as Pack<u32>>::T)>::default(),
|acc, pixel| P::counted_increment(pixel, acc),
);
let threshold = n as u32 / 2;
target[sx][sy] = (threshold, <P as PackedMap<u32, bool>>::map(&counts, |&sum| sum > threshold));
}
}
For each segment, we fold over the pixels of that segment, incrementing counters. Then we threshold the counts at half the number of pixels in that segment.
Inference
Fused inference can be implemented as a fold over segments. This is not actually significantly faster then generic layer by layer inference, but is a useful example for latter improvements. source
(0..SX)
.map(|sx| iter::repeat(sx).zip(0..SY))
.flatten()
.fold(<[u32; C]>::default(), |class_acts, (sx, sy)| {
let (n, counts) = <IS as SegmentedConvFold<
(usize, <OPS as Pack<u32>>::T),
<IPS as BitPack<bool>>::T,
SX,
SY,
PX,
PY,
>>::seg_conv_fold(
input,
sx,
sy,
<(usize, <OPS as Pack<u32>>::T)>::default(),
|(n, acc), patch| {
(
n + 1,
<OPS as ZipMap<u32, [[<IPS as BitPack<W>>::T; PY]; PX], u32>>::zip_map(
&acc,
&self.conv,
|sum, weights| {
sum + (<[[IPS; PY]; PX] as WeightArray<W>>::act(
weights, &patch,
)) as u32
},
),
)
},
);
let threshold = n as u32 / 2;
let acts = <OPS as PackedMap<u32, bool>>::map(&counts, |&sum| sum > threshold);
<[(); C] as ZipMap<u32, [[<OPS as BitPack<W>>::T; SY]; SX], u32>>::zip_map(
&class_acts,
&self.fc,
|sum, weights| sum + OPS::bma(&weights[sx][sy], &acts),
)
})
For each segment
We fold over the patches of the segment.
let (n, counts) = <IS as SegmentedConvFold<
(usize, <OPS as Pack<u32>>::T),
<IPS as BitPack<bool>>::T,
SX,
SY,
PX,
PY,
>>::seg_conv_fold(
input,
sx,
sy,
<(usize, <OPS as Pack<u32>>::T)>::default(),
|(n, acc), patch| {
(
n + 1,
<OPS as ZipMap<u32, [[<IPS as BitPack<W>>::T; PY]; PX], u32>>::zip_map(
&acc,
&self.conv,
|sum, weights| {
sum + (<[[IPS; PY]; PX] as WeightArray<W>>::act(
weights, &patch,
)) as u32
},
),
)
},
);
For each patch
For each output channel, we compute the act, and increment the sum by it.
<OPS as ZipMap<u32, [[<IPS as BitPack<W>>::T; PY]; PX], u32>>::zip_map(
&acc,
&self.conv,
|sum, weights| {
sum + (<[[IPS; PY]; PX] as WeightArray<W>>::act(weights, &patch)) as u32
},
),
Then bitpack the segment acts, and add the Bitwise Multiply Accumulate to the sum.
let threshold = n as u32 / 2;
let acts = <OPS as PackedMap<u32, bool>>::map(&counts, |&sum| sum > threshold);
<[(); C] as ZipMap<u32, [[<OPS as BitPack<W>>::T; SY]; SX], u32>>::zip_map(
&class_acts,
&self.fc,
|sum, weights| sum + OPS::bma(&weights[sx][sy], &acts),
)
Now we have the acts for each class.
To compute loss, we sum the squared distance.
class_acts
.iter()
.enumerate()
.map(|(c, &act)| {
let target_act = (c == class) as u32 * <[[OPS; SY]; SX] as WeightArray<W>>::MAX;
let dist = act.saturating_sub(target_act) | target_act.saturating_sub(act);
(dist as u64).pow(2)
})
.sum()
Fused inference is relatively straight forward.
Now let us implement bulk loss deltas.
Implementations
Implementation 5
By way of a baseline, let us benchmark a simple layerwise, non-fused version.
Perf
Note that, unlike last post, these are all single threaded performance numbers.
We will analyze multi threaded performance later.
Also note that we are using lto = "thin"
but not codegen-units = 1
.
Performance is slightly super linear with output size.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|---|
5 | 2 | 1 | 32 | 32 | 9216 | 7.520 | 235806.315 | 235806.315 | 818.772 |
5 | 2 | 1 | 32 | 64 | 18432 | 14.720 | 460257.876 | 230128.938 | 799.059 |
5 | 2 | 1 | 32 | 128 | 36864 | 30.640 | 957925.780 | 239481.445 | 831.533 |
5 | 2 | 1 | 32 | 256 | 73728 | 66.000 | 2063432.201 | 257929.025 | 895.587 |
5 | 2 | 1 | 32 | 512 | 147456 | 151.833 | 4747994.635 | 296749.665 | 1030.381 |
5 | 2 | 1 | 32 | 1024 | 294912 | 379.667 | 11869930.125 | 370935.316 | 1287.970 |
However it is very pleasingly linear with input size.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|---|
5 | 2 | 1 | 32 | 32 | 9216 | 7.960 | 249888.235 | 249888.235 | 867.667 |
5 | 2 | 1 | 64 | 32 | 18432 | 14.360 | 224629.198 | 449258.395 | 779.962 |
5 | 2 | 1 | 128 | 32 | 36864 | 29.480 | 230611.853 | 922447.410 | 800.736 |
5 | 2 | 1 | 256 | 32 | 73728 | 58.917 | 230458.443 | 1843667.544 | 800.203 |
5 | 2 | 1 | 512 | 32 | 147456 | 120.000 | 234546.248 | 3752739.969 | 814.397 |
5 | 2 | 1 | 1024 | 32 | 294912 | 242.333 | 236870.863 | 7579867.615 | 822.468 |
Our target is 1024x1024 in less then 1 second per example. However implementation 5 is ~12 seconds per example.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|---|
5 | 2 | 1 | 1024 | 1024 | 9437184 | 12224.000 | 11937557.654 | 11937557.654 | 1295.308 |
So far, we have looked at 2x2 segments. Perhaps we will want 3x3 segments. Now larger output sizes are even worse.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|---|
5 | 3 | 1 | 32 | 32 | 9216 | 7.480 | 234254.375 | 234254.375 | 813.383 |
5 | 3 | 1 | 32 | 64 | 18432 | 16.000 | 500898.439 | 250449.219 | 869.615 |
5 | 3 | 1 | 32 | 128 | 36864 | 33.680 | 1052794.808 | 263198.702 | 913.884 |
5 | 3 | 1 | 32 | 256 | 73728 | 77.333 | 2419090.292 | 302386.286 | 1049.952 |
5 | 3 | 1 | 32 | 512 | 147456 | 201.500 | 6301272.781 | 393829.549 | 1367.464 |
5 | 3 | 1 | 32 | 1024 | 294912 | 598.000 | 18694642.365 | 584207.574 | 2028.499 |
5 | 3 | 1 | 32 | 32 | 9216 | 7.400 | 232492.975 | 232492.975 | 807.267 |
5 | 3 | 1 | 64 | 32 | 18432 | 15.240 | 238745.919 | 477491.838 | 828.979 |
5 | 3 | 1 | 128 | 32 | 36864 | 30.480 | 238321.659 | 953286.636 | 827.506 |
5 | 3 | 1 | 256 | 32 | 73728 | 60.167 | 235050.251 | 1880402.005 | 816.147 |
5 | 3 | 1 | 512 | 32 | 147456 | 121.500 | 237415.161 | 3798642.578 | 824.358 |
5 | 3 | 1 | 1024 | 32 | 294912 | 249.333 | 243755.361 | 7800171.552 | 846.373 |
5 | 3 | 1 | 1024 | 1024 | 9437184 | 19227.333 | 18776937.619 | 18776937.619 | 2037.428 |
This is not good enough. Our needed performance is in another implementation.
Implementation 6
First we make all things ready.
We precompute null class acts, null segment acts, and null loss.
Then we can begin iterating over the channels.
For each channel
We extract the patches, along with some additional information:
<IS as Conv<<IPS as BitPack<bool>>::T, ([[<IPS as BitPack<bool>>::T; PY]; PX], u32, Option<bool>, bool), PX, PY>>::conv(input, |patch| {
let sum = <[[IPS; PY]; PX] as BMA<W>>::bma(weights_channel, &patch);
let sign = sum > <[[IPS; PY]; PX] as WeightArray<W>>::THRESHOLD;
let state = <[[IPS; PY]; PX] as WeightArray<W>>::mutant_act(sum);
(patch, sum, state, sign)
})
Then we compute the act were that channel to be removed.
(0..SX).map(|sx| iter::repeat(sx).zip(0..SY)).flatten().fold(null_class_acts, |class_acts, (sx, sy)| {
let (n, count) = <IS as SegmentedPixelFold<(usize, u32), ([[<IPS as BitPack<bool>>::T; PY]; PX], u32, Option<bool>, bool), SX, SY, PX, PY>>::seg_fold(
&patches,
sx,
sy,
<(usize, u32)>::default(),
|(n, c), (_, _, _, sign)| (n + 1, c + *sign as u32),
);
let act = count > (n as u32 / 2);
<[(); C] as ZipMap<u32, [[<OPS as BitPack<W>>::T; SY]; SX], u32>>::zip_map(&class_acts, &self.fc, |sum, weights| {
let w: W = OPS::get(&weights[sx][sy], o);
sum - w.bma(act)
})
})
And finally, the trit states of each segment:
let mut ranges = <[[Option<bool>; SY]; SX]>::default();
for sx in 0..SX {
for sy in 0..SY {
let (n, min, max) = <IS as SegmentedPixelFold<(usize, u32, u32), ([[<IPS as BitPack<bool>>::T; PY]; PX], u32, Option<bool>, bool), SX, SY, PX, PY>>::seg_fold(
&patches,
sx,
sy,
<(usize, u32, u32)>::default(),
|(n, min, max), (_, _, state, _)| {
if let Some(sign) = state {
if *sign {
(n + 1, min + 1, max + 1)
} else {
(n + 1, min, max)
}
} else {
(n + 1, min, max + 1)
}
},
);
let threshold = n as u32 / 2;
ranges[sx][sy] = Some(max > threshold).filter(|_| (min > threshold) | (max <= threshold));
}
}
ranges
For each segment, we fold over the states of the patches of that segment.
If the segment patch is dead and set, increment our min and max, if it is dead and unset, increment neither, if it is alive, increment max.
if let Some(sign) = state {
if *sign {
(n + 1, min + 1, max + 1)
} else {
(n + 1, min, max)
}
} else {
(n + 1, min, max + 1)
}
Then return the state of the segment, None
if n/2 is within the range (min..max)
, Some(sign)
if outside the range.
Some(max > threshold).filter(|_| (min > threshold) | (max <= threshold))
If all the segments are dead, we know that all weights in this channel are dead, and we can short circuit. But if some are alive, we must compute the gradients.
For each weight
We fold over the segments, collecting class acts.
(0..SX)
.map(|sx| iter::repeat(sx).zip(0..SY))
.flatten()
.fold(<[[<IPS as Pack<[u32; C]>>::T; PY]; PX]>::default(), |class_acts, (sx, sy)| {
let fc_weights: [W; C] =
<[(); C] as Map<[[<OPS as BitPack<W>>::T; SY]; SX], W>>::map(&self.fc, |class_weights| <OPS as PackedIndexSGet<W>>::get(&class_weights[sx][sy], o));
if let Some(act) = seg_states[sx][sy] {
<[[IPS; PY]; PX] as Map<[u32; C], [u32; C]>>::map(&class_acts, |class_acts| {
<[(); C] as ZipMap<u32, W, u32>>::zip_map(&class_acts, &fc_weights, |sum, class_weight| sum + class_weight.bma(act))
})
} else {
let counts = <IS as SegmentedPixelFold<
(usize, [[<IPS as Pack<u32>>::T; PY]; PX], u32),
([[<IPS as BitPack<bool>>::T; PY]; PX], u32, Option<bool>, bool),
SX,
SY,
PX,
PY,
>>::seg_fold(
&patches,
sx,
sy,
<(usize, [[<IPS as Pack<u32>>::T; PY]; PX], u32)>::default(),
|mut acc, (patch, cur_sum, state, _)| {
if let Some(act) = state {
<[[IPS; PY]; PX]>::none_option_counted_increment_in_place(*act, &mut acc);
} else {
let acts = <[[IPS; PY]; PX]>::acts_simple(weights_channel, patch, *cur_sum, w);
<[[IPS; PY]; PX]>::some_option_counted_increment_in_place(&acts, &mut acc);
}
acc
},
);
let (n, seg_counts) = <[[IPS; PY]; PX]>::finalize_option_counted_increment(counts);
let threshold = n as u32 / 2;
<[[IPS; PY]; PX] as ZipMap<[u32; C], u32, [u32; C]>>::zip_map(&class_acts, &seg_counts, |class_acts, &count| {
let act = count > threshold;
<[(); C] as ZipMap<u32, W, u32>>::zip_map(&class_acts, &fc_weights, |sum, class_weight| sum + class_weight.bma(act))
})
}
})
If the segment is dead, increment all the class act counters accordingly.
<[[IPS; PY]; PX] as Map<[u32; C], [u32; C]>>::map(&class_acts, |class_acts| {
<[(); C] as ZipMap<u32, W, u32>>::zip_map(&class_acts, &fc_weights, |sum, class_weight| sum + class_weight.bma(act))
})
But if it is alive, we fold over patches of this segment, counting patch acts.
<IS as SegmentedPixelFold<
(usize, [[<IPS as Pack<u32>>::T; PY]; PX], u32),
([[<IPS as BitPack<bool>>::T; PY]; PX], u32, Option<bool>, bool),
SX,
SY,
PX,
PY,
>>::seg_fold(
&patches,
sx,
sy,
<(usize, [[<IPS as Pack<u32>>::T; PY]; PX], u32)>::default(),
|mut acc, (patch, cur_sum, state, _)| {
if let Some(act) = state {
<[[IPS; PY]; PX]>::none_option_counted_increment_in_place(*act, &mut acc);
} else {
let acts = <[[IPS; PY]; PX]>::acts_simple(weights_channel, patch, *cur_sum, w);
<[[IPS; PY]; PX]>::some_option_counted_increment_in_place(&acts, &mut acc);
}
acc
},
)
If the patch is dead, we metaphorically increment all the counters by that sign. (Actually we are saving up all our dead patches to apply at the end) But if it is alive, we use powerful bit wise magic to compute acts, and increment each counter accordingly.
if let Some(act) = state {
<[[IPS; PY]; PX]>::none_option_counted_increment_in_place(*act, &mut acc);
} else {
let acts = <[[IPS; PY]; PX]>::acts_simple(weights_channel, patch, *cur_sum, w);
<[[IPS; PY]; PX]>::some_option_counted_increment_in_place(&acts, &mut acc);
}
Now that we know the counts for each member of the input patch, we increment each set of class acts accordingly.
<[[IPS; PY]; PX] as ZipMap<[u32; C], u32, [u32; C]>>::zip_map(&class_acts, &seg_counts, |class_acts, &count| {
let act = count > threshold;
<[(); C] as ZipMap<u32, W, u32>>::zip_map(&class_acts, &fc_weights, |sum, class_weight| sum + class_weight.bma(act))
})
Now we have partial class acts for each patch mutation. We can map over them, adding the else class acts, computing loss.
<[[IPS; PY]; PX] as Shape>::indices()
.map(|i| {
let mut_loss = <[[IPS; PY]; PX] as IndexGet<[u32; C]>>::index_get(&patch_class_acts, i)
.iter()
.zip(target_acts.iter())
.zip(else_class_acts.iter())
.map(|((&act, &target_act), else_class_act)| {
let act = act + else_class_act;
let dist = act.saturating_sub(target_act) | target_act.saturating_sub(act);
(dist as u64).pow(2)
})
.sum::<u64>();
let index = LayerIndex::Head((o, i));
let loss_delta = mut_loss as i64 - null_loss as i64;
(index, w, loss_delta)
})
.collect::<Vec<_>>()
This has been a fairly direct port of the fused forward pass implementation. It is relativity simple. But is it fast enough?
Perf
Performance is now pleasingly liner with output size.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|---|
6 | 2 | 1 | 32 | 32 | 9216 | 6.240 | 196169.263 | 196169.263 | 681.143 |
6 | 2 | 1 | 32 | 64 | 18432 | 12.800 | 400123.318 | 200061.659 | 694.659 |
6 | 2 | 1 | 32 | 128 | 36864 | 25.120 | 785715.521 | 196428.880 | 682.045 |
6 | 2 | 1 | 32 | 256 | 73728 | 49.333 | 1543467.849 | 192933.481 | 669.908 |
6 | 2 | 1 | 32 | 512 | 147456 | 100.500 | 3141999.922 | 196374.995 | 681.858 |
6 | 2 | 1 | 32 | 1024 | 294912 | 196.000 | 6129373.125 | 191542.910 | 665.080 |
Even better, it is sub liner with input size.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|---|
6 | 2 | 1 | 32 | 32 | 9216 | 6.840 | 214885.931 | 214885.931 | 746.132 |
6 | 2 | 1 | 64 | 32 | 18432 | 11.640 | 182382.667 | 364765.334 | 633.273 |
6 | 2 | 1 | 128 | 32 | 36864 | 18.640 | 145794.872 | 583179.490 | 506.232 |
6 | 2 | 1 | 256 | 32 | 73728 | 31.333 | 122501.882 | 980015.052 | 425.354 |
6 | 2 | 1 | 512 | 32 | 147456 | 57.500 | 112473.484 | 1799575.740 | 390.533 |
6 | 2 | 1 | 1024 | 32 | 294912 | 107.333 | 104913.308 | 3357225.854 | 364.282 |
1024x1024 is now ~3 second per example. This is very nice.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|---|
6 | 2 | 1 | 1024 | 1024 | 9437184 | 3044.667 | 2973484.268 | 2973484.268 | 322.644 |
3x3 segments, while slightly more expensive, follows the same general curves.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per pixel bit | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|---|
6 | 3 | 1 | 32 | 32 | 9216 | 7.520 | 235910.079 | 235910.079 | 819.132 |
6 | 3 | 1 | 32 | 64 | 18432 | 13.720 | 429050.230 | 214525.115 | 744.879 |
6 | 3 | 1 | 32 | 128 | 36864 | 27.240 | 852088.600 | 213022.150 | 739.660 |
6 | 3 | 1 | 32 | 256 | 73728 | 53.917 | 1687038.039 | 210879.755 | 732.221 |
6 | 3 | 1 | 32 | 512 | 147456 | 108.500 | 3392835.042 | 212052.190 | 736.292 |
6 | 3 | 1 | 32 | 1024 | 294912 | 218.333 | 6824996.927 | 213281.154 | 740.560 |
6 | 3 | 1 | 32 | 32 | 9216 | 7.400 | 231396.531 | 231396.531 | 803.460 |
6 | 3 | 1 | 64 | 32 | 18432 | 12.200 | 190722.744 | 381445.487 | 662.232 |
6 | 3 | 1 | 128 | 32 | 36864 | 21.480 | 167821.317 | 671285.267 | 582.713 |
6 | 3 | 1 | 256 | 32 | 73728 | 38.000 | 148507.264 | 1188058.112 | 515.650 |
6 | 3 | 1 | 512 | 32 | 147456 | 72.000 | 140766.711 | 2252267.370 | 488.773 |
6 | 3 | 1 | 1024 | 32 | 294912 | 149.333 | 146105.803 | 4675385.688 | 507.312 |
6 | 3 | 1 | 1024 | 1024 | 9437184 | 4111.333 | 4014986.518 | 4014986.518 | 435.654 |
But consider performance changes as we run multiple examples in parallel.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
6 | 2 | 1 | 1024 | 1024 | 9437184 | 3030.500 | 2959610.248 | 321.138 |
6 | 2 | 2 | 1024 | 1024 | 9437184 | 3041.500 | 2970259.857 | 322.294 |
6 | 2 | 4 | 1024 | 1024 | 9437184 | 3110.500 | 3037650.779 | 329.606 |
6 | 2 | 8 | 1024 | 1024 | 9437184 | 4319.250 | 4218019.613 | 457.684 |
6 | 2 | 16 | 1024 | 1024 | 9437184 | 9745.750 | 9517400.445 | 1032.704 |
Our needed performance is in another implementation.
Implementation 7
Consider the size of an array of class acts. Given 10 classes, it is 40 bytes. Multiplied by the number of bits in a patch, this is quite large.
Consider the information transmitted by a single channel of a single segment. It is but one bit. A single channel of 2x2=4 channels is 4 bits. 3x3=9 channels is 9 bits. Even if we use 4x4=16 segments, (significantly larger then the 2x2=4 used in Belilovsky et al.), this is merely 16 bits. We may use more classes. 10 classes is small, CIFAR 100 has 100 class. ImageNet has 1000.
It is preferable to accumulate things into an accumulator of 16 bits rather then 40 bytes.
Before beginning on a given output channel, we we can cheaply analyze the liveness of each segment. The segments which are dead, we need not compute their acts, for we know it already.
Given our set of segments, we can filter to those segments which are live. We then need only fold over the live segments accumulating into the bits of the segment index. Once we are done folding, we have a finite set of act combinations. We must then hydrate it.
Usually, 2^n_live_segs
is significantly smaller then the size of an input patch.
Now let us consider the implementation. source
As before, we must first make all things ready.
We compute null class acts, null seg acts, null loss.
For each channel
As before, we prepare the patches and the segment states.
Now however, instead is simply shortcircuiting for all dead segments, we perform more complex branching.
We have three cases.
No live segments
If no segments are alive, then as before, we know that all mutations are dead, and we can return an empty Vec
.
This is easy.
But if at least one segment is alive, we compute the dead class acts.
We fold over all the segments, if the segment is alive, we skip, but if it is dead, we increment the acts for the dead classes.
seg_states.iter().fold(null_class_acts, |class_acts, &((sx, sy), _, state)| {
let act = <OPS as PackedIndexSGet<bool>>::get(&null_seg_acts[sx][sy], o);
if let Some(act) = state {
class_acts
} else {
<[(); C] as ZipMap<u32, [[<OPS as BitPack<W>>::T; SY]; SX], u32>>::zip_map(&class_acts, &self.fc, |sum, weights| {
sum - <OPS as PackedIndexSGet<W>>::get(&weights[sx][sy], o).bma(act)
})
}
})
1 live segment
If but one segment is live, there can only be one non zero loss delta and we can use a simpler implementation.
To compute it, we take the dead class act, and the class weights, and combine them according to the inverse of the null segment act.
<[(); C] as ZipMap<u32, W, u32>>::zip_map(&dead_class_acts, &class_weights, |sum, w| sum + w.bma(!null_seg_act))
.iter()
.zip(target_acts.iter())
.map(|(&act, &target_act)| {
let dist = act.saturating_sub(target_act) | target_act.saturating_sub(act);
(dist as u64).pow(2)
})
.sum::<u64>() as i64
- null_loss as i64
As before, we fold over the segment, and then, if the old weight not the same as the new weight, we extract the segment acts, and if that segment acts is different, we return the loss delta.
<[[IPS; PY]; PX] as Shape>::indices()
.filter_map(|i| {
if <[[IPS; PY]; PX] as PackedIndexSGet<W>>::get(&weights_channel, i) == w {
None
} else {
let act = <[[IPS; PY]; PX] as PackedIndexSGet<bool>>::get(&seg_acts, i);
if act == null_seg_act {
None
} else {
Some((LayerIndex::Head((o, i)), w, loss_delta))
}
}
})
.collect::<Vec<_>>()
If we have only one live segment, obtaining the set of patch mutations which satisfy it is fairly simple.
But if more then one segment is alive, we must do more complex things.
>1 live segments
Number of loss deltas grows exponentially with number of live segments. Although 4x4=16 segments has, worst case, 65536 loss deltas, in practice, it is quite rare for all segments to be live in a given channel. In almost all channels, the number of loss deltas which must be precomputed is many times smaller then worst case.
First, we must compute the class acts for the two states of all live segments.
seg_states
.iter()
.filter(|(_, _, state)| state.is_none())
.map(|(_, class_weights, _)| {
[
<[(); C] as Map<W, u32>>::map(&class_weights, |w| w.bma(false)),
<[(); C] as Map<W, u32>>::map(&class_weights, |w| w.bma(true)),
]
})
.collect()
Then, for every index from 0
to 2^live_segs.len()
, we combine the elements of the live class acts, selecting the segment act state according to the corresponding bit of the index, and compute the loss delta for that class act.
(0..2usize.pow(live_seg_class_acts.len() as u32))
.map(|i| {
live_seg_class_acts
.iter()
.enumerate()
.fold(dead_class_acts, |class_acts, (seg_index, live_class_acts)| {
<[(); C] as ZipMap<u32, u32, u32>>::zip_map(&class_acts, &live_class_acts[i.bit(seg_index) as usize], |sum, live_act| sum + live_act)
})
.iter()
.zip(target_acts.iter())
.map(|(&act, &target_act)| {
let dist = act.saturating_sub(target_act) | target_act.saturating_sub(act);
(dist as u64).pow(2)
})
.sum::<u64>() as i64
- null_loss as i64
})
.collect()
Now we have a dense mod2 cube with one dimension for each live segment. At each corner of this cube, there is a loss delta.
Now we need to find the segment acts for each patch weight mutation.
For each target weight, for each live segment, we fold over the patches of that segment as before, counting the patch acts. Then we threshold and bitpack the counts. Now we have a vec of, for each live segment, the bit packed segment acts.
This is a 2d bit matrix of shape [segs, patch_bit_index]
.
We need it in the shape [patch_bit_index, segs]
.
We iterate of the indices of the patch.
<[[IPS; PY]; PX] as Shape>::indices()
.filter_map(|i| {
if <[[IPS; PY]; PX] as PackedIndexSGet<W>>::get(&weights_channel, i) == w {
None
} else {
let loss_index: usize = seg_acts.iter().enumerate().fold(0usize, |loss_index, (bit_index, seg)| {
loss_index | ((<[[IPS; PY]; PX] as PackedIndexSGet<bool>>::get(&seg, i) as usize) << bit_index)
});
Some((LayerIndex::Head((o, i)), w, loss_deltas[loss_index]))
}
})
.collect::<Vec<_>>()
We fold over live segments, adding a bit to the index according to that act of that segment.
seg_acts.iter().enumerate().fold(0usize, |loss_index, (bit_index, seg)| {
loss_index | ((<[[IPS; PY]; PX] as PackedIndexSGet<bool>>::get(&seg, i) as usize) << bit_index)
})
Then we can use the index to select the loss delta: loss_deltas[loss_index]
.
Perf
We have done many complex things.
Small output channels is slightly worse then implementation 6, However scaling is better.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
7 | 2 | 1 | 32 | 32 | 9216 | 7.333 | 232860.406 | 808.543 |
7 | 2 | 1 | 32 | 64 | 18432 | 17.000 | 269651.302 | 936.289 |
7 | 2 | 1 | 32 | 128 | 36864 | 29.667 | 232224.052 | 806.334 |
7 | 2 | 1 | 32 | 256 | 73728 | 39.667 | 155113.255 | 538.588 |
7 | 2 | 1 | 32 | 512 | 147456 | 78.667 | 154234.921 | 535.538 |
7 | 2 | 1 | 32 | 1024 | 294912 | 157.333 | 153678.509 | 533.606 |
Scaling with input is a small improvement over implementation 6.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
7 | 2 | 1 | 32 | 32 | 9216 | 6.000 | 190813.396 | 662.547 |
7 | 2 | 1 | 64 | 32 | 18432 | 7.667 | 249737.479 | 433.572 |
7 | 2 | 1 | 128 | 32 | 36864 | 24.000 | 757013.958 | 657.130 |
7 | 2 | 1 | 256 | 32 | 73728 | 27.667 | 866082.323 | 375.904 |
7 | 2 | 1 | 512 | 32 | 147456 | 39.667 | 1240358.542 | 269.175 |
7 | 2 | 1 | 1024 | 32 | 294912 | 59.333 | 1860848.979 | 201.915 |
1024 x 1024 is now 1.7 seconds, down from 3 seconds.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
7 | 2 | 1 | 1024 | 1024 | 9437184 | 1789.333 | 1747511.273 | 189.617 |
Now however, performance improves slightly as we increase number of segments.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
7 | 1 | 1 | 1024 | 1024 | 9437184 | 1866.000 | 1822499.346 | 197.754 |
7 | 2 | 1 | 1024 | 1024 | 9437184 | 1796.333 | 1754418.040 | 190.367 |
7 | 3 | 1 | 1024 | 1024 | 9437184 | 1726.667 | 1686353.853 | 182.981 |
7 | 4 | 1 | 1024 | 1024 | 9437184 | 1606.333 | 1568983.929 | 170.246 |
Consider performance changes as we run multiple examples in parallel. Although only slightly faster when single threaded, when running 16 examples in parallel, implementation 7 is significantly better then implementation 6.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
7 | 2 | 1 | 1024 | 1024 | 9437184 | 1791.250 | 1749446.327 | 189.827 |
7 | 2 | 2 | 1024 | 1024 | 9437184 | 1774.500 | 1733053.783 | 188.048 |
7 | 2 | 4 | 1024 | 1024 | 9437184 | 1852.000 | 1808798.837 | 196.267 |
7 | 2 | 8 | 1024 | 1024 | 9437184 | 2010.500 | 1963392.940 | 213.042 |
7 | 2 | 16 | 1024 | 1024 | 9437184 | 2843.500 | 2777027.236 | 301.327 |
Implementation 8: Multi core
If we have only one core to work with, 1.7 seconds per example is about as good as we can get.
But, if we are willing to spend multiple cores per example, we can do better.
With rayon, parallelizing over output channels is quite simple.
Implementation 8 is essentially identical to implementation 7 with the exception that we insert .par_bridge()
after OPS::indices()
.
rayon will automagically distribute the channels across our cores.
Consider the scaling as we increase the number of rayon worker threads. Note that now, the threads column refers to the number of threads per example. Each example was processed serially.
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
8 | 2 | 1 | 1024 | 1024 | 9437184 | 1612.467 | 1574732.979 | 170.869 |
8 | 2 | 2 | 1024 | 1024 | 9437184 | 912.333 | 890975.042 | 96.677 |
8 | 2 | 4 | 1024 | 1024 | 9437184 | 548.067 | 535264.414 | 58.080 |
8 | 2 | 8 | 1024 | 1024 | 9437184 | 378.400 | 369564.049 | 40.100 |
8 | 2 | 16 | 1024 | 1024 | 9437184 | 320.333 | 312889.043 | 33.951 |
16 threads can process an example in 320 ms. This is only 5 times faster then the 1612 ms of a single core. 8 threads is almost as fast.
Consider the throughput if we parallelize over examples, instead of channels.
(note that, unlike in previous benchmarks, here, ms per example
is not normalized by number of threads)
version | segs | threads | input pixel size | output pixel size | n params | ms per example | ns per channel | ns per parameter |
---|---|---|---|---|---|---|---|---|
7 | 2 | 1 | 1024 | 1024 | 9437184 | 1690.000 | 1650410.841 | 179.081 |
7 | 2 | 2 | 1024 | 1024 | 9437184 | 890.000 | 1738340.941 | 188.622 |
7 | 2 | 4 | 1024 | 1024 | 9437184 | 453.133 | 1770094.310 | 192.068 |
7 | 2 | 8 | 1024 | 1024 | 9437184 | 255.733 | 1998001.801 | 216.797 |
7 | 2 | 16 | 1024 | 1024 | 9437184 | 170.733 | 2667993.305 | 289.496 |
By the time we get to 16 threads, example parallelism is almost twice as high throughout as channel parallelism. What merit have we acquired through channel parallelism if we must burn twice as many cores?
However, if we care about per example latency, not total throughput, if our mini batch size is fixed, and we have unlimited cores with independent cache, some small degree of channel parallelism may be valuable to reduce latency by a factor 2 or 3.
Conclusions
In this story of mathematics, of counting, in which the conclusion is driven by the largest number; in this story of majority rule, what have we accomplished?
Consider a fairly representative point in hyper-parameter space:
- 32x32 pixel image
- 1024 input channels
- 3x3 patches
- 1024 output channels
- 2x2 segments
Running 16 examples in parallel, we have improved performance by a factor of 4.87. If we use only a single thread, the improvement a factor of 7.
If we permit ourselves 16 cores per example, the improvement is a factor of ~25x. If we use 4x4 segments, and compare single threaded implementation 5 to 16 thread per example implementation 8, we have improved by a factor of ~60. However as we have discussed, using multiple cores per example is undesirable, and 4x4 segments is probably large then we need.
Still, a factor of 5 improvement is nice.
Potential further improvements
Reducing heap allocations
Currently, we perform a number of needless collect()
s.
They are needed to avoid ownership issues.
If we were a more skilled rust user, perhaps we could avoid them and so avoid some heap allocations.
asm
There is likely some assembly magic that could improve performance. However for now, we are not willing to sacrifice portability.
Compiler optimizations
Currently we use only the compiler flag lto = "thin"
.
We note that setting codegen-units = 1
, or lto = "fat"
increase execution time by a factor of ~15.
This should not be the case, the compiler is making bad optimization decisions.
We have not yet investigated the cause of this performance drop. It seems plausible that if the compiler made better optimization decisions the code would run faster.
Bigger CPU cache / Von Neumann bottleneck circumvention
We observe significant decline in performance when running multiple threads simultaneously. This is likely due to memory access contention.
Given that example parallelism is embarrassingly parallel, it would be better to use 16 single core machines then one 16 core machine.
One could imagine combining channel parallelism with example parallelism by assigning each example to a separate ~4 core machine. Channel parallelism still attains reasonably good scaling with 4 cores.
AWS spot prices per core
Observe that price per core increases as we increase number of cores per machine. It is therefore desirable to use few core machines.
cores per machine | cost per 1000 cores |
---|---|
1 | $3.8 |
2 | $4.2 |
4 | $11.9 |
8 | $12.8 |
16 | $15.9 |
32 | $15.9 |
64 | $16.6 |
We have not yet benchmarked how memory bandwidth/cache per core changes as we increase core count.
Future work
For a single 32x32 pixel, 1024 channel image, at a single corner of our ~9 million dimensional cube, we can observe the gradient of all dimensions in ~1.7 seconds, or ~0.25 seconds if we are willing to use 8 cores on it. Given n times more resources, we can observe loss deltas for n examples in the same time. Given the average loss deltas for a minibatch of examples, we can select a subset of dimensions and move along them, improving loss on that minibatch. If we repeatedly reduce the loss of our model with respect to sufficiently large mini batches, we will reduce loss on the full training set and test set. Once we have descended the loss surface of our first layer, we can discard the objective head, apply the convolution layer, and repeat. Each time we train a layer, we improve liner separability. If we repeatedly improve liner separability, we will obtain a representation which achieves good classification accuracy.
This is theory, it has not yet been tested outside of Belilovsky et al. 2018.
In the short term, we will train models on CIFAR 10 using a single 16 core CPU. Assuming this to be successful, we then intend to:
- implement distributed training so as to distribute our work across 10,000 cheap spot cores
- implement dynamically sized images so as to support ImageNet and other larger, inconsistently sized, datasets