问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

【可视化3D卷积计算过程】

创作时间:
作者:
@小白创作中心

【可视化3D卷积计算过程】

引用
CSDN
1.
https://blog.csdn.net/qq_44166630/article/details/138320939

3D卷积可以用来处理视频输入,对于图片来讲,shape为:[C_in, H, W]。而视频多了时间这一维度,因此视频的shape为:[C_in, D, H, W],其中D为帧数(frame),比如一条视频有10帧,则D=10。(以上都忽略了batch size N)

假如我们现在的输入的视频shape为:[3, 7, 4, 4]。即:

input_channel frame H W
3 7 4 4

kernel shape为:[5, 3, 2, 2, 2]

output_channel input_channel kernel_D kernel_H kernel_W
5 3 2 2 2


计算过程可视化如下:


output shape为:[5, 6, 3, 3],其中:

output_channel output_D output_H output_W
5 6 3 3

代码验证

import torch
import torch.nn as nn

N, C_in, D, H, W = 1, 3, 7, 4, 4
C_out = 5

m = nn.Conv3d(in_channels=C_in, out_channels=C_out, kernel_size=2, stride=1, bias=False)
inputs = torch.zeros(N, C, D, H, W)
m.weight = nn.Parameter(torch.ones(C_out, C_in, 2, 2, 2))

inputs[:, 0, :, :, :] = torch.ones(D, H, W)
inputs[:, 1, :, :, :] = torch.ones(D, H, W) * 2
inputs[:, 2, :, :, :] = torch.ones(D, H, W) * 3

output = m(inputs)
print(inputs, inputs.shape)
"""
tensor([[[[[1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.]],
          [[1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.]],
          [[1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.]],
          [[1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.]],
          [[1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.]],
          [[1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.]],
          [[1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.],
           [1., 1., 1., 1.]]],
         [[[2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.]],
          [[2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.]],
          [[2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.]],
          [[2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.]],
          [[2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.]],
          [[2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.]],
          [[2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.],
           [2., 2., 2., 2.]]],
         [[[3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.]],
          [[3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.]],
          [[3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.]],
          [[3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.]],
          [[3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.]],
          [[3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.]],
          [[3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.],
           [3., 3., 3., 3.]]]]])
shape:
torch.Size([1, 3, 7, 4, 4])
"""

print(output, output.shape)
"""
tensor([[[[[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]]],
         [[[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]]],
         [[[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]]],
         [[[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]]],
         [[[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]],
          [[48., 48., 48.],
           [48., 48., 48.],
           [48., 48., 48.]]]]], grad_fn=<SlowConv3DBackward0>)
shape:
torch.Size([1, 5, 6, 3, 3])
"""

48怎么来的?
2x2x2x1 + 2x2x2x2 + 2x2x2x3 = 48
© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号