Here, we will explain the basic usage of the machine learning library scikit-learn. Machine learning algorithms will be covered in another article. It is supposed to use Python3 series.
Like other libraries, it can be read with ʻimport, but as described below, when actually using it, it is often read with ʻimport
and from
.
scikit-learn_1.py
import sklearn
scikit-learn has various datasets that can be used for machine learning. You can find out what dataset you have by running the code below.
scikit-learn_2.py
import sklearn.datasets
[s for s in dir(sklearn.datasets) if s.startswith('load_')]
Here, we will use the ʻiris` (iris) dataset in the above dataset. Consider using linear regression to predict the calyx width from the calyx length. First, prepare the data.
scikit-learn_3.py
from sklearn.datasets import load_iris
import pandas as pd
data_iris = load_iris()
X = pd.DataFrame(data_iris.data, columns=data_iris.feature_names)
x = X.iloc[:, 0] #The length of the iris calyx
y = X.iloc[:, 1] #Width of iris calyx
When the data is ready, perform a linear regression.
scikit-learn_4.py
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score
import matplotlib.pyplot as plt
%matplotlib inline
X_train = [[5.1], [4.9], [4.7], [4.6], [5.0], [5.4], [4.6], [5.0], [4.4], [4.9]]
y_train = [3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1]
model = LinearRegression()
model.fit(X_train, y_train) #Create a linear regression model
print(model.coef_) #Tilt
print(model.intercept_) #Intercept
X_test = [[5.4], [4.8], [4.8], [4.3], [5.8]]
y_test = [3.7, 3.4, 3.0, 3.0, 4.0]
y_pred = model.predict(X_test) #Forecast
print(y_pred)
fig, ax = plt.subplots()
ax.scatter(X_test, y_test, label='Test set') #Scatter plot of measured values
ax.plot(X_test, y_pred, label = 'Regression curve') #Regression line
ax.legend()
plt.show() #Illustrates the data used for forecasting
plt.savefig('scikit-learn_4.png')
print(r2_score(y_test, y_pred)) # R^2 values
The data for the test and the regression line are shown below.
The final R ^ 2 value indicates how well the model fits, but the parameters you see will vary depending on whether it is regression or classification, and other purposes.
Here, we have explained the basic parts of scikit-learn. It's a good idea to get a rough idea of the process of preparing a dataset, preprocessing data, creating a predictive model, and validating the model.
I can't hear you anymore! What is machine learning? Why is Python used?
Recommended Posts