问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

GEMM算法及优化流程详解

创作时间:
作者:
@小白创作中心

GEMM算法及优化流程详解

引用
CSDN
1.
https://blog.csdn.net/qq_20880415/article/details/104332743

GEMM(General Matrix Multiply)算法是深度学习中最重要的基础算法之一,广泛应用于全连接层和卷积层的计算。本文将详细介绍GEMM算法及其优化过程,帮助读者理解如何通过减少内存访问和优化循环结构来提升计算效率。

前言

神经网络前向耗时主要由卷积的耗时决定。主流的卷积加速方法有以下几种:

  • im2col+GEMM:目前几乎所有的主流计算框架包括 Caffe, MXNet 等都实现了该方法。该方法把整个卷积过程转化成了GEMM过程,而GEMM在各种 BLAS 库中都是被极致优化的,一般来说,速度较快。

  • Winograd:Winograd 是存在已久最近被重新发现的方法,在大部分场景中, Winograd方法都显示和较大的优势,目前cudnn中计算卷积就使用了该方法。

  • Strassen:1969年,Volker Strassen提出了第一个时间复杂度低于O(N^3)的算法,其复杂度为O(N^(2^(log2(7)))),但这种方法只在大卷积核情况下优势才比较明显,目前还没有在开源框架中见到这种方法。

  • FFT:傅里叶变换和快速傅里叶变化是在经典图像处理里面经常使用的计算方法,但是,在 ConvNet中通常不采用,主要是因为在 ConvNet 中的卷积模板通常都比较小,例如 3×3 等,这种情况下,FFT 的时间开销反而更大,所以很少在CNN中利用FFT实现卷积。

im2col+GEMM算法简介

GEMM在深度学习中是十分重要的,全连接层以及卷积层基本上都是通过GEMM来实现的,而网络中大约90%的运算都是在这两层中。而一个良好的GEMM的实现可以充分利用系统的多级存储结构和程序执行的局部性来充分加速运算。

常规的卷积操作为:

3维卷积运算执行完毕,得一个2维的平面:

将卷积操作的3维立体变为二维矩阵乘法,可以调用BLAS中的GEMM库,按 [kernel_height, kernel_width, kernel_depth] ⇒ 将输入分成 3 维的 patch,并将其展成一维向量:
此时的卷积操作就可转化为矩阵乘法:

下面我们将以M=K=N=600为例说明GEMM算法的优化过程:

直接暴力卷积

for (int m = 0; m < M; m++) {
    for (int n = 0; n < N; n++) {
        for (int k = 0; k < K; k++) {
            C[m][n]+= A[m][k] * B[k][n];
        }
    }
}

上述公式总计算量为2MNK FLOPs(其中 𝑀、𝑁、𝐾 分别指代三层循环执行的次数,2 指代循环最内层的一次乘法和加法),内存访问操作总数为 4MNK(其中 2MNK 指代对 𝐶 的内存访问,𝐶 需要先读取内存、累和再存储)。GEMM 的优化均以此为基点。

耗时分析:上述暴力gemm代码耗时约为872ms

GEMM算法优化

optimize1

首先能想到的就是减少C矩阵的访存次数,将C[m][n]放到外面,全部累和之后再赋值即可:

for (int m = 0; m < M; m++) {
    for (int n = 0; n < N; n++) {
        float temp = C[m][n];
        for (int k = 0; k < K; k++) {
            temp += A[m][k] * B[k][n];
        }
        C[m][n] = temp;
    }
}

上述公式总计算量依然为2MNK FLOPs,内存访问操作总数为 2MNK+2MN(其中 2MN 指代对 𝐶 的内存访问,𝐶 需要先读取内存、累加完毕在存储)。

耗时分析:上述代码耗时约为791ms,耗时变少的原因是减少了部分C的访存

optimize2

将输出的计算拆分为 1×4 的小块,即将 𝑁 维度拆分为两部分。计算该块输出时,需要使用 𝐴 矩阵的 1 行,和 𝐵 矩阵的 4 列。

图一:矩阵乘计算 1×4输出

下面是该计算的伪代码表示,这里已经将 1×4 中 N 维度的内部拆分进行了展开。这里的计算量仍然是 2𝑀𝑁𝐾 ,这一点在本文中不会有变化。

for (int m = 0; m < M; m++) {
    for (int n = 0; n < N; n += 4) {
        float temp_m0n0 = C[m][n + 0];
        float temp_m0n1 = C[m][n + 1];
        float temp_m0n2 = C[m][n + 2];
        float temp_m0n3 = C[m][n + 3];
        for (int k = 0; k < K; k++) {
            float temp = A[m][k];
            temp_m0n0 += temp * B[k][n + 0];
            temp_m0n1 += temp * B[k][n + 1];
            temp_m0n2 += temp * B[k][n + 2];
            temp_m0n3 += temp * B[k][n + 3];
        }
        C[m][n + 0] = temp_m0n0;
        C[m][n + 1] = temp_m0n1;
        C[m][n + 2] = temp_m0n2;
        C[m][n + 3] = temp_m0n3;
    }
}

简单的观察即可发现,上述伪代码的最内侧计算使用的矩阵 𝐴 的元素是一致的。因此可以将 𝐴[𝑚][𝑘] 读取到寄存器中,从而实现 4 次数据复用(这里不再给出示例)。一般将最内侧循环称作计算核(micro kernel)。进行这样的优化后,内存访问操作数量变为 2MN+5/4MNK,访存约为上面的5/8。

耗时分析:本优化耗时约为473ms,相比暴力耗时减少300ms左右,可能的两个原因:1、由于B是行优先排列,1x4方法能够减少数据从内存到cache的加载次数;2、合理利用寄存器,减少对𝐴矩阵访存次数

optimize3

类似地,我们可以继续拆分输出的 𝑀 维度,从而在内侧循环中计算 4×4 输出,如图二。

图二:矩阵乘计算 4×4输出

同样地,将计算核心展开,可以得到下面的伪代码。由于乘数效应,4×4 的拆分可以将对输入数据的访存缩减到 MN/16*(162+8K)=2MN+1/2MNK。这相对于最开始的 4MNK 已经得到了 8X 的改进,这些改进都是通过展开循环后利用寄存器存储数据减少访存得到的。

for (int m = 0; m < M; m += 4) {
    for (int n = 0; n < N; n += 4) {
        float temp_m0n0 = C[m + 0][n + 0];
        float temp_m0n1 = C[m + 0][n + 1];
        float temp_m0n2 = C[m + 0][n + 2];
        float temp_m0n3 = C[m + 0][n + 3];
        
        float temp_m1n0 = C[m + 1][n + 0];
        float temp_m1n1 = C[m + 1][n + 1];
        float temp_m1n2 = C[m + 1][n + 2];
        float temp_m1n3 = C[m + 1][n + 3];
        float temp_m2n0 = C[m + 2][n + 0];
        float temp_m2n1 = C[m + 2][n + 1];
        float temp_m2n2 = C[m + 2][n + 2];
        float temp_m2n3 = C[m + 2][n + 3];
        float temp_m3n0 = C[m + 3][n + 0];
        float temp_m3n1 = C[m + 3][n + 1];
        float temp_m3n2 = C[m + 3][n + 2];
        float temp_m3n3 = C[m + 3][n + 3];
        
        for (int k = 0; k < K; k++) {
            float temp_m0 = A[m + 0][k];
            float temp_m1 = A[m + 1][k];
            float temp_m2 = A[m + 2][k];
            float temp_m3 = A[m + 3][k];
            float temp_n0 = B[k][n + 0];
            float temp_n1 = B[k][n + 1];
            float temp_n2 = B[k][n + 2];
            float temp_n3 = B[k][n + 3];
            temp_m0n0 += temp_m0 * temp_n0;
            temp_m0n1 += temp_m0 * temp_n1;
            temp_m0n2 += temp_m0 * temp_n2;
            temp_m0n3 += temp_m0 * temp_n3;
            temp_m1n0 += temp_m1 * temp_n0;
            temp_m1n1 += temp_m1 * temp_n1;
            temp_m1n2 += temp_m1 * temp_n2;
            temp_m1n3 += temp_m1 * temp_n3;
            temp_m2n0 += temp_m2 * temp_n0;
            temp_m2n1 += temp_m2 * temp_n1;
            temp_m2n2 += temp_m2 * temp_n2;
            temp_m2n3 += temp_m2 * temp_n3;
            temp_m3n0 += temp_m3 * temp_n0;
            temp_m3n1 += temp_m3 * temp_n1;
            temp_m3n2 += temp_m3 * temp_n2;
            temp_m3n3 += temp_m3 * temp_n3;
        }
        C[m + 0][n + 0] = temp_m0n0;
        C[m + 0][n + 1] = temp_m0n1;
        C[m + 0][n + 2] = temp_m0n2;
        C[m + 0][n + 3] = temp_m0n3;
        C[m + 1][n + 0] = temp_m1n0;
        C[m + 1][n + 1] = temp_m1n1;
        C[m + 1][n + 2] = temp_m1n2;
        C[m + 1][n + 3] = temp_m1n3;
        C[m + 2][n + 0] = temp_m2n0;
        C[m + 2][n + 1] = temp_m2n1;
        C[m + 2][n + 2] = temp_m2n2;
        C[m + 2][n + 3] = temp_m2n3;
        C[m + 3][n + 0] = temp_m3n0;
        C[m + 3][n + 1] = temp_m3n1;
        C[m + 3][n + 2] = temp_m3n2;
        C[m + 3][n + 3] = temp_m3n3;
    }
}

耗时分析:本优化耗时约为354ms,相比1x4耗时减少120ms左右

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号