问小白 wenxiaobai
资讯
历史
科技
环境与自然
成长
游戏
财经
文学与艺术
美食
健康
家居
文化
情感
汽车
三农
军事
旅行
运动
教育
生活
星座命理

零基础学PyTorch:自动求导机制与梯度计算详解

创作时间:
作者:
@小白创作中心

零基础学PyTorch:自动求导机制与梯度计算详解

引用
CSDN
1.
https://m.blog.csdn.net/qq_46038405/article/details/145595731

PyTorch的自动求导机制是深度学习中非常重要的一部分,它能够自动计算梯度,从而帮助我们优化模型参数。本文将详细介绍PyTorch的自动求导机制,包括计算图构建、梯度计算、梯度累积特性等内容,并通过多项式函数求导等实战案例加深理解。

PyTorch Day 2:自动求导机制与梯度计算

一、Autograd核心原理

1. 计算图构建

import torch
# 创建需要跟踪梯度的张量
x = torch.tensor(2.0, requires_grad=True)
w = torch.tensor(1.0, requires_grad=True)
b = torch.tensor(0.5, requires_grad=True)
# 前向计算
y = w * x + b  # 构建计算图

2. 梯度计算

y.backward()  # 自动计算梯度
print(f'dy/dw = {w.grad}')  # 输出2.0 (x的值)
print(f'dy/dx = {x.grad}')  # 输出1.0 (w的值)
print(f'dy/db = {b.grad}')  # 输出1.0

3. 梯度累积特性

# 多次反向传播前需清零梯度
x.grad.data.zero_()
w.grad.data.zero_()
b.grad.data.zero_()

二、梯度计算实战

案例:多项式函数求导

f(x) = 3x^3 + 2x^2 + 5x + 1

x = torch.tensor(2.0, requires_grad=True)
f = 3*x**3 + 2*x**2 + 5*x + 1
f.backward()
print(f'在x=2处的导数:{x.grad:.2f}')  # 应输出49.00

验证计算正确性

手动计算:

f’(x) = 9x^2 + 4x + 5

当x=2时:

f’(2) = 94 + 42 + 5 = 36 + 8 + 5 = 49

三、高阶梯度应用

二阶导数计算

x = torch.tensor(3.0, requires_grad=True)
y = x**2 + 2*x
# 一阶导
first_grad = torch.autograd.grad(y, x, create_graph=True)[0]
# 二阶导
second_grad = torch.autograd.grad(first_grad, x)[0]
print(f'二阶导数值:{second_grad}')  # 输出2.0

四、梯度控制技巧

1. 禁用梯度跟踪

with torch.no_grad():
    y = x * 2  # 不记录计算历史

2. 分离计算图

loss = model(input)
# 仅保留数值,断开梯度连接
loss_value = loss.detach()

五、注意事项

  1. 梯度清零:每次反向传播前使用
    .zero_()
    方法清空梯度

  2. 类型匹配:确保所有参与计算的张量类型一致(float/double)

  3. 中间变量:避免对非叶子节点直接修改,可能导致梯度错误

  4. 内存管理:及时释放不再需要的计算图(
    del
    变量或使用
    with
    语句)

  5. GPU计算:梯度计算与设备无关,但需确保所有张量在同一设备

六、今日总结

  1. 理解计算图构建原理与梯度传播机制

  2. 掌握
    backward()

    grad
    的基本使用方法

  3. 实现手动梯度验证,理解自动微分正确性

  4. 学习梯度控制技巧,避免常见内存泄漏问题

完整代码示例:https://gitee.com/sr01/pytorch

⚠️常见错误

  • 忘记
    requires_grad=True
    导致无法计算梯度

  • 未及时清零梯度导致数值累加

  • 在验证阶段未禁用梯度跟踪造成内存浪费

© 2023 北京元石科技有限公司 ◎ 京公网安备 11010802042949号