线性回归:损失

损失是一个数值指标,用于描述模型的预测有多大偏差。损失函数用于衡量模型预测与实际标签之间的距离。训练模型的目标是尽可能降低损失,将其降至最低值。

在下图中,您可以将损失可视化为从数据点指向模型的箭头。箭头表示模型的预测结果与实际值之间的差距。

图 9. 损失函数线将数据点与模型相关联。

图 9. 损失是从实际值到预测值衡量的。

丢失距离

在统计学和机器学习中,损失函数用于衡量预测值与实际值之间的差异。损失函数侧重于值之间的距离,而不是方向。例如,如果模型预测值为 2,但实际值为 5,我们并不关心损失为负值 $ -3 $($ 2-5=-3 $)。我们关心的是这两个值之间的距离为 $ 3 $。因此,所有用于计算损失的方法都会移除符号。

移除此标记的两种最常用方法如下:

  • 计算实际值与预测值之间的差值的绝对值。
  • 将实际值与预测值之间的差值平方。

损失类型

在线性回归中,有四种主要的损失函数,如下表所示。

损失类型 定义 公式
L1 损失 预测值与实际值之间差异的绝对值的总和。 $ ∑ | actual\ value - predicted\ value | $
平均绝对误差 (MAE) 一组示例的 L1 损失的平均值。 $ \frac{1}{N} ∑ | actual\ value - predicted\ value | $
L2 损失 预测值与实际值之间的平方差的总和。 $ ∑(actual\ value - predicted\ value)^2 $
均方误差 (MSE) 一组示例的 L2 损失的平均值。 $ \frac{1}{N} ∑ (actual\ value - predicted\ value)^2 $

L1 损失函数和 L2 损失函数(或 MAE 和 MSE)之间的功能差异在于平方。当预测值与标签之间的差异较大时,平方会使损失变得更大。当差异很小(小于 1)时,平方会使损失更小。

同时处理多个示例时,我们建议对所有示例的损失进行平均,无论是使用 MAE 还是 MSE。

计算损失示例

使用之前的最佳拟合直线,我们将计算单个示例的 L2 损失。从最优拟合线中,我们得到了权重和偏差的以下值:

  • $ \small{Weight: -3.6} $
  • $ \small{Bias: 30} $

如果模型预测重 2,370 磅的汽车每加仑可行驶 21.5 英里,但实际每加仑可行驶 24 英里,我们将按如下方式计算 L2 损失:

公式 结果
预测

$\small{bias + (weight * feature\ value)}$

$\small{30 + (-3.6*2.37)}$

$\small{21.5}$
实际值 $ \small{ label } $ $ \small{ 24 } $
L2 损失

$ \small{ (actual\ value - predicted\ value)^2 } $

$\small{ (24 - 21.5)^2 }$

$\small{6.25}$

在此示例中,该单个数据点的 L2 损失为 6.25。

选择损失

确定是使用 MAE 还是 MSE 可能取决于数据集以及您希望处理特定预测的方式。数据集中的大多数特征值通常属于一个特定范围。例如,汽车通常在 2,000 到 5,000 磅之间,每加仑汽油能行使的英里数介于 8 到 50 英里之间。重 8,000 磅的汽车或每加仑汽油能行驶 100 英里的汽车都超出了典型范围,会被视为离群值

离群值还可以指模型的预测与真实值之间的差距。例如,3,000 磅的车重属于典型的车重范围,而每加仑 40 英里的油耗属于典型的油耗范围。但是,对于每加仑 40 英里的 3,000 磅汽车,模型的预测结果会认为其属于离群值,因为模型会预测 3,000 磅的汽车每加仑能行驶 18 到 20 英里。

选择最佳损失函数时,请考虑您希望模型如何处理离群值。例如,MSE 会使模型更接近离群值,而 MAE 则不会。与 L1 损失函数相比,L2 损失函数对离群值的惩罚更高。例如,以下图片显示了使用 MAE 训练的模型和使用 MSE 训练的模型。红线表示将用于进行预测的完全训练好的模型。离群值更接近使用 MSE 训练的模型,而不是使用 MAE 训练的模型。

图 10。模型更倾向于离群值。

图 10. 使用 MSE 训练的模型会使模型更接近离群值。

图 11. 模型会进一步远离离群值。

图 11. 使用 MAE 训练的模型与离群值的距离更远。

请注意模型与数据之间的关系:

  • MSE。模型更接近离群值,但与大多数其他数据点的距离更远。

  • MAE。模型离离群值较远,但离大多数其他数据点较近。

检查您的理解情况

请考虑以下两个图表:

由 10 个点构成的曲线图。
      一条线穿过其中 6 个点。2 个点位于线条上方 1 个单位处;另外 2 个点位于线条下方 1 个单位处。 由 10 个点构成的曲线图。一条线穿过其中 8 个点。1 个点位于线条上方 2 个单位处;另一个点位于线条下方 2 个单位处。
以上两个图表中显示的数据集,哪个数据集的均方误差 (MSE) 较高
左侧的数据集。
该行中的 6 个示例的总损失为 0。四个不在线上的示例与线条相距不远,因此即使对其偏移量进行平方处理,得到的值仍然很低:$MSE = \frac{0^2 + 1^2 + 0^2 + 1^2 + 0^2 + 1^2 + 0^2 + 1^2 + 0^2 + 0^2} {10} = 0.4$
右侧的数据集。
该行中的 8 个示例的总损失为 0。不过,尽管只有两个点在线外,但这两个点离线的距离依然是左图中离群点的 2 倍。平方损失进一步加大差异,因此两个单位的偏移量产生的损失是一个单位的 4 倍:$MSE = \frac{0^2 + 0^2 + 0^2 + 2^2 + 0^2 + 0^2 + 0^2 + 2^2 + 0^2 + 0^2} {10} = 0.8$