Comprenons les problèmes des réseaux de neurones récurrents

Partager sur Facebook
Partager sur Twitter
Partager sur lié
Partager sur télégramme
Partager sur WhatsApp

Contenu

Cet article a été publié dans le cadre du Blogathon sur la science des données

introduction

Réseau de neurones récurrents (RNN) c'était l'un des meilleurs concepts introduits qui pourraient utiliser des éléments de mémoire dans notre réseau de neurones. Avant que, nous avions un réseau de neurones qui pouvait se propager dans les deux sens pour mettre à jour les poids et réduire les erreurs dans le réseau. Mais, comme nous savons, de nombreux problèmes dans le monde réel sont de nature temporaire et dépendent fortement du temps.

De nombreuses applications linguistiques sont toujours séquentielles et le mot suivant dans une phrase dépend du précédent. Ces problèmes ont été résolus par un simple RNN. Mais si nous comprenons RNN, nous apprécions le fait que même RNN ne peut pas nous aider lorsque nous voulons garder une trace des mots qui ont déjà été utilisés dans notre phrase. Dans cet article, Je vais discuter de certains des principaux inconvénients de RNN et pourquoi nous utilisons un meilleur modèle pour la plupart des applications basées sur le langage.

Comprendre la rétropropagation dans le temps (BPTT)

RNN utilise une technique appelée rétropropagation au fil du temps pour se propager à travers le réseau pour ajuster ses poids afin que nous puissions réduire les erreurs dans le réseau. A obtenu son nom “à travers le temps”, puisque dans RNN, nous traitons des données séquentielles et chaque fois que nous revenons en arrière, c'est comme remonter le temps dans le passé. Voici comment fonctionne BPTT:

rn-2055804

La source: (http://www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/)

A l'étape BPTT, on calcule la dérivée partielle à chaque poids du réseau. Ensuite, si on est dans le temps t = 3, alors on considère la dérivée de E3 par rapport à celle de S3. À présent, x3 est également connecté à s3. Ensuite, sa dérivée est également considérée. À présent, si on voit que s3 est connecté à s2, alors s3 dépend de la valeur de s2 et ici la dérivée de s3 par rapport à s2 est également considérée. Cela agit comme une règle de chaîne et nous accumulons toute la dépendance avec ses dérivés et l'utilisons pour calculer l'erreur.

Dans E3, nous avons un gradient qui est S3 et son équation à ce moment est:

capture d

Maintenant, nous avons également s2 associé à s3 donc,

capture d

Et s1 est également associé à s2 et, donc, maintenant tout s1, s2, s3 et a un effet sur E3,

capture d

En accumulant tout, on obtient l'équation suivante que Ws a contribué à ce réseau au temps t = 3,

capture d

L'équation générale pour laquelle nous adaptons Ws dans notre réseau BPTT peut s'écrire sous la forme,

capture d

À présent, comme nous l'avons remarqué, Wx est également associé au réseau. Ensuite, en faisant de même, on peut généralement écrire,

capture d

Maintenant que vous avez compris comment fonctionne BPTT, il s'agit essentiellement de la façon dont RNN ajuste ses poids et réduit l'erreur. À présent, le principal défaut ici est que ce n'est fondamentalement que pour un petit réseau avec 4 couvre. Mais imaginez si nous avions des centaines de couches et, à la fois, disons t = 100, on finirait par calculer toutes les dérivées partielles associées au réseau et c'est une énorme multiplication et cela peut réduire la valeur globale à une très petite valeur ou une valeur infime telle qu'il peut être inutile de corriger l'erreur. Ce problème s'appelle Problème de dégradé en train de disparaître.

Problème de dégradé en train de disparaître

Comme nous le savons tous, dans RNN pour prédire une sortie, nous utiliserons une fonction d'activation sigmoïde afin d'obtenir la probabilité de sortie pour une classe particulière. Comme nous l'avons vu dans la section précédente quand il s'agit de dire E3, il y a une dépendance à long terme. Le problème survient quand on prend la dérivée et que la dérivée du sigmoïde est toujours en dessous 0.25 Oui, donc, quand on multiplie plusieurs dérivées ensemble selon la règle de la chaîne, on se retrouve avec une valeur de fuite telle qu'on ne peut pas les utiliser pour le calcul de l'erreur. .

16a3a_rt4ymumhusvtvvtxw-7780692

La source: (https://versdatascience.com/the-vanishing-gradient-problem-69bf08b15484)

Donc, les poids et les biais ne seront pas mis à jour correctement et, au fur et à mesure que les couches continuent d'augmenter, nous sommes tombés plus loin et notre modèle ne fonctionne pas correctement et génère des imprécisions sur tout le réseau.

Certaines façons de résoudre ce problème consistent à initialiser correctement la matrice de poids ou à opter pour quelque chose comme un ReLU au lieu de fonctions sigmoïdes ou tanh.

Problème de gradient explosif

L'explosion de gradient est un problème où la valeur du gradient devient très grande et cela se produit souvent lorsque nous initialisons des poids plus importants et que nous pourrions nous retrouver avec NaN. Si notre modèle souffrait de ce problème, nous ne pouvons pas du tout mettre à jour les poids. Mais heureusement, le recadrage en dégradé est un processus que nous pouvons utiliser pour cela. A une valeur seuil prédéfinie, on coupe le dégradé. Cela empêchera la valeur du gradient de dépasser le seuil et nous ne nous retrouverons jamais avec de grands nombres ou NaN.

Dépendance à long terme aux mots

À présent, considérons une phrase comme, "Les nuages ​​sont dans le ____". Notre modèle RNN peut facilement prédire « Sky’ ici et cela est dû au contexte des nuages ​​et très bientôt cela vient en entrée de votre couche précédente. Mais ce n'est peut-être pas toujours le cas.

Imaginez si nous avions une phrase comme: « Jane est née au Kerala. Jane avait l'habitude de jouer pour l'équipe de football féminine et a également remporté les examens de niveau de l'État. Jane parle ____ couramment “.

C'est une très longue phrase et le problème ici est que, en tant qu'humain, je peux dire que, depuis que Jane est née au Kerala et a réussi son examen d'État, il est évident que vous devez maîtriser le “malayalam” très couramment. Mais, Comment notre machine le sait-elle? Au point où le modèle veut prédire les mots, vous avez peut-être oublié le contexte du Kerala et plus d'autre chose. C'est le problème de la dépendance à long terme sur RNN.

Unidirectionnel dans RNN

Comme nous l'avons déjà commenté, RNN prend les données de manière séquentielle et mot par mot ou lettre par lettre. À présent, quand nous essayons de prédire un mot particulier, nous ne pensons pas dans son contexte futur. C'est-à-dire, disons que nous avons quelque chose comme: "La souris est vraiment bien. La souris sert à ____ pour faciliter l'utilisation des ordinateurs “. À présent, si nous pouvons voyager dans les deux sens et que nous pouvons également voir le contexte futur, on peut dire que « Déplacement’ est le mot approprié ici. Mais, s'il est unidirectionnel, notre modèle n'a jamais vu d'ordinateurs, ensuite, Comment savez-vous si nous parlons de la souris animale ou de la souris d'ordinateur?

Ces problèmes sont résolus plus tard en utilisant des modèles de langage comme BERT, où nous pouvons entrer des phrases complètes et utiliser le mécanisme d'auto-attention pour comprendre le contexte du texte.

Utiliser la mémoire à court terme à long terme (LSTM)

Une façon de résoudre le problème du gradient de fuite et de la dépendance à long terme au RNN est d'opter pour les réseaux LSTM. LSTM a une introduction à trois portes appelées portes d'entrée, sortie et oubli. Dans lequel les portes de l'oubli s'occupent des informations qui doivent être autorisées à traverser le réseau. De cette façon, nous pouvons avoir une mémoire à court et à long terme. Nous pouvons transmettre les informations à travers le réseau et les récupérer même à un stade beaucoup plus tardif pour identifier le contexte de prédiction. Le schéma suivant montre le réseau LSTM.

1280px-the_lstm_cell-svg_-3503279

(https://en.wikipedia.org/wiki/Long_short-term_memory#/media/File:La_LSTM_Cell.svg)

Suivez ce tutoriel pour une meilleure compréhension et un exemple intuitif de LSTM: https://versdatascience.com/illustrated-guide-to-lstms-and-gru-sa-step-by-step-explanation-44e9eb85bf21

Avec chance, maintenant vous avez compris les problèmes d'utilisation d'un RNN et pourquoi nous avons opté pour des réseaux plus complexes comme LSTM.

Les références

1.http: //www.wildml.com/2015/10/recurrent-neural-networks-tutorial-part-3-backpropagation-through-time-and-vanishing-gradients/

2. https://analyticsindiamag.com/what-are-the-challenges-of-training-recurrent-neural-networks/

3. https://versdatascience.com/the-vanishing-gradient-problem-69bf08b15484

4. https://fr.wikipedia.org/wiki/Long_short-term_memory

5. https://www.udacity.com/course/deep-learning-nanodegree–nd101

6. Aperçu de l'image: https://unsplash.com/photos/Sot0f3hQQ4Y

conclusion

N'hésitez pas à me contacter sur:

1. https://www.linkedin.com/in/siddharth-m-426a9614a/

2. https://github.com/Siddharth1698

Les médias présentés dans cet article ne sont pas la propriété de DataPeaker et sont utilisés à la discrétion de l'auteur.

Abonnez-vous à notre newsletter

Nous ne vous enverrons pas de courrier SPAM. Nous le détestons autant que vous.