Процес навчання моделі логістичної регресії такий самий, що й моделі лінійної регресії, але має дві ключові відмінності.
- У моделях логістичної регресії використовується логістична функція втрат, а не квадратична.
- Дуже важливо застосовувати регуляризацію, щоб запобігти надмірному навчанню.
У наступних розділах ми докладніше розглянемо ці два аспекти.
Логістична функція втрат
Вивчаючи модуль "Лінійна регресія", ви використовували квадратичну функцію втрат (інша назва – втрати L2) як функцію втрат. Квадратична функція втрат добре підходить для лінійної моделі, у якої швидкість зміни вихідних значень постійна. Візьмімо лінійну модель $y' = b + 3x_1$: щоразу, коли вхідні дані $x_1$ збільшуються на 1, вихідне значення $y'$ збільшується на 3.
Однак швидкість змін у моделі логістичної регресії не є постійною. Як пояснювалося в розділі "Обчислення ймовірності", сигмоїдна крива має S-подібну форму, а не лінійну. Коли логарифмічні шанси ($z$) близькі до 0, їх незначне збільшення призводить до набагато помітніших змін, ніж якби значення $z$ було великим додатним чи від’ємним числом. У наведеній нижче таблиці показано вихідні дані сигмоїдної функції для вхідних значень від 5 до 10, а також відповідну влучність, необхідну для відображення відмінностей у результатах.
Вхідні дані | Логістичні вихідні дані | Необхідна кількість цифр для влучності |
---|---|---|
5 | 0,993 | 3 |
6 | 0,997 | 3 |
7 | 0,999 | 3 |
8 | 0,9997 | 4 |
9 | 0,9999 | 4 |
10 | 0,99998 | 5 |
Якби для обчислення помилок сигмоїдної функції використовувалася квадратична функція втрат, у міру наближення результату до 0
й 1
знадобилося б більше пам’яті, щоб зберегти влучність, необхідну, щоб відстежувати ці значення.
Натомість для логістичної регресії використовується логістична функція втрат. Рівняння логістичної функції втрат виводить логарифм величини зміни, а не просто відстань від точки даних до прогнозного значення. Логістична функція втрат обчислюється так:
\(\text{Log Loss} = \sum_{(x,y)\in D} -y\log(y') - (1 - y)\log(1 - y')\)
Визначення змінних:
- \((x,y)\in D\) – це набір даних, що містить багато прикладів із мітками, які є парами \((x,y)\) ;
- \(y\) – мітка такого прикладу (оскільки це логістична регресія, кожне значення \(y\) має бути 0 або 1);
- \(y'\) – прогноз моделі (значення між 0 і 1) з урахуванням набору ознак, які має \(x\).
Регуляризація в логістичній регресії
Регуляризація – механізм штрафування складності моделі під час навчання – надзвичайно важлива при моделюванні логістичної регресії. Без регуляризації через асимптотичний характер логістичної регресії втрати продовжували б збільшуватися до 0, якщо в моделі велика кількість ознак. Тому в більшості моделей логістичної регресії використовується одна з двох стратегій для зменшення їх складності:
- регуляризація L2;
- рання зупинка (обмеження кількості навчальних кроків із метою припинити тренування моделі, коли значення втрат усе ще зменшується).