Surprisingly, there was a person who read my disorganized article, so there is only one earth → all human beings brothers → a sense of fellowship has sprung up on my own, and I write occasionally. (Also, it's good for me.)
--Use LSTM for model layers --Predict cancer genes (5 types) by multi-class classification using RNA-Seq data (801 x 20531)
--Laptop (general) --Optional: NVIDIA-GPU (1050Ti this time) (AMD is no good, because dl4j relies on CUDA)
--ubuntu 18.04 (Since it is java, OS does not matter in detail) --maven + dl4j related
This data is a part of the data randomly extracted from the dataset acquired by the cancer genome atlas pan-cancer analysis project using a high-performance RNA analyzer called HiSeq from Illumina. For more information, https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3919969/
RNA-Seq is a side-by-side arrangement of how much a gene is likely to be expressed with respect to the reference sequence, quantified for each gene. The reference sequence is a gene pattern that a researcher decides when the gene pattern is known to some extent. Sequence patterns are attracting attention because they may explain the properties and attributes of individuals. In particular, the search for RNA in cancer cells has been carried out for a long time in drug discovery research. According to a person familiar with the matter, even if the human genome is elucidated, there is still a lot of work to be done.
The following cancer types are examined in this study, and the sample data includes five (★) of these.
These five types of cancer are classified.
Create a maven project and create an arbitrary class file.
//loading csv
File dataFile = new File("./TCGA-PANCAN-HiSeq-801x20531/data.csv");
File labelFile = new File("./TCGA-PANCAN-HiSeq-801x20531/labels.csv");
int numClasses = 5; //5 classes
int batchSize = 801; //samples total
//Get the contents from the data body first
RecordReader reader = new CSVRecordReader(1,',');//skip header
try {
reader.initialize(new FileSplit(dataFile));
} catch (IOException | InterruptedException e) {
e.printStackTrace();
}
double[][] dataObj = new double[batchSize][];
int itr = 0;
while(reader.hasNext()) {
List<Writable> row = reader.next();
double scalers[] = new double[row.size()-1];
for(int i = 0; i < row.size()-1; i++) {
if(i == 0) {//skip subject
continue;
}
double scaler = Double.parseDouble(new ConvertToString().map(row.get(i)).toString());
scalers[i] = scaler;
}
dataObj[itr] = scalers;
itr++;
}
System.out.println("Data samples "+ +dataObj.length);//801
//Read label
//Also convert for multi-label
//label
try {
reader = new CSVRecordReader(1,',');//skip header
reader.initialize(new FileSplit(labelFile));
} catch (IOException | InterruptedException e) {
e.printStackTrace();
}
double[][] labels = new double[batchSize][];
itr = 0;
while(reader.hasNext()) {
List<Writable> row = reader.next();
double scalers[] = null;
for(int i = 0; i < row.size(); i++) {
if(i == 0) {//skip subject
continue;
}
// Class
if(i == 1) {
String classname = new ConvertToString().map(row.get(i)).toString();
switch(classname) {
case "BRCA":
scalers = new double[]{1,0,0,0,0};
break;
case "PRAD":
scalers = new double[]{0,1,0,0,0};
break;
case "LUAD":
scalers = new double[]{0,0,1,0,0};
break;
case "KIRC":
scalers = new double[]{0,0,0,1,0};
break;
case "COAD":
scalers = new double[]{0,0,0,0,1};
break;
default:
break;
}
labels[itr] = scalers;
itr++;
}
}
}
System.out.println("LABEL : "+labels.length);//801
//Create a DataSet
INDArray dataArray = Nd4j.create(dataObj,'c');
System.out.println(dataArray.shapeInfoToString());
INDArray labelArray = Nd4j.create(labels,'c');
System.out.println(labelArray.shapeInfoToString());
//Rank: 2,Offset: 0
// Order: c Shape: [801,20531], stride: [20531,1]
//Rank: 2,Offset: 0
// Order: c Shape: [801,5], stride: [5,1]
DataSet dataset = new DataSet(dataArray, labelArray);
SplitTestAndTrain sp = dataset.splitTestAndTrain(600, new Random(42L));//600 train, 201 test
DataSet train = sp.getTrain();
DataSet test = sp.getTest();
System.out.println(train.labelCounts());
System.out.println(test.labelCounts());
//{0=220.0, 1=105.0, 2=104.0, 3=109.0, 4=62.0}
//{0=80.0, 1=31.0, 2=37.0, 3=37.0, 4=16.0}
//MODEL TRAIN AND EVALUATION
int numInput = 20531;
int numOutput = numClasses;
int hiddenNode = 500;//Powerless
int numEpochs = 50;
MultiLayerConfiguration LSTMConf = new NeuralNetConfiguration.Builder()
.seed(123)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
.weightInit(WeightInit.XAVIER)
.updater(new Adam(0.001))
.list()
.layer(0,new LSTM.Builder()
.nIn(numInput)
.nOut(hiddenNode)
.activation(Activation.RELU)
.build())
.layer(1,new LSTM.Builder()
.nIn(hiddenNode)
.nOut(hiddenNode)
.activation(Activation.RELU)
.build())
.layer(2,new LSTM.Builder()
.nIn(hiddenNode)
.nOut(hiddenNode)
.activation(Activation.RELU)
.build())
.layer(3,new RnnOutputLayer.Builder()
.nIn(hiddenNode)
.nOut(numOutput)
.activation(Activation.SOFTMAX)
.lossFunction(LossFunction.MCXENT)//multi class cross entropy
.build())
.pretrain(false)
.backprop(true)
.build();
MultiLayerNetwork model = new MultiLayerNetwork(LSTMConf);
model.init();
System.out.println("TRAIN START...");
for(int i=0;i<numEpochs;i++) {
model.fit(train);
}
System.out.println("EVALUATION START...");
Evaluation eval = new Evaluation(5);
for(DataSet row :test.asList()) {
INDArray testdata = row.getFeatures();
INDArray pred = model.output(testdata);
eval.eval(row.getLabels(), pred);
}
System.out.println(eval.stats());
TRAIN START...
EVALUATION START...
Predictions labeled as 0 classified by model as 0: 80 times
Predictions labeled as 1 classified by model as 1: 31 times
Predictions labeled as 2 classified by model as 0: 3 times
Predictions labeled as 2 classified by model as 2: 34 times
Predictions labeled as 3 classified by model as 2: 1 times
Predictions labeled as 3 classified by model as 3: 36 times
Predictions labeled as 4 classified by model as 2: 4 times
Predictions labeled as 4 classified by model as 4: 12 times
==========================Scores========================================
# of classes: 5
Accuracy: 0.9602
Precision: 0.9671
Recall: 0.9284
F1 Score: 0.9440
Precision, recall & F1: macro-averaged (equally weighted avg. of 5 classes)
========================================================================
Logistic regression and SVM (linear) are used when the features exceed several hundreds, but I thought it might be better to use LSTM. It doesn't take as long to study as I expected. I just tried it, but I'm glad I tried it.
pom.xml
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.vis</groupId>
<artifactId>CancerGenomeTest</artifactId>
<version>0.0.1-SNAPSHOT</version>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<java.version>1.8</java.version>
<nd4j.version>1.0.0-alpha</nd4j.version>
<dl4j.version>1.0.0-alpha</dl4j.version>
<datavec.version>1.0.0-alpha</datavec.version>
<arbiter.version>1.0.0-alpha</arbiter.version>
<logback.version>1.2.3</logback.version>
<dl4j.spark.version>1.0.0-alpha_spark_2</dl4j.spark.version>
</properties>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-spark_2.11</artifactId>
<version>${dl4j.spark.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-zoo</artifactId>
<version>${dl4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-deeplearning4j</artifactId>
<version>${arbiter.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>arbiter-ui_2.11</artifactId>
<version>${arbiter.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-codec</artifactId>
<version>${datavec.version}</version>
</dependency>
<dependency>
<groupId>org.apache.httpcomponents</groupId>
<artifactId>httpclient</artifactId>
<version>4.3.5</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>2.11.0</version>
</dependency>
</dependencies>
</project>