線性迴歸:損失

損失是數值指標,用來描述模型的預測有多錯誤。損失可用於評估模型預測結果與實際標籤之間的距離。訓練模型的目標是盡量減少損失,將損失降到最小值。

在下圖中,您可以透過從資料點繪製的箭頭,以視覺化方式呈現損失。箭頭會顯示模型預測結果與實際值的差距。

圖 9:損失線會將資料點連結至模型。

圖 9. 損失是從實際值到預測值的差異。

損失距離

在統計和機器學習中,損失是用來評估預測值和實際值之間的差異。損失函數著重於值之間的距離,而非方向。舉例來說,如果模型預測值為 2,但實際值為 5,我們不會在意損失值為負 $ -3 $ 美元 ($ 2-5=-3 $)。我們在意的其實是值之間的差距為 $ 3 $ 美元。因此,所有計算損失值的方法都會移除符號。

移除標誌最常見的兩種方法如下:

  • 取實際值與預測值之間差異的絕對值。
  • 將實際值與預測值之間的差異平方。

損失類型

在線性迴歸中,損失主要分為四種,如下表所示。

損失類型 定義 方程式
L1 loss 預測值與實際值之間差異的絕對值總和。 $ ∑ | actual\ value - predicted\ value | $
平均絕對誤差 (MAE) 一組例子的 L1 損失值平均值。 $ \frac{1}{N} ∑ | actual\ value - predicted\ value | $
L2 損失 預測值與實際值之間的平方差的總和。 $ ∑(actual\ value - predicted\ value)^2 $
均方誤差 (MSE) 一組例子的 L2 損失值平均值。 $ \frac{1}{N} ∑ (actual\ value - predicted\ value)^2 $

L1 損失函式與 L2 損失函式 (或 MAE 與 MSE) 之間的差異在於平方。當預測值和標籤之間的差異很大時,平方運算會讓損失更大。如果差異很小 (小於 1),平方運算會讓損失更小。

一次處理多個範例時,無論是使用 MAE 還是 MSE,建議您對所有範例的損失值求平均值。

計算損失範例

我們將使用先前的最佳擬合線,為單一示例計算 L2 損失。從最佳擬合線,我們取得了以下的權重和偏差值:

  • $ \small{Weight: -3.6} $
  • $ \small{Bias: 30} $

如果模型預測 2,370 磅重的汽車每加侖可行駛 21.5 英里,但實際上每加侖可行駛 24 英里,我們會計算 L2 損失,如下所示:

方程式 結果
預測

$\small{bias + (weight * feature\ value)}$

$\small{30 + (-3.6*2.37)}$

$\small{21.5}$
實際值 $ \small{ label } $ $ \small{ 24 } $
L2 損失

$ \small{ (prediction - actual\ value)^2} $

$\small{ (21.5 - 24)^2 }$

$\small{6.25}$

在這個例子中,單一資料點的 L2 損失為 6.25。

選擇損失

決定是否使用 MAE 或 MSE 時,可以考量資料集和您要處理特定預測值的方式。資料集中的大部分特徵值通常會落在特定範圍內。舉例來說,汽車通常重 2000 到 5000 磅,每加侖油耗為 8 到 50 英里。重量 8,000 磅的車輛或每加侖行駛 100 英里的車輛,都超出一般範圍,因此會被視為異常值

異常值也可能指模型預測結果與實際值的差距。舉例來說,3,000 磅屬於一般車輛重量範圍,每加侖 40 英里則屬於一般燃油效率範圍。不過,如果 3,000 磅的汽車每加侖可行駛 40 英里,則會是模型預測的異常值,因為模型會預測 3,000 磅的汽車每加侖可行駛 18 到 20 英里。

選擇最佳損失函式時,請考量您希望模型如何處理異常值。舉例來說,MSE 會將模型更偏向異常值,而 MAE 則不會。相較於 L1 損失函式,L2 損失函式會對異常值施加更嚴重的懲罰。舉例來說,下列圖片顯示使用 MAE 訓練的模型,以及使用 MSE 訓練的模型。紅線代表已完全訓練的模型,可用於進行預測。與使用 MAE 訓練的模型相比,使用 MSE 訓練的模型更能預測離群值。

圖 10:模型更偏向離群值。

圖 10. 使用 MSE 訓練的模型會讓模型更接近異常值。

圖 11.模型會傾斜離開離群值。

圖 11. 使用 MAE 訓練的模型與離群值的距離較遠。

請注意模型與資料之間的關係:

  • MSE。模型更接近離群值,但與大多數其他資料點的距離較遠。

  • MAE。模型離離群值較遠,但離大多數其他資料點較近。

隨堂測驗

請參考以下兩個圖表:

10 個點的圖表。其中 6 個點會連成一條線。2 個點位於線條上方 1 個單位,其他 2 個點位於線條下方 1 個單位。 10 個點的圖表。一條線穿過 8 個點。1 個點位於線條上方 2 個單位,另一個點位於線條下方 2 個單位。
在前述圖表中,哪一組資料集的平均平方誤差 (MSE) 較高
左側的資料集。
這行中的六個例子總損失為 0。四個不在直線上的範例與直線的距離並不遠,因此即使將其偏移值平方,仍可產生低值:$MSE = \frac{0^2 + 1^2 + 0^2 + 1^2 + 0^2 + 1^2 + 0^2 + 1^2 + 0^2 + 0^2} {10} = 0.4$
右側的資料集。
這行上的八個例子總損失為 0。不過,雖然只有兩個點位於線條外,但這兩個點位與左圖中的異常值點相比,離線條的距離是兩倍。平方損失會放大這些差異,因此兩個偏移會造成損失,是偏移為一個的四倍:$MSE = \frac{0^2 + 0^2 + 0^2 + 2^2 + 0^2 + 0^2 + 0^2 + 2^2 + 0^2 + 0^2} {10} = 0.8$