Lately I've been implementing some machine learning using the Naive Bayes algorithm in Java. I wanted a side project to get under the hood with Java and this has also co-incided with courses in probability theory I've been taking. There are also a lot of Python tutorials online for writing this type of classifier but not a lot in the Java realm. While Java isn't used a lot in analytics compared to Python and R, it is used a lot in data engineering pipelines so hopefully this might be useful contribution.
This classifier is based around Bayes Theorem which underpins a lot of conditional statistics and describes the probability of an event, based on prior knowledge of conditions that might be related to the event. Where X and Y are events and P(B) ≠ 0, the likelihood of event X occurring given Y has occurred is as follows.
P(X|Y)=P(Y|X)P(X)P(Y)
Probability theory is fundamental to a lot of machine learning but that isn't the focus of the tutorial. What is important is how does this work as a machine learning classifier? The equation that we'll be using within the code is the normal distribution approximation (or event model) to the Naive Bayes classifier. This uses the probability density function described below. Probability density functions are used for numerical data and describe the probability of a random variable occurring within a given distribution. This requires that we know the standard deviation and mean of each variable in the training set in order to model that distribution. These are easily found and we will be calculating these for the vectors associated with each class labels (0 and 1 in this case) we can predict the probability of a random variable as belonging to one or the other class.
pdf=1σ21√2π e−(x−μ)22σ2
The thing that can be difficult to grasp initially is how the decision is made for the individual classes and one that I've often found explained poorly from a beginners perspective. Even if we have an equation that determines the probability of one variable occurring within a certain distribution, how can each vector be assigned a class label. Assuming our class labels are Y and the variables in our vector are X1...Xn, then the naive bayes algorithm can be simplified to assume that each of the vectors are independent from each other which just means there is no influence between them that can determine them belonging to either class label.
P(X1...Xn|Y)=n∏i=1P(Xi|Y)
This equation implies that the product of all individual probabilities in the vector will give us a overall probability of all the variables being as they are for a specific class label. As we will calculate the individual probability of each variable in the vector using the normal distribution approximation above, it's a matter or multiplying these for the mean and standard deviation associated with each class label. Once we have these values, the one with the highest probability is chosen as the predicted class.
From a programming point of view, it's easier to see how these pieces fit together, so the diagram below does some of that work in showing how the statistics work on the flow of the data.
The dataset I'll be using for testing is an implementation of Leo Breiman's twonorm example twonorm example which is a vector of 20 random variables from two normally distributed models that lie inside each other. This is a synthetic data set, so it will return a very high accuracy, however it works for the purpose of testing the classifier.
The program itself is split into four classes in addition to the public static void main class that is the standard main() class that Java executes at run time. Here is a list of the packages that I'll be using in the code which you may need to path if your IDE doesn't import them automatically. Thanks IntelliJ :)
The first class CSVReader contains a single method parseCSV that is used to import the data into an array of arrays. An ArrayList<ArrayList> to be exact. Not to be confused with regular arrays, ArrayLists are more flexible than standard Arrays since they are mutable. This data structure will be the basis of the program and mimicks a 2D matrix.
Java imports data via a BufferedReader and we define that as br and start reading in each line. Each row of the data contains a vector of numbers that are the numerical values for the variables (X1 ... Xn). The last entry is either a 0 or 1 and is the class label. This is the category that this vector belongs to.
Our CSV file consists of vectors on each line that include a prediction in the final column. The length of the vector determines the number of columns in our matrix and this is defined as Integer len. We cast the string as a double since we can't be certain that our data consists purely of integers. We then loop through this with a for loop and enter this into an ArrayList<Double> before dumping these into our final matrix.
Some of the math involved
This classifier is based around Bayes Theorem which underpins a lot of conditional statistics and describes the probability of an event, based on prior knowledge of conditions that might be related to the event. Where X and Y are events and P(B) ≠ 0, the likelihood of event X occurring given Y has occurred is as follows.
P(X|Y)=P(Y|X)P(X)P(Y)
Probability theory is fundamental to a lot of machine learning but that isn't the focus of the tutorial. What is important is how does this work as a machine learning classifier? The equation that we'll be using within the code is the normal distribution approximation (or event model) to the Naive Bayes classifier. This uses the probability density function described below. Probability density functions are used for numerical data and describe the probability of a random variable occurring within a given distribution. This requires that we know the standard deviation and mean of each variable in the training set in order to model that distribution. These are easily found and we will be calculating these for the vectors associated with each class labels (0 and 1 in this case) we can predict the probability of a random variable as belonging to one or the other class.
pdf=1σ21√2π e−(x−μ)22σ2
The thing that can be difficult to grasp initially is how the decision is made for the individual classes and one that I've often found explained poorly from a beginners perspective. Even if we have an equation that determines the probability of one variable occurring within a certain distribution, how can each vector be assigned a class label. Assuming our class labels are Y and the variables in our vector are X1...Xn, then the naive bayes algorithm can be simplified to assume that each of the vectors are independent from each other which just means there is no influence between them that can determine them belonging to either class label.
P(X1...Xn|Y)=n∏i=1P(Xi|Y)
This equation implies that the product of all individual probabilities in the vector will give us a overall probability of all the variables being as they are for a specific class label. As we will calculate the individual probability of each variable in the vector using the normal distribution approximation above, it's a matter or multiplying these for the mean and standard deviation associated with each class label. Once we have these values, the one with the highest probability is chosen as the predicted class.
From a programming point of view, it's easier to see how these pieces fit together, so the diagram below does some of that work in showing how the statistics work on the flow of the data.
Programming and code examples
The dataset I'll be using for testing is an implementation of Leo Breiman's twonorm example twonorm example which is a vector of 20 random variables from two normally distributed models that lie inside each other. This is a synthetic data set, so it will return a very high accuracy, however it works for the purpose of testing the classifier.
The program itself is split into four classes in addition to the public static void main class that is the standard main() class that Java executes at run time. Here is a list of the packages that I'll be using in the code which you may need to path if your IDE doesn't import them automatically. Thanks IntelliJ :)
1 2 3 4 5 6 | import org.apache.commons.csv.CSVFormat; import org.apache.commons.csv.CSVRecord; import java.io.*; import java.util.*; import java.util.List; import static oracle.jrockit.jfr.events.Bits.intValue; |
CSVReader
The first class CSVReader contains a single method parseCSV that is used to import the data into an array of arrays. An ArrayList<ArrayList> to be exact. Not to be confused with regular arrays, ArrayLists are more flexible than standard Arrays since they are mutable. This data structure will be the basis of the program and mimicks a 2D matrix.
Java imports data via a BufferedReader and we define that as br and start reading in each line. Each row of the data contains a vector of numbers that are the numerical values for the variables (X1 ... Xn). The last entry is either a 0 or 1 and is the class label. This is the category that this vector belongs to.
Our CSV file consists of vectors on each line that include a prediction in the final column. The length of the vector determines the number of columns in our matrix and this is defined as Integer len. We cast the string as a double since we can't be certain that our data consists purely of integers. We then loop through this with a for loop and enter this into an ArrayList<Double> before dumping these into our final matrix.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | class CSVReader { ArrayList<arraylist> parseCSV(String pathto) { try { // Establishing the path Reader in = new FileReader(pathto); Iterable<csvrecord> records = CSVFormat.EXCEL.parse(in); // Create an instance and read in the lines BufferedReader br = new BufferedReader( new FileReader(pathto)); String line = br.readLine(); // Determine the number of variables by splitting at the commas Integer len = line.split( "," ).length; // Defining our matrix ArrayList<arraylist> colMatrix = new ArrayList<>(); // For loop that iterates over each line in the CSV for (CSVRecord record : records) { ArrayList< double > tempVector = new ArrayList<>(); // Now we use the length to add each variable to the vector for ( int i = 0 ; i < len; i++){ Double rec = Double.parseDouble(record.get(i)); tempVector.add(rec); } // And add this vector as a row to the matrix colMatrix.add(tempVector); } return colMatrix; } catch (IOException e){ e.printStackTrace(); } return null ; } } |
DataProcess
Out next class is DataProcess where we will separate out the data we intend to use for our test and training sets. Our main method splitSet takes the matrix we created and a split ratio between 0 and 1.
A random number is generated based on the size of our data and we use this as an index to retrieve the random vector for inclusion in our training set. We remove this vector from the test set and add it to our training set. Finally both the sets are returned via an array .
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 | class DataProcess { ArrayList<arraylist> splitSet(ArrayList<arraylist> dataset, double splitRatio) { // Getting the dataset size and trainSet size int dataSize = dataset.size(); double trainSize = dataSize * splitRatio; // Training and test set variables ArrayList<arraylist> trainSet = new ArrayList<>(); ArrayList<arraylist> testSet = dataset; Random rn = new Random(dataSize); while (trainSet.size() < trainSize) { // Create a new index Integer rand = rn.nextInt(testSet.size()); // Retrieve vector that corresponds to ArrayList<Double> tempVector = testSet.get(rand); // Switch the vector from one set to another trainSet.add(tempVector); testSet.remove(intValue(rand)); } // Returning sets in an arrayList ArrayList<arraylist> splitsets = new ArrayList<>(Arrays.asList(trainSet, testSet)); return splitsets; } // Method for retrieving a single column from the matrix ArrayList< double > getCol(ArrayList<arraylist> dataset, int col) { ArrayList< double > colArray = new ArrayList< double >(); for ( int i = 0 ; i < dataset.size(); i++) { ArrayList row = dataset.get(i); Double column = (Double) row.get(col); colArray.add(column); } return colArray; } } |
Statistics
Next we need to generate some statistics for our sets and we do this within the statistics class. We will use this class to produce mean and standard deviations for each of the variables. One for the positive class results and one for the negative results.
We begin by sending the trainSet to the classStats method, after which we use the getCol mthod in a for loop to retrieve each column (variable). These are then separated into the positive and negative class results by checking if the idCol (the last column in the matrix) matches 0 or not. Although this is a binary classifier, it would be possible to split out sepearate classification groups by adding more conditional controls here. Once these class separations are performed, each variable is sent to the individualStats method which calls the meanList and stdDev methods.
The method classStats eventually returns a hashmap with 0:{n x [mean, stdDev]}, 1:{n x [mean, stdDev]} where n is the number of variables.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | class Statistics { ArrayList< double > individualStats(ArrayList< double > variable) { // Calculate mean and standard deviation double seMean = meanList(variable); double seSD = stdDev(variable); ArrayList< double > eachStats = new ArrayList<>(Arrays.asList(seMean, seSD)); return eachStats; } HashMap classStats(ArrayList<arraylist> trainset) { // Instance of class DataProcess process = new DataProcess(); // Hashmap used to store summaries HashMap<Integer, ArrayList<arraylist>> summaries = new HashMap<>(); int len = trainset.get( 1 ).size(); //?? // Arrays used to store the 0 and 1 class statistics ArrayList<arraylist> ready0 = new ArrayList<>(); ArrayList<arraylist> ready1 = new ArrayList<>(); // Iterates across columns for ( int i = 0 ; i < (len - 1 ); i++) { // Gets last column that is the class identifier List<Double> idCol = process.getCol(trainset, (len - 1 )); // For each vector iterate across variables List< double > valCol = process.getCol(trainset, i); // Lists for the two classes ArrayList< double > list1 = new ArrayList<>(); ArrayList< double > list0 = new ArrayList<>(); // Loop to separate into two classes for ( int j = 0 ; j < idCol.size(); j++) { // Splits out vectors based on their class ID if (idCol.get(j) == 0 ) { list0.add(valCol.get(j)); } else { list1.add(valCol.get(j)); } } // Creates mean and SD for each variable after being split into classes ready0.add(individualStats(list0)); ready1.add(individualStats(list1)); } // Stores these in the hashmap for return summaries.put( 0 , ready0); summaries.put( 1 , ready1); return summaries; } // Used to sum a column double sumList(ArrayList<Double> a) { double sum = 0 ; for ( int i = 0 ; i < a.size(); i++) { double in = a.get(i); sum = sum + in; } return sum; } // Takes the sum and returns the mean double meanList(ArrayList<Double> a) { double mean = sumList(a) / a.size(); return mean; } // Takes the mean and calculates the standard deviation double stdDev(ArrayList a) { double mean = meanList(a); double num = 0 ; for ( int i = 0 ; i < a.size(); i++) { double listVal = ( double ) a.get(i); num = num + Math.pow((listVal - mean), 2 ); } num = Math.sqrt(num / (a.size() - 1 )); return num; } } |
Predictions
The final class to be implemented is the Predictions class which will be used to calculate probability and decide which class label each vector is predicted to be.
The summaries and testSet for each vector are sent to the goPredict method. It's now time to calculate the predictions for the remaining data. Each vector from the testSet is sent to the decidePredict method and added to the finalPredictions ArrayList.
The decidePredict method obtains the combined probability from the classProbability method for each of the class labels which determines the probability of each variable in the vector by sending the summary statistics of each class label (0 and 1) and the vector variable to the densityFunc method. This calculates the probability using the probability density function for the normal distribution which is outlined in the math section above. The highest of these two values is selected and returned by the decidePredict method.
The last method accuracy is used to calculate what percentage of the class predictions were correct. We pass the testSet and the resulting class predictions to the method and count the number of successes by comparing the prediction to the last column in the testSet. A percentage is then calculated and returned.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | class Predictions { Double densityFunc( double x, double mean, double stdDev) { // Probability density function used for determining probability double expTerm = Math.exp(-(Math.pow(x - mean, 2.0 ) / ( 2 * Math.pow(stdDev, 2.0 )))); return ( 1 / (Math.sqrt(( 2 * Math.PI)) * stdDev)) * expTerm; } HashMap<Integer, Double> classProbability(HashMap<Integer, ArrayList<arraylist>> summaries, ArrayList vector) { HashMap<Integer, Double> probability = new HashMap<>(); // For each class split in the summaries, in this case 0 and 1 for (Map.Entry<Integer, ArrayList<arraylist>> entry : summaries.entrySet()) { // Assign the class key to the new hashmap probability.put(entry.getKey(), 1.0 ); // Iterating through each random variable in the vector and then // the matching mean and SD for the variables. for ( int i = 0 ; i < entry.getValue().size(); i++) { // for each in the values // Get the mean and SD of the variable at column i Double mean = (Double) entry.getValue().get(i).get( 0 ); Double stdDev = (Double) entry.getValue().get(i).get( 1 ); // Get the random variable from the vector at index i double x = (Double) vector.get(i); // Probabilities are multiplied together double probval = probability.get(entry.getKey()) * densityFunc(x, mean, stdDev); probability.put(entry.getKey(), probval); } } return probability; } double decidePredict(HashMap<Integer, ArrayList<ArrayList>> summaries, ArrayList vector) { // Retrieves a new hashmap from classProbability method HashMap<Integer, Double> probability = classProbability(summaries, vector); // Sets up null values // These are over written as the summaries are determined double hiLabel = 99 ; double hiProb = - 1 ; // Step through the returned hash with probability for each class for (Map.Entry<Integer, Double> entry : probability.entrySet()) { // Makes the decision which class label to assign to the vector // If the next probability is higher, it will replace the lower one if (entry.getValue() > hiProb) { hiProb = entry.getValue(); hiLabel = entry.getKey(); } } return hiLabel; } ArrayList goPredict(HashMap<Integer, ArrayList<arraylist>> summaries, ArrayList<arraylist> testSet) { ArrayList< double > finalPredictions = new ArrayList<>(); // Loops through every vector in the testSet for ( int i = 0 ; i < testSet.size(); i++) { // summaries and vector sent to predict and stored in final predictions double result = decidePredict(summaries, testSet.get(i)); finalPredictions.add(result); } return finalPredictions; } Double accuracy (ArrayList<ArrayList> matrix, ArrayList< double > predictions){ int correct = 0 ; int len = matrix.get( 0 ).size(); // The class labels are checked against the predictions for ( int i = 0 ; i < matrix.size(); i++){ double var_a = (Double) matrix.get(i).get(len- 1 ); double var_b = predictions.get(i); // Increase count for correct predictions if (var_a == var_b){ correct = correct + 1 ; } } double msize = matrix.size(); // Normalize to a percent double accuracy = correct/msize* 100 ; return accuracy; } } |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | public class Classifier { public static void main(String[] args) { // Creating instances of all of our classes CSVReader parse = new CSVReader(); DataProcess process = new DataProcess(); Predictions pred = new Predictions(); Statistics stats = new Statistics(); // This block imports the data and saves it into the matrix String s = "C:/Path_to/TwoNormDataset.csv" ; ArrayList<arraylist> matrix = parse.parseCSV(s); // Here we split the data into training and test set based on the split ratio ArrayList splitsets = process.splitSet(matrix, . 7 ); // Retrieving the arrays using get ArrayList trainSet = (ArrayList)splitsets.get( 0 ); ArrayList testSet = (ArrayList)splitsets.get( 1 ); // Create summaries from the trainingSet for each of the variables HashMap summaries = stats.classStats(trainSet); // Finally send the summaries to make a prediction from the testSet ArrayList predictions = pred.goPredict(summaries, testSet); // Finally we return the accuracy of our prediction System.out.println( "accuracy" ); System.out.println(pred.accuracy(testSet, predictions)); } } |