I touched Tribuo published by Oracle. Document Tribuo --Intro classification with Irises

Classification tutorial

This tutorial will show you how to use Fisher's famous iris (iris) dataset to predict iris (iris) species using Tribuo's taxonomy model (now in 2020, but still in the demo in 1936). I'm using the year dataset. Rest assured that next time I'll be using the 90's MNIST). Here we focus on simple logistic regression and investigate the source and metadata of the data that Tribuo stores inside each model.

Setup I need to get a copy of the iris (iris) dataset.

wget https://archive.ics.uci.edu/ml/machine-learning-databases/iris/bezdekIris.data

First, load the required Tribuo jar library. Here, the classification experiment jar and the json interop jar library are used to read and write the proof information.

jars ./tribuo-classification-experiments-4.0.0-jar-with-dependencies.jar
%jars ./tribuo-json-4.0.0-jar-with-dependencies.jar
import java.nio.file.Paths;

Import everything from the basic org.tribuo package, as well as a simple CSV loader and classification package. We also need it because we are trying to build logistic regression.

import org.tribuo.*;
import org.tribuo.evaluation.TrainTestSplitter;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.classification.*;
import org.tribuo.classification.evaluation.*;
import org.tribuo.classification.sgd.linear.LogisticRegressionTrainer;

These imports are for the history system.

import com.fasterxml.jackson.databind.*;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.config.json.*;

Reading data In Tribuo, all prediction types are associated with an OutputFactory implementation that allows you to create the appropriate Output subclass from the input. Here, we are performing multi-class classification, so we will use LabelFactory. Then pass the labelFactory to a simple CSVLoader and load all the columns into the DataSource.

var labelFactory = new LabelFactory();
var csvLoader = new CSVLoader<>(labelFactory);

The copy of the iris (iris) does not have a column header, so create a header and feed it to the load method along with the path and which variables to output ("species" in this case). Since the iris (iris) does not have a predefined training / test split, we will use 70% of the data for training to create the split.

var irisHeaders = new String[]{"sepalLength", "sepalWidth", "petalLength", "petalWidth", "species"};
var irisesSource = csvLoader.loadDataSource(Paths.get("bezdekIris.data"),"species",irisHeaders);
var irisSplitter = new TrainTestSplitter<>(irisesSource,0.7,1L);

Populate the training and test data sources into their respective datasets. These datasets calculate all the required metadata, such as feature areas and output areas. It is best to use MutableDataset for training datasets. Now that you have the dataset, you are ready to train your model.

var trainingDataset = new MutableDataset<>(irisSplitter.getTrain());
var testingDataset = new MutableDataset<>(irisSplitter.getTest());
System.out.println(String.format("Training data size = %d, number of features = %d, number of classes = %d",trainingDataset.size(),trainingDataset.getFeatureMap().size(),trainingDataset.getOutputInfo().size()));
System.out.println(String.format("Testing data size = %d, number of features = %d, number of classes = %d",testingDataset.size(),testingDataset.getFeatureMap().size(),testingDataset.getOutputInfo().size()));
Training data size = 105, number of features = 4, number of classes = 3
Testing data size = 45, number of features = 4, number of classes = 3

Training the model

Now let's create an instance of the trainer and see the default hyperparameters. Fully configurable LinearSGD Trainer can be used directly for full control of these parameters.

Trainer<Label> trainer = new LogisticRegressionTrainer();
System.out.println(trainer.toString());
LinearSGDTrainer(objective=LogMulticlass,optimiser=AdaGrad(initialLearningRate=1.0,epsilon=0.1,initialValue=0.0),epochs=5,minibatchSize=1,seed=12345)

This is a linear model with logistic loss, trained with 5 epochs using AdaGrad.

Now let's train the model. As with any package, training is very easy with training algorithms and training data.

Model<Label> irisModel = trainer.train(trainingDataset);

Model evaluation Once you've trained your model, you need to evaluate how well it's trained. To do this, ask labelFactory (or directly instantiate) what the appropriate evaluator is and pass the model and test dataset to the evaluator. You can also pass a data source instead of dataest. The LabelEvaluator class implements all common classification metrics, each of which can be inspected individually. LabelEvaluator.toString () produces a nicely formatted summary of the metrics.

var evaluator = new LabelEvaluator();
var evaluation = evaluator.evaluate(irisModel,testingDataset);
System.out.println(evaluation.toString());
Class                           n          tp          fn          fp      recall        prec       f1
Iris-versicolor                16          16           0           1       1.000       0.941       0.970
Iris-virginica                 15          14           1           0       0.933       1.000       0.966
Iris-setosa                    14          14           0           0       1.000       1.000       1.000
Total                          45          44           1           1
Accuracy                                                                    0.978
Micro Average                                                               0.978       0.978       0.978
Macro Average                                                               0.978       0.980       0.978
Balanced Error Rate                                                         0.022

precision, recall, and F1 are standard indicators used when evaluating multiclass classifiers.

You can also display the confusion matrix.

System.out.println(evaluation.getConfusionMatrix().toString());
                   Iris-versicolor   Iris-virginica      Iris-setosa
Iris-versicolor                 16                0                0
Iris-virginica                   1               14                0
Iris-setosa    

Model metadata

Tribuo keeps track of the feature and output areas of every model built. This allows you to perform LIME-like techniques without accessing the original training data, or add a check to see if a particular input is within the training model.

Let's take a look at the feature areas of the Irises model.

var featureMap = irisModel.getFeatureIDMap();
for (var v : featureMap) {
    System.out.println(v.toString());
    System.out.println();
}
CategoricalFeature(name=petalLength,id=0,count=105,map={1.2=1, 6.9=1, 3.6=1, 3.0=1, 1.7=4, 4.9=4, 4.4=3, 3.5=2, 5.9=2, 5.4=1, 4.0=4, 1.4=12, 4.5=4, 5.0=2, 5.5=3, 6.7=2, 3.7=1, 1.9=1, 6.0=2, 5.2=1, 5.7=2, 4.2=2, 4.7=2, 4.8=4, 1.6=4, 5.8=2, 3.8=1, 6.3=1, 3.3=1, 1.0=1, 5.6=4, 5.1=5, 4.6=3, 4.1=2, 1.5=9, 1.3=4, 3.9=3, 6.6=1, 6.1=2})

CategoricalFeature(name=petalWidth,id=1,count=105,map={2.0=3, 0.5=1, 1.2=3, 0.3=6, 1.6=2, 0.1=3, 0.4=5, 2.5=3, 2.3=4, 1.7=2, 1.1=3, 2.1=4, 0.6=1, 1.4=6, 1.0=5, 2.4=1, 1.8=12, 0.2=20, 1.9=4, 1.5=7, 1.3=8, 2.2=2})

CategoricalFeature(name=sepalLength,id=2,count=105,map={6.9=3, 6.4=3, 7.4=1, 4.9=4, 4.4=1, 5.9=3, 5.4=5, 7.2=3, 7.7=3, 5.0=8, 6.2=2, 5.5=5, 6.7=7, 6.0=3, 5.2=2, 6.5=3, 5.7=4, 4.7=2, 4.8=3, 5.8=4, 5.3=1, 6.8=3, 6.3=5, 7.3=1, 5.6=6, 5.1=7, 4.6=4, 7.6=1, 7.1=1, 6.6=2, 6.1=5})

CategoricalFeature(name=sepalWidth,id=3,count=105,map={2.0=1, 2.8=10, 3.6=4, 2.3=3, 2.5=5, 3.1=8, 3.8=4, 3.0=19, 2.6=4, 4.4=1, 3.3=4, 3.5=4, 2.4=2, 3.2=10, 2.9=5, 3.7=3, 3.4=6, 2.2=2, 3.9=2, 4.2=1, 2.7=7})

You can see a histogram of the four features and their values. This information can be used to sample from each feature, build candidate examples of local explanatory variables like LIME, and see the extent. The feature information is frozen during model training, so if the feature set is sparse (as is often the case with NLP problems), it can also be used to see how many features have occurred during the training set. ..

Model certificate

Many different types of ML models have been deployed in modern applications to support various aspects of the application. However, most ML packages do not support model tracking and rebuilding. In Tribuo, each model tracks its performance. You can see how it was created, when it was created, and what data is involved. Here, let's take a look at the actual data of the iris model. By default, Tribuo displays the certificate in a reasonable human-readable format by using the toString () method of each certificate object. All information is accessible programmatically.

var provenance = irisModel.getProvenance();
System.out.println(ProvenanceUtil.formattedProvenanceString(provenance.getDatasetProvenance().getSourceProvenance()));
TrainTestSplitter(
	class-name = org.tribuo.evaluation.TrainTestSplitter
	source = CSVLoader(
			class-name = org.tribuo.data.csv.CSVLoader
			outputFactory = LabelFactory(
					class-name = org.tribuo.classification.LabelFactory
				)
			response-name = species
			separator = ,
			quote = "
			path = file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data
			file-modified-time = 1999-12-14T15:12:39-05:00
			resource-hash = 0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC
		)
	train-proportion = 0.7
	seed = 1
	size = 150
	is-train = true
)

You can see that the model is being trained on a data source that is split in two, using a specific random seed and split ratio. The original data source is a CSV file, which also records the modification time of the file and the SHA-256 hash.

Similarly, you can find out the training algorithm by looking at the source of the trainer.

Here, as expected, we can see that our model is trained using the LogisticRegressionTrainer with AdaGrad as the gradient descent algorithm.

If you want to keep another record, you can extract the achievement from the model and save it as a json file (or you can undo the achievement from the deployed model).

ObjectMapper objMapper = new ObjectMapper();
objMapper.registerModule(new JsonProvenanceModule());
objMapper = objMapper.enable(SerializationFeature.INDENT_OUTPUT);

Although json's track record is verbose, it offers another human-readable serialization format.

System.out.println(jsonProvenance);
[ {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "linearsgdmodel-0",
  "object-class-name" : "org.tribuo.classification.sgd.linear.LinearSGDModel",
  "provenance-class" : "org.tribuo.provenance.ModelProvenance",
  "map" : {
    "instance-values" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.MapMarshalledProvenance",
      "map" : { }
    },
    "tribuo-version" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "tribuo-version",
      "value" : "4.0.1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "trainer" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "trainer",
      "value" : "logisticregressiontrainer-2",
      "provenance-class" : "org.tribuo.provenance.impl.TrainerProvenanceImpl",
      "additional" : "",
      "is-reference" : true
    },
    "trained-at" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "trained-at",
      "value" : "2020-08-31T20:24:37.854775-04:00",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "dataset" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "dataset",
      "value" : "mutabledataset-1",
      "provenance-class" : "org.tribuo.provenance.DatasetProvenance",
      "additional" : "",
      "is-reference" : true
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.classification.sgd.linear.LinearSGDModel",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "mutabledataset-1",
  "object-class-name" : "org.tribuo.MutableDataset",
  "provenance-class" : "org.tribuo.provenance.DatasetProvenance",
  "map" : {
    "num-features" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "num-features",
      "value" : "4",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "num-examples" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "num-examples",
      "value" : "105",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "num-outputs" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "num-outputs",
      "value" : "3",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "tribuo-version" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "tribuo-version",
      "value" : "4.0.1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "datasource" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "datasource",
      "value" : "traintestsplitter-3",
      "provenance-class" : "org.tribuo.evaluation.TrainTestSplitter$SplitDataSourceProvenance",
      "additional" : "",
      "is-reference" : true
    },
    "transformations" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ListMarshalledProvenance",
      "list" : [ ]
    },
    "is-sequence" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "is-sequence",
      "value" : "false",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "is-dense" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "is-dense",
      "value" : "false",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.MutableDataset",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "logisticregressiontrainer-2",
  "object-class-name" : "org.tribuo.classification.sgd.linear.LogisticRegressionTrainer",
  "provenance-class" : "org.tribuo.provenance.impl.TrainerProvenanceImpl",
  "map" : {
    "seed" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "seed",
      "value" : "12345",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "minibatchSize" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "minibatchSize",
      "value" : "1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "train-invocation-count" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "train-invocation-count",
      "value" : "0",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "is-sequence" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "is-sequence",
      "value" : "false",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "shuffle" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "shuffle",
      "value" : "true",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "epochs" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "epochs",
      "value" : "5",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "optimiser" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "optimiser",
      "value" : "adagrad-4",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
      "additional" : "",
      "is-reference" : true
    },
    "host-short-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "host-short-name",
      "value" : "Trainer",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.classification.sgd.linear.LogisticRegressionTrainer",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "objective" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "objective",
      "value" : "logmulticlass-5",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
      "additional" : "",
      "is-reference" : true
    },
    "loggingInterval" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "loggingInterval",
      "value" : "1000",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "traintestsplitter-3",
  "object-class-name" : "org.tribuo.evaluation.TrainTestSplitter",
  "provenance-class" : "org.tribuo.evaluation.TrainTestSplitter$SplitDataSourceProvenance",
  "map" : {
    "train-proportion" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "train-proportion",
      "value" : "0.7",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "seed" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "seed",
      "value" : "1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.LongProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "size" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "size",
      "value" : "150",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.IntProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "source" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "source",
      "value" : "csvloader-6",
      "provenance-class" : "org.tribuo.data.csv.CSVLoader$CSVLoaderProvenance",
      "additional" : "",
      "is-reference" : true
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.evaluation.TrainTestSplitter",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "is-train" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "is-train",
      "value" : "true",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.BooleanProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "adagrad-4",
  "object-class-name" : "org.tribuo.math.optimisers.AdaGrad",
  "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
  "map" : {
    "epsilon" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "epsilon",
      "value" : "0.1",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "initialLearningRate" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "initialLearningRate",
      "value" : "1.0",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "initialValue" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "initialValue",
      "value" : "0.0",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DoubleProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "host-short-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "host-short-name",
      "value" : "StochasticGradientOptimiser",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.math.optimisers.AdaGrad",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "logmulticlass-5",
  "object-class-name" : "org.tribuo.classification.sgd.objectives.LogMulticlass",
  "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.impl.ConfiguredObjectProvenanceImpl",
  "map" : {
    "host-short-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "host-short-name",
      "value" : "LabelObjective",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.classification.sgd.objectives.LogMulticlass",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "csvloader-6",
  "object-class-name" : "org.tribuo.data.csv.CSVLoader",
  "provenance-class" : "org.tribuo.data.csv.CSVLoader$CSVLoaderProvenance",
  "map" : {
    "resource-hash" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "resource-hash",
      "value" : "0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.HashProvenance",
      "additional" : "SHA256",
      "is-reference" : false
    },
    "path" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "path",
      "value" : "file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.URLProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "file-modified-time" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "file-modified-time",
      "value" : "1999-12-14T15:12:39-05:00",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.DateTimeProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "quote" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "quote",
      "value" : "\"",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.CharProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "response-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "response-name",
      "value" : "species",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "outputFactory" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "outputFactory",
      "value" : "labelfactory-7",
      "provenance-class" : "org.tribuo.classification.LabelFactory$LabelFactoryProvenance",
      "additional" : "",
      "is-reference" : true
    },
    "separator" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "separator",
      "value" : ",",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.CharProvenance",
      "additional" : "",
      "is-reference" : false
    },
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.data.csv.CSVLoader",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
}, {
  "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.ObjectMarshalledProvenance",
  "object-name" : "labelfactory-7",
  "object-class-name" : "org.tribuo.classification.LabelFactory",
  "provenance-class" : "org.tribuo.classification.LabelFactory$LabelFactoryProvenance",
  "map" : {
    "class-name" : {
      "marshalled-class" : "com.oracle.labs.mlrg.olcut.provenance.io.SimpleMarshalledProvenance",
      "key" : "class-name",
      "value" : "org.tribuo.classification.LabelFactory",
      "provenance-class" : "com.oracle.labs.mlrg.olcut.provenance.primitives.StringProvenance",
      "additional" : "",
      "is-reference" : false
    }
  }
} ]

Alternatively, the model certificate is also present in the output of Model.toString (), but this format is not machine readable.

linear-sgd-model - Model(class-name=org.tribuo.classification.sgd.linear.LinearSGDModel,dataset=Dataset(class-name=org.tribuo.MutableDataset,datasource=SplitDataSourceProvenance(className=org.tribuo.evaluation.TrainTestSplitter,innerSourceProvenance=CSV(class-name=org.tribuo.data.csv.CSVLoader,outputFactory=OutputFactory(class-name=org.tribuo.classification.LabelFactory),response-name=species,separator=,,quote=",path=file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data,file-modified-time=1999-12-14T15:12:39-05:00,resource-hash=SHA-256[0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC]),trainProportion=0.7,seed=1,size=150,isTrain=true),transformations=[],is-sequence=false,is-dense=false,num-examples=105,num-features=4,num-outputs=3,tribuo-version=4.0.1),trainer=Trainer(class-name=org.tribuo.classification.sgd.linear.LogisticRegressionTrainer,seed=12345,minibatchSize=1,shuffle=true,epochs=5,optimiser=StochasticGradientOptimiser(class-name=org.tribuo.math.optimisers.AdaGrad,epsilon=0.1,initialLearningRate=1.0,initialValue=0.0,host-short-name=StochasticGradientOptimiser),objective=LabelObjective(class-name=org.tribuo.classification.sgd.objectives.LogMulticlass,host-short-name=LabelObjective),loggingInterval=1000,train-invocation-count=0,is-sequence=false,host-short-name=Trainer),trained-at=2020-08-31T20:24:37.854775-04:00,instance-values={},tribuo-version=4.0.1)

In the evaluation, there is a track record of recording the model performance as well as the test data performance. I'm using another format of JSON achievements. However, this is a bit less accurate. Instead, it's easier to read. This format is good for references, but it's all converted to a string and can't be used to rebuild the original achievement object.

String jsonEvaluationProvenance = objMapper.writeValueAsString(ProvenanceUtil.convertToMap(evaluation.getProvenance()));
System.out.println(jsonEvaluationProvenance);
{
  "tribuo-version" : "4.0.1",
  "dataset-provenance" : {
    "num-features" : "4",
    "num-examples" : "45",
    "num-outputs" : "3",
    "tribuo-version" : "4.0.1",
    "datasource" : {
      "train-proportion" : "0.7",
      "seed" : "1",
      "size" : "150",
      "source" : {
        "resource-hash" : "0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC",
        "path" : "file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data",
        "file-modified-time" : "1999-12-14T15:12:39-05:00",
        "quote" : "\"",
        "response-name" : "species",
        "outputFactory" : {
          "class-name" : "org.tribuo.classification.LabelFactory"
        },
        "separator" : ",",
        "class-name" : "org.tribuo.data.csv.CSVLoader"
      },
      "class-name" : "org.tribuo.evaluation.TrainTestSplitter",
      "is-train" : "false"
    },
    "transformations" : [ ],
    "is-sequence" : "false",
    "is-dense" : "false",
    "class-name" : "org.tribuo.MutableDataset"
  },
  "class-name" : "org.tribuo.provenance.EvaluationProvenance",
  "model-provenance" : {
    "instance-values" : { },
    "tribuo-version" : "4.0.1",
    "trainer" : {
      "seed" : "12345",
      "minibatchSize" : "1",
      "train-invocation-count" : "0",
      "is-sequence" : "false",
      "shuffle" : "true",
      "epochs" : "5",
      "optimiser" : {
        "epsilon" : "0.1",
        "initialLearningRate" : "1.0",
        "initialValue" : "0.0",
        "host-short-name" : "StochasticGradientOptimiser",
        "class-name" : "org.tribuo.math.optimisers.AdaGrad"
      },
      "host-short-name" : "Trainer",
      "class-name" : "org.tribuo.classification.sgd.linear.LogisticRegressionTrainer",
      "objective" : {
        "host-short-name" : "LabelObjective",
        "class-name" : "org.tribuo.classification.sgd.objectives.LogMulticlass"
      },
      "loggingInterval" : "1000"
    },
    "trained-at" : "2020-08-31T20:24:37.854775-04:00",
    "dataset" : {
      "num-features" : "4",
      "num-examples" : "105",
      "num-outputs" : "3",
      "tribuo-version" : "4.0.1",
      "datasource" : {
        "train-proportion" : "0.7",
        "seed" : "1",
        "size" : "150",
        "source" : {
          "resource-hash" : "0FED2A99DB77EC533A62DC66894D3EC6DF3B58B6A8F3CF4A6B47E4086B7F97DC",
          "path" : "file:/Users/apocock/Development/Tribuo/tutorials/bezdekIris.data",
          "file-modified-time" : "1999-12-14T15:12:39-05:00",
          "quote" : "\"",
          "response-name" : "species",
          "outputFactory" : {
            "class-name" : "org.tribuo.classification.LabelFactory"
          },
          "separator" : ",",
          "class-name" : "org.tribuo.data.csv.CSVLoader"
        },
        "class-name" : "org.tribuo.evaluation.TrainTestSplitter",
        "is-train" : "true"
      },
      "transformations" : [ ],
      "is-sequence" : "false",
      "is-dense" : "false",
      "class-name" : "org.tribuo.MutableDataset"
    },
    "class-name" : "org.tribuo.classification.sgd.linear.LinearSGDModel"
  }
}

You can see that this performance information includes all the fields contained in the performance information of the model, as well as test data, split data, and CSV.

This track record is useful for tracking models by itself, but when combined with the configuration system described in the configuration tutorial, it provides a powerful way to reconstruct models and experiments, and what ML models But you can achieve almost perfect reproducibility.

Conclusion

We looked at Tribuo's csv loading mechanism, how to train a simple classifier, how to evaluate a classifier on test data, and the metadata and performance information stored within Tribuo's model and evaluation objects. ..

Recommended Posts

I touched Tribuo published by Oracle. Document Tribuo --Intro classification with Irises
I touched Tribuo published by Oracle. Document Tribuo --A Java prediction library (v4.0)
I tried Tribuo published by Oracle. Tribuo --A Java prediction library (v4.0)
I took a second look at Tribuo published by Oracle. Tribuo --A Java prediction library (v4.0)