Explanation of the concept of regression analysis using Python Part 1

Recently, machine learning, artificial intelligence, etc. have become popular, and I think that knowledge about statistics, which is the basis of them, is important. Therefore, I would like to try to explain the principle of regression analysis, whose effect is easy to understand even in statistics, with the aim of being able to conceptually understand it while calculating and drawing graphs in Python. I am not a statistician, so if you have any suggestions or comments, please feel free to contact me. I think there are some points that are not mathematically strict, but please forgive me ...

data set###

First, get the dataset.

cars data

This page uses Python for explanation, but the data used is the cars data of the dataset in the statistical analysis software R. Please download and use the csv data from here. (However, according to the description of this data, it looks like the data of the 1920s, so it is just a sample data.) In this explanation, it is assumed that it is saved as "cars.csv".

teble definition of cars data

[,0]    index
[,1]	speed	numeric	Speed (mph)
[,2]	dist	numeric	Stopping distance (ft)

A detailed explanation of the data can be found at here, but the speed of the car and its This is data that collects 50 sets of stopping distances when the brakes are applied at speed.

Python preparation

Python uses version 2.7. It is also assumed that the following libraries are already installed.

Import these.


import numpy as np
import matplotlib.pyplot as plt

Reading data

Read the data and first draw a scatter plot. To get an image of the data, it is easy to draw a graph.

By the way, since the original data is a unit that is not familiar to Japanese people such as miles and feet, we will convert the unit to meters.

1 foot ≒ 0.3048 meters 1 mph   ≒ 1.61 km/h So, convert the unit of data as follows.


data= np.loadtxt('cars.csv',delimiter=',',skiprows=1)
data[:,1] = map(lambda x: x * 1.61, data[:,1])    #km from mph/Convert to h
data[:,2] = map(lambda y: y * 0.3048, data[:,2])  #Convert from ft to m

Then draw a scatter plot based on that data.


fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(111)
ax.set_title("Stopping Distances of Cars")
ax.set_xlabel("speed(km/h)")
ax.set_ylabel("distance(m)")
plt.scatter(data[:,1],data[:,2])

cardistans_scatter.png

What is regression analysis?

Using this data, regression analysis can be used to determine how far a car running at a certain speed can stop completely when suddenly braking. In other words


distance = \alpha * speed + \beta

It can be expressed by a linear equation like this. This is called simple regression analysis. (Since there is only one variable, it is "single") This straight line fitting is sometimes called a first-order approximation. $ \ Alpha $ is the slope of a line, and $ \ beta $ is the intercept value. Let's call this straight line an approximate straight line.

スクリーンショット 2015-02-15 15.23.38.png

I will try to apply a straight line to this visually. The graph that was visually applied is shown below, but I set $ \ alpha $ to 0.74 and $ \ beta $ to -5. how is it? Isn't it possible that such a straight line can be applied somehow?

# y = 0.74x -5 straight lines((0,-5),(50,32)Pass through)
x = [0,  50]
y = [-5, 32]
plt.plot(x,y)

cardistans_scatter2.png

However, I just applied this visually, so I don't know if it's really the best straight line. Let's see how to find the optimal $ \ alpha $ and $ \ beta $, the least squares method.

Least squares

What is the least squares method, what is the "minimum"? What do you "square"?

What to minimize?

The minimum is the error between the straight line and the data. The error is the vertical line drawn from each point to the approximate straight line. See the graph below.


# line: y = 0.74 -5
x = [0,  50]
y = [-5, 32]
plt.plot(x,y)

# draw errors
for d in data:
    plt.plot([d[1],d[1]],[d[2],d[1]*0.74-5],"k")

cardistans_scatter_with_errors.png

The error is represented by this black line. Intuitively, I feel that the straight line that adds all of this error to the minimum is the best fit for this data.

Let's see what happens when we change $ \ alpha $ and $ \ beta $ for the straight line that we applied visually earlier.

First, let's take a look at the straight line that changed $ \ alpha $.


# line: y = 0.54x -5
x = [0,  50]
y = [-5, 72]
plt.plot(x,y)

# draw errors
for d in data:
    plt.plot([d[1],d[1]],[d[2],d[1]*1.54-5],"k")

cardistans_scatter_with_errors_a.png

Let's take a look at the modified $ \ beta $.


# line: y = 0.74x +15
x = [0,  50]
y = [15, 52]
plt.plot(x,y)

# draw errors
for d in data:
    plt.plot([d[1],d[1]],[d[2],d[1]*0.74+15],"k")

cardistans_scatter_with_errors_b.png

How is it? It seems that this error is small when you first draw an approximate straight line that looks good visually and the one that dares to shift $ \ alpha $ and $ \ beta $. The least squares method is a method to find $ \ alpha $ and $ \ beta $ that minimize this error.

What to square?

This also squares the error. The reason is that I want to find the distance between the straight line and the straight line connecting the points of each data, but if it is left as it is, the error will be positive for the data above the straight line and negative for the data below it, so squared. Then, take the plus and minus and make everything positive.

Let's calculate

Now that you know what to square and minimize, let's proceed with the concrete calculation.

First from the definition of terms. The $ i $ th data is $ x_i, y_i $, and its approximation is $ \ hat {y} _i $.

Also, let the error be $ \ epsilon_i $. Let's expand the 19th data.


i = 18
x_i = data[i,1]
y_i = data[i,2]
y_hat = x_i*0.74-5
ax.set_ylim(0,y_i+3)
ax.set_xlim(x_i-5,x_i+5)

plt.plot([x_i,x_i],[y_i,y_hat],"k")
plt.plot(x,y)

plt.scatter([x_i],[y_i])

スクリーンショット 2015-02-15 16.23.03.png

Now, if you square this error and add all the data together,


S = \sum_i^n\epsilon_i^2=\sum_i^n (y_i-\hat{y}_i )^2

Can be expressed as.


\hat{y}_i = \alpha  x_i + \beta

So, if you substitute the approximate value data, it will be as follows.


S = \sum_i^n \epsilon_i^2 = \sum_i^n  ( y_i-\alpha x_i - \beta )^2 

The optimum approximate straight line can be obtained by differentiating this $ S $ with the parameters $ \ alpha and \ beta $ and finding the minimum value. Expressing this $ S $ by the equation of $ \ alpha $,


S(\alpha) = \left( \sum_i^n x_i^2 \right) \alpha^2
 + 2\left( \sum_i^n (x_i\beta - x_i y_i ) \right) \alpha 
 + n\beta^2 - 2\beta\sum_i^n y_i + \sum_i^n y_i^2

Will be. What kind of function is this $ S (\ alpha) $? It's a quadratic function of $ \ alpha $. Since the coefficient of $ \ alpha $ is the sum of squares, it is always 0 or a positive value, so it is a downwardly convex quadratic function. Here, in order to give an image of the shape, let's draw a graph assuming that $ \ beta = 0 $.


S(\alpha) = \left( \sum_i^n x_i^2 \right) \alpha^2
 - 2\left( \sum_i^n x_i y_i \right) \alpha 
 + \sum_i^n y_i^2   ...  (if \beta = 0 )

It will be easy. Let's calculate this coefficient from the data.

sum_x2 = np.sum([x ** 2 for x in data[:,1]])   # \sum x_i^2 
sum_y2 = np.sum([x ** 2 for y in data[:,2]])   # \sum y_i^2 
sum_xy = data[:,1].dot(data[:,2])              # \sum x_i y_i

print sum_x2
>>> 34288.2988

print sum_y2
>>> 11603.8684051

print sum_xy
>>> 18884.194896

Therefore,


S(\alpha) ≒ 34288 \alpha^2 - 37768 \alpha + 11604 

When you draw this graph,

x1 = np.linspace(-1,5,200)
x1_2 = np.array([x ** 2 for x in x1])

#34288α2−37768α+ 11604
y1 = np.array(34288 * x1_2) - (37768 * x1) + 11604

plt.plot(x1,y1)


# Y =Line of 11604
plt.plot([-1,5],[11604, 11604])
plt.plot([0,0],[13000, 0], "--k")

minimize_a3.png

If you shift $ \ alpha $, you will find the minimum value somewhere.

Similarly for $ \ beta $


S(\beta) = n\beta^2
+ 2 \left( \sum_i^n (x_i\alpha - y_i) \right) \beta
+ \alpha^2\sum_i^n x_i^2 - 2\alpha \sum_i^n x_iy_i + \sum_i^n y_i^2

Let's draw a graph assuming that $ \ alpha = 0 $.


S(\beta) = n\beta^2
- 2 \left( \sum_i^n y_i \right) \beta + \sum_i^n y_i^2
   ...  (if \alpha = 0 )

Since there was no $ \ sum y_i $ in the value calculated earlier, when I calculated it (and also $ \ sum x_i $),


sum_x = np.sum(data[:,1])
print sum_x
>>> 1239.7

sum_y = np.sum(data[:,2])
print sum_y
>>> 655.0152

Therefore, the quadratic equation for $ \ beta $ is


S(\beta) ≒ 50\beta^2 - 1310 \beta + 11604

And draw a graph


x1 = np.arange(-100,100,0.01)
x1_2 = np.array([x ** 2 for x in x1])
n = len(data[:,2])

# nβ^2-1310β+11604 
y1 = np.array(n * x1_2) - (1310 * x1) + 11604 

fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(111)
ax.set_xlim(-5,30)
ax.set_ylim(0,20000)
ax.set_xlabel("beta")
plt.plot([-50,30],[11604,11604],"g")
plt.plot([0,0],[30000,0],"--k")
plt.plot(x1,y1)


minimize_b3.png

And, after all, if you shift $ \ beta $, you can see that a certain $ \ beta $ takes the minimum value.

Continue to Part 2.

Recommended Posts

Explanation of the concept of regression analysis using python Part 2
Explanation of the concept of regression analysis using Python Part 1
Explanation of the concept of regression analysis using Python Extra 1
Cut a part of the string using a Python slice
Review the concept and terminology of regression
Time variation analysis of black holes using python
[Python] Read the source code of Bottle Part 2
Shortening the analysis time of Openpose using sound
Try using the Python web framework Tornado Part 1
[Python] Read the source code of Bottle Part 1
Try using the collections module (ChainMap) of python3
Find the geometric mean of n! Using Python
Try using the Python web framework Tornado Part 2
[Python + OpenCV] Whiten the transparent part of the image
the zen of Python
Basics of regression analysis
Regression analysis in Python
Predicting the future of Numazu's population transition by time-series regression analysis using Prophet
[Python] [Word] [python-docx] Simple analysis of diff data using python
Python --Explanation and usage summary of the top 24 packages
A python implementation of the Bayesian linear regression class
The pain of gRPC using Python. November 2019. (Personal memo)
Study from the beginning of Python Hour8: Using packages
Towards the retirement of Python2
About the ease of Python
Static analysis of Python programs
python: Basics of using scikit-learn ①
Simple regression analysis in Python
About the features of Python
Data analysis using python pandas
Basics of Python × GIS (Part 1)
The Power of Pandas: Python
3. Natural language processing with Python 5-1. Concept of sentiment analysis [AFINN-111]
Think about the analysis environment (Part 1: Overview) * As of January 2017
View using the python module of Nifty Cloud mobile backend
[Python] I tried collecting data using the API of wikipedia
The story of Python and the story of NaN
Japanese Natural Language Processing Using Python3 (4) Sentiment Analysis by Logistic Regression
Image capture of firefox using python
Wrap (part of) the AtCoder Library in Cython for use in Python
First Python 3 ~ The beginning of repetition ~
[Python] I wrote the route of the typhoon on the map using folium
[Python] PCA scratch in the example of "Introduction to multivariate analysis"
Feature extraction by TF method using the result of morphological analysis
Removal of haze using Python detailEnhanceFilter
Existence from the viewpoint of Python
[Python] LINE notification of the latest information using Twitter automatic search
pyenv-change the python version of virtualenv
Change the Python version of Homebrew
[In-Database Python Analysis Tutorial with SQL Server 2017] Step 6: Using the model
[Python] Understanding the potential_field_planning of Python Robotics
Evaluate the performance of a simple regression model using LeaveOneOut cross-validation
Review of the basics of Python (FizzBuzz)
Basics of Python x GIS (Part 2)
Implementation of desktop notifications using Python
Get and set the value of the dropdown menu using Python and Selenium
Extract the targz file using python
Japanese analysis processing using Janome part1
Try using the Python Cmd module
From the introduction of JUMAN ++ to morphological analysis of Japanese with Python
Recommendation of data analysis using MessagePack