PyTorch Mobile Es kam ungefähr im Oktober letzten Jahres (2019) heraus. Mit Tensolflow Lite war maschinelles Lernen mit Android iOS möglich, aber schließlich wurde Pytorch 1.3 für Handys veröffentlicht. Es ist das Beste von der Seite der Verwendung von Pytorch anstelle von Tensorflow! Wie Tensorflow Lite kann es mit Android iOS verwendet werden.
Hier klicken für Details Offizielle Website von PyTorch Mobile: https://pytorch.org/mobile/home/
Von der offiziellen Website
Machen Sie das Tutorial auf der offiziellen Website. Schreiben Sie in Kotlin! Klassifizieren Sie Bilder mit dem trainierten Modell von resNet. (Nur Inferenz)
github wird unter https://github.com/SY-BETA/PyTorchMobile veröffentlicht
So ↓ Es ist so einfach wie das Anzeigen der zu klassifizierenden Bilder, der beiden besten Klassifizierungsergebnisse und ihrer Ergebnisse. (Was ist Canis Lupus?)
--python Ausführungsumgebung (ich habe es mit Jupyter Notebook gemacht) --pytorch, torchVision (neueste Version empfohlen)
Nur das
Erstellen Sie zunächst ein neues Projekt in Android Studio. Erstellen Sie einen Assets-Ordner in diesem Projekt. (Sie können dies tun, indem Sie mit der rechten Maustaste auf die App auf der linken Seite des Ordners UI-> Neu-> Ordner-> Assets klicken.) Führen Sie nach der Erstellung den folgenden Python-Code in derselben Hierarchie wie der App-Ordner für dieses Projekt aus
createModel.py
import torch
import torchvision
#Verwenden Sie das Resnet-Modell
model = torchvision.models.resnet18(pretrained=True)
#Im Inferenzmodus
model.eval()
example = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example)
traced_script_module.save("app/src/main/assets/resnet.pt")
Wenn es erfolgreich ausgeführt werden kann, wird eine Datei mit dem Namen "resnet.pt" zum zuvor erstellten Assets-Ordner hinzugefügt.
Speichern Sie die folgenden Beispielbilder im Assets-Ordner und im Zeichenordner mit dem Namen "image.jpg "
Folgendes wurde zu gradle hinzugefügt (Stand: 4. Januar 2020)
dependencies {
implementation 'org.pytorch:pytorch_android:1.3.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.3.0'
}
Layout entsprechend erstellen Layout mit nur 1 Bild und 6 Texten vertikal
activity_main.xml
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<TextView
android:id="@+id/textView"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Input"
android:textSize="30sp"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toTopOf="parent" />
<ImageView
android:id="@+id/imageView"
android:layout_width="wrap_content"
android:layout_height="230dp"
android:scaleType="fitCenter"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/textView"
app:srcCompat="@drawable/image" />
<TextView
android:id="@+id/textView2"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:text="Result"
android:textSize="30sp"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/imageView" />
<TextView
android:id="@+id/result1Score"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="32dp"
android:text="TextView"
android:textSize="18sp"
app:layout_constraintBottom_toTopOf="@+id/result1Class"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/textView2" />
<TextView
android:id="@+id/result1Class"
android:layout_width="250dp"
android:layout_height="wrap_content"
android:layout_marginStart="40dp"
android:layout_marginTop="8dp"
android:layout_marginEnd="40dp"
android:gravity="center"
android:text="TextView"
android:textSize="18sp"
app:layout_constraintBottom_toTopOf="@+id/result2Score"
app:layout_constraintEnd_toEndOf="@+id/result1Score"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="@+id/result1Score"
app:layout_constraintTop_toBottomOf="@+id/result1Score" />
<TextView
android:id="@+id/result2Score"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_marginTop="24dp"
android:text="TextView"
android:textSize="18sp"
app:layout_constraintBottom_toTopOf="@+id/result2Class"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="parent"
app:layout_constraintTop_toBottomOf="@+id/result1Class"
app:layout_constraintVertical_bias="0.94" />
<TextView
android:id="@+id/result2Class"
android:layout_width="250dp"
android:layout_height="wrap_content"
android:layout_marginStart="40dp"
android:layout_marginTop="8dp"
android:layout_marginEnd="40dp"
android:layout_marginBottom="32dp"
android:gravity="center"
android:text="TextView"
android:textSize="18sp"
app:layout_constraintBottom_toBottomOf="parent"
app:layout_constraintEnd_toEndOf="@+id/result2Score"
app:layout_constraintHorizontal_bias="0.5"
app:layout_constraintStart_toStartOf="@+id/result2Score"
app:layout_constraintTop_toBottomOf="@+id/result2Score" />
</androidx.constraintlayout.widget.ConstraintLayout>
Laden Sie die zuvor erstellte resnet.pt
MainActivity.kt
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
////Funktion zum Abrufen des Pfads aus der Asset-Datei
fun assetFilePath(context: Context, assetName: String): String {
val file = File(context.filesDir, assetName)
if (file.exists() && file.length() > 0) {
return file.absolutePath
}
context.assets.open(assetName).use { inputStream ->
FileOutputStream(file).use { outputStream ->
val buffer = ByteArray(4 * 1024)
var read: Int
while (inputStream.read(buffer).also { read = it } != -1) {
outputStream.write(buffer, 0, read)
}
outputStream.flush()
}
return file.absolutePath
}
}
///Laden Sie Modelle und Bilder
///Serialisiertes Modell laden
val bitmap = BitmapFactory.decodeStream(assets.open("image.jpg "))
val module = Module.load(assetFilePath(this, "resnet.pt"))
}
Beachten Sie, dass das Laden von Bildern und Modellen aus dem Assets-Ordner sehr umständlich sein kann.
Geben Sie ein Beispielbild mit dem Modul ein, das zu Abhängigkeiten hinzugefügt und neu vernetzt wurde, und geben Sie das Ergebnis aus
MainActivity.kt
///In Tensor umwandeln
val inputTensor = TensorImageUtils.bitmapToFloat32Tensor(
bitmap,
TensorImageUtils.TORCHVISION_NORM_MEAN_RGB, TensorImageUtils.TORCHVISION_NORM_STD_RGB
)
///Folgerung und ihre Folgen
///Vorwärtsausbreitung
val outputTensor = module.forward(IValue.from(inputTensor)).toTensor()
val scores = outputTensor.dataAsFloatArray
Extrahieren Sie die höhere Punktzahl
MainActivity.kt
///Variable zum Speichern der Punktzahl
var maxScore: Float = 0F
var maxScoreIdx = -1
var maxSecondScore: Float = 0F
var maxSecondScoreIdx = -1
///Nehmen Sie die ersten beiden mit der höchsten Punktzahl
for (i in scores.indices) {
if (scores[i] > maxScore) {
maxSecondScore = maxScore
maxSecondScoreIdx = maxScoreIdx
maxScore = scores[i]
maxScoreIdx = i
}
}
Der Name der zu klassifizierenden Klasse
Wird weggelassen, weil es sehr lang ist (es handelt sich um eine Klassifizierung der imageNet 1000-Klasse).
Da es auf github veröffentlicht ist, kopieren Sie bitte den Inhalt von ImageNetClasses.kt
github Klassennamenliste (ImageNetClasses.kt)
ImageNetClasses.kt
class ImageNetClasses {
var IMAGENET_CLASSES = arrayOf(
"tench, Tinca tinca",
"goldfish, Carassius auratus",
//~~~~~~~~~~~~~~Abkürzung(Bitte kopieren Sie von Github)~~~~~~~~~~~~~~~~//
"toilet tissue, toilet paper, bathroom tissue"
)
}
Ruft den aus dem Index abgeleiteten Klassennamen ab Zeigen Sie abschließend das Inferenzergebnis im Layout an
MainActivity.kt
///Ruft den Klassennamen aus dem Index ab
val className = ImageNetClasses().IMAGENET_CLASSES[maxScoreIdx]
val className2 = ImageNetClasses().IMAGENET_CLASSES[maxSecondScoreIdx]
result1Score.text = "score: $maxScore"
result1Class.text = "Klassifizierungsergebnis:$className"
result2Score.text = "score:$maxSecondScore"
result2Class.text = "Klassifizierungsergebnis:$className2"
Erledigt! !! Wenn Sie es erstellen, sollten Sie einen Bildschirm wie am Anfang erhalten. Bitte fügen Sie verschiedene Bilder ein und spielen Sie mit ihnen.
Die Bibliothek ist bequem. Was ist damit eine Bildklassifizierung möglich? Ich hatte das Gefühl, dass die Umstellung auf Tensor ein wenig stecken geblieben ist, aber jetzt kann ich Pytorch für Android verwenden. Abgesehen davon war die Version von pytorch zunächst nicht die neueste und ich bekam einen Fehler beim Laden des Modells und konnte es überhaupt nicht tun, und ich war ziemlich süchtig danach, den Pfad des Assets-Ordners zu finden.
Recommended Posts