I’m experimenting with the MathWorks example that inserts a multi-head self-attention layer into a simple CNN for the DigitDataset:
layers = [
imageInputLayer([28 28 1])
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,64,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
flattenLayer
selfAttentionLayer(NUM_HEADS, NUM_KEY_CHANNELS) % <— point of interest
fullyConnectedLayer(10)
softmaxLayer
classificationLayer];
When I change only the self-attention hyperparameters, I get different number of learnable parameters reported by MATLAB in both cases below:
Case 1: NumHeads = 4, NumKeyChannels = 784
Case 2: NumHeads = 8, NumKeyChannels = 392
In both cases the product NumHeads × NumKeyChannels = 3136, so I expected the number of learnable parameters to be the same. However, MATLAB reports different parameter counts.
My understanding from research papers is that the total parameterization of Q/K/V projections should scale with the total key dimension, not with how it is split across heads
Why does MATLAB’s selfAttentionLayer produce different parameter counts for these two configurations? Am I misinterpreting how the layer is implemented in this toolbox?
I’d appreciate any clarification on how MATLAB calculates the number of parameters here, especially since this isn’t clearly documented in the Deep Learning Toolbox.