Gradients can propagate across layers of single bit activations
Epistemic status: We have empirical evidence, however it is possible that we have misinterpreted the results
This is work in progress, check back later for more complete numbers.
Abstract
The sign function has no gradient across almost all of its domain.
In other words, if our activations are of one bit precision, we cannot compute gradients for any weights before a single bit activation.
As a consequence, it is conventional wisdom that gradients cannot propagate across even a single layer of single bit precision activations.
We empirically demonstrate this to be untrue for bit and trit weighted fully connected multi layer neural networks trained on binarized MNIST.
We train FC models with zero, one, two and three single bit precision hidden layers, and observe that gradients do propagate to early layers, and that the models are able to achieve significantly greater then random accuracy.
We note that models with more hidden layers are more difficult to train, and that more then three hidden layers is effectively impossible.
We achieve this by computing loss deltas with respect to weight mutations rather then the more traditional method of computing weight gradients with respect to loss.
Implementation
We wrote a pure rust implementation which is generic across weight implementations and hidden layer shape.
We have implemented bit and trit weights, each bit packed for efficiency, and intend to implement quat and pent in future.
Inference requires only bitwise operations and unsigned integer addition.
Loss delta computation is inherently sparse, most gradients are dead and need not be computed.
For each example in a minibatch, a sparse set of loss deltas is computed by means of skillful magic.
This is set is further sparsified by taking the n largest positive and n smallest negative loss deltas.
This set of 2n loss deltas from each example is merged together.
The resulting summed set is then sorted and the best k mutations applied.
Hyperparameters
The model has four training hyperparameters.
- truncation: The number of samples from each example in a minibatch to be collected. The top
truncation
negative and positive samples are collected, so a total of twice his number are collected.
- n updates: The number of updates applyed each sample. Large value means faster training.
- min: The smallest minibatch size. Smaller means more epochs, and more samples.
- scale: a fraction. The ratio of minibatch sizes between two epochs. Must be less then 1. Large means more epochs and more samples.
The model has three architecture hyperparameters.
- Weight type: Precision of weight. Currently only bit or trit.
- n layers: Number of hidden layers.
- hidden layer width: Number of bits in a hidden layer. For simplicity, width is the same across all hidden layers.
Results
Train time is in seconds, measured on an AMD Ryzen Threadripper 2950X with 32 rayon threads.
Best test accuracies for each depth:
- zero hidden layer: 79%
- one hidden layer: 82%
- two hidden layers: 66%
For a float weighted FC model on full precision MNIST, 82% is embarrassingly bad.
For binarised MNIST, it is somewhat less embarrassingly bad.
This 83% accuracy model takes 152 seconds to train, and (32*25*128 + 128*10)/8
= less then 13 KB to store.
Inference should require 25*128 + 4*10 = 3240
32 bit XOR, POPCNT, ADD operations, and 128 compare, bit shift, OR operations.
zero hidden layers
bit weights
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
7 |
1000 |
77.812% |
78.320% |
120 |
19 |
4475 |
50 |
2/3 |
7 |
1000 |
77.572% |
78.020% |
82 |
18 |
2975 |
70 |
2/3 |
7 |
1000 |
77.975% |
78.610% |
70 |
17 |
1975 |
100 |
2/3 |
7 |
1000 |
78.712% |
79.080% |
61 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
77.482% |
78.210% |
52 |
13 |
382 |
500 |
2/3 |
7 |
1000 |
77.842% |
78.190% |
35 |
12 |
253 |
1000 |
2/3 |
7 |
1000 |
75.627% |
76.170% |
32 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
73.700% |
74.330% |
53 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
77.787% |
78.520% |
53 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
77.482% |
78.210% |
53 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
78.317% |
78.810% |
52 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
74.683% |
75.610% |
52 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
7 |
1000 |
77.633% |
78.200% |
25 |
7 |
1095 |
200 |
1/3 |
7 |
1000 |
78.148% |
78.590% |
27 |
10 |
1023 |
200 |
1/2 |
7 |
1000 |
77.465% |
78.160% |
35 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
77.482% |
78.210% |
53 |
21 |
1255 |
200 |
3/4 |
7 |
1000 |
78.627% |
79.490% |
64 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
7 |
100 |
22.423% |
22.010% |
23 |
16 |
1309 |
200 |
2/3 |
7 |
200 |
65.107% |
66.300% |
28 |
16 |
1309 |
200 |
2/3 |
7 |
500 |
75.802% |
76.720% |
41 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
77.482% |
78.210% |
53 |
16 |
1309 |
200 |
2/3 |
7 |
2000 |
70.567% |
70.740% |
69 |
16 |
1309 |
200 |
2/3 |
7 |
5000 |
65.443% |
65.870% |
91 |
trit weights
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
7 |
1000 |
69.288% |
69.650% |
197 |
19 |
4475 |
50 |
2/3 |
7 |
1000 |
69.122% |
69.680% |
146 |
18 |
2975 |
70 |
2/3 |
7 |
1000 |
69.510% |
70.580% |
127 |
17 |
1975 |
100 |
2/3 |
7 |
1000 |
68.010% |
68.680% |
112 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
69.315% |
70.060% |
98 |
13 |
382 |
500 |
2/3 |
7 |
1000 |
69.982% |
70.710% |
70 |
12 |
253 |
1000 |
2/3 |
7 |
1000 |
69.085% |
69.730% |
62 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
30.823% |
31.300% |
83 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
68.883% |
69.600% |
97 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
69.315% |
70.060% |
99 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
68.312% |
68.920% |
100 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
69.830% |
70.760% |
100 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
7 |
1000 |
68.207% |
68.970% |
47 |
7 |
1095 |
200 |
1/3 |
7 |
1000 |
69.102% |
70.090% |
49 |
10 |
1023 |
200 |
1/2 |
7 |
1000 |
70.228% |
70.790% |
65 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
69.315% |
70.060% |
99 |
21 |
1255 |
200 |
3/4 |
7 |
1000 |
68.853% |
70.040% |
122 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
7 |
100 |
56.873% |
57.850% |
60 |
16 |
1309 |
200 |
2/3 |
7 |
200 |
75.138% |
76.130% |
64 |
16 |
1309 |
200 |
2/3 |
7 |
500 |
72.768% |
73.870% |
73 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
69.315% |
70.060% |
98 |
16 |
1309 |
200 |
2/3 |
7 |
2000 |
45.770% |
46.760% |
135 |
16 |
1309 |
200 |
2/3 |
7 |
5000 |
66.395% |
66.670% |
191 |
two hidden layers
bit 1 x 32
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
7 |
1000 |
36.227% |
36.000% |
44 |
19 |
4475 |
50 |
2/3 |
7 |
1000 |
74.592% |
75.430% |
45 |
18 |
2975 |
70 |
2/3 |
7 |
1000 |
74.180% |
74.910% |
38 |
17 |
1975 |
100 |
2/3 |
7 |
1000 |
72.270% |
72.930% |
34 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
73.218% |
74.190% |
30 |
13 |
382 |
500 |
2/3 |
7 |
1000 |
75.035% |
76.040% |
17 |
12 |
253 |
1000 |
2/3 |
7 |
1000 |
72.825% |
73.690% |
14 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
58.915% |
59.610% |
39 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
73.707% |
74.480% |
33 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
73.218% |
74.190% |
30 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
68.593% |
69.700% |
27 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
58.205% |
58.910% |
23 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
7 |
1000 |
67.712% |
67.910% |
15 |
7 |
1095 |
200 |
1/3 |
7 |
1000 |
64.735% |
65.130% |
16 |
10 |
1023 |
200 |
1/2 |
7 |
1000 |
67.700% |
67.950% |
20 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
73.218% |
74.190% |
30 |
21 |
1255 |
200 |
3/4 |
7 |
1000 |
75.292% |
76.250% |
34 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
7 |
100 |
14.675% |
14.190% |
13 |
16 |
1309 |
200 |
2/3 |
7 |
200 |
62.512% |
62.520% |
18 |
16 |
1309 |
200 |
2/3 |
7 |
500 |
75.602% |
76.790% |
28 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
73.218% |
74.190% |
30 |
16 |
1309 |
200 |
2/3 |
7 |
2000 |
67.645% |
68.470% |
30 |
16 |
1309 |
200 |
2/3 |
7 |
5000 |
72.817% |
73.930% |
29 |
trit 1 x 32
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
7 |
1000 |
9.870% |
10.150% |
88 |
19 |
4475 |
50 |
2/3 |
7 |
1000 |
72.080% |
72.690% |
93 |
18 |
2975 |
70 |
2/3 |
7 |
1000 |
60.960% |
60.980% |
82 |
17 |
1975 |
100 |
2/3 |
7 |
1000 |
69.440% |
70.040% |
77 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
75.075% |
76.390% |
66 |
13 |
382 |
500 |
2/3 |
7 |
1000 |
38.025% |
38.300% |
43 |
12 |
253 |
1000 |
2/3 |
7 |
1000 |
64.530% |
64.630% |
37 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
56.757% |
58.420% |
81 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
64.642% |
65.840% |
71 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
75.075% |
76.390% |
66 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
74.267% |
75.610% |
62 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
60.312% |
61.530% |
59 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
7 |
1000 |
70.310% |
71.020% |
32 |
7 |
1095 |
200 |
1/3 |
7 |
1000 |
68.157% |
69.030% |
35 |
10 |
1023 |
200 |
1/2 |
7 |
1000 |
70.365% |
71.060% |
47 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
75.075% |
76.390% |
65 |
21 |
1255 |
200 |
3/4 |
7 |
1000 |
72.287% |
73.080% |
79 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
7 |
100 |
10.912% |
11.160% |
33 |
16 |
1309 |
200 |
2/3 |
7 |
200 |
17.098% |
17.200% |
38 |
16 |
1309 |
200 |
2/3 |
7 |
500 |
66.902% |
67.170% |
54 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
75.075% |
76.390% |
65 |
16 |
1309 |
200 |
2/3 |
7 |
2000 |
74.472% |
74.980% |
66 |
16 |
1309 |
200 |
2/3 |
7 |
5000 |
67.212% |
68.180% |
69 |
bit 1 x 64
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
7 |
1000 |
76.377% |
76.790% |
102 |
19 |
4475 |
50 |
2/3 |
7 |
1000 |
76.703% |
77.720% |
102 |
18 |
2975 |
70 |
2/3 |
7 |
1000 |
79.592% |
80.230% |
92 |
17 |
1975 |
100 |
2/3 |
7 |
1000 |
77.830% |
78.620% |
79 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
77.368% |
78.130% |
69 |
13 |
382 |
500 |
2/3 |
7 |
1000 |
73.768% |
74.820% |
40 |
12 |
253 |
1000 |
2/3 |
7 |
1000 |
68.457% |
69.020% |
35 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
46.227% |
46.440% |
89 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
73.818% |
74.370% |
76 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
77.368% |
78.130% |
69 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
79.748% |
80.610% |
65 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
75.460% |
76.180% |
57 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
7 |
1000 |
73.523% |
73.620% |
35 |
7 |
1095 |
200 |
1/3 |
7 |
1000 |
74.850% |
75.420% |
38 |
10 |
1023 |
200 |
1/2 |
7 |
1000 |
79.233% |
79.420% |
47 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
77.368% |
78.130% |
69 |
21 |
1255 |
200 |
3/4 |
7 |
1000 |
78.523% |
79.200% |
79 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
7 |
100 |
23.685% |
24.330% |
21 |
16 |
1309 |
200 |
2/3 |
7 |
200 |
46.715% |
47.280% |
34 |
16 |
1309 |
200 |
2/3 |
7 |
500 |
62.575% |
63.550% |
62 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
77.368% |
78.130% |
69 |
16 |
1309 |
200 |
2/3 |
7 |
2000 |
78.087% |
78.970% |
70 |
16 |
1309 |
200 |
2/3 |
7 |
5000 |
77.128% |
78.280% |
68 |
trit 1 x 64
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
7 |
1000 |
39.847% |
40.090% |
154 |
19 |
4475 |
50 |
2/3 |
7 |
1000 |
61.537% |
62.040% |
172 |
18 |
2975 |
70 |
2/3 |
7 |
1000 |
68.332% |
69.090% |
164 |
17 |
1975 |
100 |
2/3 |
7 |
1000 |
67.207% |
67.310% |
155 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
69.842% |
70.310% |
134 |
13 |
382 |
500 |
2/3 |
7 |
1000 |
48.548% |
49.460% |
97 |
12 |
253 |
1000 |
2/3 |
7 |
1000 |
58.883% |
60.160% |
84 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
29.125% |
29.790% |
167 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
43.450% |
44.360% |
148 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
69.842% |
70.310% |
134 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
70.887% |
71.780% |
126 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
71.358% |
71.340% |
114 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
7 |
1000 |
63.938% |
64.540% |
58 |
7 |
1095 |
200 |
1/3 |
7 |
1000 |
42.183% |
42.720% |
67 |
10 |
1023 |
200 |
1/2 |
7 |
1000 |
50.235% |
51.760% |
90 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
69.842% |
70.310% |
134 |
21 |
1255 |
200 |
3/4 |
7 |
1000 |
72.327% |
73.220% |
165 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
7 |
100 |
10.945% |
11.100% |
59 |
16 |
1309 |
200 |
2/3 |
7 |
200 |
20.532% |
21.160% |
70 |
16 |
1309 |
200 |
2/3 |
7 |
500 |
27.960% |
28.270% |
95 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
69.842% |
70.310% |
134 |
16 |
1309 |
200 |
2/3 |
7 |
2000 |
80.133% |
81.040% |
151 |
16 |
1309 |
200 |
2/3 |
7 |
5000 |
79.515% |
80.010% |
150 |
bit 1 x 128
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
7 |
1000 |
69.590% |
70.870% |
192 |
19 |
4475 |
50 |
2/3 |
7 |
1000 |
64.617% |
65.190% |
187 |
18 |
2975 |
70 |
2/3 |
7 |
1000 |
66.873% |
67.040% |
171 |
17 |
1975 |
100 |
2/3 |
7 |
1000 |
62.333% |
63.090% |
160 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
56.935% |
57.470% |
143 |
13 |
382 |
500 |
2/3 |
7 |
1000 |
54.892% |
55.290% |
92 |
12 |
253 |
1000 |
2/3 |
7 |
1000 |
53.640% |
54.930% |
78 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
55.453% |
56.060% |
154 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
57.440% |
58.380% |
149 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
56.935% |
57.470% |
144 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
59.353% |
60.140% |
137 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
51.248% |
51.680% |
127 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
7 |
1000 |
56.373% |
56.760% |
64 |
7 |
1095 |
200 |
1/3 |
7 |
1000 |
60.127% |
61.020% |
75 |
10 |
1023 |
200 |
1/2 |
7 |
1000 |
57.332% |
57.820% |
95 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
56.935% |
57.470% |
142 |
21 |
1255 |
200 |
3/4 |
7 |
1000 |
58.545% |
59.140% |
173 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
7 |
100 |
20.128% |
20.530% |
32 |
16 |
1309 |
200 |
2/3 |
7 |
200 |
48.848% |
49.690% |
49 |
16 |
1309 |
200 |
2/3 |
7 |
500 |
50.050% |
50.350% |
97 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
56.935% |
57.470% |
143 |
16 |
1309 |
200 |
2/3 |
7 |
2000 |
80.762% |
80.940% |
155 |
16 |
1309 |
200 |
2/3 |
7 |
5000 |
81.900% |
82.140% |
152 |
trit 1 x 128
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
7 |
1000 |
10.435% |
10.690% |
233 |
19 |
4475 |
50 |
2/3 |
7 |
1000 |
33.213% |
33.730% |
264 |
18 |
2975 |
70 |
2/3 |
7 |
1000 |
44.198% |
44.860% |
264 |
17 |
1975 |
100 |
2/3 |
7 |
1000 |
18.840% |
19.660% |
240 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
31.325% |
31.190% |
227 |
13 |
382 |
500 |
2/3 |
7 |
1000 |
22.113% |
23.030% |
168 |
12 |
253 |
1000 |
2/3 |
7 |
1000 |
28.097% |
28.790% |
158 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
22.447% |
21.730% |
285 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
25.963% |
25.640% |
267 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
31.325% |
31.190% |
226 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
31.118% |
31.590% |
213 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
29.708% |
30.010% |
187 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
7 |
1000 |
30.445% |
30.270% |
95 |
7 |
1095 |
200 |
1/3 |
7 |
1000 |
26.688% |
27.800% |
110 |
10 |
1023 |
200 |
1/2 |
7 |
1000 |
39.520% |
40.180% |
148 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
31.325% |
31.190% |
231 |
21 |
1255 |
200 |
3/4 |
7 |
1000 |
26.275% |
26.890% |
287 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
7 |
100 |
27.758% |
27.980% |
113 |
16 |
1309 |
200 |
2/3 |
7 |
200 |
13.607% |
13.450% |
129 |
16 |
1309 |
200 |
2/3 |
7 |
500 |
17.518% |
17.330% |
184 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
31.325% |
31.190% |
226 |
16 |
1309 |
200 |
2/3 |
7 |
2000 |
62.773% |
63.760% |
307 |
16 |
1309 |
200 |
2/3 |
7 |
5000 |
80.888% |
81.500% |
343 |
bit 1 x 256
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
7 |
1000 |
74.045% |
74.270% |
342 |
19 |
4475 |
50 |
2/3 |
7 |
1000 |
69.655% |
70.140% |
296 |
18 |
2975 |
70 |
2/3 |
7 |
1000 |
63.477% |
63.550% |
268 |
17 |
1975 |
100 |
2/3 |
7 |
1000 |
62.312% |
62.830% |
247 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
60.877% |
61.000% |
223 |
13 |
382 |
500 |
2/3 |
7 |
1000 |
48.470% |
49.100% |
154 |
12 |
253 |
1000 |
2/3 |
7 |
1000 |
42.968% |
43.760% |
138 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
27.827% |
28.350% |
250 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
58.080% |
58.860% |
237 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
60.877% |
61.000% |
227 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
60.197% |
60.620% |
208 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
20.720% |
21.790% |
174 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
7 |
1000 |
64.490% |
64.810% |
96 |
7 |
1095 |
200 |
1/3 |
7 |
1000 |
57.942% |
58.260% |
110 |
10 |
1023 |
200 |
1/2 |
7 |
1000 |
57.322% |
57.860% |
149 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
60.877% |
61.000% |
222 |
21 |
1255 |
200 |
3/4 |
7 |
1000 |
56.803% |
57.500% |
274 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
7 |
100 |
23.710% |
23.540% |
53 |
16 |
1309 |
200 |
2/3 |
7 |
200 |
12.190% |
12.560% |
76 |
16 |
1309 |
200 |
2/3 |
7 |
500 |
45.965% |
46.090% |
136 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
60.877% |
61.000% |
221 |
16 |
1309 |
200 |
2/3 |
7 |
2000 |
62.818% |
62.890% |
342 |
16 |
1309 |
200 |
2/3 |
7 |
5000 |
80.167% |
80.430% |
378 |
trit 1 x 256
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
7 |
1000 |
11.747% |
11.870% |
483 |
19 |
4475 |
50 |
2/3 |
7 |
1000 |
21.753% |
20.900% |
500 |
18 |
2975 |
70 |
2/3 |
7 |
1000 |
24.923% |
23.970% |
478 |
17 |
1975 |
100 |
2/3 |
7 |
1000 |
36.458% |
36.390% |
456 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
20.702% |
20.190% |
422 |
13 |
382 |
500 |
2/3 |
7 |
1000 |
32.597% |
32.750% |
315 |
12 |
253 |
1000 |
2/3 |
7 |
1000 |
30.987% |
31.630% |
278 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
33.227% |
33.890% |
444 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
28.892% |
29.070% |
432 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
20.702% |
20.190% |
410 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
23.120% |
22.840% |
392 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
21.022% |
20.350% |
347 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
7 |
1000 |
25.632% |
25.420% |
178 |
7 |
1095 |
200 |
1/3 |
7 |
1000 |
23.687% |
23.400% |
197 |
10 |
1023 |
200 |
1/2 |
7 |
1000 |
20.210% |
19.860% |
273 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
20.702% |
20.190% |
410 |
21 |
1255 |
200 |
3/4 |
7 |
1000 |
29.015% |
29.660% |
511 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
7 |
100 |
11.410% |
11.420% |
208 |
16 |
1309 |
200 |
2/3 |
7 |
200 |
22.233% |
21.910% |
229 |
16 |
1309 |
200 |
2/3 |
7 |
500 |
28.828% |
28.690% |
305 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
20.702% |
20.190% |
411 |
16 |
1309 |
200 |
2/3 |
7 |
2000 |
14.680% |
14.110% |
543 |
16 |
1309 |
200 |
2/3 |
7 |
5000 |
60.752% |
61.870% |
753 |
two hidden layers
bit 2 x 32
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
3 |
1000 |
36.720% |
37.160% |
67 |
19 |
4475 |
50 |
2/3 |
3 |
1000 |
40.013% |
40.260% |
58 |
18 |
2975 |
70 |
2/3 |
3 |
1000 |
40.240% |
40.130% |
55 |
17 |
1975 |
100 |
2/3 |
3 |
1000 |
48.278% |
48.740% |
49 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
48.842% |
48.800% |
42 |
13 |
382 |
500 |
2/3 |
3 |
1000 |
43.052% |
42.880% |
26 |
12 |
253 |
1000 |
2/3 |
3 |
1000 |
39.032% |
38.450% |
22 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
52.883% |
52.810% |
44 |
16 |
1309 |
200 |
2/3 |
2 |
1000 |
46.643% |
46.710% |
43 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
48.842% |
48.800% |
42 |
16 |
1309 |
200 |
2/3 |
4 |
1000 |
37.942% |
37.690% |
40 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
9.833% |
9.540% |
29 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
9.872% |
10.140% |
30 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
10.482% |
10.750% |
18 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
3 |
1000 |
37.275% |
37.960% |
21 |
7 |
1095 |
200 |
1/3 |
3 |
1000 |
37.087% |
37.040% |
22 |
10 |
1023 |
200 |
1/2 |
3 |
1000 |
42.428% |
42.920% |
29 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
48.842% |
48.800% |
42 |
21 |
1255 |
200 |
3/4 |
3 |
1000 |
48.532% |
49.080% |
50 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
3 |
100 |
26.228% |
26.030% |
16 |
16 |
1309 |
200 |
2/3 |
3 |
200 |
41.350% |
41.460% |
23 |
16 |
1309 |
200 |
2/3 |
3 |
500 |
44.663% |
44.800% |
36 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
48.842% |
48.800% |
42 |
16 |
1309 |
200 |
2/3 |
3 |
2000 |
36.442% |
36.420% |
42 |
16 |
1309 |
200 |
2/3 |
3 |
5000 |
48.812% |
48.740% |
43 |
trit 2 x 32
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
3 |
1000 |
38.730% |
38.760% |
100 |
19 |
4475 |
50 |
2/3 |
3 |
1000 |
60.645% |
61.590% |
105 |
18 |
2975 |
70 |
2/3 |
3 |
1000 |
40.160% |
40.020% |
90 |
17 |
1975 |
100 |
2/3 |
3 |
1000 |
45.692% |
45.600% |
85 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
45.045% |
45.110% |
78 |
13 |
382 |
500 |
2/3 |
3 |
1000 |
45.765% |
46.390% |
50 |
12 |
253 |
1000 |
2/3 |
3 |
1000 |
42.840% |
42.130% |
45 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
48.198% |
48.330% |
82 |
16 |
1309 |
200 |
2/3 |
2 |
1000 |
57.767% |
58.110% |
78 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
45.045% |
45.110% |
77 |
16 |
1309 |
200 |
2/3 |
4 |
1000 |
47.975% |
48.830% |
70 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
9.865% |
10.140% |
54 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
9.035% |
8.630% |
49 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
9.405% |
9.470% |
48 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
3 |
1000 |
40.658% |
41.800% |
37 |
7 |
1095 |
200 |
1/3 |
3 |
1000 |
38.690% |
39.350% |
38 |
10 |
1023 |
200 |
1/2 |
3 |
1000 |
25.193% |
25.620% |
48 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
45.045% |
45.110% |
77 |
21 |
1255 |
200 |
3/4 |
3 |
1000 |
49.023% |
50.000% |
88 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
3 |
100 |
10.723% |
11.150% |
38 |
16 |
1309 |
200 |
2/3 |
3 |
200 |
13.163% |
13.550% |
44 |
16 |
1309 |
200 |
2/3 |
3 |
500 |
43.393% |
43.300% |
62 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
45.045% |
45.110% |
77 |
16 |
1309 |
200 |
2/3 |
3 |
2000 |
48.015% |
48.920% |
83 |
16 |
1309 |
200 |
2/3 |
3 |
5000 |
53.875% |
54.730% |
86 |
bit 2 x 64
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
3 |
1000 |
45.487% |
45.250% |
126 |
19 |
4475 |
50 |
2/3 |
3 |
1000 |
53.705% |
53.840% |
114 |
18 |
2975 |
70 |
2/3 |
3 |
1000 |
59.217% |
59.290% |
108 |
17 |
1975 |
100 |
2/3 |
3 |
1000 |
63.665% |
64.330% |
97 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
55.478% |
56.250% |
87 |
13 |
382 |
500 |
2/3 |
3 |
1000 |
52.237% |
52.640% |
59 |
12 |
253 |
1000 |
2/3 |
3 |
1000 |
48.428% |
48.770% |
52 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
40.372% |
41.180% |
100 |
16 |
1309 |
200 |
2/3 |
2 |
1000 |
61.812% |
61.900% |
89 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
55.478% |
56.250% |
86 |
16 |
1309 |
200 |
2/3 |
4 |
1000 |
52.292% |
53.080% |
85 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
45.303% |
46.770% |
80 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
10.118% |
10.370% |
69 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
9.880% |
9.440% |
52 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
3 |
1000 |
46.943% |
46.480% |
45 |
7 |
1095 |
200 |
1/3 |
3 |
1000 |
51.633% |
52.380% |
49 |
10 |
1023 |
200 |
1/2 |
3 |
1000 |
59.210% |
59.350% |
60 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
55.478% |
56.250% |
87 |
21 |
1255 |
200 |
3/4 |
3 |
1000 |
56.825% |
56.890% |
100 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
3 |
100 |
12.567% |
12.690% |
25 |
16 |
1309 |
200 |
2/3 |
3 |
200 |
36.493% |
36.690% |
40 |
16 |
1309 |
200 |
2/3 |
3 |
500 |
45.743% |
46.190% |
69 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
55.478% |
56.250% |
86 |
16 |
1309 |
200 |
2/3 |
3 |
2000 |
66.028% |
66.630% |
95 |
16 |
1309 |
200 |
2/3 |
3 |
5000 |
65.070% |
65.850% |
96 |
trit 2 x 64
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
3 |
1000 |
20.878% |
21.200% |
168 |
19 |
4475 |
50 |
2/3 |
3 |
1000 |
21.045% |
21.360% |
174 |
18 |
2975 |
70 |
2/3 |
3 |
1000 |
44.842% |
45.300% |
173 |
17 |
1975 |
100 |
2/3 |
3 |
1000 |
33.473% |
34.040% |
150 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
28.810% |
29.200% |
147 |
13 |
382 |
500 |
2/3 |
3 |
1000 |
24.037% |
23.850% |
108 |
12 |
253 |
1000 |
2/3 |
3 |
1000 |
18.920% |
18.690% |
99 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
15.713% |
16.030% |
160 |
16 |
1309 |
200 |
2/3 |
2 |
1000 |
20.447% |
20.180% |
148 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
28.810% |
29.200% |
148 |
16 |
1309 |
200 |
2/3 |
4 |
1000 |
29.525% |
29.480% |
139 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
11.470% |
11.460% |
116 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
9.915% |
9.780% |
126 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
9.872% |
10.010% |
116 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
3 |
1000 |
15.533% |
15.270% |
63 |
7 |
1095 |
200 |
1/3 |
3 |
1000 |
19.385% |
19.130% |
70 |
10 |
1023 |
200 |
1/2 |
3 |
1000 |
43.350% |
43.670% |
98 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
28.810% |
29.200% |
145 |
21 |
1255 |
200 |
3/4 |
3 |
1000 |
15.758% |
15.490% |
178 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
3 |
100 |
14.513% |
14.240% |
73 |
16 |
1309 |
200 |
2/3 |
3 |
200 |
11.380% |
11.490% |
88 |
16 |
1309 |
200 |
2/3 |
3 |
500 |
13.373% |
13.130% |
111 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
28.810% |
29.200% |
148 |
16 |
1309 |
200 |
2/3 |
3 |
2000 |
50.128% |
51.130% |
180 |
16 |
1309 |
200 |
2/3 |
3 |
5000 |
57.550% |
58.180% |
196 |
bit 2 x 128
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
3 |
1000 |
43.238% |
43.690% |
195 |
19 |
4475 |
50 |
2/3 |
3 |
1000 |
49.682% |
49.610% |
205 |
18 |
2975 |
70 |
2/3 |
3 |
1000 |
52.742% |
53.000% |
215 |
17 |
1975 |
100 |
2/3 |
3 |
1000 |
45.303% |
46.860% |
207 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
39.400% |
40.550% |
186 |
13 |
382 |
500 |
2/3 |
3 |
1000 |
24.747% |
25.010% |
123 |
12 |
253 |
1000 |
2/3 |
3 |
1000 |
23.115% |
23.190% |
107 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
24.035% |
24.120% |
197 |
16 |
1309 |
200 |
2/3 |
2 |
1000 |
29.423% |
30.400% |
195 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
39.400% |
40.550% |
186 |
16 |
1309 |
200 |
2/3 |
4 |
1000 |
48.733% |
49.290% |
173 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
54.185% |
54.940% |
158 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
43.345% |
44.320% |
146 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
9.843% |
9.560% |
125 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
3 |
1000 |
31.133% |
31.540% |
88 |
7 |
1095 |
200 |
1/3 |
3 |
1000 |
29.793% |
29.670% |
96 |
10 |
1023 |
200 |
1/2 |
3 |
1000 |
32.292% |
33.120% |
122 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
39.400% |
40.550% |
180 |
21 |
1255 |
200 |
3/4 |
3 |
1000 |
44.907% |
45.760% |
214 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
3 |
100 |
14.183% |
14.430% |
50 |
16 |
1309 |
200 |
2/3 |
3 |
200 |
26.337% |
26.170% |
76 |
16 |
1309 |
200 |
2/3 |
3 |
500 |
35.265% |
35.450% |
126 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
39.400% |
40.550% |
181 |
16 |
1309 |
200 |
2/3 |
3 |
2000 |
60.387% |
60.970% |
221 |
16 |
1309 |
200 |
2/3 |
3 |
5000 |
60.320% |
61.190% |
243 |
trit 2 x 128
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
3 |
1000 |
10.202% |
10.320% |
263 |
19 |
4475 |
50 |
2/3 |
3 |
1000 |
12.350% |
12.460% |
299 |
18 |
2975 |
70 |
2/3 |
3 |
1000 |
13.672% |
13.790% |
333 |
17 |
1975 |
100 |
2/3 |
3 |
1000 |
13.477% |
13.620% |
308 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
11.938% |
11.960% |
294 |
13 |
382 |
500 |
2/3 |
3 |
1000 |
18.890% |
19.120% |
228 |
12 |
253 |
1000 |
2/3 |
3 |
1000 |
20.463% |
21.040% |
206 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
26.787% |
27.530% |
354 |
16 |
1309 |
200 |
2/3 |
2 |
1000 |
24.108% |
24.240% |
335 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
11.938% |
11.960% |
294 |
16 |
1309 |
200 |
2/3 |
4 |
1000 |
11.843% |
11.870% |
278 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
10.308% |
10.600% |
248 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
10.753% |
11.260% |
242 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
11.448% |
11.510% |
222 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
3 |
1000 |
11.337% |
11.320% |
137 |
7 |
1095 |
200 |
1/3 |
3 |
1000 |
11.498% |
11.590% |
157 |
10 |
1023 |
200 |
1/2 |
3 |
1000 |
12.507% |
12.500% |
215 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
11.938% |
11.960% |
295 |
21 |
1255 |
200 |
3/4 |
3 |
1000 |
11.378% |
11.370% |
363 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
3 |
100 |
11.588% |
11.440% |
141 |
16 |
1309 |
200 |
2/3 |
3 |
200 |
14.092% |
14.370% |
167 |
16 |
1309 |
200 |
2/3 |
3 |
500 |
16.903% |
17.320% |
239 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
11.938% |
11.960% |
299 |
16 |
1309 |
200 |
2/3 |
3 |
2000 |
14.422% |
14.650% |
369 |
16 |
1309 |
200 |
2/3 |
3 |
5000 |
52.958% |
53.490% |
466 |
bit 2 x 256
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
3 |
1000 |
19.887% |
20.440% |
315 |
19 |
4475 |
50 |
2/3 |
3 |
1000 |
34.190% |
34.640% |
449 |
18 |
2975 |
70 |
2/3 |
3 |
1000 |
35.387% |
35.770% |
420 |
17 |
1975 |
100 |
2/3 |
3 |
1000 |
25.567% |
25.510% |
381 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
24.228% |
24.860% |
337 |
13 |
382 |
500 |
2/3 |
3 |
1000 |
16.043% |
16.000% |
236 |
12 |
253 |
1000 |
2/3 |
3 |
1000 |
11.285% |
11.250% |
211 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
14.277% |
14.330% |
347 |
16 |
1309 |
200 |
2/3 |
2 |
1000 |
16.337% |
16.530% |
346 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
24.228% |
24.860% |
339 |
16 |
1309 |
200 |
2/3 |
4 |
1000 |
28.688% |
29.810% |
336 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
29.345% |
30.380% |
322 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
23.863% |
24.610% |
293 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
14.570% |
14.940% |
233 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
3 |
1000 |
19.767% |
20.070% |
150 |
7 |
1095 |
200 |
1/3 |
3 |
1000 |
14.328% |
15.020% |
169 |
10 |
1023 |
200 |
1/2 |
3 |
1000 |
18.477% |
18.320% |
229 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
24.228% |
24.860% |
336 |
21 |
1255 |
200 |
3/4 |
3 |
1000 |
24.882% |
25.220% |
410 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
3 |
100 |
13.555% |
13.660% |
113 |
16 |
1309 |
200 |
2/3 |
3 |
200 |
17.433% |
17.530% |
145 |
16 |
1309 |
200 |
2/3 |
3 |
500 |
13.622% |
13.760% |
230 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
24.228% |
24.860% |
336 |
16 |
1309 |
200 |
2/3 |
3 |
2000 |
25.888% |
27.110% |
500 |
16 |
1309 |
200 |
2/3 |
3 |
5000 |
31.478% |
32.300% |
720 |
trit 2 x 256
min
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
21 |
10311 |
20 |
2/3 |
3 |
1000 |
9.927% |
9.770% |
590 |
19 |
4475 |
50 |
2/3 |
3 |
1000 |
11.482% |
11.610% |
769 |
18 |
2975 |
70 |
2/3 |
3 |
1000 |
22.790% |
22.750% |
723 |
17 |
1975 |
100 |
2/3 |
3 |
1000 |
26.713% |
26.730% |
666 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
25.152% |
25.440% |
603 |
13 |
382 |
500 |
2/3 |
3 |
1000 |
14.002% |
14.510% |
443 |
12 |
253 |
1000 |
2/3 |
3 |
1000 |
12.562% |
12.980% |
396 |
n updates
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
1 |
1000 |
15.457% |
15.300% |
610 |
16 |
1309 |
200 |
2/3 |
2 |
1000 |
17.095% |
17.440% |
607 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
25.152% |
25.440% |
600 |
16 |
1309 |
200 |
2/3 |
4 |
1000 |
16.878% |
17.080% |
577 |
16 |
1309 |
200 |
2/3 |
7 |
1000 |
12.653% |
12.570% |
553 |
16 |
1309 |
200 |
2/3 |
10 |
1000 |
13.683% |
13.950% |
516 |
16 |
1309 |
200 |
2/3 |
20 |
1000 |
11.547% |
11.560% |
467 |
scale
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
6 |
1375 |
200 |
1/4 |
3 |
1000 |
20.588% |
19.890% |
259 |
7 |
1095 |
200 |
1/3 |
3 |
1000 |
20.985% |
20.850% |
296 |
10 |
1023 |
200 |
1/2 |
3 |
1000 |
19.803% |
20.170% |
391 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
25.152% |
25.440% |
591 |
21 |
1255 |
200 |
3/4 |
3 |
1000 |
15.405% |
15.700% |
737 |
truncation
n epochs |
n samples |
min |
scale |
updates |
truncation |
train acc |
test acc |
train time |
16 |
1309 |
200 |
2/3 |
3 |
100 |
11.237% |
11.270% |
323 |
16 |
1309 |
200 |
2/3 |
3 |
200 |
13.475% |
13.390% |
363 |
16 |
1309 |
200 |
2/3 |
3 |
500 |
23.185% |
22.570% |
456 |
16 |
1309 |
200 |
2/3 |
3 |
1000 |
25.152% |
25.440% |
597 |
16 |
1309 |
200 |
2/3 |
3 |
2000 |
24.458% |
24.720% |
792 |
16 |
1309 |
200 |
2/3 |
3 |
5000 |
14.488% |
14.450% |
1087 |
Source code
If you want to reproduce these numbers, checkout commit foo from the repo and run cargo run --release --bin mnist_fc_multi_layer_bench -- mnist
where mnist
is the path the MNIST data set.
You will need to be using a recent nightly.
RAM usage is low (< 150 MB) but training is quite CPU intensive, we recommend a highly multi core CPU with good memory bandwidth and cache.
Questions
- Why do trit weighted models sometimes fail to train?
- Why does a large truncation sometimes produce a worse accuracy then a smaller truncation?
- Why is test accuracy often slightly higher then train? It is understandable that such as power model would not overfit much, but why is test higher?
Future work
- 2d conv layers
- test with multiple seeds
- Stacking layerwise trained layers so as to achieve deeper models
- Distributed training
- Better loss delta filtering algorithm