损失函数(Loss Function)的全面介绍(简单易懂版)
损失函数(Loss Function)的全面介绍(简单易懂版)
一、什么是损失函数
损失函数是机器学习中一个非常重要的概念。简单来说,损失函数用于衡量模型预测值与真实值之间的差异。在机器学习中,我们希望模型的预测值能够尽可能地接近真实值,因此需要定义一个损失函数来量化这种差异。损失函数本质上就是计算预测值和真实值差距的一类函数,通过库(如PyTorch、TensorFlow等)的封装形成了有具体名字的函数。
二、为什么需要损失函数
损失函数在机器学习中的作用非常重要。通过损失函数,我们可以量化模型预测值与真实值之间的差异,从而指导模型参数的优化。在训练过程中,我们希望最小化损失函数的值,使得模型的预测结果尽可能接近真实值。选择合适的损失函数对于模型的训练效果至关重要,不同的损失函数在梯度下降的速度上可能会有很大差异。
三、损失函数通常使用的位置
在机器学习中,损失函数通常位于向前传播(Forward Pass)和向后传播(Backward Pass)之间。向前传播是指输入特征通过模型预测出输出值的过程,而向后传播则是根据损失函数计算的误差来更新模型参数的过程。损失函数在这里起到了承上启下的作用:它接收模型的预测值,计算预测值和真实值的差值,为反向传播提供输入数据。
四、常用的损失函数(基于PyTorch)
1. L1Loss函数
(1)数学本质
说明:
- 带小帽子的y(y_hat)表示的是经过模型的预测值,y可以表示真实值。
- 图中的m指的是一行数据中的m列。
(2)证明
我们可以通过一个简单的PyTorch代码示例来验证L1Loss函数的计算结果:
import torch as th
import torch.nn as nn
loss = nn.L1Loss()
input = th.Tensor([2, 3, 4, 5])
target = th.Tensor([4, 5, 6, 7])
output = loss(input, target)
print(output)
# tensor(2.)
手动计算验证:
output = (|2-4| + |3-5| + |4-6| + |5-7|) / 4 = 2
说明:
- 因为我们函数的“reduction”(L1Loss函数的参数)选择的是默认的"mean"(平均值),所以还会在除以一个"4"。
- 如果设置“loss=L1Loss(reduction='sum')”,则不用再除以4。
2. MSELoss函数
(1)数学本质
说明:
- 在此数学公式中的参数含义与L1Loss函数参数意义相同。
(2)证明
我们可以通过一个简单的PyTorch代码示例来验证MSELoss函数的计算结果:
import torch as th
import torch.nn as nn
loss = nn.MSELoss()
input = th.Tensor([2, 3, 4, 5])
target = th.Tensor([4, 5, 6, 7])
output = loss(input, target)
print(output)
# tensor(4.)
手动计算验证:
output = [(2-4)^2 + (3-5)^2 + (4-6)^2 + (5-7)^2] / 4 = 4
3. CrossEntropyLoss函数(交叉熵函数)
CrossEntropyLoss函数主要用于分类任务中。它结合了Softmax函数和交叉熵损失,适用于多分类问题。关于CrossEntropyLoss函数的详细介绍,可以参考以下链接:
五、补充知识点
1. CrossEntropyLoss函数主要用于分类项目中运用
2. One-hot编码
One-hot编码是一种常用的分类数据表示方法。例如:
one-hot | 猫 | 狗 | 兔 |
---|---|---|---|
1 | 1 | 0 | 0 |
2 | 0 | 1 | 0 |
3 | 0 | 0 | 1 |
说明:
- 最左边的一列1,2,3代表样本属于猫、狗、兔中的某一种。
- 最上面一行是分类(图中猫、狗、兔三类属于三分类问题,当然在编码过程中种类是用数字来代替的)。
- 这种编码方式让样本与样本的欧式距离一致,便于模型处理。
3. Softmax函数的理解
关于图中公式的理解,此变换是Softmax函数变换具体表达方式如下:
假设输入的y_hat=[1,2,3,4]
则经过图中函数变换输出的值为:
[ \left[ \frac{e^1}{e^1+e^2+e^3+e^4}, \frac{e^2}{e^1+e^2+e^3+e^4}, \frac{e^3}{e^1+e^2+e^3+e^4}, \frac{e^4}{e^1+e^2+e^3+e^4} \right] ]