I'm interested in a fast 8x8 32-bit float matrix multiply in Rust, assuming availability of AVX2. After learning about the AVX2 intrinsics, here is what I came up with:
pub unsafe fn mm8_simd(c: &mut [f32], a: &[f32], b: &[f32], lda: usize, ldb: usize) {
use std::arch::x86_64::*;
let av0 = _mm256_loadu_ps(a.get_unchecked(0 * lda));
let av1 = _mm256_loadu_ps(a.get_unchecked(1 * lda));
let av2 = _mm256_loadu_ps(a.get_unchecked(2 * lda));
let av3 = _mm256_loadu_ps(a.get_unchecked(3 * lda));
let av4 = _mm256_loadu_ps(a.get_unchecked(4 * lda));
let av5 = _mm256_loadu_ps(a.get_unchecked(5 * lda));
let av6 = _mm256_loadu_ps(a.get_unchecked(6 * lda));
let av7 = _mm256_loadu_ps(a.get_unchecked(7 * lda));
let mut cv0 = _mm256_loadu_ps(c.get_unchecked_mut(0 * ldb));
let mut cv1 = _mm256_loadu_ps(c.get_unchecked_mut(1 * ldb));
let mut cv2 = _mm256_loadu_ps(c.get_unchecked_mut(2 * ldb));
let mut cv3 = _mm256_loadu_ps(c.get_unchecked_mut(3 * ldb));
let mut cv4 = _mm256_loadu_ps(c.get_unchecked_mut(4 * ldb));
let mut cv5 = _mm256_loadu_ps(c.get_unchecked_mut(5 * ldb));
let mut cv6 = _mm256_loadu_ps(c.get_unchecked_mut(6 * ldb));
let mut cv7 = _mm256_loadu_ps(c.get_unchecked_mut(7 * ldb));
let mut mask = _mm256_setzero_si256();
for k in 0..8 {
let bv = _mm256_loadu_ps(b.get_unchecked(k * ldb));
cv0 = _mm256_fmadd_ps(_mm256_permutevar8x32_ps(av0, mask), bv, cv0);
cv1 = _mm256_fmadd_ps(_mm256_permutevar8x32_ps(av1, mask), bv, cv1);
cv2 = _mm256_fmadd_ps(_mm256_permutevar8x32_ps(av2, mask), bv, cv2);
cv3 = _mm256_fmadd_ps(_mm256_permutevar8x32_ps(av3, mask), bv, cv3);
cv4 = _mm256_fmadd_ps(_mm256_permutevar8x32_ps(av4, mask), bv, cv4);
cv5 = _mm256_fmadd_ps(_mm256_permutevar8x32_ps(av5, mask), bv, cv5);
cv6 = _mm256_fmadd_ps(_mm256_permutevar8x32_ps(av6, mask), bv, cv6);
cv7 = _mm256_fmadd_ps(_mm256_permutevar8x32_ps(av7, mask), bv, cv7);
mask = _mm256_add_epi32(mask, _mm256_set1_epi32(1));
}
_mm256_storeu_ps(c.get_unchecked_mut(0 * ldb), cv0);
_mm256_storeu_ps(c.get_unchecked_mut(1 * ldb), cv1);
_mm256_storeu_ps(c.get_unchecked_mut(2 * ldb), cv2);
_mm256_storeu_ps(c.get_unchecked_mut(3 * ldb), cv3);
_mm256_storeu_ps(c.get_unchecked_mut(4 * ldb), cv4);
_mm256_storeu_ps(c.get_unchecked_mut(5 * ldb), cv5);
_mm256_storeu_ps(c.get_unchecked_mut(6 * ldb), cv6);
_mm256_storeu_ps(c.get_unchecked_mut(7 * ldb), cv7);
}
I load the A and C matrices into 256-bit variables to begin with, and multiply thorugh, loading each vector of B once. So, I think I've optimized the number of memory accesses.
Because I multiply a column of A with a row of B at a time, I use mask with _mm256_permutevar8x32_ps to pull out one float at a time from A, broadcast it across a 256-bit variable to multiply with B. The results are accumulated through the loop.
Interestingly, on my machine which supports AVX512, the compiler optimizes away the memory caching and permute intrinsics and generates a function almost of purely unrolled vfmadd***ps intructions. I found this quite interesting.
What are the opportunities to make this faster?