אלגוריתם לחזרה
האלגוריתם להטמעה ברקע הוא חיוני לאימון מהיר של רשתות נוירונים
במהירות. במאמר הזה נסביר איך האלגוריתם עובד.
יש לגלול למטה...
רשת נוירונים פשוטה
בצד שמאל תופיע רשת נוירונים עם קלט אחד, צומת פלט אחד ושתי שכבות נסתרות של שני צמתים.
הצמתים בשכבות הסמוכות מחוברים למשקלים wij, שהם הפרמטרים של הרשת.
פונקציית הפעלה
לכל צומת יש סה"כ קלט x, פונקציית הפעלה f(x)
ופלט y=f(x).
f(x) הפונקציה צריכה להיות לא לינארית, אחרת הרשת המוחשית תוכל ללמוד מודלים לינאריים בלבד.
פונקציית הפעלה נפוצה היא פונקציית Sigmoid:
f(x)=11+e−x.
פונקציית שגיאה
המטרה היא לקבל מידע על המשקלים של הרשת באופן אוטומטי על סמך נתונים כך שהפלט החזוי youtput
יהיה קרוב ליעד ytarget בכל הקלט xinput.
כדי לבדוק את המרחק שלנו מהיעד, אנחנו משתמשים בפונקציית שגיאה E.
פונקציית שגיאה נפוצה היא E(youtput,ytarget)=12(youtput−ytarget)2.
ריבוי קדימה
אנחנו מתחילים בדוגמת קלט (xinput,ytarget) ומעדכנים את שכבת הקלט של הרשת.
כדי לשמור על עקביות, אנחנו מחשיבים את הקלט כמו כל צומת אחר, אך ללא פונקציית הפעלה, כך שהפלט שלו שווה לקלט שלו. כלומר, y1=xinput.
ריבוי קדימה
עכשיו אנחנו מעדכנים את השכבה המוסתרת הראשונה. אנחנו מביאים את הפלט y של הצמתים בשכבה הקודמת
ומשתמשים בשקלול כדי לחשב את הקלט x של הצמתים בשכבה הבאה.
ריבוי קדימה
לאחר מכן, אנחנו מעדכנים את הפלט של הצמתים בשכבה המוסתרת הראשונה.
לשם כך אנחנו משתמשים בפונקציית ההפעלה, f(x).
ריבוי קדימה
באמצעות שתי הנוסחאות האלה, אנחנו מפיצים את שאר הרשת ומפיקים את הפלט הסופי של הרשת.
נגזרת של שגיאה
אלגוריתם ההפצה החוזר קובע כמה
לעדכן כל משקל של הרשת אחרי השוואה בין הפלט הצפוי לבין הפלט הרצוי בדוגמה ספציפית.
לשם כך, עלינו לחשב כיצד השגיאה משתנה
ביחס לכל משקל dEdwij.
אחרי שנקבל שגיאות נגזרות, נוכל לעדכן את המשקלים באמצעות כלל עדכון פשוט:
שבו α הוא קבוע, נקרא 'שיעור למידה', וצריך להתאים אותו בצורה אמפירית.
[הערה] כלל העדכון הוא פשוט מאוד: אם השגיאה יורדת כאשר המשקל עולה (dEdwij<0),
אז מעלים את המשקל, אחרת אם העלייה במשקל גדולה יותר (dEdwij>0)
ואז יורדת המשקל.
נגזרות נוספות
כדי לעזור לחשב dEdwij, אנחנו מאחסנים גם לכל צומת שתי נגזרות נוספות:
איך השגיאה משתנה עם:
- הקלט הכולל של הצומת dEdx
- הפלט של הצומת dEdy.
ריבוי גב
שנתחיל להפיץ את הנגזרים של השגיאות?
מכיוון שיש לנו את הפלט הצפוי של דוגמת הקלט הספציפית הזו, אנחנו יכולים לחשב איך השגיאה משתנה בפלט הזה.
בהתאם לפונקציית השגיאות שלנו E=12(youtput−ytarget)2 יש לנו:
ריבוי גב
עכשיו dEdy אנחנו יכולים לקבל dEdx באמצעות כלל הרשת.
כאשר ddxf(x)=f(x)(1−f(x))
f(x) היא פונקציית ההפעלה של Sigmoid.
ריבוי גב
ברגע שנקבל את השגיאה נגזרת ביחס לקלט הכולל של הצומת,
נוכל לקבל את השגיאה שנגזרת מהמשקלים שמגיעים לצומת.
ריבוי גב
באמצעות כלל השרשרת, אנחנו יכולים גם לקבל dEdy מהשכבה הקודמת. יצרנו מעגל שלם.
ריבוי גב
כל מה שנשאר לעשות הוא לחזור על שלוש הנוסחאות הקודמות עד שהמערכת תחשב את כל נגזרות השגיאה.