SIMD
In this version we will be adding explicit SIMD vector types and vector instructions to utilize CPU registers to their full width.
As we saw in v2
, compilers are sometimes able to auto-vectorize simple loops.
This time, however, we will not be hoping for auto-vectorization magic, but we'll write all vector instructions directly into the code.
Since we only need a few simple instructions and are currently targeting only the x86_64
platform, we won't be pulling in any external crates.
Instead, we define our own, tiny simd
-library with safe Rust wrappers around a few Intel AVX intrinsics.
We'll be using the same approach as in the reference solution, which is to pack all rows of d
and t
into 256-bit wide vectors (f32x8
), each containing 8 single precision (f32
) floats.
First, we initialize initialize two std::vec::Vec
containers for d
and its transpose t
.
This time they will not contain f32
values, but instead SIMD vectors of 8 f32
elements:
// How many f32x8 vectors we need for all elements from a row or column of d
let vecs_per_row = (n + simd::f32x8_LENGTH - 1) / simd::f32x8_LENGTH;
// All rows and columns d packed into f32x8 vectors,
// each initially filled with 8 f32::INFINITYs
let mut vd = std::vec![simd::f32x8_infty(); n * vecs_per_row];
let mut vt = std::vec![simd::f32x8_infty(); n * vecs_per_row];
// Assert that all addresses of vd and vt are properly aligned to the size of f32x8
debug_assert!(vd.iter().all(simd::is_aligned));
debug_assert!(vt.iter().all(simd::is_aligned));
We shouldn't have to worry about proper memory alignment since std::vec::Vec
by default allocates its memory aligned to the size of the type of its elements.
Just to make sure, though, we added some debug asserts that check the alignment of each address in vd
and vt
by using this helper:
#[inline(always)]
pub fn is_aligned(v: &f32x8) -> bool {
(v as *const f32x8).align_offset(std::mem::align_of::<f32x8>()) == 0
}
Next, we will fill every row of vd
and vt
with f32x8
vectors in parallel.
Each thread will read one row of d
into vd
and one column of d
into vt
in chunks of 8 elements.
We use two f32
buffers of length 8, one for rows of d
(vx_tmp
) and one for columns of d
(vy_tmp
).
Each time the buffers become full, they are converted into two f32x8
vectors and pushed to vd
and vt
:
// Function: for one row of f32x8 vectors in vd and one row of f32x8 vectors in vt,
// - copy all elements from row 'i' in d,
// - pack them into f32x8 vectors,
// - insert all into row 'i' of vd (vd_row)
// and
// - copy all elements from column 'i' in d,
// - pack them into f32x8 vectors,
// - insert all into row 'i' of vt (vt_row)
let pack_simd_row = |(i, (vd_row, vt_row)): (usize, (&mut [f32x8], &mut [f32x8]))| {
// For every SIMD vector at row 'i', column 'jv' in vt and vd
for (jv, (vx, vy)) in vd_row.iter_mut().zip(vt_row.iter_mut()).enumerate() {
// Temporary buffers for f32 elements of two f32x8s
let mut vx_tmp = [std::f32::INFINITY; simd::f32x8_LENGTH];
let mut vy_tmp = [std::f32::INFINITY; simd::f32x8_LENGTH];
// Iterate over 8 elements to fill the buffers
for (b, (x, y)) in vx_tmp.iter_mut().zip(vy_tmp.iter_mut()).enumerate() {
// Offset by 8 elements to get correct index mapping of j to d
let j = jv * simd::f32x8_LENGTH + b;
if i < n && j < n {
*x = d[n * i + j];
*y = d[n * j + i];
}
}
// Initialize f32x8 vectors from buffer contents
// and assign them into the std::vec::Vec containers
*vx = simd::from_slice(&vx_tmp);
*vy = simd::from_slice(&vy_tmp);
}
};
// Fill rows of vd and vt in parallel one pair of rows at a time
vd.par_chunks_mut(vecs_per_row)
.zip(vt.par_chunks_mut(vecs_per_row))
.enumerate()
.for_each(pack_simd_row);
The nice thing is that the preprocessing we just did is by far the hardest part.
Now all data is packed into SIMD vectors and we can use reuse step_row
from v1
with minimal changes:
// Function: for a row of f32x8 elements from vd,
// compute a n f32 results into r
let step_row = |(r_row, vd_row): (&mut [f32], &[f32x8])| {
let vt_rows = vt.chunks_exact(vecs_per_row);
for (res, vt_row) in r_row.iter_mut().zip(vt_rows) {
// Fold vd_row and vt_row into a single f32x8 result
let tmp = vd_row.iter()
.zip(vt_row)
.fold(simd::f32x8_infty(),
|v, (&x, &y)| simd::min(v, simd::add(x, y)));
// Reduce 8 different f32 results in tmp into the final result
*res = simd::horizontal_min(tmp);
}
};
r.par_chunks_mut(n)
.zip(vd.par_chunks(vecs_per_row))
.for_each(step_row);
Benchmark
Let's run benchmarks with the same settings as in v2
, comparing our Rust program to the reference C++ version.
Implementation | Compiler | Time (s) | IPC |
---|---|---|---|
C++ v3 | gcc 7.4.0-1ubuntu1 | 11.5 | 1.31 |
C++ v3 | clang 6.0.0-1ubuntu2 | 11.8 | 1.37 |
Rust v3 | rustc 1.38.0-nightly | 11.4 | 1.04 |
The running times are roughly the same, but the Rust program clearly does less instructions per cycle compared to the C++ program. Let's look at the disassembly to find out why.
gcc
This is the single element loop from v0
, but with 256-bit SIMD instructions and registers.
LOOP:
vmovaps ymm0,YMMWORD PTR [rcx+rax*1]
vaddps ymm0,ymm0,YMMWORD PTR [rdx+rax*1]
add rax,0x20
vminps ymm1,ymm1,ymm0
cmp rsi,rax
jne LOOP
More detailed analysis is available here.
clang
Like gcc
, but for some reason there is a separate loop counter r10
, instead of using r9
both for loading values and checking if the loop has ended.
The extra addition could explain the higher instructions per cycle value.
LOOP:
vmovaps ymm2,YMMWORD PTR [r15+r9*1]
vaddps ymm2,ymm2,YMMWORD PTR [r8+r9*1]
vminps ymm1,ymm1,ymm2
add r10,0x1
add r9,0x20
cmp r10,rdi
jl LOOP
rustc
No bounds checking or extra instructions, except for a separate loop counter r12
.
The loop has also been unrolled for 4 iterations, which is why we might be seeing the reduction in IPC.
LOOP:
vmovaps ymm3,YMMWORD PTR [rbx+rbp*1-0x60]
vmovaps ymm4,YMMWORD PTR [rbx+rbp*1-0x40]
vmovaps ymm5,YMMWORD PTR [rbx+rbp*1-0x20]
vmovaps ymm6,YMMWORD PTR [rbx+rbp*1]
vaddps ymm3,ymm3,YMMWORD PTR [r11+rbp*1-0x60]
vminps ymm2,ymm2,ymm3
vaddps ymm3,ymm4,YMMWORD PTR [r11+rbp*1-0x40]
vminps ymm2,ymm2,ymm3
vaddps ymm3,ymm5,YMMWORD PTR [r11+rbp*1-0x20]
vminps ymm2,ymm2,ymm3
add r12,0x4
vaddps ymm3,ymm6,YMMWORD PTR [r11+rbp*1]
vminps ymm2,ymm2,ymm3
sub rbp,0xffffffffffffff80
cmp r13,r12
jne LOOP