线性回归:损失

损失是一个数值指标,用于描述模型的预测有多不准确。损失函数用于衡量模型预测结果与实际标签之间的距离。训练模型的目标是尽可能降低损失,将其降至最低值。

在下图中,您可以将损失可视化为从数据点指向模型的箭头。箭头表示模型的预测结果与实际值之间的差距。

图 9. 损失率线将数据点与模型相关联。

图 9. 损失是从实际值到预测值衡量的。

丢失距离

在统计学和机器学习中,损失函数用于衡量预测值与实际值之间的差异。损失函数侧重于值之间的距离,而不是方向。例如,如果模型预测值为 2,但实际值为 5,我们并不关心损失为负值 $ -3 $($ 2-5=-3 $)。我们关心的是这两个值之间的距离为 $ 3 $。因此,所有用于计算损失的方法都会移除符号。

移除符号的最常用方法为:

  • 计算实际值与预测值之差的绝对值。
  • 将实际值与预测值之间的差值平方。

损失类型

在线性回归中,有四种主要的损失函数,如下表所示。

损失类型 定义 公式
L1 损失 预测值与实际值之间差异的绝对值的总和。 $ ∑ | actual\ value - predicted\ value | $
平均绝对误差 (MAE) 一组示例的 L1 损失的平均值。 $ \frac{1}{N} ∑ | actual\ value - predicted\ value | $
L2 损失 预测值与实际值之间的平方差的总和。 $ ∑(actual\ value - predicted\ value)^2 $
均方误差 (MSE) 一组样本的 L2 损失平均值。 $ \frac{1}{N} ∑ (actual\ value - predicted\ value)^2 $

L1 损失与 L2 损失(或 MAE 和 MSE)之间的功能性差异是平方。当预测值与标签之间的差异较大时,平方会使损失变得更大。当差值很小(小于 1)时,平方会使损失更小。

同时处理多个示例时,我们建议对所有示例的损失进行平均,无论是使用 MAE 还是 MSE。

计算损失示例

使用前面的最佳拟合线,我们将计算单个样本的 L2 损失。从最佳拟合线中,我们获取了以下权重和偏差值:

  • $ \small{重量:-3.6} $
  • $ \small{Bias: 30} $

如果模型预测重 2,370 磅的汽车每加仑可行驶 21.5 英里,但实际每加仑可行驶 24 英里,我们将按如下方式计算 L2 损失:

公式 结果
预测

$\small{bias + (weight * feature\ value)}$

$\small{30 + (-3.6*2.37)}$

$\small{21.5}$
实际值 $ \small{ 标签 } $ $ \small{ 24 } $
L2 损失

$ \small{ (预测 - 实际\值)^2} $

$\small{ (21.5 - 24)^2 }$

$\small{6.25}$

在此示例中,该单个数据点的 L2 损失为 6.25。

选择损失

确定是使用 MAE 还是 MSE 可能取决于数据集以及您希望处理特定预测的方式。数据集中的大多数特征值通常属于一个特定范围。例如,汽车通常在 2,000 到 5,000 磅之间,每加仑汽油能行使的英里数介于 8 到 50 英里之间。重 8,000 磅的汽车或每加仑行驶 100 英里的汽车不在典型范围之内,会被视为离群值

离群值还可以指模型预测结果与真实值之间的差距。例如,重 3,000 磅的汽车或每加仑可行驶 40 英里的汽车都在典型范围内。但是,对于模型的预测而言,重 3,000 磅的汽车每加仑能行驶 40 英里属于离群值,因为模型会预测重 3,000 磅的汽车每加仑能行驶 18 到 20 英里。

选择最佳损失函数时,请考虑您希望模型如何处理离群值。例如,MSE 会使模型更接近离群值,而 MAE 则不会。与 L1 损失相比,L2 损失对离群值造成的惩罚要高得多。例如,下图显示了使用 MAE 训练的模型和使用 MSE 训练的模型。红线表示将用于进行预测的完全训练好的模型。离群值更接近使用 MSE 训练的模型,而不是使用 MAE 训练的模型。

图 10。模型更倾向于离群值。

图 10. 使用 MSE 训练的模型会使模型更接近离群值。

图 11. 模型会进一步远离离群值。

图 11. 使用 MAE 训练的模型离离群值更远。

请注意模型与数据之间的关系:

  • MSE。模型更接近离群值,但与大多数其他数据点的距离更远。

  • MAE。模型离离群值较远,但离大多数其他数据点较近。

检查您的理解情况

请考虑以下两个图表:

由 10 个点构成的曲线图。
      一条线穿过 6 个点。2 个点位于线条上方 1 个单位处;另外 2 个点位于线条下方 1 个单位处。 由 10 个点构成的曲线图。一条线穿过其中 8 个点。1 个点位于线条上方 2 个单位处;另一个点位于线条下方 2 个单位处。
以上两个图表中显示的数据集,哪个数据集的均方误差 (MSE) 较高
左侧的数据集。
该行中的 6 个示例的总损失为 0。四个不在线上的示例离线条不远,因此即使对其偏移量进行平方处理,得到的值仍然很低:$MSE = \frac{0^2 + 1^2 + 0^2 + 1^2 + 0^2 + 1^2 + 0^2 + 1^2 + 0^2 + 0^2} {10} = 0.4$
右侧的数据集。
该行上的 8 个示例的总损失为 0。不过,尽管只有两个点在线外,但这两个点离线的距离依然是左图中离群点的 2 倍。平方损失进一步加大差异,因此两个的偏移量产生的损失是 1 的四倍:$MSE = \frac{0^2 + 0^2 + 0^2 + 2^2 + 0^2 + 0^2 + 0^2 + 2^2 + 0^2} = { 1.0^2} = 0.8