Gradient Descent i Java

1. Introduksjon

I denne opplæringen lærer vi om Gradient Descent-algoritmen. Vi implementerer algoritmen i Java og illustrerer den trinn for trinn.

2. Hva er Gradient Descent?

Gradient Descent er en optimaliseringsalgoritme som brukes til å finne et lokalt minimum av en gitt funksjon. Den brukes mye innen maskinlæringsalgoritmer på høyt nivå for å minimere tap av funksjoner.

Gradient er et annet ord for skråning, og nedstigning betyr å gå ned. Som navnet antyder, går Gradient Descent nedover skråningen til en funksjon til den når slutten.

3. Egenskaper ved Gradient Descent

Gradient Descent finner et lokalt minimum, som kan være forskjellig fra det globale minimumet. Utgangspunktet lokalt er gitt som en parameter til algoritmen.

Det er en iterativ algoritme, og i hvert trinn prøver den å bevege seg nedover skråningen og komme nærmere det lokale minimumet.

I praksis går algoritmen tilbake. Vi illustrerer og implementerer backtracking Gradient Descent i denne opplæringen.

4. Trinnvis illustrasjon

Gradient Descent trenger en funksjon og et startpunkt som input. La oss definere og plotte en funksjon:

Vi kan starte når som helst. La oss starte kl x=1:

I det første trinnet går Gradient Descent nedover skråningen med en forhåndsdefinert trinnstørrelse:

Deretter går det videre med samme trinnstørrelse. Imidlertid ender det denne gangen på en større y enn det siste trinnet:

Dette indikerer at algoritmen har passert det lokale minimumet, så den går bakover med en senket trinnstørrelse:

Deretter, når gjeldende y er større enn den forrige y, trinnstørrelsen senkes og negeres. Iterasjonen fortsetter til ønsket presisjon er oppnådd.

Som vi kan se, fant Gradient Descent et lokalt minimum her, men det er ikke det globale minimumet. Hvis vi begynner kl x= -1 i stedet for x= 1, vil det globale minimumet bli funnet.

5. Implementering i Java

Det er flere måter å implementere Gradient Descent på. Her beregner vi ikke den avledede funksjonen for å finne retningen på skråningen, så implementeringen vår fungerer også for ikke-differensierbare funksjoner.

La oss definere presisjon og trinnkoeffisient og gi dem innledende verdier:

dobbel presisjon = 0,000001; dobbelt trinn Koeffisient = 0,1;

I det første trinnet har vi ingen tidligere y til sammenligning. Vi kan enten øke eller redusere verdien av x for å se om y senker eller hever. En positiv trinnkoeffisient betyr at vi øker verdien av x.

La oss nå utføre det første trinnet:

doble forrigeX = initialX; doble forrigeY = f.apply (forrigeX); currentX + = trinnkoeffisient * forrige Y;

I koden ovenfor, f er en Funksjon, og initialX er en dobbelt, begge leveres som input.

Et annet viktig poeng å vurdere er at Gradient Descent ikke garanteres å konvergere. For å unngå å bli sittende fast i løkken, la oss ha en grense for antall iterasjoner:

int iter = 100;

Senere reduserer vi iter av en ved hver iterasjon. Derfor vil vi komme ut av løkken med maksimalt 100 iterasjoner.

Nå som vi har en previousX, kan vi sette opp løkken vår:

mens (forrige trinn> presisjon && iter> 0) {iter--; dobbeltstrøm Y = f.apply (currentX); hvis (nåværendeY> forrigeY) {stepCoefficient = -stepCoefficient / 2; } forrigeX = nåværendeX; currentX + = trinnkoeffisient * forrige Y; previousY = currentY; previousStep = StrictMath.abs (currentX - previousX); }

I hver iterasjon beregner vi den nye y og sammenlign den med den forrige y. Hvis gjeldende Y er større enn forrige, endrer vi retning og reduserer trinnstørrelsen.

Sløyfen fortsetter til trinnstørrelsen vår er mindre enn ønsket presisjon. Endelig kan vi komme tilbake currentX som det lokale minimumet:

returstrømX;

6. Konklusjon

I denne artikkelen gikk vi gjennom algoritmen Gradient Descent med en trinnvis illustrasjon.

Vi implementerte også Gradient Descent i Java. Koden er tilgjengelig på GitHub.


$config[zx-auto] not found$config[zx-overlay] not found