Thema in Kurzform
Die lineare Regression ist ein Algorithmus des überwachten maschinellen Lernens. Sie ermöglicht es auf Grundlage von Trainingsdaten einen gesuchten Zielwert zu ermitteln. Der Zielwert (Kriterium) steht in linearer Abhängigkeit zu einem Ausgangswert (Prädiktor).
In diesem Tutorial stellen wir dir einen Java-Algorithmus zur linearen Regression vor.
Die lineare Regression (genauer gesagt: einfache lineare Regression) gehört zu den grundlegenden Modellen des überwachten Maschinellen Lernens. Überwachtes Lernen bedeutet, dass wir dem Programm Trainingsdaten zur Verfügung stellen.
Mithilfe der Trainingsdaten wird es möglich, eine Vorhersage zu einem gesuchten Zielwert zu machen. Lineare Modelle gehen nämlich davon aus, dass es einen linearen Zusammenhang (Korrelation) zwischen den beobachteten Daten und einem abhängigen (von uns gesuchten) Zielwert gibt. Beim Zielwert handelt es sich um eine mehr oder weniger genaue Schätzung.
Für unser Beispiel wählen wir Trainingsdaten aus den Sonntags-Umsätzen einer gut besuchten Eisdiele. Der Zusammenhang besteht zwischen der Temperatur (Grad Celsius) und den Einnahmen in Euro. Die Höhe der Temperatur korreliert also mit dem Umsatz:
Temperatur (Prädiktoren) | 27 | 22 | 29 | 30 | 23 | 19 | 26 | ... |
Umsatz (Kriterium) | 540 | 300 | 500 | 580 | 310 | 320 | 400 | ... |
Die Temperatur-Werte sind unabhängige Variablen (Prädiktoren). Da die Umsatz-Werte sind von den Temperaturen abhängig sind, handelt es sich bei diesen um abhängige Variablen (Kriterium).
Die einzelnen Trainingsdaten lassen sich als Datenpunkte in einem Koordinatensystem abbilden. Die Temperaturen (Prädiktoren) bilden die X-Achse, die Umsätze (Kriterium) die Y-Achse.
Um nun zu einem beliebigen x-Wert einen gesuchten Zielwert auf der y-Achse vorauszusagen, benötigen wir die Regressionsgerade.
Die Regressionsgerade beschreibt den linearen Zusammenhang zwischen den einzelnen Datenpunkten und durchquert optimal, d.h. maximal gleichmäßig, die „Datenwolke“:
Erst mit der Regressionsgeraden wird eine Vorhersage für neue Daten möglich, das heißt, welcher Prädiktor mit welchem Kriterium im linearen Zusammenhang steht.
Wollen wir etwa für die Temperatur von 35 Grad den Umsatz schätzen, wäre der vorausgesagte Wert 679,39. Das können wir so auch grob an der Geraden ablesen.
Doch wie kommen wir zu der Regressionsgeraden? Das ist die eigentliche Kernfrage und die wollen wir nun mit einem Java-Algorithmus lösen.
Zunächst müssen wir uns aber noch klar machen, was eine Gerade überhaupt ist. Dazu bedarf es eines kleinen Ausflugs in die Mathematik.
Die Regressionsgerade ist nämlich eine lineare Funktion und lässt sich wie folgt darstellen:
y = m * x + b
Die Funktion beschreibt das Verhältnis der Variablen y und x. Daneben gibt es noch die Variablen m und b:
Wenn die Werte für m und b bekannt sind, können wir durch das Einsetzen des x-Werts (Prädiktor) den gewünschten y-wert (Kriterium) ermitteln.
Zum Beispiel: Bei der Funktion y = 2*x + 3 erhalten wir durch das Einsetzen x = 1 den y-Wert 5.
Gehen wir jetzt zum Java-Algorithmus. Zum Ermitteln der Funktion für die Regressionsgeraden benöten wir folgende Zutaten:
Der Trainingsdatensatz für die lineare Regression besteht immer aus einem Paar von zwei Listen (Prädiktoren und Kriterien). Es gibt verschiedene Wege, die beiden Listen in dein Java-Programm zu bringen: Mit einer csv-Datei, einer Datenbank etc. Wir haben uns hier der Einfachheit halber für zwei Arrays entschieden:
double[] x = {27,22,29,30,23,19,26,25,23,24}; // Prädiktorwerte
double[] y = {540,300,500,580,310,320,400,340,300,400}; // Kriterium
Nach dem Einlesen der Trainingsdaten benötigen wir die Mittelwerte aller Prädiktoren und Kriteriums-Werte. Dazu summieren wir jeweils die Werte der beiden Arrays und dividieren die Summen anschließend durch die Anzahl der Elemente des jeweiligen Arrays:
// Arithmetisches Mittel aller x-Werte
double x_sum = 0;
for(Double sV : x){
x_sum += sV;
}
double x_avg = x_sum/ x.length;
// Arithmetisches Mittel aller y-Werte
double y_sum = 0;
for(Double sV : y){
y_sum += sV;
}
double y_avg = y_sum/ y.length;
Als letzte Zutat benötigen wir noch zwei Summen von quadrarischen Abweichungen:
// Sxx ermitteln
double sxx_sum = 0;
for(Double sV : x){
sxx_sum += Math.pow((sV - x_avg), 2);
}
// Sxy ermitteln
double sxy_sum = 0;
for(int i = 0; i < x.length; i++){
sxy_sum += (x[i] - x_avg) * (y[i] - y_avg);
}
Jetzt haben wir alle Bestandteile, die wir für das Aufstellen der Regressionsfunktion benötigen und wir können die festen Werte für die Geradensteigung (m) und den y-Achsenabschnitt (b) bestimmen:
double m = sxy_sum/ sxx_sum;
double b = y_avg - (m * x_avg);
Endlich können wir loslegen! Setzen wir den Prädiktor-Wert von (z.B.) 35 in die Regressionsgleichung und erhalten den passenden Zielwert.
double praediktor = 35;
double zielwert = m * praediktor + b; // y = m * x + b
System.out.println(zielwert); // 679.3975903614457
Nice 😎
Den vollstänigen Code siehst du hier. Passe ihn einfach entsprechend deinen Trainingsdaten an und lege einen Prädiktor fest, um an den gewünschten Zielwert zu kommen. Viel Spaß damit 🙂
package sample;
public class LineareRegression {
public static void main(String[] args){
// Trainingsdaten
double[] x = {27,22,29,30,23,19,26,25,23,24};
double[] y = {540,300,500,580,310,320,400,340,300,400};
// Arithmetisches Mittel (Mittelwert) x
double x_sum = 0;
for(Double sV : x){
x_sum += sV;
}
double x_avg = x_sum/ x.length;
// Arithmetisches Mittel (Mittelwert) y
double y_sum = 0;
for(Double sV : y){
y_sum += sV;
}
double y_avg = y_sum/ y.length;
// Summen der quadratischen Abweichungen
double sxx_sum = 0;
for(Double sV : x){
sxx_sum += Math.pow((sV - x_avg), 2);
}
double sxy_sum = 0;
for(int i = 0; i < x.length; i++){
sxy_sum += (x[i] - x_avg) * (y[i] - y_avg);
}
// Geraden-Funktion ermitteln:
double m = sxy_sum/ sxx_sum; // Steigung
double b = y_avg - (m * x_avg); // y-Achsenabschnitt
// Zielwert bestimmen
double praediktor = 35;
double zielwert = m * praediktor + b;
System.out.println(zielwert); // 679.3975903614457
}
}
Java Basics
[Java einrichten] [Variablen] [Primitive Datentypen] [Operatoren] [if else] [switch-case] [Arrays] [Schleifen]
Objektorientierung
[Einstieg] [Variablen ] [Konstruktor] [Methoden] [Rekursion] [Statische Member] [Initializer] [Pass-by-value] [Objektsammlungen] [Objektinteraktion] [Objekte löschen]
Klassenbibliothek
[Allgemeines] [String ] [Math] [Wrapper] [Scanner] [java.util.Arrays] [Date-Time-API]