Logistisk regresjon i Java

1. Introduksjon

Logistisk regresjon er et viktig instrument i verktøykasse for maskinlæring (ML).

I denne veiledningen, vi vil utforske hovedideen bak logistisk regresjon.

Først, la oss starte med en kort oversikt over ML-paradigmer og algoritmer.

2. Oversikt

ML tillater oss å løse problemer som vi kan formulere på menneskelig vis. Imidlertid kan dette faktum være en utfordring for oss programvareutviklere. Vi har vant oss til å ta opp problemene vi kan formulere på datamaskinvennlige vilkår. For eksempel, som mennesker, kan vi enkelt oppdage gjenstandene på et bilde eller etablere stemningen til en setning. Hvordan kunne vi formulere et slikt problem for en datamaskin?

For å komme opp med en løsning, i ML er det en spesiell scene som heter opplæring. I løpet av dette stadiet mater vi inngangsdataene til algoritmen vår slik at den prøver å komme med et optimalt sett med parametere (de såkalte vektene). Jo mer inndata vi kan mate til algoritmen, jo mer presise spådommer kan vi forvente av den.

Trening er en del av en iterativ ML-arbeidsflyt:

Vi starter med å skaffe oss data. Ofte kommer dataene fra forskjellige kilder. Derfor må vi få det til å ha samme format. Vi bør også kontrollere at datasettet representerer studienes domene. Hvis modellen aldri har blitt trent på røde epler, kan den knapt forutsi den.

Deretter bør vi bygge en modell som vil konsumere dataene og være i stand til å komme med spådommer. I ML er det ingen forhåndsdefinerte modeller som fungerer bra i alle situasjoner.

Når du søker etter riktig modell, kan det lett skje at vi bygger en modell, trener den, ser dens spådommer og forkaster modellen fordi vi ikke er fornøyd med spådommene den gir. I dette tilfellet bør vi gå tilbake og bygge en annen modell og gjenta prosessen igjen.

3. ML-paradigmer

I ML, basert på hva slags inngangsdata vi har til rådighet, kan vi trekke frem tre hovedparadigmer:

  • veiledet læring (bildeklassifisering, gjenkjenning av gjenstander, sentimentanalyse)
  • uten tilsyn læring (deteksjon av avvik)
  • forsterkningslæring (spillstrategier)

Saken vi skal beskrive i denne opplæringen tilhører veiledet læring.

4. ML Verktøykasse

I ML er det et sett med verktøy som vi kan bruke når vi bygger en modell. La oss nevne noen av dem:

  • Lineær regresjon
  • Logistisk regresjon
  • Nevrale nettverk
  • Support Vector Machine
  • k-Nærmeste naboer

Vi kan kombinere flere verktøy når vi bygger en modell som har høy prediktivitet. Faktisk, for denne opplæringen, vil modellen vår bruke logistisk regresjon og nevrale nettverk.

5. ML-biblioteker

Selv om Java ikke er det mest populære språket for prototyping av ML-modeller,det har et rykte som et pålitelig verktøy for å skape robust programvare på mange områder, inkludert ML. Derfor kan det hende at vi finner ML-biblioteker skrevet på Java.

I denne sammenheng kan vi nevne de-facto standardbiblioteket Tensorflow som også har en Java-versjon. En annen verdt å nevne er et dyp læringsbibliotek kalt Deeplearning4j. Dette er et veldig kraftig verktøy, og vi skal også bruke det i denne opplæringen.

6. Logistisk regresjon ved siffergjenkjenning

Hovedideen med logistisk regresjon er å bygge en modell som forutsier etikettene til inngangsdataene så presist som mulig.

Vi trener modellen til den såkalte tapsfunksjonen eller objektive funksjonen når noen minimal verdi. Tapsfunksjonen avhenger av de faktiske modellspådommene og forventede (etikettene til inngangsdataene). Målet vårt er å minimere avviket mellom faktiske modellspådommer og forventede.

Hvis vi ikke er fornøyd med den minste verdien, bør vi bygge en annen modell og utføre opplæringen igjen.

For å se logistisk regresjon i aksjon illustrerer vi det på gjenkjenningen av håndskrevne sifre. Dette problemet har allerede blitt et klassisk problem. Deeplearning4j-biblioteket har en serie realistiske eksempler som viser hvordan du bruker API-en. Den koderelaterte delen av denne opplæringen er sterkt basert på MNIST klassifisering.

6.1. Inndata

Som inngangsdata bruker vi den velkjente MNIST-databasen med håndskrevne sifre. Som inndata har vi 28 × 28 piksler gråskala bilder. Hvert bilde har en naturlig etikett som er sifferet som bildet representerer:

For å estimere effektiviteten til modellen vi skal bygge, deler vi inndataene i opplæring og testsett:

DataSetIterator-tog = nytt RecordReaderDataSetIterator (...); DataSetIterator test = ny RecordReaderDataSetIterator (...);

Når vi har merket inngangsbildene og delt inn i de to settene, er "datautvikling" -fasen over, og vi kan gå videre til "modellbygningen".

6.2. Modellbygging

Som vi har nevnt, er det ingen modeller som fungerer bra i alle situasjoner. Likevel, etter mange års forskning i ML, har forskere funnet modeller som fungerer veldig bra når de gjenkjenner håndskrevne sifre. Her bruker vi den såkalte LeNet-5-modellen.

LeNet-5 er et nevralt nettverk som består av en serie lag som forvandler 28 × 28 pikselbildet til en ti-dimensjonal vektor:

Den ti-dimensjonale utgangsvektoren inneholder sannsynligheter for at inngangsbildets etikett er enten 0, 1, eller 2, og så videre.

For eksempel hvis utgangsvektoren har følgende form:

{0.1, 0.0, 0.3, 0.2, 0.1, 0.1, 0.0, 0.1, 0.1, 0.0}

det betyr at sannsynligheten for at inngangsbildet er null er 0,1, til ett er 0, for å være to er 0,3 osv. Vi ser at maksimal sannsynlighet (0,3) tilsvarer merkelapp 3.

La oss dykke ned i detaljer om modellbygging. Vi utelater Java-spesifikke detaljer og konsentrerer oss om ML-konsepter.

Vi satte opp modellen ved å lage en MultiLayerNetwork gjenstand:

MultiLayerNetwork-modell = ny MultiLayerNetwork (config);

I sin konstruktør skal vi passere en MultiLayerConfiguration gjenstand. Dette er selve objektet som beskriver geometrien til nevrale nettverk. For å definere nettverksgeometrien, bør vi definere hvert lag.

La oss vise hvordan vi gjør dette med den første og den andre:

ConvolutionLayer layer1 = nytt ConvolutionLayer .Builder (5, 5) .nIn (kanaler) .stride (1, 1) .nOut (20) .aktivering (Activation.IDENTITY) .build (); SubsamplingLayer layer2 = new SubsamplingLayer .Builder (SubsamplingLayer.PoolingType.MAX) .kernelSize (2, 2) .stride (2, 2) .build ();

Vi ser at lagdefinisjonene inneholder en betydelig mengde ad-hoc-parametere som påvirker hele nettverksytelsen betydelig. Det er akkurat der vår evne til å finne en god modell i landskapet til alle blir avgjørende.

Nå er vi klare til å konstruere MultiLayerConfiguration gjenstand:

MultiLayerConfiguration config = ny NeuralNetConfiguration.Builder () // klargjøringstrinn .liste () .lag (lag1) .lag (lag2) // andre lag og slutttrinn .bygg ();

at vi overfører til MultiLayerNetwork konstruktør.

6.3. Opplæring

Modellen som vi konstruerte inneholder 431080 parametere eller vekter. Vi kommer ikke til å gi her den nøyaktige beregningen av dette tallet, men vi bør være klar over at bare tdet første laget har mer enn 24x24x20 = 11520 vekter.

Treningsstadiet er så enkelt som:

model.fit (tog); 

I utgangspunktet har 431080-parametrene noen tilfeldige verdier, men etter opplæringen tilegner de seg noen verdier som bestemmer modellytelsen. Vi kan evaluere modellens prediktivitet:

Evalueringseval = modell.evaluere (test); logger.info (eval.stats ());

LeNet-5-modellen oppnår ganske høy nøyaktighet på nesten 99% selv i bare en enkelt treningsoppgave (epoke). Hvis vi ønsker å oppnå høyere nøyaktighet, bør vi gjøre flere iterasjoner ved hjelp av en vanlig for-loop:

for (int i = 0; i <epoker; i ++) {model.fit (tog); train.reset (); test.reset (); } 

6.4. Prediksjon

Nå, når vi trente modellen og vi er fornøyde med dens spådommer på testdataene, kan vi prøve modellen på noen helt nye innspill. For å oppnå dette, la oss lage en ny klasse MnistPrediksjon der vi laster inn et bilde fra en fil som vi velger fra filsystemet:

INDArray image = new NativeImageLoader (høyde, bredde, kanaler) .asMatrix (fil); ny ImagePreProcessingScaler (0, 1) .transform (bilde);

Variabelen bilde inneholder at bildet vårt blir redusert til 28 × 28 gråtoner. Vi kan mate den til vår modell:

INDArray-utgang = modell.utgang (bilde);

Variabelen produksjon vil inneholde sannsynligheten for at bildet er null, en, to osv.

La oss nå spille litt og skrive et siffer 2, digitalisere dette bildet og mate det modellen. Vi kan få noe slikt:

Som vi ser, har komponenten med maksimal verdi 0,99 indeks to. Det betyr at modellen har korrekt gjenkjent vårt håndskrevne siffer.

7. Konklusjon

I denne opplæringen beskrev vi de generelle begrepene maskinlæring. Vi illustrerte disse begrepene på logistisk regresjonseksempel som vi brukte på en håndskrevet siffergjenkjenning.

Som alltid kan vi finne de tilsvarende kodebitene på GitHub-depotet vårt.


$config[zx-auto] not found$config[zx-overlay] not found