Go 语言字符串相等函数的 SIMD 实现

背景说明

Go 语言是 Google 开发的一种静态强类型、编译型、并发型,并具有垃圾回收功能的编程语言。它用批判吸收的眼光,融合 C 语言、Java 等众家之长,将简洁、高效演绎得淋漓尽致。

SIMD 全称 Single Instruction Multiple Data,即单指令多数据流,用一条指令对多个数据进行操作。一般用向量寄存器来实现。常用于加速数据密集型运算,如数列求和、矩阵乘法。

本实验对 Go 语言自带的字符串相等函数的源码进行分析,用自己实现的函数进行替换,并比较性能。

探索过程

不同 CPU 对 SIMD 的支持不同。用 CPU-Z 查看当前 CPU 支持的 SIMD 指令集。最新的指令集为 AVX2

Go 语言汇编器基于 Plan 9 汇编器的输入风格,与 GNU 汇编器不同。阅读和编写代码时需要注意。

Go 语言字符串相等函数的代码在 GOROOT/src/internal/bytealg/equal_[arch].s 文件中。当前 CPU 架构为 AMD64,对应文件为 equal_amd64.s

文件中定义了三个函数:runtime·memequalruntime·memequal_varlenmemeqbody。前两个函数为 ABI,供应用程序调用。它们将设置相应寄存器并跳转到第三个函数。部分定义如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
// memequal(a, b unsafe.Pointer, size uintptr) bool
TEXT runtime·memequal<ABIInternal>(SB),NOSPLIT,$0-25
// AX = a (want in SI)
// BX = b (want in DI)
// CX = size (want in BX)
...
JMP memeqbody<>(SB)

// memequal_varlen(a, b unsafe.Pointer) bool
TEXT runtime·memequal_varlen<ABIInternal>(SB),NOSPLIT,$0-17
// AX = a (want in SI)
// BX = b (want in DI)
// 8(DX) = size (want in BX)
...
JMP memeqbody<>(SB)

// Input:
// a in SI
// b in DI
// count in BX
// Output:
// result in AX
TEXT memeqbody<>(SB),NOSPLIT,$0-0
...
RET

我们只需要关注真正进行相等判断的 memeqbody 函数。它接收两个字符串的地址(SIDI)以及字符串的长度(BX),返回 0 或 1(AX)。Go 语言的编译器保证这两个字符串的长度相等(否则可以直接判断不相等)。

原代码巧妙地利用了 SIMD。思路如下:

  1. 如果字符串的长度不小于 64,则一轮循环比较 64 个字符(512 位),直到剩余长度小于 64;
  2. 如果字符串的长度不小于 8,则一轮循环比较 8 个字符(64 位),直到剩余长度小于 8;
  3. 比较剩余字符(不用循环)。

结合代码分析。在函数开头,根据字符串的长度进入不同的循环。

1
2
3
4
5
6
7
TEXT memeqbody<>(SB),NOSPLIT,$0-0
CMPQ BX, $8
JB small
CMPQ BX, $64
JB bigloop
CMPB internal∕cpu·X86+const_offsetX86HasAVX2(SB), $1
JE hugeloop_avx2

第 6 行的代码判断 CPU 是否支持 AVX2 指令集,如果支持则用 Y 系列寄存器比较 64 个字符,否则用 X 系列寄存器比较 64 个字符。64 个字符的循环如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
        // 64 bytes at a time using xmm registers
hugeloop:
CMPQ BX, $64
JB bigloop
MOVOU (SI), X0
MOVOU (DI), X1
MOVOU 16(SI), X2
MOVOU 16(DI), X3
MOVOU 32(SI), X4
MOVOU 32(DI), X5
MOVOU 48(SI), X6
MOVOU 48(DI), X7
PCMPEQB X1, X0
PCMPEQB X3, X2
PCMPEQB X5, X4
PCMPEQB X7, X6
PAND X2, X0
PAND X6, X4
PAND X4, X0
PMOVMSKB X0, DX
ADDQ $64, SI
ADDQ $64, DI
SUBQ $64, BX
CMPL DX, $0xffff
JEQ hugeloop
XORQ AX, AX // return 0
RET

// 64 bytes at a time using ymm registers
hugeloop_avx2:
CMPQ BX, $64
JB bigloop_avx2
VMOVDQU (SI), Y0
VMOVDQU (DI), Y1
VMOVDQU 32(SI), Y2
VMOVDQU 32(DI), Y3
VPCMPEQB Y1, Y0, Y4
VPCMPEQB Y2, Y3, Y5
VPAND Y4, Y5, Y6
VPMOVMSKB Y6, DX
ADDQ $64, SI
ADDQ $64, DI
SUBQ $64, BX
CMPL DX, $0xffffffff
JEQ hugeloop_avx2
VZEROUPPER
XORQ AX, AX // return 0
RET

如果发现不同,则直接返回 0,否则继续循环,直到剩余长度小于 64,进入 8 个字符的循环。一个通用寄存器刚好可以装下 8 个字符,所以不需要 SIMD。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
bigloop_avx2:
VZEROUPPER

// 8 bytes at a time using 64-bit register
bigloop:
CMPQ BX, $8
JBE leftover
MOVQ (SI), CX
MOVQ (DI), DX
ADDQ $8, SI
ADDQ $8, DI
SUBQ $8, BX
CMPQ CX, DX
JEQ bigloop
XORQ AX, AX // return 0
RET

最后比较剩余字符。不过这里并没有用循环,而是直接加载字符串末尾的 8 个字符(可能与之前判断过的字符重叠)。

1
2
3
4
5
6
7
        // remaining 0-8 bytes
leftover:
MOVQ -8(SI)(BX*1), CX
MOVQ -8(DI)(BX*1), DX
CMPQ CX, DX
SETEQ AX
RET

如果字符串的长度本来就小于 8,这么做会加载一些不属于字符串的字符。代码中对这种情况也做了处理。

可以看出,这个函数尽可能地使用了 SIMD,用一条指令判断多个字符是否相等,加快了处理速度,同时也保证了边界情况下的正确性,对较短的字符串进行特殊处理。

尝试用自己编写的函数进行替换。首先用最简单的方法,直接一个一个字符进行比较。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
TEXT memeqbody<>(SB),NOSPLIT,$0-0
loop_1:
CMPQ BX, $0
JEQ equal
MOVB (SI), CX
MOVB (DI), DX
ADDQ $1, SI
ADDQ $1, DI
SUBQ $1, BX
CMPB CX, DX
JEQ loop_1
XORQ AX, AX
RET

equal:
SETEQ AX
RET

代码很短,但显然效率不够高。添加 8 个字符的循环,剩余字符还是一个一个比较。

1
2
3
4
5
6
7
8
9
10
11
12
loop_8:
CMPQ BX, $8
JB loop_1
MOVQ (SI), CX
MOVQ (DI), DX
ADDQ $8, SI
ADDQ $8, DI
SUBQ $8, BX
CMPQ CX, DX
JEQ loop_8
XORQ AX, AX
RET

当前 CPU 支持 MMX、SSE4.2、AVX2 指令集。MMX 的寄存器为 64 位寄存器,一个寄存器可以装下 8 个字符。将 8 个字符的循环改为使用 MMX 的寄存器。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
loop_8_mmx:
CMPQ BX, $8
JB loop_1
MOVQ (SI), M0
MOVQ (DI), M1
PCMPEQB M0, M1
MOVQ M1, CX
ADDQ $8, SI
ADDQ $8, DI
SUBQ $8, BX
CMPQ CX, $-1
JEQ loop_8_mmx
XORQ AX, AX
RET

第 4、5 行将从地址 SIDI 开始的 8 个字符分别存入 M0M1 寄存器,即进行了打包。第 6 行用 PCMPEQB 指令(compare packed bytes for equal)比较 M0M1 打包字节整数值的相等性,并将比较结果存入 M1。如果 M0M1 中的某个字节相等,则 M1 中的这个字节会变成全 1(即有符号数的 -1),否则变成全 0。因为 MMX 指令不会修改状态寄存器,所以需要将 M1 的值存入 CX,再与 -1 比较。

SSE4.2 的寄存器为 128 位寄存器,一个寄存器可以装下 16 个字符。添加 16 个字符的循环。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
loop_16_sse:
CMPQ BX, $16
JB loop_8_mmx
MOVUPD (SI), X0
MOVUPD (DI), X1
PCMPEQB X0, X1
PMOVMSKB X1, CX
ADDQ $16, SI
ADDQ $16, DI
SUBQ $16, BX
CMPW CX, $0xffff
JEQ loop_16_sse
XORQ AX, AX
RET

第 4、5 行用 MOVUPD 指令(move two unaligned packed double-precision floating-point values between XMM registers and memory)将从地址 SIDI 开始的 16 个字符分别存入 X0X1 寄存器。第 6 行将比较结果存入 X1。第 7 行用 PMOVMSKB 指令(move byte mask)将 X1 中每个字节的最高位提取出来存入 CX 的低位。这是因为 X1 寄存器有 128 位,无法直接存入通用寄存器进行判断。而比较某个字节时,如果相等则这个字节会变成全 1,所以只要提取每个字节的最高位即可判断是否全部相等。

AVX2 的寄存器为 256 位寄存器,一个寄存器可以装下 32 个字符。添加 32 个字符的循环。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
loop_32_avx:
CMPQ BX, $32
JB loop_16_sse
VMOVUPD (SI), Y0
VMOVUPD (DI), Y1
VPCMPEQB Y0, Y1, Y1
VPMOVMSKB Y1, CX
ADDQ $32, SI
ADDQ $32, DI
SUBQ $32, BX
CMPL CX, $0xffffffff
JEQ loop_32_avx
XORQ AX, AX
RET

AVX2 指令集与 SSE4.2 指令集类似,只需在指令前加字母 V。唯一不同的是第 6 行的 VPCMPEQB 指令,它需要三个操作数,将前两个操作数的比较结果存入第三个操作数。

由于 CPU 不支持 AVX512 指令集,因此无法使用 512 位寄存器。不过,还可以用循环展开来优化代码。AVX2 的寄存器有 16 个,所以可以展开 2 次、4 次、8 次循环。以展开 4 次循环为例,添加 128 个字符的循环。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
loop_128_unroll:
CMPQ BX, $128
JB loop_32_avx
VMOVUPD (SI), Y0
VMOVUPD (DI), Y1
VMOVUPD 32(SI), Y2
VMOVUPD 32(DI), Y3
VMOVUPD 64(SI), Y4
VMOVUPD 64(DI), Y5
VMOVUPD 96(SI), Y6
VMOVUPD 96(DI), Y7
VPCMPEQB Y0, Y1, Y1
VPCMPEQB Y2, Y3, Y3
VPCMPEQB Y4, Y5, Y5
VPCMPEQB Y6, Y7, Y7
VPAND Y1, Y3, Y3
VPAND Y5, Y7, Y7
VPAND Y3, Y7, Y7
VPMOVMSKB Y7, CX
ADDQ $128, SI
ADDQ $128, DI
SUBQ $128, BX
CMPL CX, $0xffffffff
JEQ loop_128_unroll
XORQ AX, AX
RET

第 4 行到第 7 行将连续 128 个字符存入 Y 系列寄存器,第 12 行到第 15 行分别比较四对 Y 系列寄存器,第 16 行到第 18 行将比较结果作按位与(因为相等的比较结果为全 1),第 19 行将结果每个字节的最高位提取出来存入 CX 的低位。展开 2 次、8 次循环的代码类似。

至此,已基本实现了用 SIMD 优化的字符串相等函数。

效果分析

编写测试函数,用于测试字符串相等函数的正确性。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
func TestEqual(t *testing.T) {
d, s := []byte("abcde"), []byte("abcde")
if !bytes.Equal(d, s) { // len < 8
t.Errorf("Equal(d, s) = false; want true")
}
for i := 0; i < 50; i++ {
d, s = append(d, 'f'), append(s, 'f')
}
if !bytes.Equal(d, s) { // len < 64
t.Errorf("Equal(d, s) = false; want true")
}
for i := 0; i < 500; i++ {
d, s = append(d, 'g'), append(s, 'g')
}
if !bytes.Equal(d, s) { // len >= 64
t.Errorf("Equal(d, s) = false; want true")
}
d = append(d, 'h')
if bytes.Equal(d, s) { // len(d) > len(s)
t.Errorf("Equal(d, s) = true; want false")
}
s = append(s, 'i')
if bytes.Equal(d, s) { // len(d) == len(s) && d[len-1] != s[len-1]
t.Errorf("Equal(d, s) = true; want false")
}
s = append(s, 'j')
if bytes.Equal(d, s) { // len(d) < len(s)
t.Errorf("Equal(d, s) = true; want false")
}
d, s = []byte("k"), []byte("l")
for i := 0; i < 5000; i++ {
d, s = append(d, 'm'), append(s, 'm')
}
if bytes.Equal(d, s) { // len(d) == len(s) && d[0] != s[0]
t.Errorf("Equal(d, s) = true; want false")
}
}

go test -run ^TestEqual$ 命令运行测试函数。实验中编写的每一种循环都可以通过测试。

编写基准函数,用于测试字符串相等函数的性能。

1
2
3
4
5
6
7
8
9
10
11
func BenchmarkEqual(b *testing.B) {
d, s := []byte(""), []byte("")
for i := 0; i < 4096; i++ { // 4KB
d = append(d, 'n')
s = append(s, 'n')
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
bytes.Equal(d, s)
}
}

go test -run ^$ -bench ^BenchmarkEqual$ -count 10 命令运行基准函数,可以得到字符串相等函数的运行次数和平均运行时间。字符串大小均为 4KB。据此可以算出函数处理数据的速率。

在表格的说明中,R1 表示一轮比较 1 个字符,R8 表示用通用寄存器一轮比较 8 个字符,M8 表示用 MMX 寄存器一轮比较 8 个字符,X16 表示用 SSE4.2 寄存器一轮比较 16 个字符,Y32 表示用 AVX2 寄存器一轮比较 32 个字符,Y64Y128Y256 分别表示用 AVX2 寄存器和 2 次、4 次、8 次循环展开一轮比较 64 个、128 个、256 个字符。

测试程序 说明 运行次数 运行时间(ns/op) 处理速率(GB/s)
original.asm Go 语言自带 13499010 87.92 43.39
loop_1.asm R1 424518 2721 1.402
loop_8.asm R8 + R1 3993097 298.0 12.80
loop_8_mmx.asm M8 + R1 2980017 395.4 9.648
loop_16_sse.asm X16 + R8 + R1 5688907 206.8 18.45
loop_32_avx.asm Y32 + R8 + R1 10291144 115.2 33.11
loop_64_unroll.asm Y64 + R8 + R1 13146477 88.00 43.35
loop_128_unroll.asm Y128 + Y32 + R8 + R1 17871183 64.31 59.32
loop_256_unroll.asm Y256 + Y32 + R8 + R1 21847575 54.56 69.92

可以看出:

  1. 随着 SIMD 寄存器位数的增加,函数的运行时间会减少,处理数据的速率也会变快;
  2. 用通用寄存器比较 8 个字符比用 MMX 寄存器比较 8 个字符要快,是因为 MMX 寄存器比较相等之后还需要移回通用寄存器进行跳转判断;
  3. Go 语言自带的函数用 AVX2 的寄存器时只展开了 2 次循环,实验中展开了 2 次循环的函数与 Go 语言自带的函数速度相近,而展开了 4 次、8 次循环的函数速度更快;
  4. 对于大小为 4KB 的字符串,实验中最快的函数相比最慢的函数速度提升 4887%,相比 Go 语言自带的函数速度提升 61.14%。

以上结果符合预期。

参考文献

  1. A Quick Guide to Go’s Assembler, https://go.dev/doc/asm.
  2. A Manual for the Plan 9 assembler, https://9p.io/sys/doc/asm.html.
  3. x64 Cheat Sheet, https://cs.brown.edu/courses/cs033/docs/guides/x64_cheatsheet.pdf.
  4. x86 and amd64 instruction reference, https://www.felixcloutier.com/x86/index.html.
  5. x86 Assembly Language Reference Manual, https://docs.oracle.com/cd/E37838_01/html/E61064/index.html.
  6. Intel® Instruction Set Extensions Technology, https://www.intel.com/content/www/us/en/support/articles/000005779/processors.html.
  7. equal_amd64.s, https://github.com/golang/go/blob/master/src/internal/bytealg/equal_amd64.s.
  8. testing package, https://pkg.go.dev/testing.

本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!