Diejenigen, die Grundkenntnisse über das Convolutional Neural Network (CNN) von Deep Learning haben und die Bedeutung der folgenden Wörter verstehen Beispiel) --Falten
Es ist eine der CNN-Methoden und kann mehr Schichten als andere CNNs hinzufügen. Als Merkmal werden am Ende des Moduls die Eingabedaten des Moduls zu den im Modul verarbeiteten Daten hinzugefügt (Verknüpfungsverbindung). Weitere Informationen finden Sie unter hier.
GoogleColaboratory
#Installation der erforderlichen Bibliotheken
import tensorflow.keras as keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.datasets import cifar10
#Eine Klasse, die CIFAR10-Daten in einen Vektor konvertiert
class CIFAR10Dataset():
def __init__(self):
self.image_shape = (32, 32, 3)
self.num_classes = 10
#Erfassen Sie Trainingsdaten und Testdaten.
def get_batch(self):
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = [self.change_vec(img_data) for img_data in [x_train, x_test]]
y_train, y_test = [self.change_vec(img_data, label_data=True) for img_data in [y_train, y_test]]
return x_train, y_train, x_test, y_test
#Wenn es sich um eine Zielvariable handelt, ändern Sie sie in einen Klassenvektor. Erklärende Variablen sind standardisiert.
def change_vec(self, img_data, label=False):
if label:
data = keras.utils.to_categorical(img_data, self.num_classes)
else:
img_data = img_data.astype("float32")
img_data /= 255
shape = (img_data.shape[0],) + self.image_shape
img_data = img_data.reshape(shape)
return img_data
#Eine Funktion, die ein Deep-Learning-Modell festlegt und zurückgibt
def network(input_shape, num_classes, count):
filter_count = 32
inputs = Input(shape=input_shape)
x = Conv2D(32, kernel_size=3, padding="same", activation="relu")(inputs)
x = BatchNormalization()(x)
for i in range(count):
shutcut = x #Rufen Sie die Eingabedaten des Moduls für die Verknüpfungsverbindung ab.
x = Conv2D(filter_count, kernel_size=3, padding="same")(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Dropout(rate=0.3)(x)
x = Conv2D(filter_count, kernel_size=3, padding="same")(x)
x = BatchNormalization()(x)
x = Concatenate()([x, shutcut]) #Verknüpfungsverbindung
if i != count - 1:
x = MaxPooling2D(pool_size=2)(x)
filter_count = filter_count * 2
x = Flatten()(x)
x = BatchNormalization()(x)
x = Dense(1024, activation="relu")(x)
x = Dropout(rate=0.3)(x)
x = BatchNormalization()(x)
x = Dense(1024, activation="relu")(x)
x = Dropout(rate=0.3)(x)
x = BatchNormalization()(x)
x = Dense(num_classes, activation="softmax")(x)
model = Model(inputs=inputs, outputs=x)
print(model.summary())
return model
#Klasse, um das Modell zu trainieren
class Trainer():
#Kompilieren Sie das Modell und legen Sie die Einstellungen für das Training in privaten Objekten fest.
def __init__(self, model, loss, optimizer):
self._model = model
self._model.compile(
loss=loss,
optimizer=optimizer,
metrics=["accuracy"]
)
self._verbose = 1
self._batch_size = 128
self._epochs = 30
#Tatsächliches Lernen
def fit(self, x_train, y_train, x_test, y_test):
self._model.fit(
x_train,
y_train,
batch_size=self._batch_size,
epochs=self._epochs,
verbose=self._verbose,
validation_data=(x_test, y_test)
)
return self._model
dataset = CIFAR10Dataset() #Instanziierung des CIFAR10-Datensatzes zum Abrufen von Daten
model = network(dataset.image_shape, dataset.num_classes, 4) #Holen Sie sich Modell
x_train, y_train, x_test, y_test = dataset.get_batch() #Erfassung von Trainingsdaten und Testdaten
trainer = Trainer(model, loss="categorical_crossentropy", optimizer="adam") #Trainer-Instanziierung mit Modell, Verlustfunktion und Optimierungsalgorithmus als Argumenten
model = trainer.fit(x_train, y_train, x_test, y_test) #Modelllernen
#Modellbewertung
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss: ', score[0])
print('Test accuracy: ', score[1])
Intuitives tiefes Lernen Warum ResNet eine gute Leistung zeigt
Recommended Posts