Efficient computation of bit convolution loss deltas

Meta

This post describes the current Rust implementation of a particular component of binary NN training. It is likely only of interest to the intersection of Rust devs and ML devs. The implementation is likely to change in future. To follow along, see the correct commit.

The Rust source code shown here is not typical, idiomatic, Rust code. Its 9 type parameters encompass it about, its 46 lines of trait bounds blot out the sun. Most Rust code is far nicer. Please do not use it as an example of typical Rust

Abstract

Often, we desire to observe the loss deltas of all weight mutations of a convolution layer. We wish to perform this computation quickly, on commodity hardware, however it is computationally expensive. We describe and benchmark a number of different implementations, first achieving execution time quadratic with output and/or input size, then linear, and eventually barely sub linear with input size. Which implementation will have the best performance is non obvious. In attempting to improve performance, we utilize some somewhat godly optimizations, but ultimately fail to achieve acceptably good performance.

In a future post, we will describe further optimizations.

Introduction

Convolution layers are a standard component of image classification neural nets.

As previously discussed, training models with binary activations and weights has some very nice benefits. To train a binary weighed model, we need to be able to observe the loss deltas for each mutation. In other words, for each weight, if we set it to a new value, how would the loss change?

All benchmarks were carried out on a AMD Ryzen Threadripper 2950X 16 core processor with SMT disabled.

Benchmark table fields:

Forward pass

The forward pass is simple.

<IS as Conv<<IPS as BitPack<bool>>::T, <OPS as BitPack<bool>>::T, PX, PY>>::conv(
    input,
    |patch| {
        <OPS as PackedMap<<[[IPS; PY]; PX] as BitPack<W>>::T, bool>>::map(
            &self.kernel,
            |weights| <[[IPS; PY]; PX] as WeightArray<W>>::act(weights, &patch),
        )
    },
)

We use a convolution operation which takes an image of shape IS, and returns an image the same shape.

The input image holds pixels of type:

<IPS as BitPack<bool>>::T

while the output image holds pixels of type:

<OPS as BitPack<bool>>::T

In both cases, the bool contained in the pixels are efficiently packed using type magic such that they take only one bit each.

To the conv() method on the Conv trait on the IS type, we pass input and a function. As previously mentioned, input is of type

<IS as PixelPack<<IPS as BitPack<bool>>::T>>::I

The function takes a patch of type:

[[<IPS as BitPack<bool>>::T; PY]; PX]

where PY and PX are the dimensions of the patch, usually 3.

Inside this function, we perform a packed map operation:

<OPS as PackedMap<<[[IPS; PY]; PX] as BitPack<W>>::T, bool>>::map(
    &self.kernel,
    |weights| <[[IPS; PY]; PX] as WeightArray<W>>::act(weights, &patch),
)

We call the map() method on the output pixel shape. This function wants an array of shape OPS, and a function that maps a:

[[IPS; PY]; PX] as BitPack<W>>::T

to a bool.

It will then pack the bools into a

<OPS as BitPack<bool>>::T

Implementations of loss deltas computation

We present multiple implementations of computing all non null weight mutation loss deltas. We assume that implementation 0 to be correct, and compare with its results, the results of the other implementations. For 10,000 random inputs, all implementations shown here produce the same results. This is significant evidence that they implement the same function. A formal proof of equivalence is being the scope of this page and of our current capabilities.

Implementation 0

source

The naive way to compute loss deltas is to mutate the weights, compute the loss, and subtract the null loss.

let null_loss = self.loss(input, class) as i64;

<Self::Weight as BitScaler>::states()
    .iter()
    .map(|&w| {
        Self::indices()
            .map(|i| {
                (
                    i,
                    w,
                    self.mutate(i, w).loss(input, class) as i64 - null_loss,
                )
            })
            .collect::<Vec<_>>()
    })
    .flatten()
    .filter(|(_, _, l)| l.abs() as u64 > threshold)
    .collect()

We compute the null loss and store it for later:

let null_loss = self.loss(input, class) as i64;

For each state of the weights,

<Self::Weight as BitScaler>::states().iter()

for each index of the weights,

Self::indices()
    .map(|i| {
        (
            i,
            w,
            self.mutate(i, w).loss(input, class) as i64 - null_loss,
        )
    })
    .collect::<Vec<_>>()

we set the ith weight to the value of w, and compute the loss for the new model.

self.mutate(i, w).loss(input, class)

Then we subtract null_loss from it to get the loss delta.

This is very simple and easy to understand.

Perf

But let us examine the performance characteristics of this implementation. Each time we double output pixel size, example time increased by ~4, and time per parameter increases by ~2.

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
0 16 32 8 2384 118.432 3701327.623 14805310.492 49682.250
0 16 32 16 4768 551.552 17236356.369 34472712.737 115680.244
0 16 32 32 9536 2102.064 65689691.446 65689691.446 220435.206
0 16 32 64 19072 8461.376 264418687.694 132209343.847 443655.516
0 16 32 128 38144 36606.400 1143954888.040 285988722.010 959693.698
0 16 32 256 76288 132802.500 4150090041.562 518761255.195 1740809.581

Or consider how numbers change as we increase input pixel size. Here, the curve is somewhat messier, likely due to some combination of intermittent compiler optimizations and cache levels, but still, parameter time increases by, very approximately, a factor of 2 each time we double input size.

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
0 16 8 32 2624 464.848 58106625.976 14526656.494 177154.347
0 16 16 32 4928 981.216 61326093.455 30663046.727 199110.693
0 16 32 32 9536 2263.488 70734151.507 70734151.507 237362.925
0 16 64 32 18752 7282.031 113781979.748 227563959.495 388334.402
0 16 128 32 37184 25540.375 199534283.366 798137133.465 686865.003
0 16 256 32 74048 119042.500 465011068.965 3720088551.719 1607644.145

Time to observe a single mutation is linear both with output pixel size and input pixel size. Since number of weights is also linear with both output pixel size and input pixel size, total time to observe all loss deltas is quadratic with both input and output. This is undesirable. We can do significantly better.

Implementation 1

source

A single mutation to the weights matrix will only change one channel of the output. Our objective head (stored in self.tail) supports caching, we can feed it a single channel update and it will give us the loss delta more cheaply than recomputing from scratch.

First we extract the null_acts.

let null_acts = self.apply(input);

Then we use it to initialize the cache.

let cache = self.tail.cache(&null_acts, class);

Then, for each channel, we extract the one channel for each output pixel and store in null_chan_acts,

let null_chan_acts = <IS as PixelMap<<OPS as BitPack<bool>>::T, bool>>::map(
    &null_acts,
    |pixel| OPS::get(&pixel, o),
);

and then subtract it from the cache.

let chan_cache = self.tail.subtract_input(&cache, o, &null_chan_acts);

We also extract a channel of the weights for later use.

let weights_channel = OPS::index_get(&self.kernel, o);

For each input, we extract the current weight bit, and use it to filter out null mutations.

Then for each weight and input index, we mutate the weights channel,

let new_weights_channel = <[[IPS; PY]; PX]>::set(*weights_channel, i, w);

and use it to compute the new channel acts.

let new_acts = <IS as Conv<
    <IPS as BitPack<bool>>::T,
    bool,
    PX,
    PY,
>>::conv(
    input,
    |patch| {
        <[[IPS; PY]; PX] as WeightArray<W>>::act(
            &new_weights_channel,
            &patch,
        )
    },
);

Once we have the new acts, we can ask the cache for the loss delta,

let loss_delta = chan_cache.loss_delta(&new_acts);

and filter.

if loss_delta.abs() as u64 > threshold {
    Some((LayerIndex::Head((o, i)), w, loss_delta))
} else {
    None
}

Perf

As we double the number of channels, time per example increases by ~2, and time per parameters stays approximately constant. This is a nice improvement.

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
1 16 32 8 2384 6.624 207371.187 829484.748 2783.506
1 16 32 16 4768 12.816 400695.533 801391.066 2689.232
1 16 32 32 9536 25.648 801605.402 801605.402 2689.951
1 16 32 64 19072 51.688 1615237.516 807618.758 2710.130
1 16 32 128 38144 106.500 3329722.520 832430.630 2793.391
1 16 32 256 76288 212.062 6628561.961 828570.245 2780.437

However as we double input size, time per parameter tends to increases. It does not double smoothly, possibly due to some combination of intermittently applied compiler optimizations and/or cache levels.

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
1 16 8 32 2624 8.816 1102565.246 275641.312 3361.479
1 16 16 32 4928 20.928 1308267.392 654133.696 4247.621
1 16 32 32 9536 25.760 805499.404 805499.404 2703.018
1 16 64 32 18752 96.531 1508643.859 3017287.718 5148.955
1 16 128 32 37184 742.438 5800456.872 23201827.486 19967.149
1 16 256 32 74048 2698.438 10541000.500 84328004.002 36442.525

It is desirable that time per parameter be constant with both output and input. We can make it so.

Implementation 2

source

Implementation 2 is approximately the same as implementation 1 with the exception that that we are using the acts() method from the WeightArray trait on the patch [[IPS; PY]; PX]. Internally, acts() performs powerful bitwise magic to compute all the activations were each of the weights individually set to a given state.

Consider the expression

<IS as Conv<
    <IPS as BitPack<bool>>::T,
    <[[IPS; PY]; PX] as BitPack<bool>>::T,
    PX,
    PY,
>>::conv(input, |patch| {
    <[[IPS; PY]; PX] as WeightArray<W>>::acts(
        weights_channel,
        &patch,
        w,
    )
})

We perform a convolution over patches of input. The function which we are are mapping is

<[[IPS; PY]; PX] as WeightArray<W>>::acts(weights_channel, &patch, w)

This expression returns one bit for each weight weights_channel, the activation of the patch if the bit were to be set to w. All the “multiplication” is being performed in a very efficient packed fashion, 32 bits at a time. If the patch is dead, it will short circuit appropriately. We are unable to skip null mutations however.

Once we have the activations for each patch, we need only re order the dimensions of the matrix. To achieve this, we map over the matrix, extracting the needed bit.

<IS as PixelMap<
    <[[IPS; PY]; PX] as BitPack<bool>>::T,
    bool,
>>::map(
    &chan_acts,
    |pixel| <[[IPS; PY]; PX]>::get(pixel, i),
)

This we can pass this image of bool pixels to the cache as before.

Let us examine the performance characteristics of this implementation.

As we double number of output channels, time per parameter stays constant, (or perhaps very slightly decreases, but this could just be an artifact of overhead).

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
2 16 32 8 2384 2.470 77294.393 309177.574 1037.509
2 16 32 16 4768 4.928 154059.696 308119.391 1033.958
2 16 32 32 9536 9.808 306596.198 306596.198 1028.846
2 16 32 64 19072 19.568 611546.380 305773.190 1026.085
2 16 32 128 38144 39.216 1225546.302 306386.575 1028.143
2 16 32 256 76288 78.032 2438572.927 304821.616 1022.891

As we double number of input channels, time per parameter drifts gradually up. We hypothesize that this is an artifact of cache.

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
2 16 8 32 2624 2.928 366376.050 91594.012 1117.000
2 16 16 32 4928 5.165 322948.920 161474.460 1048.535
2 16 32 32 9536 9.674 302382.484 302382.484 1014.706
2 16 64 32 18752 27.366 427646.928 855293.856 1459.546
2 16 128 32 37184 53.968 421643.461 1686573.845 1451.440
2 16 256 32 74048 109.933 429425.194 3435401.553 1484.616

Consider the same table but for only one thread. Here we see that time per parameter, while lower for small inputs, is still slightly super linear with input size.

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
2 1 8 32 2624 2.692 336600.884 84150.221 1026.222
2 1 16 32 4928 4.757 297349.003 148674.502 965.419
2 1 32 32 9536 7.756 242396.527 242396.527 813.411
2 1 64 32 18752 23.724 370694.938 741389.877 1265.170
2 1 128 32 37184 47.601 371887.977 1487551.907 1280.165
2 1 256 32 74048 99.474 388573.445 3108587.560 1343.383

At last, we have achieved time per parameter constant with output and almost constant with input.

But we can do better.

Implementation 3

source

Convolution

Consider the expression

<IS as Conv<
    <IPS as BitPack<bool>>::T,
    <[[IPS; PY]; PX] as BitPack<bool>>::T,
    PX,
    PY,
>>::conv(input, |patch| patch)

We are apply a 3x3 convolution. This operation takes an input of image shape IS holding pixels of type:

<IPS as BitPack<bool>>::T

that is to say, bits packed into the shape IPS. It returns an image of shape IS, holding pixels of type:

<[[IPS; PY]; PX] as BitPack<bool>>::T

that is, bits packed into the shape [[IPS; PY]; PX].

The convolution operation also wants a function that maps a

[[<IPS as BitPack<bool>>::T; PY]; PX]

to a

<[[IPS; PY]; PX] as BitPack<bool>>::T

As we have previously explained to the compiler, these two types are equivalent.

In this case, we are using the null function as we need only to extract the patch unaltered. In this operation, we have preserved all information, but have increased memory usage by PX * PY. In the case of a 3x3 convolution, the memory usage of the new representation is 9 times larger than the original representation.

Dimension reordering

Now consider another, somewhat more complex, expression:

<[[IPS; PY]; PX]>::indices()
    .map(|i| {
        <IS as PixelMap<<[[IPS; PY]; PX] as BitPack<bool>>::T, bool>>::map(
            &chan_acts,
            |pixel| <[[IPS; PY]; PX]>::get(pixel, i),
        )
    })
    .collect::<Vec<<IS as PixelPack<bool>>::I>>()

Here, we are collecting a Vec of

<IS as PixelPack<bool>>::I

that is, an image of shape IS, holding pixels of type bool.

The map function is

<IS as PixelMap<<[[IPS; PY]; PX] as BitPack<bool>>::T, bool>>::map(
    &chan_acts,
    |pixel| <[[IPS; PY]; PX]>::get(pixel, i),
)

This performs a pixel wise map over the pixels of chan_acts, for each pixel, using the expression

<[[IPS; PY]; PX]>::get(pixel, i)

to extracting the ith bit of the patch.

Using a bool to store what was previously a packed bit increases memory usage by a factor of 8.

Using these two expressions, we have increased memory usage by (assuming 3x3 patches) a factor of 9*8, or 72.

For each channel

We need to know the full sum of each patch, in other words, the hamming distance between each patch of the input and that channels weights patch.

This can be easily achieved by using another convolution operation:

<IS as Conv<<IPS as BitPack<bool>>::T, u32, PX, PY>>::conv(input, |patch| <[[IPS; PY]; PX] as BMA<W>>::bma(weights_channel, &patch))

In this expression, we map the following expression over each patch of input:

<[[IPS; PY]; PX] as BMA<W>>::bma(weights_channel, &patch)

to produce an image of shape IS, holding pixels of type u32.

If we wish to obtain the activations, we map over the pixels, and compare to the threshold.

<IS as PixelMap<u32, bool>>::map(&null_chan_full_sum, |&sum| sum > <[[IPS; PY]; PX] as WeightArray<W>>::THRESHOLD)

However we do not wish to obtain the activations if no mutation is performed to the weights, but rather the activations if we were to set a single weight to a new value.

For each patch index

<IS as PixelZipMap<u32, bool, bool>>::zip_map(&null_chan_full_sum, layer, |sum, input| {
    ((sum - cur_weight.bma(input)) + w.bma(input)) > <[[IPS; PY]; PX] as WeightArray<W>>::THRESHOLD
})

Perf

Implementation 3 is somewhat flatter than implementation 2, while also being a constant factor faster.

As with implementation 2, time per parameter is ~constant with output size. However it is ~2/3 the time.

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
3 16 32 8 2384 1.699 53182.833 212731.333 713.864
3 16 32 16 4768 3.274 102343.799 204687.597 686.871
3 16 32 32 9536 6.186 193356.556 193356.556 648.848
3 16 32 64 19072 12.131 379123.450 189561.725 636.113
3 16 32 128 38144 24.528 766585.951 191646.488 643.109
3 16 32 256 76288 49.046 1532744.650 191593.081 642.930

Likewise, time per parameter is only very slightly super liner with input size.

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
3 16 8 32 2624 1.763 220722.256 55180.564 672.934
3 16 16 32 4928 3.258 203678.717 101839.358 661.295
3 16 32 32 9536 6.262 195711.478 195711.478 656.750
3 16 64 32 18752 12.746 199150.159 398300.318 679.693
3 16 128 32 37184 25.485 199114.908 796459.632 685.421
3 16 256 32 74048 51.866 202603.022 1620824.176 700.443

Implementation 4

source

Consider the distribution of a popcnt. Assuming random input bits and/or weights (which is not necessarily a valid assumption but we will pretend that it is), it is a bell curve. A single bit mutation can increase or decrease the popcnt by 2 in either direction. Therefore, if the popcnt of a patch and the current weights is outside of this range, the gradient is dead, and no mutation of any weights of that patch will result on the activation changing. As the pixel width grows, the portion of output bits which are alive shrinks. In an extreme case of very wide inputs, the vast majority of output activations will be dead. To exploit this, we need to put the patch loop inside the image spatial loop.

For each output channel, we initialize the chan cache as usual.

Then, for each weight and state, we do two things. Firstly, we fold over the patch indices of the image. source

The accumulator of this fold is a vec of length equal to the number of bits in the patch:

<[[IPS; PY]; PX]>::indices().map(|_| null_chan_acts).collect::<Vec<<IS as PixelPack<bool>>::I>>()

Each element of the vec is initialized to the null chan acts.

For each patch index, we extract the null sum,

IS::get_pixel(&null_chan_full_sum, spatial_index)

and then filter to acts within range.

(sum > (<[[IPS; PY]; PX] as WeightArray<W>>::THRESHOLD - (W::RANGE - 1))) & (sum < (<[[IPS; PY]; PX] as WeightArray<W>>::THRESHOLD + W::RANGE))

If we are within range, then for each bit of the patch, set the current patch in the channel to the new act.

<[[IPS; PY]; PX]>::indices().zip(acc.iter_mut()).for_each(|(i, target)| {
    let act: bool = <[[IPS; PY]; PX]>::get(&acts, i);
    <IS as PixelIndexSGet<bool>>::set_pixel_in_place(target, spatial_index, act);
});

In this way, we are avoiding mutating the majority of the pixels.

Secondly, we map over the resulting acts, and compute loss deltas as usual.

<[[IPS; PY]; PX]>::indices()
    .zip(acts.iter())
    .map(|(i, acts)| {
        let loss_delta = chan_cache.loss_delta(acts);
        (LayerIndex::<(OPS::Index, <[[IPS; PY]; PX] as Shape>::Index), H::Index>::Head((o, i)), w, loss_delta)
    })
    .collect::<Vec<_>>()

Perf

While the complexity with respect to output is no worse than 3, performance (for input of size 32) is only slightly better than that of implementation 2.

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
4 16 32 8 2384 2.797 87521.723 350086.891 1174.788
4 16 32 16 4768 4.469 139914.453 279828.905 939.023
4 16 32 32 9536 8.641 270129.776 270129.776 906.476
4 16 32 64 19072 17.156 536329.072 268164.536 899.881
4 16 32 128 38144 34.484 1077954.840 269488.710 904.325
4 16 32 256 76288 69.422 2169586.003 271198.250 910.061
4 16 32 512 152576 137.719 4303967.263 268997.954 902.678

Now, for the first time, performance is (somewhat) sub linear with input size. However at 256, time per parameter suddenly doubles. We believe this to be due to cache levels.

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
4 16 8 32 2624 3.953 494690.453 123672.613 1508.203
4 16 16 32 4928 5.453 341502.390 170751.195 1108.774
4 16 32 32 9536 8.906 278530.672 278530.672 934.667
4 16 64 32 18752 15.688 245297.243 490594.487 837.192
4 16 128 32 37184 29.062 227061.526 908246.106 781.623
4 16 256 32 74048 140.734 549803.949 4398431.596 1900.792
4 16 512 32 147776 552.719 1079535.044 17272560.699 3740.269
4 16 1024 32 295232 1338.688 1307321.567 41834290.150 4534.391

With only a single thread, the performance improvements continue until 1024. At input size of 256, we are almost competitive with implementation 3. If the trend were to continue, we could likely surpass 3 for large inputs.

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
4 1 8 32 2624 3.500 437547.905 109386.976 1333.988
4 1 16 32 4928 5.060 316507.318 158253.659 1027.621
4 1 32 32 9536 7.860 245626.743 245626.743 824.251
4 1 64 32 18752 14.040 219446.048 438892.095 748.963
4 1 128 32 37184 24.860 194250.423 777001.690 668.676
4 1 256 32 74048 47.500 185561.645 1484493.162 641.527
4 1 512 32 147776 106.100 207234.037 3315744.585 718.004
4 1 1024 32 295232 317.150 309724.084 9911170.677 1074.265

The accumulator for an input of size 1024 should be ~9.5 MB. The L2 cache on this CPU is 8 MB. We suspect this to be the cause of the discontinuity.

Conclusion

Let us inspect our progress so far.

Consider the performance of a 64 channel wide conv layer. This is a smallish model, but is within an order of magnitude of realistic sizes.

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
0 16 64 64 37504 28503.000 445365259.312 445365259.312 760008.975
1 16 64 64 37504 237.750 3718063.523 3718063.523 6344.818
2 16 64 64 37504 58.578 915286.823 915286.823 1561.923
3 16 64 64 37504 26.094 407774.059 407774.059 695.860
4 16 64 64 37504 32.453 507122.407 507122.407 865.397

We have reduced the time complexity of single parameter loss deltas with respect to layer width from quadratic to constant time. At 64 wide, performance is ~1,000 times better, although comparing absolute performance of algorithms of different complexities at an arbitrary point in size space is unfair.

Within the three constant time implementations, we have reduced time per parameter by a factor of ~2.2.

Our input images are 32x32 pixels. They therefore have 30 * 30 = 900 patches. Implementation 3 can compute a loss delta for all 900 patches in an average of ~700 nanoseconds. If we divide by number of patches, this is ~0.77 ns per parameter per patch, or (assuming a 4 GHz CPU), ~3 cycles (on average).

26 ms per (small) example is nice, but consider a 1024 channel wide layer (now single threaded to be extra fair):

version threads input pixel size output pixel size n params ms per example ns per pixel bit ns per channel ns per parameter
2 1 1024 1024 9447424 14385.600 14048457.464 14048457.464 1522.703
3 1 1024 1024 9447424 6035.600 5894239.136 5894239.136 638.873
4 1 1024 1024 9447424 10745.600 10493801.507 10493801.507 1137.416

Using implementation 3 to computing loss deltas for all 9.4 million parameters costs us over six seconds. Larger images will be yet more expensive. While we could likely parallelize across output channels fairly easily, multi core machines are more expensive, our other cores have other examples to process. Six seconds per (large) example is unacceptably poor performance.

We can do better.

Zeros

The most beautiful thing about it is that its solution is zero. It’s my feeling that because the answer is zero, nobody needs to bother with it.

  • Oshino Ougi

After we apply two distinct layers of single bit activations, most mutations have a zero gradient.

While there exists a small set of mutations which have non-zero gradients, the great majority are dead. Why must we still compute them if they are only to be filtered out later? But we cannot avoid it. The objective head is the objective head; it is a black box. Unless all activations in a channel are dead, we must construct all the activations of that channel, and call the objective head cache.

Fused pooling layer

What is our objective head? Currently, it is a global avg pooling followed by a single fully connected objective layer. If we are willing to fuse the conv layer with the pooling layer, we can exploit further performance optimizations.

More generally, our pooling operation is a segmented avg pooling. We divide the image into, for example, 2x2 segments. In each segment, for each channel, we count the number of set bits and compare to n_pixels/2. This results in an output of shape [2, 2, chans] which we can feed to the FC layer.

We have two levels of sign function activations. First, the conv patch activations, then the pooling activations. For a single channel, consider a single segment of the image; all of the patches which feed into one pooling region. Just as a conv patch activation is dead if its null sum is outside the range of the threshold, so too, a segment activation is dead if its null sum is outside the range of the segment threshold.

However, correctly detecting segment liveness is more complex than correctly detecting patch liveness. Firstly, the threshold of a segment is not necessarily known at compile time. This is somewhat annoying but is ultimately not a significant problem. More importantly, the range of a segment threshold is non obvious. Naively, the range is n_pixels on either side. However, recall that, for a given chan, most of the pixels in a segment will be dead. To correctly compute segment range, we must first compute liveness for all pixels. The min range is the number of dead pixels which are set, while the max range is min range + the number of live pixels. If the segment threshold is within this range, that channel of the segment is alive and its activation must be computed. But if the threshold is outside this range, it is dead and the activation will be the null activation. Note that segment death does not require that all of its patches be dead. While patch death tends to promote segment death, a segment can be dead while many if its patches are still locally alive. And yet, segment death propagates back to patches. Why should we bother to compute a gradient that we know will inevitably die later?

If all four segments are dead, the FC layer is likewise dead, and that whole channel of the conv weights will have loss deltas of 0, and need not be computed. Just as the probability of a conv patch act being dead increases with input size, so too, the probability of a segment activation being dead increases with segment size.

Consider implementation 4. It was very nice with the exception of large memory usage and consequent cache contention. How can we adapt it to work better with fused segmented avg pool?

For a given channel, we must compute the activations and liveness for each of the patches in that segment. From this we can compute the range, and check if the segment threshold is inside this range.

For implementation and perf numbers see the next post.

Last updated on: Wed, Dec 23, 2020