Regresja logistyczna modele są trenowane przy użyciu tego samego procesu regresja liniowa i wyróżniamy 2 główne różnice:
- Wykorzystanie modeli regresji logistycznej Log Loss jako funkcja utraty. zamiast straty kwadratowej.
- Stosuję regularizację. ma kluczowe znaczenie dla zapobiegania zbyt.
W kolejnych sekcjach bardziej szczegółowo omawiamy te 2 kwestie.
Logarytmiczna funkcja utraty danych
W module regresji liniowej użyto straty kwadratowej (nazywanej też strata L2) jako funkcji straty. Kwadratowa strata dobrze sprawdza się w przypadku liniowej model, w którym szybkość zmian wartości wyjściowych jest stała. Przykład: dla modelu liniowego: $y' = b + 3x_1$, za każdym razem, gdy zwiększysz wartość wejściową wartość $x_1$ przez 1, wartość wyjściowa $y'$ zwiększa się o 3.
Jednak szybkość zmian modelu regresji logistycznej nie jest stała. Jak widać w sekcji Obliczanie prawdopodobieństwa, Krzywa sigmoidalna ma kształt s a nie liniowe. Gdy wartość log-odds ($z$) jest bliższa 0, mała wartość wzrost wartości $z$ skutkuje znacznie większymi zmianami wartości $y$ niż wtedy, gdy $z$ jest dużym liczbę dodatnią lub ujemną. Poniższa tabela przedstawia funkcję sigmoidalną dane wyjściowe dla wartości wejściowych od 5 do 10 oraz odpowiadającej im precyzji wymagane do uchwycenia różnic w wynikach.
dane wejściowe | dane wyjściowe logistyczne | wymagane cyfry dokładności |
---|---|---|
5 | 0,993 | 3 |
6 | 0,997 | 3 |
7 | 0,999 | 3 |
8 | 0,9997 | 4 |
9 | 0,9999 | 4 |
10 | 0,99 998 | 5 |
Jeśli do obliczenia błędów funkcji sigmoidalnej użyjesz straty kwadratowej, której wartości
dane wyjściowe są coraz bliżej punktów 0
i 1
, potrzebujesz więcej pamięci, aby
aby zachować precyzję niezbędną do śledzenia tych wartości.
Zamiast tego funkcja utraty dla regresji logistycznej to Logarytmiczne. Równanie logarytmicznej straty zwraca logarytm wielkości zmiany, a nie niż tylko odległość od danych do prognozy. Logarytmiczna funkcja utraty danych jest obliczana według wzoru: następujące:
\(\text{Log Loss} = \sum_{(x,y)\in D} -y\log(y') - (1 - y)\log(1 - y')\)
gdzie:
- \((x,y)\in D\) to zbiór danych zawierający wiele przykładów oznaczonych etykietami, \((x,y)\) pary.
- \(y\) to etykieta w przykładzie oznaczonym etykietą. Jest to regresja logistyczna, każda wartość \(y\) musi wynosić 0 lub 1.
- \(y'\) to prognoza Twojego modelu (gdzieś od 0 do 1) dla danego zbioru funkcji w \(x\).
Regularizacja regresji logistycznej
Regularizację, mechanizm zmniejszanie złożoności modelu podczas trenowania, jest niezwykle istotne w logistyce modelowanie regresji. Bez regularyzacji asymptotyczny charakter logistyki regresja nadal powoduje stratę do 0 w przypadkach, gdy model ma bardzo wiele funkcji. W efekcie większość modeli regresji logistycznej wykorzystuje jeden można zastosować 2 strategie mające na celu zmniejszenie złożoności modelu:
- L2 regularyzacja
- Wczesne zatrzymanie: Ograniczenie liczby kroków trenowania w celu przerwania trenowania, gdy utrata jest nadal maleje.