Wednesday, April 19, 2017

Naive Bayes Classifier in Java Tutorial

Lately I've been implementing some machine learning using the Naive Bayes algorithm in Java. I wanted a side project to get under the hood with Java and this has also co-incided with courses in probability theory I've been taking. There are also a lot of Python tutorials online for writing this type of classifier but not a lot in the Java realm. While Java isn't used a lot in analytics compared to Python and R, it is used a lot in data engineering pipelines so hopefully this might be useful contribution.

Some of the math involved


This classifier is based around Bayes Theorem which underpins a lot of conditional statistics and describes the probability of an event, based on prior knowledge of conditions that might be related to the event. Where X and Y are events and P(B) ≠ 0, the likelihood of event X occurring given Y has occurred is as follows.

$$ P(X|Y) = \dfrac{P(Y|X)P(X)}{P(Y)} $$

Probability theory is fundamental to a lot of machine learning but that isn't the focus of the tutorial. What is important is how does this work as a machine learning classifier? The equation that we'll be using within the code is the normal distribution approximation (or event model) to the Naive Bayes classifier. This uses the probability density function described below. Probability density functions are used for numerical data and describe the probability of a random variable occurring within a given distribution. This requires that we know the standard deviation and mean of each variable in the training set in order to model that distribution. These are easily found and we will be calculating these for the vectors associated with each class labels (0 and 1 in this case) we can predict the probability of a random variable as belonging to one or the other class.

$$ \displaystyle pdf = \frac{1}{\sigma_1^2 \sqrt{2 \pi}}\ e^{-\frac{(x-\mu)^2}{2 \sigma^2}} $$

The thing that can be difficult to grasp initially is how the decision is made for the individual classes and one that I've often found explained poorly from a beginners perspective. Even if we have an equation that determines the probability of one variable occurring within a certain distribution, how can each vector be assigned a class label. Assuming our class labels are Y and the variables in our vector are X1...Xn, then the naive bayes algorithm can be simplified to assume that each of the vectors are independent from each other which just means there is no influence between them that can determine them belonging to either class label.

$$ P(X_1...X_n|Y) = \prod_{i=1}^n P(X_i|Y) $$

This equation implies that the product of all individual probabilities in the vector will give us a overall probability of all the variables being as they are for a specific class label. As we will calculate the individual probability of each variable in the vector using the normal distribution approximation above, it's a matter or multiplying these for the mean and standard deviation associated with each class label. Once we have these values, the one with the highest probability is chosen as the predicted class.

From a programming point of view, it's easier to see how these pieces fit together, so the diagram below does some of that work in showing how the statistics work on the flow of the data.



Programming and code examples


The dataset I'll be using for testing is an implementation of Leo Breiman's twonorm example twonorm example which is a vector of 20 random variables from two normally distributed models that lie inside each other. This is a synthetic data set, so it will return a very high accuracy, however it works for the purpose of testing the classifier.

The program itself is split into four classes in addition to the public static void main class that is the standard main() class that Java executes at run time. Here is a list of the packages that I'll be using in the code which you may need to path if your IDE doesn't import them automatically. Thanks IntelliJ :)


CSVReader


The first class CSVReader contains a single method parseCSV that is used to import the data into an array of arrays. An ArrayList<ArrayList> to be exact. Not to be confused with regular arrays, ArrayLists are more flexible than standard Arrays since they are mutable. This data structure will be the basis of the program and mimicks a 2D matrix.

Java imports data via a BufferedReader and we define that as br and start reading in each line. Each row of the data contains a vector of numbers that are the numerical values for the variables (X1 ... Xn). The last entry is either a 0 or 1 and is the class label. This is the category that this vector belongs to.

Our CSV file consists of vectors on each line that include a prediction in the final column. The length of the vector determines the number of columns in our matrix and this is defined as Integer len. We cast the string as a double since we can't be certain that our data consists purely of integers. We then loop through this with a for loop and enter this into an ArrayList<Double> before dumping these into our final matrix.



DataProcess


Out next class is DataProcess where we will separate out the data we intend to use for our test and training sets. Our main method splitSet takes the matrix we created and a split ratio between 0 and 1.
A random number is generated based on the size of our data and we use this as an index to retrieve the random vector for inclusion in our training set. We remove this vector from the test set and add it to our training set. Finally both the sets are returned via an array .


Statistics


Next we need to generate some statistics for our sets and we do this within the statistics class. We will use this class to produce mean and standard deviations for each of the variables. One for the positive class results and one for the negative results.

We begin by sending the trainSet to the classStats method, after which we use the getCol mthod in a for loop to retrieve each column (variable). These are then separated into the positive and negative class results by checking if the idCol (the last column in the matrix) matches 0 or not. Although this is a binary classifier, it would be possible to split out sepearate classification groups by adding more conditional controls here. Once these class separations are performed, each variable is sent to the individualStats method which calls the meanList and stdDev methods.

The method classStats eventually returns a hashmap with 0:{n x [mean, stdDev]}, 1:{n x [mean, stdDev]} where n is the number of variables.


Predictions


The final class to be implemented is the Predictions class which will be used to calculate probability and decide which class label each vector is predicted to be.

The summaries and testSet for each vector are sent to the goPredict method. It's now time to calculate the predictions for the remaining data. Each vector from the testSet is sent to the decidePredict method and added to the finalPredictions ArrayList.

The decidePredict method obtains the combined probability from the classProbability method for each of the class labels which determines the probability of each variable in the vector by sending the summary statistics of each class label (0 and 1) and the vector variable to the densityFunc method. This calculates the probability using the probability density function for the normal distribution which is outlined in the math section above. The highest of these two values is selected and returned by the decidePredict method.

The last method accuracy is used to calculate what percentage of the class predictions were correct. We pass the testSet and the resulting class predictions to the method and count the number of successes by comparing the prediction to the last column in the testSet. A percentage is then calculated and returned.

Lastly we arrange all the method calls in the public static void main method which performs all the requested operations and returns our prediction accuracy.