Introduction to Python Data Visualization 8 - Linear Regression Analysis with Numpy

힘센캥거루
2023년 11월 21일(수정됨)
13
python

In the previous post, we drew graphs using subplots in Matplotlib. 

In this post, we’ll try doing linear regression analysis using numpy’s polynomial.

1. What is linear regression analysis?

Wikipedia describes linear regression as follows.

In statistics, linear regression is a regression analysis method for modeling the linear relationship between a dependent variable y and one or more independent variables X.
-Wikipedia-

To put it very simply, you could say it’s the average slope of the values plotted on the graph.

When you draw linear regression on a graph, it can show the correlation between two values— which you might otherwise only roughly see with your eyes—much more clearly.

Introduction to Python Data Visualization 8 - Linear Regression Analysis with Numpy-1

The graph above is a visualization of data downloaded from Kaggle.

It contains parents’ education levels, race, and students’ performance; among them, we plotted the relationship between math and reading scores.

You can download the file attached below, or use the link to download it. The file below has been localized into Korean, so choose whichever is more convenient for you.

2. Drawing the graph

First, let’s draw a scatter plot of students’ math and reading scores, like the graph above.

Start with exactly the same code we used in the previous lesson, and only change the path of the file you load.

import pandas as pd

# 모듈 호출 및 한글폰트 설정
import matplotlib.pyplot as plt
import matplotlib

# MacOS에서 폰트설정
# matplotlib.rcParams["font.family"] = "AppleGothic"

# 윈도우에서 폰트설정
matplotlib.rcParams["font.family"] = "Malgun Gothic"

# 폰트 크기 설정
matplotlib.rcParams["font.size"] = 13

# 마이너스 출력 문제 해결
plt.rcParams['axes.unicode_minus'] = False

score = pd.read_excel("./StudentsPerformance.xlsx")
score.head(3)
Introduction to Python Data Visualization 8 - Linear Regression Analysis with Numpy-2

Now use this to pass the math and reading scores as parameters to plt.scatter.

plt.scatter(score["수학점수"], score["읽기점수"])
Introduction to Python Data Visualization 8 - Linear Regression Analysis with Numpy-3

The graph doesn’t look very nice yet, so I decorated it a bit more.

I set the color and transparency of the graph and added labels for each axis.

plt.scatter(score["수학점수"], score["읽기점수"], alpha=0.4, color="green")
plt.xlabel("수학점수")
plt.ylabel("읽기점수")
Introduction to Python Data Visualization 8 - Linear Regression Analysis with Numpy-4

Now that the basic graph is done, let’s use numpy to perform linear regression analysis.

3. Polynomial

Import numpy’s polynomial, then pass in the x values, y values, and the degree of the function you want to fit as parameters.

We’ll perform linear regression by fitting math and reading scores with a first-degree (linear) function.

from numpy.polynomial import Polynomial

f = Polynomial.fit(score["수학점수"], score["읽기점수"], 1)

When you do this, Polynomial returns the estimated linear function.

Therefore, f becomes a function that takes an x value as a parameter.

Let’s check the predicted value by entering the following.

from numpy.polynomial import Polynomial

f = Polynomial.fit(score["수학점수"], score["읽기점수"], 1)
f(40)
Introduction to Python Data Visualization 8 - Linear Regression Analysis with Numpy-5

The predicted reading score of a student who scored 40 in math is 40.

Now let’s draw the graph.

4. Linear regression graph

In the dataset, the students’ math and reading scores are not sorted from 0 to 100.

So if you plug the math scores directly into f as parameters, the higher the degree, the more jumbled the graph becomes

Introduction to Python Data Visualization 8 - Linear Regression Analysis with Numpy-6Introduction to Python Data Visualization 8 - Linear Regression Analysis with Numpy-7

First, generate numbers from 0 to 100, then use these as the x values for the linear function to draw the graph.

numpy’s linspace takes the start point, end point, and the number of values to generate as parameters and creates the values.

import numpy as np

x = np.linspace(0,100,200)
plt.plot(x,f(x))

If you check the value of x, you can see on the left how it was generated.

And if you draw a graph with this, you’ll get the following.

Introduction to Python Data Visualization 8 - Linear Regression Analysis with Numpy-8Introduction to Python Data Visualization 8 - Linear Regression Analysis with Numpy-9

5. Completing the graph

Now all we have to do is overlay the two graphs.

plt.scatter(score["수학점수"], score["읽기점수"], alpha=0.4, color="green")
plt.xlabel("수학점수")
plt.ylabel("읽기점수")
plt.plot(x, f(x),"r--")
Introduction to Python Data Visualization 8 - Linear Regression Analysis with Numpy-10

6. In closing

In this post, we drew a graph to see what kind of relationship two values have through linear regression analysis.

In the next post, I’m going to think about how to fill out high school student records using data visualization.

댓글을 불러오는 중...