问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

高性能卷积计算:img2col 原理详解

创作时间:
作者:
@小白创作中心

高性能卷积计算:img2col 原理详解

引用
CSDN
1.
https://blog.csdn.net/taoqick/article/details/129051936

卷积神经网络(CNN)在图像处理和计算机视觉领域有着广泛的应用,而卷积计算的效率直接影响着模型的训练速度和性能。本文将详细介绍一种优化卷积计算的方法——img2col算法,通过将其转化为矩阵乘法(GEMM),可以显著提升计算效率。

简介

img2col是一种实现卷积操作的加速计算策略。它能将卷积操作转化为GEMM(通用矩阵乘法),从而最大化地缩短卷积计算的时间。

GEMM是通用矩阵乘(General Matrix Multiply)的英文缩写,其实就是一般意义上的矩阵乘法,数学表达就是C = A x B。根据上下文语境,GEMM有时也指实现矩阵乘法的函数接口。

为什么要将卷积操作转化为GEMM呢?

  1. 因为线性代数领域已经有非常成熟的计算接口(BLAS,Fortran语言实现)来高效地实现大型的矩阵乘法,几乎可以做到极限优化。
  2. 将卷积过程中用到的所有特征子矩阵整合成一个大型矩阵存放在连续的内存中,虽然增加了存储成本,但是减少了内存访问的次数,从而缩短了计算时间。

原理

img2col的原理可以用下面这一张图来概括:

此图出自论文:High Performance Convolutional Neural Networks for Document Processing,感兴趣的同学可以自行去拜读一下。

上面这张图还是有些抽象,下面我们一步一步来分解上面这张图:

1. Input Features -> Input Matrix

不难看出,输入特征图一共有三个通道,我们以不同的颜色来区分。

以蓝色的特征图为例,它是一个3 x 3的矩阵,而卷积核是一个2 x 2的矩阵,当卷积核的滑动步长为1时,那么传统的直接卷积计算一共需要进行4次卷积核与对应特征子矩阵之间的点积运算。

现在我们把每一个特征子矩阵都排列成一个行向量(如图中编号1️⃣、2️⃣所示),然后把这4个行向量堆叠成一个新的矩阵,就得到了蓝色特征图所对应的Input Matrix。

当输入特征图不止一个通道时,则对每一个通道的特征图都采用上述操作,然后再把每一个通道对应的Input Matrix堆叠成一个完整的Input Matrix。

2. Convolution Kernel -> Kernel Matrix

不难看出,卷积核一共有两个,每个均为三通道,我们以第一个卷积核为例进行讲解。

将卷积核转化成矩阵的方式和第一步有些类似,只是这里应该转化成列向量(如图中编号1️⃣、2️⃣、3️⃣所示)。如果第一步转化成列向量,则这里应该转化成行向量,这是由矩阵乘法的计算特性决定的,即一个矩阵的每一行和另一个矩阵的每一列做内积,所以特征图和卷积核只能一个展开为行,一个展开为列。

同样地,如果卷积核有多个通道,则对每一个通道的卷积核都采用上述操作,然后再把每一个通道对应的Kernel Matrix堆叠成一个完整的Kernel Matrix。

3. Input Matrix * Kernel Matrix = Output Matrix

在得到上述两个矩阵之后,接下来调用GEMM函数接口进行矩阵乘法运算即可得到输出矩阵,然后将输出矩阵通过col2img函数就可以得到和卷积运算一样的输出特征图。

结语

通过img2col函数,我们只需执行一次矩阵乘法计算就能得到与卷积运算相同的结果,而传统的直接卷积计算光是一个通道就需要进行4次(仅指本例中)卷积核与对应特征子矩阵之间的点积运算,那么如果通道数特别多?输入特征图非常庞大呢?那计算的次数将是成倍增长的!

有些同学可能会担心将所有特征子矩阵都堆叠到一个矩阵中,会不会导致内存不够用或者计算速度非常慢,尤其是在深度神经网络中。其实不用担心,因为矩阵的存储和计算其实都是非常规则的,很容易通过分布式和并行的方式来解决,感兴趣的同学可以自行阅读相关论文。

代码实现

以下是Python版本的img2col实现:

import numpy as np

def img2col(image: np.ndarray, kernel_size: int, stride: int = 1, padding: int = 0) -> np.ndarray:
    h, w = image.shape
    out_h = (h + 2 * padding - kernel_size) // stride + 1
    out_w = (w + 2 * padding - kernel_size) // stride + 1
    pad_image = np.zeros((h + 2 * padding, w + 2 * padding))
    pad_image[padding:h+padding, padding:w+padding] = image
    ph, pw = pad_image.shape
    res = np.zeros((out_h * out_w, kernel_size * kernel_size))
    idx = 0
    # x * (k^2) * k^2
    for i in range(0, ph - kernel_size + 1, stride):
        for j in range(0, pw - kernel_size + 1, stride):
            res[idx] = pad_image[i:i + kernel_size, j:j + kernel_size].reshape(-1)
            idx += 1
    return res

img = np.zeros((6,3))
print(img)
print(img2col(img, 2))

以下是Pytorch版本的img2col实现:

import torch

def img2col(image: torch.Tensor, kernel_size: int, stride: int = 1, padding: int = 0) -> torch.Tensor:
    # 将变成(h,w)的Tensor转换成(out_h*out_w,kernal_size*kernal_size)的Tensor,其中out_h = (h + 2 * padding - kernel_size) // stride + 1
    in_h, in_w = image.shape
    out_h = (in_h + 2 * padding - kernel_size) // stride + 1
    out_w = (in_w + 2 * padding - kernel_size) // stride + 1
    print('out_w={} out_h={}'.format(out_w, out_h))
    pad_image = torch.zeros((in_h + 2 * padding, in_w + 2 * padding))
    pad_image[padding:in_h+padding, padding:in_w+padding] = image
    ph, pw = pad_image.shape
    res = torch.zeros((out_h * out_w, kernel_size * kernel_size))
 
    idx = 0
    # x * (k^2) * k^2
    for i in range(0, ph - kernel_size + 1, stride):
        for j in range(0, pw - kernel_size + 1, stride):
            res[idx] = pad_image[i:i + kernel_size, j:j + kernel_size].reshape(-1)
            idx += 1
    return res

img = torch.rand((6,3))
print('raw_img={}'.format(img))
# print(img2col(img, 2).shape)
print('img2col={}'.format(img2col(img, kernel_size=2, padding=1, stride=1)))
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号