Fused Convolution Segmented Pooling Loss Deltas

Introduction

In the previous post, we described a number of implementations of bit conv loss deltas. In it, we used a global avg pooling layer. However, often we want avg pooling layers with multiple segments, usually 2x2 or perhaps 3x3.

The numbers presented here are not directly comparable to the numbers present in the previous post. We have performed a number of miscellaneous improvements, in addition to the switch from global avg pooling to segmented avg pooling.

As before, this post is like only of interest the intersection of Rust and ML devs. To follow along, use the correct commit.

Segmented pooling

Traditionally when we perform avg pooling, we average across all pixels of the image. However this destroys all spatial information. Alternatively, we could use a fully connected layer. But this is incompatible with dynamically sized images.

One solution is to cut the image into x by y segments, where x and y are usually 2 or perhaps 3. Then we can apply a fully connected objective layer to the segments.

Consider the implementation of segmented pooling.

let mut target = <[[(u32, <P as BitPack<bool>>::T); SY]; SX]>::default();
for sx in 0..SX {
    for sy in 0..SY {
        let (n, counts) = <I as SegmentedPixelFold<(usize, <P as Pack<u32>>::T), <P as BitPack<bool>>::T, SX, SY, PX, PY>>::seg_fold(
            input,
            sx,
            sy,
            <(usize, <P as Pack<u32>>::T)>::default(),
            |acc, pixel| P::counted_increment(pixel, acc),
        );
        let threshold = n as u32 / 2;
        target[sx][sy] = (threshold, <P as PackedMap<u32, bool>>::map(&counts, |&sum| sum > threshold));
    }
}

For each segment, we fold over the pixels of that segment, incrementing counters. Then we threshold the counts at half the number of pixels in that segment.

Inference

Fused inference can be implemented as a fold over segments. This is not actually significantly faster then generic layer by layer inference, but is a useful example for latter improvements. source

(0..SX)
    .map(|sx| iter::repeat(sx).zip(0..SY))
    .flatten()
    .fold(<[u32; C]>::default(), |class_acts, (sx, sy)| {
        let (n, counts) = <IS as SegmentedConvFold<
            (usize, <OPS as Pack<u32>>::T),
            <IPS as BitPack<bool>>::T,
            SX,
            SY,
            PX,
            PY,
        >>::seg_conv_fold(
            input,
            sx,
            sy,
            <(usize, <OPS as Pack<u32>>::T)>::default(),
            |(n, acc), patch| {
                (
                    n + 1,
                    <OPS as ZipMap<u32, [[<IPS as BitPack<W>>::T; PY]; PX], u32>>::zip_map(
                        &acc,
                        &self.conv,
                        |sum, weights| {
                            sum + (<[[IPS; PY]; PX] as WeightArray<W>>::act(
                                weights, &patch,
                            )) as u32
                        },
                    ),
                )
            },
        );
        let threshold = n as u32 / 2;
        let acts = <OPS as PackedMap<u32, bool>>::map(&counts, |&sum| sum > threshold);
        <[(); C] as ZipMap<u32, [[<OPS as BitPack<W>>::T; SY]; SX], u32>>::zip_map(
            &class_acts,
            &self.fc,
            |sum, weights| sum + OPS::bma(&weights[sx][sy], &acts),
        )
    })

For each segment

We fold over the patches of the segment.

let (n, counts) = <IS as SegmentedConvFold<
    (usize, <OPS as Pack<u32>>::T),
    <IPS as BitPack<bool>>::T,
    SX,
    SY,
    PX,
    PY,
>>::seg_conv_fold(
    input,
    sx,
    sy,
    <(usize, <OPS as Pack<u32>>::T)>::default(),
    |(n, acc), patch| {
        (
            n + 1,
            <OPS as ZipMap<u32, [[<IPS as BitPack<W>>::T; PY]; PX], u32>>::zip_map(
                &acc,
                &self.conv,
                |sum, weights| {
                    sum + (<[[IPS; PY]; PX] as WeightArray<W>>::act(
                        weights, &patch,
                    )) as u32
                },
            ),
        )
    },
);

For each patch

For each output channel, we compute the act, and increment the sum by it.

<OPS as ZipMap<u32, [[<IPS as BitPack<W>>::T; PY]; PX], u32>>::zip_map(
    &acc,
    &self.conv,
    |sum, weights| {
        sum + (<[[IPS; PY]; PX] as WeightArray<W>>::act(weights, &patch)) as u32
    },
),

Then bitpack the segment acts, and add the Bitwise Multiply Accumulate to the sum.

let threshold = n as u32 / 2;
let acts = <OPS as PackedMap<u32, bool>>::map(&counts, |&sum| sum > threshold);
<[(); C] as ZipMap<u32, [[<OPS as BitPack<W>>::T; SY]; SX], u32>>::zip_map(
    &class_acts,
    &self.fc,
    |sum, weights| sum + OPS::bma(&weights[sx][sy], &acts),
)

Now we have the acts for each class.

To compute loss, we sum the squared distance.

class_acts
    .iter()
    .enumerate()
    .map(|(c, &act)| {
        let target_act = (c == class) as u32 * <[[OPS; SY]; SX] as WeightArray<W>>::MAX;
        let dist = act.saturating_sub(target_act) | target_act.saturating_sub(act);
        (dist as u64).pow(2)
    })
    .sum()

Fused inference is relatively straight forward.

Now let us implement bulk loss deltas.

Implementations

Implementation 5

source

By way of a baseline, let us benchmark a simple layerwise, non-fused version.

Perf

Note that, unlike last post, these are all single threaded performance numbers. We will analyze multi threaded performance later. Also note that we are using lto = "thin" but not codegen-units = 1.

Performance is slightly super linear with output size.

version segs threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
5 2 1 32 32 9216 7.520 235806.315 235806.315 818.772
5 2 1 32 64 18432 14.720 460257.876 230128.938 799.059
5 2 1 32 128 36864 30.640 957925.780 239481.445 831.533
5 2 1 32 256 73728 66.000 2063432.201 257929.025 895.587
5 2 1 32 512 147456 151.833 4747994.635 296749.665 1030.381
5 2 1 32 1024 294912 379.667 11869930.125 370935.316 1287.970

However it is very pleasingly linear with input size.

version segs threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
5 2 1 32 32 9216 7.960 249888.235 249888.235 867.667
5 2 1 64 32 18432 14.360 224629.198 449258.395 779.962
5 2 1 128 32 36864 29.480 230611.853 922447.410 800.736
5 2 1 256 32 73728 58.917 230458.443 1843667.544 800.203
5 2 1 512 32 147456 120.000 234546.248 3752739.969 814.397
5 2 1 1024 32 294912 242.333 236870.863 7579867.615 822.468

Our target is 1024x1024 in less then 1 second per example. However implementation 5 is ~12 seconds per example.

version segs threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
5 2 1 1024 1024 9437184 12224.000 11937557.654 11937557.654 1295.308

So far, we have looked at 2x2 segments. Perhaps we will want 3x3 segments. Now larger output sizes are even worse.

version segs threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
5 3 1 32 32 9216 7.480 234254.375 234254.375 813.383
5 3 1 32 64 18432 16.000 500898.439 250449.219 869.615
5 3 1 32 128 36864 33.680 1052794.808 263198.702 913.884
5 3 1 32 256 73728 77.333 2419090.292 302386.286 1049.952
5 3 1 32 512 147456 201.500 6301272.781 393829.549 1367.464
5 3 1 32 1024 294912 598.000 18694642.365 584207.574 2028.499
5 3 1 32 32 9216 7.400 232492.975 232492.975 807.267
5 3 1 64 32 18432 15.240 238745.919 477491.838 828.979
5 3 1 128 32 36864 30.480 238321.659 953286.636 827.506
5 3 1 256 32 73728 60.167 235050.251 1880402.005 816.147
5 3 1 512 32 147456 121.500 237415.161 3798642.578 824.358
5 3 1 1024 32 294912 249.333 243755.361 7800171.552 846.373
5 3 1 1024 1024 9437184 19227.333 18776937.619 18776937.619 2037.428

This is not good enough. Our needed performance is in another implementation.

Implementation 6

source

First we make all things ready.

We precompute null class acts, null segment acts, and null loss.

Then we can begin iterating over the channels.

For each channel

We extract the patches, along with some additional information:

<IS as Conv<<IPS as BitPack<bool>>::T, ([[<IPS as BitPack<bool>>::T; PY]; PX], u32, Option<bool>, bool), PX, PY>>::conv(input, |patch| {
    let sum = <[[IPS; PY]; PX] as BMA<W>>::bma(weights_channel, &patch);
    let sign = sum > <[[IPS; PY]; PX] as WeightArray<W>>::THRESHOLD;
    let state = <[[IPS; PY]; PX] as WeightArray<W>>::mutant_act(sum);
    (patch, sum, state, sign)
})

Then we compute the act were that channel to be removed.

(0..SX).map(|sx| iter::repeat(sx).zip(0..SY)).flatten().fold(null_class_acts, |class_acts, (sx, sy)| {
    let (n, count) = <IS as SegmentedPixelFold<(usize, u32), ([[<IPS as BitPack<bool>>::T; PY]; PX], u32, Option<bool>, bool), SX, SY, PX, PY>>::seg_fold(
        &patches,
        sx,
        sy,
        <(usize, u32)>::default(),
        |(n, c), (_, _, _, sign)| (n + 1, c + *sign as u32),
    );

    let act = count > (n as u32 / 2);
    <[(); C] as ZipMap<u32, [[<OPS as BitPack<W>>::T; SY]; SX], u32>>::zip_map(&class_acts, &self.fc, |sum, weights| {
        let w: W = OPS::get(&weights[sx][sy], o);
        sum - w.bma(act)
    })
})

And finally, the trit states of each segment:

let mut ranges = <[[Option<bool>; SY]; SX]>::default();
for sx in 0..SX {
    for sy in 0..SY {
        let (n, min, max) = <IS as SegmentedPixelFold<(usize, u32, u32), ([[<IPS as BitPack<bool>>::T; PY]; PX], u32, Option<bool>, bool), SX, SY, PX, PY>>::seg_fold(
            &patches,
            sx,
            sy,
            <(usize, u32, u32)>::default(),
            |(n, min, max), (_, _, state, _)| {
                if let Some(sign) = state {
                    if *sign {
                        (n + 1, min + 1, max + 1)
                    } else {
                        (n + 1, min, max)
                    }
                } else {
                    (n + 1, min, max + 1)
                }
            },
        );
        let threshold = n as u32 / 2;
        ranges[sx][sy] = Some(max > threshold).filter(|_| (min > threshold) | (max <= threshold));
    }
}
ranges

For each segment, we fold over the states of the patches of that segment.

If the segment patch is dead and set, increment our min and max, if it is dead and unset, increment neither, if it is alive, increment max.

if let Some(sign) = state {
    if *sign {
        (n + 1, min + 1, max + 1)
    } else {
        (n + 1, min, max)
    }
} else {
    (n + 1, min, max + 1)
}

Then return the state of the segment, None if n/2 is within the range (min..max), Some(sign) if outside the range.

Some(max > threshold).filter(|_| (min > threshold) | (max <= threshold))

If all the segments are dead, we know that all weights in this channel are dead, and we can short circuit. But if some are alive, we must compute the gradients.

For each weight

We fold over the segments, collecting class acts.

(0..SX)
    .map(|sx| iter::repeat(sx).zip(0..SY))
    .flatten()
    .fold(<[[<IPS as Pack<[u32; C]>>::T; PY]; PX]>::default(), |class_acts, (sx, sy)| {
        let fc_weights: [W; C] =
            <[(); C] as Map<[[<OPS as BitPack<W>>::T; SY]; SX], W>>::map(&self.fc, |class_weights| <OPS as PackedIndexSGet<W>>::get(&class_weights[sx][sy], o));

        if let Some(act) = seg_states[sx][sy] {
            <[[IPS; PY]; PX] as Map<[u32; C], [u32; C]>>::map(&class_acts, |class_acts| {
                <[(); C] as ZipMap<u32, W, u32>>::zip_map(&class_acts, &fc_weights, |sum, class_weight| sum + class_weight.bma(act))
            })
        } else {
            let counts = <IS as SegmentedPixelFold<
                (usize, [[<IPS as Pack<u32>>::T; PY]; PX], u32),
                ([[<IPS as BitPack<bool>>::T; PY]; PX], u32, Option<bool>, bool),
                SX,
                SY,
                PX,
                PY,
            >>::seg_fold(
                &patches,
                sx,
                sy,
                <(usize, [[<IPS as Pack<u32>>::T; PY]; PX], u32)>::default(),
                |mut acc, (patch, cur_sum, state, _)| {
                    if let Some(act) = state {
                        <[[IPS; PY]; PX]>::none_option_counted_increment_in_place(*act, &mut acc);
                    } else {
                        let acts = <[[IPS; PY]; PX]>::acts_simple(weights_channel, patch, *cur_sum, w);
                        <[[IPS; PY]; PX]>::some_option_counted_increment_in_place(&acts, &mut acc);
                    }
                    acc
                },
            );
            let (n, seg_counts) = <[[IPS; PY]; PX]>::finalize_option_counted_increment(counts);
            let threshold = n as u32 / 2;

            <[[IPS; PY]; PX] as ZipMap<[u32; C], u32, [u32; C]>>::zip_map(&class_acts, &seg_counts, |class_acts, &count| {
                let act = count > threshold;
                <[(); C] as ZipMap<u32, W, u32>>::zip_map(&class_acts, &fc_weights, |sum, class_weight| sum + class_weight.bma(act))
            })
        }
    })

If the segment is dead, increment all the class act counters accordingly.

<[[IPS; PY]; PX] as Map<[u32; C], [u32; C]>>::map(&class_acts, |class_acts| {
    <[(); C] as ZipMap<u32, W, u32>>::zip_map(&class_acts, &fc_weights, |sum, class_weight| sum + class_weight.bma(act))
})

But if it is alive, we fold over patches of this segment, counting patch acts.

<IS as SegmentedPixelFold<
    (usize, [[<IPS as Pack<u32>>::T; PY]; PX], u32),
    ([[<IPS as BitPack<bool>>::T; PY]; PX], u32, Option<bool>, bool),
    SX,
    SY,
    PX,
    PY,
>>::seg_fold(
    &patches,
    sx,
    sy,
    <(usize, [[<IPS as Pack<u32>>::T; PY]; PX], u32)>::default(),
    |mut acc, (patch, cur_sum, state, _)| {
        if let Some(act) = state {
            <[[IPS; PY]; PX]>::none_option_counted_increment_in_place(*act, &mut acc);
        } else {
            let acts = <[[IPS; PY]; PX]>::acts_simple(weights_channel, patch, *cur_sum, w);
            <[[IPS; PY]; PX]>::some_option_counted_increment_in_place(&acts, &mut acc);
        }
        acc
    },
)

If the patch is dead, we metaphorically increment all the counters by that sign. (Actually we are saving up all our dead patches to apply at the end) But if it is alive, we use powerful bit wise magic to compute acts, and increment each counter accordingly.

if let Some(act) = state {
    <[[IPS; PY]; PX]>::none_option_counted_increment_in_place(*act, &mut acc);
} else {
    let acts = <[[IPS; PY]; PX]>::acts_simple(weights_channel, patch, *cur_sum, w);
    <[[IPS; PY]; PX]>::some_option_counted_increment_in_place(&acts, &mut acc);
}

Now that we know the counts for each member of the input patch, we increment each set of class acts accordingly.

<[[IPS; PY]; PX] as ZipMap<[u32; C], u32, [u32; C]>>::zip_map(&class_acts, &seg_counts, |class_acts, &count| {
    let act = count > threshold;
    <[(); C] as ZipMap<u32, W, u32>>::zip_map(&class_acts, &fc_weights, |sum, class_weight| sum + class_weight.bma(act))
})

Now we have partial class acts for each patch mutation. We can map over them, adding the else class acts, computing loss.

<[[IPS; PY]; PX] as Shape>::indices()
    .map(|i| {
        let mut_loss = <[[IPS; PY]; PX] as IndexGet<[u32; C]>>::index_get(&patch_class_acts, i)
            .iter()
            .zip(target_acts.iter())
            .zip(else_class_acts.iter())
            .map(|((&act, &target_act), else_class_act)| {
                let act = act + else_class_act;
                let dist = act.saturating_sub(target_act) | target_act.saturating_sub(act);
                (dist as u64).pow(2)
            })
            .sum::<u64>();
        let index = LayerIndex::Head((o, i));
        let loss_delta = mut_loss as i64 - null_loss as i64;
        (index, w, loss_delta)
    })
    .collect::<Vec<_>>()

This has been a fairly direct port of the fused forward pass implementation. It is relativity simple. But is it fast enough?

Perf

Performance is now pleasingly liner with output size.

version segs threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
6 2 1 32 32 9216 6.240 196169.263 196169.263 681.143
6 2 1 32 64 18432 12.800 400123.318 200061.659 694.659
6 2 1 32 128 36864 25.120 785715.521 196428.880 682.045
6 2 1 32 256 73728 49.333 1543467.849 192933.481 669.908
6 2 1 32 512 147456 100.500 3141999.922 196374.995 681.858
6 2 1 32 1024 294912 196.000 6129373.125 191542.910 665.080

Even better, it is sub liner with input size.

version segs threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
6 2 1 32 32 9216 6.840 214885.931 214885.931 746.132
6 2 1 64 32 18432 11.640 182382.667 364765.334 633.273
6 2 1 128 32 36864 18.640 145794.872 583179.490 506.232
6 2 1 256 32 73728 31.333 122501.882 980015.052 425.354
6 2 1 512 32 147456 57.500 112473.484 1799575.740 390.533
6 2 1 1024 32 294912 107.333 104913.308 3357225.854 364.282

1024x1024 is now ~3 second per example. This is very nice.

version segs threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
6 2 1 1024 1024 9437184 3044.667 2973484.268 2973484.268 322.644

3x3 segments, while slightly more expensive, follows the same general curves.

version segs threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
6 3 1 32 32 9216 7.520 235910.079 235910.079 819.132
6 3 1 32 64 18432 13.720 429050.230 214525.115 744.879
6 3 1 32 128 36864 27.240 852088.600 213022.150 739.660
6 3 1 32 256 73728 53.917 1687038.039 210879.755 732.221
6 3 1 32 512 147456 108.500 3392835.042 212052.190 736.292
6 3 1 32 1024 294912 218.333 6824996.927 213281.154 740.560
6 3 1 32 32 9216 7.400 231396.531 231396.531 803.460
6 3 1 64 32 18432 12.200 190722.744 381445.487 662.232
6 3 1 128 32 36864 21.480 167821.317 671285.267 582.713
6 3 1 256 32 73728 38.000 148507.264 1188058.112 515.650
6 3 1 512 32 147456 72.000 140766.711 2252267.370 488.773
6 3 1 1024 32 294912 149.333 146105.803 4675385.688 507.312
6 3 1 1024 1024 9437184 4111.333 4014986.518 4014986.518 435.654

But consider performance changes as we run multiple examples in parallel.

version segs threads input pixel size output pixel size n params ms per example ns per channel ns per parameter
6 2 1 1024 1024 9437184 3030.500 2959610.248 321.138
6 2 2 1024 1024 9437184 3041.500 2970259.857 322.294
6 2 4 1024 1024 9437184 3110.500 3037650.779 329.606
6 2 8 1024 1024 9437184 4319.250 4218019.613 457.684
6 2 16 1024 1024 9437184 9745.750 9517400.445 1032.704

Our needed performance is in another implementation.

Implementation 7

Consider the size of an array of class acts. Given 10 classes, it is 40 bytes. Multiplied by the number of bits in a patch, this is quite large.

Consider the information transmitted by a single channel of a single segment. It is but one bit. A single channel of 2x2=4 channels is 4 bits. 3x3=9 channels is 9 bits. Even if we use 4x4=16 segments, (significantly larger then the 2x2=4 used in Belilovsky et al.), this is merely 16 bits. We may use more classes. 10 classes is small, CIFAR 100 has 100 class. ImageNet has 1000.

It is preferable to accumulate things into an accumulator of 16 bits rather then 40 bytes.

Before beginning on a given output channel, we we can cheaply analyze the liveness of each segment. The segments which are dead, we need not compute their acts, for we know it already.

Given our set of segments, we can filter to those segments which are live. We then need only fold over the live segments accumulating into the bits of the segment index. Once we are done folding, we have a finite set of act combinations. We must then hydrate it.

Usually, 2^n_live_segs is significantly smaller then the size of an input patch.

Now let us consider the implementation. source

As before, we must first make all things ready.

We compute null class acts, null seg acts, null loss.

For each channel

As before, we prepare the patches and the segment states.

Now however, instead is simply shortcircuiting for all dead segments, we perform more complex branching.

We have three cases.

No live segments

If no segments are alive, then as before, we know that all mutations are dead, and we can return an empty Vec. This is easy.

But if at least one segment is alive, we compute the dead class acts.

We fold over all the segments, if the segment is alive, we skip, but if it is dead, we increment the acts for the dead classes.

seg_states.iter().fold(null_class_acts, |class_acts, &((sx, sy), _, state)| {
    let act = <OPS as PackedIndexSGet<bool>>::get(&null_seg_acts[sx][sy], o);
    if let Some(act) = state {
        class_acts
    } else {
        <[(); C] as ZipMap<u32, [[<OPS as BitPack<W>>::T; SY]; SX], u32>>::zip_map(&class_acts, &self.fc, |sum, weights| {
            sum - <OPS as PackedIndexSGet<W>>::get(&weights[sx][sy], o).bma(act)
        })
    }
})

1 live segment

If but one segment is live, there can only be one non zero loss delta and we can use a simpler implementation.

To compute it, we take the dead class act, and the class weights, and combine them according to the inverse of the null segment act.

<[(); C] as ZipMap<u32, W, u32>>::zip_map(&dead_class_acts, &class_weights, |sum, w| sum + w.bma(!null_seg_act))
    .iter()
    .zip(target_acts.iter())
    .map(|(&act, &target_act)| {
        let dist = act.saturating_sub(target_act) | target_act.saturating_sub(act);
        (dist as u64).pow(2)
    })
    .sum::<u64>() as i64
    - null_loss as i64

As before, we fold over the segment, and then, if the old weight not the same as the new weight, we extract the segment acts, and if that segment acts is different, we return the loss delta.

<[[IPS; PY]; PX] as Shape>::indices()
    .filter_map(|i| {
        if <[[IPS; PY]; PX] as PackedIndexSGet<W>>::get(&weights_channel, i) == w {
            None
        } else {
            let act = <[[IPS; PY]; PX] as PackedIndexSGet<bool>>::get(&seg_acts, i);
            if act == null_seg_act {
                None
            } else {
                Some((LayerIndex::Head((o, i)), w, loss_delta))
            }
        }
    })
    .collect::<Vec<_>>()

If we have only one live segment, obtaining the set of patch mutations which satisfy it is fairly simple.

But if more then one segment is alive, we must do more complex things.

>1 live segments

Number of loss deltas grows exponentially with number of live segments. Although 4x4=16 segments has, worst case, 65536 loss deltas, in practice, it is quite rare for all segments to be live in a given channel. In almost all channels, the number of loss deltas which must be precomputed is many times smaller then worst case.

First, we must compute the class acts for the two states of all live segments.

seg_states
    .iter()
    .filter(|(_, _, state)| state.is_none())
    .map(|(_, class_weights, _)| {
        [
            <[(); C] as Map<W, u32>>::map(&class_weights, |w| w.bma(false)),
            <[(); C] as Map<W, u32>>::map(&class_weights, |w| w.bma(true)),
        ]
    })
    .collect()

Then, for every index from 0 to 2^live_segs.len(), we combine the elements of the live class acts, selecting the segment act state according to the corresponding bit of the index, and compute the loss delta for that class act.

(0..2usize.pow(live_seg_class_acts.len() as u32))
    .map(|i| {
        live_seg_class_acts
            .iter()
            .enumerate()
            .fold(dead_class_acts, |class_acts, (seg_index, live_class_acts)| {
                <[(); C] as ZipMap<u32, u32, u32>>::zip_map(&class_acts, &live_class_acts[i.bit(seg_index) as usize], |sum, live_act| sum + live_act)
            })
            .iter()
            .zip(target_acts.iter())
            .map(|(&act, &target_act)| {
                let dist = act.saturating_sub(target_act) | target_act.saturating_sub(act);
                (dist as u64).pow(2)
            })
            .sum::<u64>() as i64
            - null_loss as i64
    })
    .collect()

Now we have a dense mod2 cube with one dimension for each live segment. At each corner of this cube, there is a loss delta.

Now we need to find the segment acts for each patch weight mutation.

For each target weight, for each live segment, we fold over the patches of that segment as before, counting the patch acts. Then we threshold and bitpack the counts. Now we have a vec of, for each live segment, the bit packed segment acts.

This is a 2d bit matrix of shape [segs, patch_bit_index]. We need it in the shape [patch_bit_index, segs].

We iterate of the indices of the patch.

<[[IPS; PY]; PX] as Shape>::indices()
    .filter_map(|i| {
        if <[[IPS; PY]; PX] as PackedIndexSGet<W>>::get(&weights_channel, i) == w {
            None
        } else {
            let loss_index: usize = seg_acts.iter().enumerate().fold(0usize, |loss_index, (bit_index, seg)| {
                loss_index | ((<[[IPS; PY]; PX] as PackedIndexSGet<bool>>::get(&seg, i) as usize) << bit_index)
            });
            Some((LayerIndex::Head((o, i)), w, loss_deltas[loss_index]))
        }
    })
    .collect::<Vec<_>>()

We fold over live segments, adding a bit to the index according to that act of that segment.

seg_acts.iter().enumerate().fold(0usize, |loss_index, (bit_index, seg)| {
    loss_index | ((<[[IPS; PY]; PX] as PackedIndexSGet<bool>>::get(&seg, i) as usize) << bit_index)
})

Then we can use the index to select the loss delta: loss_deltas[loss_index].

Perf

We have done many complex things.

Small output channels is slightly worse then implementation 6, However scaling is better.

version segs threads input pixel size output pixel size n params ms per example ns per channel ns per parameter
7 2 1 32 32 9216 7.333 232860.406 808.543
7 2 1 32 64 18432 17.000 269651.302 936.289
7 2 1 32 128 36864 29.667 232224.052 806.334
7 2 1 32 256 73728 39.667 155113.255 538.588
7 2 1 32 512 147456 78.667 154234.921 535.538
7 2 1 32 1024 294912 157.333 153678.509 533.606

Scaling with input is a small improvement over implementation 6.

version segs threads input pixel size output pixel size n params ms per example ns per channel ns per parameter
7 2 1 32 32 9216 6.000 190813.396 662.547
7 2 1 64 32 18432 7.667 249737.479 433.572
7 2 1 128 32 36864 24.000 757013.958 657.130
7 2 1 256 32 73728 27.667 866082.323 375.904
7 2 1 512 32 147456 39.667 1240358.542 269.175
7 2 1 1024 32 294912 59.333 1860848.979 201.915

1024 x 1024 is now 1.7 seconds, down from 3 seconds.

version segs threads input pixel size output pixel size n params ms per example ns per channel ns per parameter
7 2 1 1024 1024 9437184 1789.333 1747511.273 189.617

Now however, performance improves slightly as we increase number of segments.

version segs threads input pixel size output pixel size n params ms per example ns per channel ns per parameter
7 1 1 1024 1024 9437184 1866.000 1822499.346 197.754
7 2 1 1024 1024 9437184 1796.333 1754418.040 190.367
7 3 1 1024 1024 9437184 1726.667 1686353.853 182.981
7 4 1 1024 1024 9437184 1606.333 1568983.929 170.246

Consider performance changes as we run multiple examples in parallel. Although only slightly faster when single threaded, when running 16 examples in parallel, implementation 7 is significantly better then implementation 6.

version segs threads input pixel size output pixel size n params ms per example ns per channel ns per parameter
7 2 1 1024 1024 9437184 1791.250 1749446.327 189.827
7 2 2 1024 1024 9437184 1774.500 1733053.783 188.048
7 2 4 1024 1024 9437184 1852.000 1808798.837 196.267
7 2 8 1024 1024 9437184 2010.500 1963392.940 213.042
7 2 16 1024 1024 9437184 2843.500 2777027.236 301.327

Implementation 8: Multi core

source

If we have only one core to work with, 1.7 seconds per example is about as good as we can get. But, if we are willing to spend multiple cores per example, we can do better. With rayon, parallelizing over output channels is quite simple. Implementation 8 is essentially identical to implementation 7 with the exception that we insert .par_bridge() after OPS::indices(). rayon will automagically distribute the channels across our cores.

Consider the scaling as we increase the number of rayon worker threads. Note that now, the threads column refers to the number of threads per example. Each example was processed serially.

version segs threads input pixel size output pixel size n params ms per example ns per channel ns per parameter
8 2 1 1024 1024 9437184 1612.467 1574732.979 170.869
8 2 2 1024 1024 9437184 912.333 890975.042 96.677
8 2 4 1024 1024 9437184 548.067 535264.414 58.080
8 2 8 1024 1024 9437184 378.400 369564.049 40.100
8 2 16 1024 1024 9437184 320.333 312889.043 33.951

16 threads can process an example in 320 ms. This is only 5 times faster then the 1612 ms of a single core. 8 threads is almost as fast.

Consider the throughput if we parallelize over examples, instead of channels.

(note that, unlike in previous benchmarks, here, ms per example is not normalized by number of threads)

version segs threads input pixel size output pixel size n params ms per example ns per channel ns per parameter
7 2 1 1024 1024 9437184 1690.000 1650410.841 179.081
7 2 2 1024 1024 9437184 890.000 1738340.941 188.622
7 2 4 1024 1024 9437184 453.133 1770094.310 192.068
7 2 8 1024 1024 9437184 255.733 1998001.801 216.797
7 2 16 1024 1024 9437184 170.733 2667993.305 289.496

By the time we get to 16 threads, example parallelism is almost twice as high throughout as channel parallelism. What merit have we acquired through channel parallelism if we must burn twice as many cores?

However, if we care about per example latency, not total throughput, if our mini batch size is fixed, and we have unlimited cores with independent cache, some small degree of channel parallelism may be valuable to reduce latency by a factor 2 or 3.

Conclusions

In this story of mathematics, of counting, in which the conclusion is driven by the largest number; in this story of majority rule, what have we accomplished?

Consider a fairly representative point in hyper-parameter space:

Running 16 examples in parallel, we have improved performance by a factor of 4.87. If we use only a single thread, the improvement a factor of 7.

If we permit ourselves 16 cores per example, the improvement is a factor of ~25x. If we use 4x4 segments, and compare single threaded implementation 5 to 16 thread per example implementation 8, we have improved by a factor of ~60. However as we have discussed, using multiple cores per example is undesirable, and 4x4 segments is probably large then we need.

Still, a factor of 5 improvement is nice.

Potential further improvements

Reducing heap allocations

Currently, we perform a number of needless collect()s. They are needed to avoid ownership issues. If we were a more skilled rust user, perhaps we could avoid them and so avoid some heap allocations.

asm

There is likely some assembly magic that could improve performance. However for now, we are not willing to sacrifice portability.

Compiler optimizations

Currently we use only the compiler flag lto = "thin".

We note that setting codegen-units = 1, or lto = "fat" increase execution time by a factor of ~15. This should not be the case, the compiler is making bad optimization decisions.

We have not yet investigated the cause of this performance drop. It seems plausible that if the compiler made better optimization decisions the code would run faster.

Bigger CPU cache / Von Neumann bottleneck circumvention

We observe significant decline in performance when running multiple threads simultaneously. This is likely due to memory access contention.

Given that example parallelism is embarrassingly parallel, it would be better to use 16 single core machines then one 16 core machine.

One could imagine combining channel parallelism with example parallelism by assigning each example to a separate ~4 core machine. Channel parallelism still attains reasonably good scaling with 4 cores.

AWS spot prices per core

Observe that price per core increases as we increase number of cores per machine. It is therefore desirable to use few core machines.

cores per machine cost per 1000 cores
1 $3.8
2 $4.2
4 $11.9
8 $12.8
16 $15.9
32 $15.9
64 $16.6

We have not yet benchmarked how memory bandwidth/cache per core changes as we increase core count.

Future work

For a single 32x32 pixel, 1024 channel image, at a single corner of our ~9 million dimensional cube, we can observe the gradient of all dimensions in ~1.7 seconds, or ~0.25 seconds if we are willing to use 8 cores on it. Given n times more resources, we can observe loss deltas for n examples in the same time. Given the average loss deltas for a minibatch of examples, we can select a subset of dimensions and move along them, improving loss on that minibatch. If we repeatedly reduce the loss of our model with respect to sufficiently large mini batches, we will reduce loss on the full training set and test set. Once we have descended the loss surface of our first layer, we can discard the objective head, apply the convolution layer, and repeat. Each time we train a layer, we improve liner separability. If we repeatedly improve liner separability, we will obtain a representation which achieves good classification accuracy.

This is theory, it has not yet been tested outside of Belilovsky et al. 2018.

In the short term, we will train models on CIFAR 10 using a single 16 core CPU. Assuming this to be successful, we then intend to:

Last updated on: Sat, Jan 9, 2021