Loading [MathJax]/jax/output/HTML-CSS/jax.js
Demonstração de retropropagação

Algoritmo de retropropagação

O algoritmo de retropropagação é essencial para treinar rapidamente grandes redes neurais. Neste artigo, explicamos como o algoritmo funciona.

Role para baixo...

Rede neural simples

À direita, você vê uma rede neural com uma entrada, um nó de saída e duas camadas ocultas de dois nós.

Os nós em camadas vizinhas são conectados com pesos wij, que são os parâmetros de rede.

Função de ativação

Cada nó tem uma entrada total x, uma função de ativação f(x)e uma saída y=f(x). f(x) precisa ser uma função não linear. Caso contrário, a rede neural só poderá aprender modelos lineares.

Uma função de ativação usada com frequência é a função sigmoide: f(x)=11+ex.

Função erro

A meta é aprender os pesos da rede automaticamente a partir dos dados para que a saída prevista youtput esteja próxima do destino ytarget para todas as entradas xinput.

Para medir a distância da meta, usamos uma função de erro E. Uma função de erro usada com frequência é E(youtput,ytarget)=12(youtputytarget)2.

Propagação para frente

Começamos com um exemplo de entrada (xinput,ytarget) e atualizamos a camada de entrada da rede.

Para manter a consistência, consideramos a entrada como qualquer outro nó, mas sem uma função de ativação. Portanto, a saída é igual à entrada, ou seja, y1=xinput.

Propagação para frente

Agora, atualizamos a primeira camada escondida. Usamos a saída y dos nós na camada anterior e usamos os pesos para calcular a entrada x dos nós na camada seguinte.
xj=
iin(j)wijyi+bj

Propagação para frente

Depois atualizamos a saída dos nós na primeira camada escondida. Para isso, usamos a função de ativação, f(x).
y=f(x)

Propagação para frente

Usando essas duas fórmulas, propagamos o resto da rede e recebemos a saída final.
y=f(x)
xj=
iin(j)wijyi+bj

Derivado ao erro

O algoritmo de retropropagação decide quanto atualizar cada peso da rede após comparar a saída prevista com a saída desejada para um exemplo específico. Para isso, precisamos calcular como o erro muda em relação a cada peso dEdwij.
Depois de usar os derivados de erros, podemos atualizar os pesos usando uma regra de atualização simples:
wij=wijαdEdwij
em que α é uma constante positiva, chamada de taxa de aprendizado, que precisamos ajustar empiricamente.

[Observação] A regra de atualização é muito simples: se o erro diminuir quando o peso aumentar (dEdwij<0), aumente o peso. Caso contrário, se o erro aumentar quando o peso aumentar (dEdwij>0), então diminua o peso.

Derivados adicionais

Para ajudar a calcular dEdwij, também armazenamos para cada nó mais dois derivados: como o erro muda com:
  • a entrada total do nó dEdx e
  • a saída do nó dEdy.

Propagação de retorno

Vamos começar a propagar a retropropagação dos derivados de erros. Como temos a saída prevista desse exemplo de entrada específica, podemos calcular como o erro muda com essa saída. Considerando nossa função de erro E=12(youtputytarget)2 :
Eyoutput=youtputytarget

Propagação de retorno

Agora que temos dEdy podemos dEdx usar a regra da cadeia.
Ex=dydxEy=ddxf(x)Ey
em que ddxf(x)=f(x)(1f(x)) quando f(x) é a função de ativação Sigmoid.

Propagação de retorno

Assim que tivermos o derivado de erro em relação à entrada total de um nó, poderemos ver o derivado de erro em relação aos pesos que chegam a esse nó.
Ewij=xjwijExj=yiExj

Propagação de retorno

Usando a regra de cadeia, também é possível conseguir dEdy da camada anterior. Fizemos um círculo completo.
Eyi=jout(i)xjyiExj=jout(i)wijExj

Propagação de retorno

Só falta repetir as três fórmulas anteriores até calcularmos todos os derivados de erros.

Fim.

1y1xinput2y2x2dE/dy2dE/dx2fw12dE/dw3y3x3dE/dy3dE/dx3fw13dE/dw4y4x4dE/dy4dE/dx4fw24dE/dww34dE/dw5y5x5dE/dy5dE/dx5fw25dE/dww35dE/dw6youtputx6dE/dy6dE/dx6fw46dE/dww56dE/dwEytarget
Computando...