[PYTHON] Try to implement linear regression using Pytorch with Google Colaboratory

For beginners! It is assumed that you have a basic knowledge of deep learning. Let's implement linear regression using Pytorch, a machine learning library. It's easy, so it's a great introduction to machine learning.

What is Pytorch?

An open source machine learning library for Python, initially developed by Facebook's Artificial Intelligence Research Group. The reading is pie torch.

It has become popular in recent years among deep learning libraries. It is recommended because it is very easy to describe.

What is Google Colaboratory?

It is a python execution environment that can be used for free provided by Google.

GPU can also be used, and the libraries required for machine learning are pre-installed. It is difficult to build a machine learning environment locally, so I will use Google Colaboratory this time.

See below for how to use it. You will be able to use it in about 1 minute. https://qiita.com/shoji9x9/items/0ff0f6f603df18d631ab

Try to implement linear regression

Install the required libraries.

import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np

Let's create a random distribution model and plot it. Linear regression on this figure.

x = torch.randn(100, 1) * 10
y = x + torch.randn(100, 1) * 3
plt.plot(x.numpy(), y.numpy(), "o")
plt.ylabel("y")
plt.xlabel("x")

When I run it, I think the following figure is output. pytorch-liniear-regression01.png

Defines a model for linear regression. It inherits from nn.Module and defines a model in init that specifies the size of inputs and outputs.

class LR(nn.Module):
  def __init__(self, input_size, output_size):
    super().__init__()
    self.linear = nn.Linear(input_size, output_size)
  def forward(self, x):
    pred = self.linear(x)
    return pred

Fix the seed of the random number. Create an instance of the linear regression model.

torch.manual_seed(1)
model = LR(1, 1)

Define a function to retrieve model parameters.

[w, b] = model.parameters()
def get_params():
  return (w[0][0].item(), b[0].item())

Define the function to plot. It takes a parameter from the model and defines y1.

def plot_fit(title):
  plt_title = title
  w1, b1 = get_params()
  x1 = np.array([-30, 30])
  y1 = w1*x1 + b1
  plt.plot(x1, y1, "r")
  plt.scatter(x, y)
  plt.show()

Let's plot the figure before learning. You can see that the red line is off because it is before learning.

plot_fit("initial Model")

pytorch-liniear-regression02.png

Now it's time to prepare for learning. The loss function is defined as the root mean square error, and the learning method is defined as the stochastic gradient descent method. The learning rate is 0.01.

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

Let's learn! The epoch is set to 100. Record the loss.

epochs = 100
losses = []
for i in range(epochs):
  y_pred = model.forward(x)
  loss = criterion(y_pred, y)
  print("epoch:", i, "loss:", loss.item())
  
  losses.append(loss)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

Let's take a look at the learning process.

plt.plot(range(epochs), losses)
plt.ylabel("Loss")
plt.xlabel("epoch")

pytorch-liniear-regression03.png

You can see that you are learning. Let's plot the figure after learning.

plot_fit("Trained Model")

pytorch-liniear-regression04.png

You can see that you are learning properly.

that's all. Thank you for your hard work!

I tried to learn linear regression using pytorch. With Pytorch and Google Colaboratory, it's pretty easy to experience machine learning, so give it a try! !!

Recommended Posts

Try to implement linear regression using Pytorch with Google Colaboratory
Try to infer using a linear regression model on android [PyTorch Mobile]
How to analyze with Google Colaboratory using Kaggle API
Try using Python with Google Cloud Functions
How to search Google Drive with Google Colaboratory
I tried to implement CVAE with PyTorch
Using Java's Jupyter Kernel with Google Colaboratory
I tried to implement reading Dataset with PyTorch
Download files directly to Google Drive (using Google Colaboratory)
Linear regression with statsmodels
Try to implement yolact
Regression with linear model
Try regression with TensorFlow
I tried to implement and learn DCGAN with PyTorch
How to load files in Google Drive with Google Colaboratory
Try to make RESTful API with MVC using Flask 1.0.2
I tried to implement SSD with PyTorch now (Dataset)
Building an environment to use CaboCha with google colaboratory
How to analyze with Google Colaboratory using Kaggle API
Tuning your Django admin site
Try to factorial with recursion
Try using PythonTex with Texpad.
Try Google Mock with C
Try StyleGAN on Google Colaboratory
Try using matplotlib with PyCharm
Try an autoencoder with Pytorch
Study Python with Google Colaboratory
Try implementing XOR with PyTorch
Implement PyTorch + GPU with Docker
Linear regression method using Numpy
Try to determine food photos using Google Cloud Vision API
[Python] Linear regression with scikit-learn
I tried to implement SSD with PyTorch now (model edition)
How to use Google Colaboratory
Robust linear regression with scikit-learn
How to display formulas in latex when using sympy (> = 1.4) in Google Colaboratory
How to use Google Colaboratory and usage example (PyTorch x DCGAN)
I tried to implement sentence classification by Self Attention with PyTorch
Try to display google map and geospatial information authority map with python
Introduction to Bayesian Statistical Modeling with python ~ Trying Linear Regression with MCMC ~
Introduction to Deep Learning (2) --Try your own nonlinear regression with Chainer-
"Deep Learning from scratch" self-study memo (No. 13) Try using Google Colaboratory
[Statistics] [R] Try using quantile regression.
Machine learning beginners try linear regression
Try to profile with ONNX Runtime
Try using pynag to configure Nagios
Linear regression with Student's t distribution
How to Data Augmentation with PyTorch
Try to get statistics using e-Stat
OpenCV feature detection with Google Colaboratory
Try to output audio with M5STACK
Try using Python's networkx with AtCoder
Sine wave prediction (regression) with Pytorch
How to deal with OAuth2 error when using Google APIs from Python
(Machine learning) I tried to understand Bayesian linear regression carefully with implementation.
Try to predict FX with LSTM using Keras + Tensorflow Part 2 (Calculate with GPU)
Try to poke DB on IBM i with python + JDBC using JayDeBeApi
Try to reproduce color film with Python
Try logging in to qiita with Python
I tried to implement Autoencoder with TensorFlow
[Logistic regression] Implement k-validation with stats models
[Beginner] Python web scraping using Google Colaboratory