问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

PyTorch中的grid_sample函数详解

创作时间:
作者:
@小白创作中心

PyTorch中的grid_sample函数详解

引用
CSDN
1.
https://blog.csdn.net/weixin_42392454/article/details/141608192

grid_sample是PyTorch提供的重要函数,主要用于图像处理中的采样操作。它允许通过给定的采样坐标从输入张量中获取相应的值。当采样坐标包含小数时,grid_sample会使用插值方法计算出对应的值。

函数签名

torch.nn.functional.grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=True)

参数说明

  1. input:这是一个形状为(N, C, H_in, W_in)的4D张量,其中N是批次大小,C是通道数,H_in和W_in分别是输入特征图的高度和宽度。

  2. grid:这是一个形状为(N, H_out, W_out, 2)的4D张量,表示目标位置的网格。最后一维表示每个位置的(x, y)坐标,值的范围通常在[-1, 1]之间,其中-1对应左/上边界,1对应右/下边界。

  3. mode:指定插值方式,有两个选项:

  • 'bilinear'(默认):使用双线性插值。
  • 'nearest':使用最近邻插值。
  1. padding_mode:当采样点超出输入特征图边界时指定填充方式,有三个选项:
  • 'zeros'(默认):超出边界的点填充为0。
  • 'border':超出边界的点采用边界值填充。
  • 'reflection':超出边界的点使用对称填充。
  1. align_corners:这个参数影响坐标网格的解释方式:
  • True:采样网格的边缘点直接对齐到原始特征图的像素格上。
  • False:采样网格的边缘点直接对齐到原始特征图的像素格的角点上。

插值方法详解

grid_sample提供两种插值方式:

  1. mode='bilinear'(默认)
  • 进行双线性插值(bilinear interpolation)。当坐标包含小数时,grid_sample会根据周围的像素值来计算出精确的采样结果。这意味着,如果采样点的坐标(即displacement)落在像素之间,grid_sample会根据四个相邻像素的值进行加权平均,生成插值结果。
  • 具体来说,如果采样点(x, y)对应的坐标在(i, j)和(i+1, j+1)之间,双线性插值会计算如下:
    value = (1 - Δx)(1 - Δy) \cdot V_{i,j} + Δx(1 - Δy) \cdot V_{i+1,j} + (1 - Δx)Δy \cdot V_{i,j+1} + ΔxΔy \cdot V_{i+1,j+1}
    
    其中,Δx和Δy是坐标的小数部分,V_{i,j}是像素值。
  1. mode='nearest'
  • 采用最近邻插值(nearest-neighbor interpolation)。如果采样坐标包含小数,grid_sample会取最近的整数位置对应的像素值。

grid坐标范围处理

grid的取值范围是[-1, 1],在函数内部会进行尺度的复原:

real ix = THTensor_fastGet4d(grid, n, h, w, 0);
real iy = THTensor_fastGet4d(grid, n, h, w, 1);

// normalize ix, iy from [-1, 1] to [0, IH-1] & [0, IW-1]
ix = ((ix + 1) / 2) * (IW-1);
iy = ((iy + 1) / 2) * (IH-1);

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号