[PYTHON] [PyTorch] Laufzeitfehler: Erwartetes Objekt vom Skalartyp Float, aber Skalartyp Double für Argument # 4'mat1 '

Fehler in PyTorch

Ich werde es denen überlassen, die am selben Ort festsitzen. Bei der Verwendung von PyTorch wurde der folgende Fehler angezeigt.

RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #4 'mat1'

Als Ergebnis verschiedener Untersuchungen scheint das Problem darin zu bestehen, dass bei der Konvertierung in den Tensortyp die Zahl im Tensor zum Typ fackel.double wird. (Es gibt viele Methoden in der PyTorch-Klasse, die auf dem Typ ** torch.float ** basieren.)

Damit

Vor der Korrektur

X_train = torch.from_numpy(X_train)
y_train = torch.from_numpy(y_train)
X_test = torch.from_numpy(X_test)
y_test = torch.from_numpy(y_test)

Überarbeitet

X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train).long()
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).long() 

Es scheint, dass Sie es mit .float () oder .long () wie folgt konvertieren sollten. (.long () ist eine Konvertierung in ein Label)

Verweise: 2. PyTorch Tensol & Datentyp Cheet Sheet

Recommended Posts

[PyTorch] Laufzeitfehler: Erwartetes Objekt vom Skalartyp Float, aber Skalartyp Double für Argument # 4'mat1 '
Hinweis zur Unterstützung von Python-Fehlern: "... unterstützt kein Argument 0 vom Typ float ..."