Künstliche Intelligenz in JAVA

Machine Learning: Lineare Regression

2023-05-30 | credit: sdecoret.stock.adobe

Thema in Kurzform

Die lineare Regression ist ein Algorithmus des überwachten maschinellen Lernens. Sie ermöglicht es auf Grundlage von Trainingsdaten einen gesuchten Zielwert zu ermitteln. Der Zielwert (Kriterium) steht in linearer Abhängigkeit zu einem Ausgangswert (Prädiktor). 

In diesem Tutorial stellen wir dir einen Java-Algorithmus zur linearen Regression vor.

Was ist lineare Regression?

Die lineare Regression (genauer gesagt: einfache lineare Regression) gehört zu den grundlegenden Modellen des überwachten Maschinellen Lernens. Überwachtes Lernen bedeutet, dass wir dem Programm Trainingsdaten zur Verfügung stellen. 

Mithilfe der Trainingsdaten wird es möglich, eine Vorhersage zu einem gesuchten Zielwert zu machen. Lineare Modelle gehen nämlich davon aus, dass es einen linearen Zusammenhang (Korrelation) zwischen den beobachteten Daten und einem abhängigen (von uns gesuchten) Zielwert gibt. Beim Zielwert handelt es sich um eine mehr oder weniger genaue Schätzung. 

Für unser Beispiel wählen wir Trainingsdaten aus den Sonntags-Umsätzen einer gut besuchten Eisdiele. Der Zusammenhang besteht zwischen der Temperatur (Grad Celsius) und den Einnahmen in Euro. Die Höhe der Temperatur korreliert also mit dem Umsatz:

Temperatur (Prädiktoren) 27 22 29 30 23 19 26 ...
Umsatz (Kriterium) 540 300 500 580 310 320 400 ...

Die Temperatur-Werte sind unabhängige Variablen (Prädiktoren). Da die Umsatz-Werte sind von den Temperaturen abhängig sind, handelt es sich bei diesen um abhängige Variablen (Kriterium).

Die Regressionsgerade

Die einzelnen Trainingsdaten lassen sich als Datenpunkte  in einem Koordinatensystem abbilden. Die Temperaturen (Prädiktoren) bilden die X-Achse, die Umsätze (Kriterium) die Y-Achse. 

Um nun zu einem beliebigen x-Wert einen gesuchten Zielwert auf der y-Achse vorauszusagen, benötigen wir die Regressionsgerade

Die Regressionsgerade beschreibt den linearen Zusammenhang zwischen den einzelnen Datenpunkten und durchquert optimal, d.h. maximal gleichmäßig, die „Datenwolke“:

Java Machine Learning Lineare Regression

Erst mit der Regressionsgeraden wird eine Vorhersage für neue Daten möglich, das heißt, welcher Prädiktor mit welchem Kriterium im linearen Zusammenhang steht. 

Wollen wir etwa für die Temperatur von 35 Grad den Umsatz schätzen, wäre der vorausgesagte Wert 679,39. Das können wir so auch grob an der Geraden ablesen.

Doch wie kommen wir zu der Regressionsgeraden? Das ist die eigentliche Kernfrage und die wollen wir nun mit einem Java-Algorithmus lösen. 

Was ist eine lineare Funktion?

Zunächst müssen wir uns aber noch klar machen, was eine Gerade überhaupt ist. Dazu bedarf es eines kleinen Ausflugs in die Mathematik.

Die Regressionsgerade ist nämlich eine lineare Funktion und lässt sich wie folgt darstellen: 

y = m * x + b

Die Funktion beschreibt das Verhältnis der Variablen y und x. Daneben gibt es noch die Variablen m und b

  • m: Steigung der Geraden
  • b: y-Achsenabschnitt

Wenn die Werte für m und b bekannt sind, können wir durch das Einsetzen des x-Werts (Prädiktor) den gewünschten y-wert (Kriterium) ermitteln. 

Zum Beispiel: Bei der Funktion y = 2*x + 3 erhalten wir durch das Einsetzen x = 1 den y-Wert 5.

Die Regressionsgerade ermitteln

Gehen wir jetzt zum Java-Algorithmus. Zum Ermitteln der Funktion für die Regressionsgeraden benöten wir folgende Zutaten: 

  • den Trainingsdatensatz
  • den Mittelwert aller x-Werte (Prädiktoren)
  • den Mittelwert aller y-Werte (Kriterium)
  • die Summen der quadratischen Abweichungen

Trainingsdatensatz

Der Trainingsdatensatz für die lineare Regression besteht immer aus einem Paar von zwei Listen (Prädiktoren und Kriterien). Es gibt verschiedene Wege, die beiden Listen in dein Java-Programm zu bringen: Mit einer csv-Datei, einer Datenbank etc. Wir haben uns hier der Einfachheit halber für zwei Arrays entschieden: 

double[] x = {27,22,29,30,23,19,26,25,23,24}; // Prädiktorwerte
double[] y = {540,300,500,580,310,320,400,340,300,400}; // Kriterium

Arithmetischen Mittelwert errechnen

Nach dem Einlesen der Trainingsdaten benötigen wir die Mittelwerte aller Prädiktoren und Kriteriums-Werte. Dazu summieren wir jeweils die Werte der beiden Arrays und dividieren die Summen anschließend durch die Anzahl der Elemente des jeweiligen Arrays: 

// Arithmetisches Mittel aller x-Werte
double x_sum = 0;
for(Double sV : x){
	x_sum += sV;
}
double x_avg = x_sum/ x.length;

// Arithmetisches Mittel aller y-Werte
double y_sum = 0;
for(Double sV : y){
	y_sum += sV;
}
double y_avg = y_sum/ y.length;

Summen der quadratischen Abweichungen

Als letzte Zutat benötigen wir noch zwei Summen von quadrarischen Abweichungen:

// Sxx ermitteln
double sxx_sum = 0;
for(Double sV : x){
	sxx_sum += Math.pow((sV - x_avg), 2);
}

// Sxy ermitteln
double sxy_sum = 0;
for(int i = 0; i < x.length; i++){
	sxy_sum += (x[i] - x_avg) * (y[i] - y_avg);
}

Funktion für Regressionsgerade fertigstellen

Jetzt haben wir alle Bestandteile, die wir für das Aufstellen der Regressionsfunktion benötigen und wir können die festen Werte für die  Geradensteigung (m) und den y-Achsenabschnitt (b) bestimmen: 

double m = sxy_sum/ sxx_sum; 
double b = y_avg - (m * x_avg);

Algorithmus testen und Zielwert bestimmen

Endlich können wir loslegen! Setzen wir den Prädiktor-Wert von (z.B.) 35 in die Regressionsgleichung und erhalten den passenden Zielwert. 

double praediktor = 35;
double zielwert = m * praediktor + b;  // y = m * x + b

System.out.println(zielwert); // 679.3975903614457

Nice 😎

Vollständiger Code

Den vollstänigen Code siehst du hier. Passe ihn einfach entsprechend deinen Trainingsdaten an und lege einen Prädiktor fest, um an den gewünschten Zielwert zu kommen. Viel Spaß damit 🙂

package sample;

public class LineareRegression {

    public static void main(String[] args){

        // Trainingsdaten
        double[] x = {27,22,29,30,23,19,26,25,23,24}; 
        double[] y = {540,300,500,580,310,320,400,340,300,400}; 

        // Arithmetisches Mittel (Mittelwert) x
        double x_sum = 0;
        for(Double sV : x){
            x_sum += sV;
        }
        double x_avg = x_sum/ x.length;

        // Arithmetisches Mittel (Mittelwert) y
        double y_sum = 0;
        for(Double sV : y){
            y_sum += sV;
        }
        double y_avg = y_sum/ y.length;

        // Summen der quadratischen Abweichungen
        double sxx_sum = 0;
        for(Double sV : x){
            sxx_sum += Math.pow((sV - x_avg), 2);
        }
        double sxy_sum = 0;
        for(int i = 0; i < x.length; i++){
            sxy_sum += (x[i] - x_avg) * (y[i] - y_avg);
        }

        // Geraden-Funktion ermitteln:
        double m = sxy_sum/ sxx_sum; // Steigung
        double b = y_avg - (m * x_avg); // y-Achsenabschnitt

        // Zielwert bestimmen
        double praediktor = 35;
        double zielwert = m * praediktor + b;
        System.out.println(zielwert); // 679.3975903614457

    }
}
Werbung

Java lernen

Werde zum Java Profi!

PHP Lernen

Lerne serverbasierte Programmierung

JavaScript lernen

Skille dein Webcoding

FALCONBYTE.NET

Handmade with 🖤️

© 2018-2023 Stefan E. Heller

Impressum | Datenschutz | Changelog

Falconbyte Youtube Falconbyte GitHub facebook programmieren lernen twitter programmieren lernen discord programmieren lernen