A chi-square test often used in statistics. This is a null hypothesis that tests whether the sample groups are related to each other. I thought it would be fun if I could program this. Refer to the following for the calculation method of the chi-square test.
http://kogolab.chillout.jp/elearn/hamburger/chap3/sec0.html
First from Python
kai.py
import numpy as np
def chi_squared(array):
#Get vertical and horizontal total value
low_sums = np.array(array).sum(axis=1)
col_sums = np.array(array).sum(axis=0)
#Get the grand total
the_total = sum(low_sums)
#Get the expected frequency
ex_freq = []
for i in low_sums:
for j in col_sums:
ex_freq.append(i * j/the_total)
pass
ex_freq = np.array(np.array_split(ex_freq, len(array)))
diff = np.array(array) - ex_freq
#To support multiple rows and columns
ex_freq_flt = ex_freq.flatten()
diff_flt = diff.flatten()
return sum(diff_flt ** 2 / ex_freq_flt)
pass
def d_f(array):
s = list(np.array(array).shape)
d = 1
for i in s:
d *= (i - 1)
return d
waku_mogu = [[435, 165],
[265, 135]]
print("The chi square is", chi_squared(waku_mogu))
print("The degree of freedom is", d_f(waku_mogu))
"""
Execution result
Kai squared is 4.464285714285714
The degree of freedom is 1
"""
The sum of the columns, Numpy, is a one-shot. it's great.
Next, a program that works the same in Java
Kai.java
import java.util.ArrayList;
import java.util.List;
import java.util.Iterator;
public class Kai {
public static void main(String[] args) {
Calc c = new Calc();
double[][] waku_mogu = {{435, 165}, {265, 135}};
System.out.println("The chi square is" + c.chi_squared(waku_mogu));
System.out.println("The degree of freedom is" + c.d_f(waku_mogu));
}
}
class Calc {
public double chi_squared(double[][] arr) {
List<Double> low_sums = new ArrayList<>();
List<Double> col_sums = new ArrayList<>();
List<Double> ex_freq = new ArrayList<>();
List<Double> diff = new ArrayList<>();
//Find the total value for each row.
for (int i = 0; i < arr.length; i++) {
double total_l = 0;
for (int j = 0; j < arr[i].length; j++) {
total_l += arr[i][j];
}
low_sums.add(total_l);
}
//Find the total value for each column. This was the hardest point...Numpy is one line
for (int j = 0; j < arr[0].length; j++) {
double total_c = 0;
for (int i = 0; i < arr.length; i++) {
total_c += arr[i][j];
}
col_sums.add(total_c);
}
double the_total = 0;
Iterator<Double> iterator = low_sums.iterator();
while (iterator.hasNext()) {
double i = iterator.next();
the_total += i;
}
iterator = low_sums.iterator();
while (iterator.hasNext()) {
double i = iterator.next();
Iterator<Double> iterator2 = col_sums.iterator();
while (iterator2.hasNext()) {
double j = iterator2.next();
ex_freq.add(i * j / the_total);
}
}
//The second most difficult point to be able to handle multiple rows and columns
int count = 0;
for (int i = 0; i < arr.length; i++) {
for (int j = 0; j < arr[i].length; j++) {
diff.add(arr[i][j] - ex_freq.get(count));
count++;
}
}
double chi_val = 0;
for (int i = 0; i < ex_freq.size(); i++) {
chi_val += Math.pow(diff.get(i), 2) / ex_freq.get(i);
}
return chi_val;
}
public int d_f(double[][] arr) {
return (arr.length - 1) * (arr[0].length - 1);
}
}
/*
Execution result
Kai squared is 4.464285714285714
The degree of freedom is 1
*/
I had a lot of trouble getting the sum of the lines.
Of course I'm not good at it, but Python is much cleaner to write. But I also like the feeling that Java is being assembled while checking the safety by myself.
Recommended Posts