Lately I've been implementing some machine learning using the Naive Bayes algorithm in Java. I wanted a side project to get under the hood with Java and this has also co-incided with courses in probability theory I've been taking. There are also a lot of Python tutorials online for writing this type of classifier but not a lot in the Java realm. While Java isn't used a lot in analytics compared to Python and R, it is used a lot in data engineering pipelines so hopefully this might be useful contribution.
Some of the math involved
This classifier is based around
Bayes Theorem which underpins a lot of conditional statistics and describes the probability of an event, based on prior knowledge of conditions that might be related to the event. Where X and Y are events and P(B) ≠ 0, the likelihood of event X occurring given Y has occurred is as follows.
$$ P(X|Y) = \dfrac{P(Y|X)P(X)}{P(Y)} $$
Probability theory is fundamental to a lot of machine learning but that isn't the focus of the tutorial. What is important is how does this work as a machine learning classifier? The equation that we'll be using within the code is the
normal distribution approximation (or event model) to the Naive Bayes classifier. This uses the probability density function described below. Probability density functions are used for numerical data and describe the probability of a random variable occurring within a given distribution. This requires that we know the standard deviation and mean of each variable in the training set in order to model that
distribution. These are easily found and we will be calculating these for the vectors associated with each class labels (0 and 1 in this case) we can predict the probability of a random variable as belonging to one or the other class.
$$ \displaystyle pdf = \frac{1}{\sigma_1^2 \sqrt{2 \pi}}\ e^{-\frac{(x-\mu)^2}{2 \sigma^2}} $$
The thing that can be difficult to grasp initially is how the decision is made for the individual classes and one that I've often found explained poorly from a beginners perspective. Even if we have an equation that determines the probability of one variable occurring within a certain distribution, how can each vector be assigned a class label. Assuming our class labels are Y and the variables in our vector are X1...Xn, then the naive bayes algorithm can be simplified to assume that each of the vectors are independent from each other which just means there is no influence between them that can determine them belonging to either class label.
$$ P(X_1...X_n|Y) = \prod_{i=1}^n P(X_i|Y) $$
This equation implies that the product of all individual probabilities in the vector will give us a overall probability of all the variables being as they are for a specific class label. As we will calculate the individual probability of each variable in the vector using the normal distribution approximation above, it's a matter or multiplying these for the mean and standard deviation associated with each class label. Once we have these values, the one with the highest probability is chosen as the predicted class.
From a programming point of view, it's easier to see how these pieces fit together, so the diagram below does some of that work in showing how the statistics work on the flow of the data.
Programming and code examples
The dataset I'll be using for testing is an implementation of Leo Breiman's twonorm example
twonorm example which is a vector of 20 random variables from two normally distributed models that lie inside each other. This is a synthetic data set, so it will return a very high accuracy, however it works for the purpose of testing the classifier.
The program itself is split into four classes in addition to the public static void main class that is the standard main() class that Java executes at run time. Here is a list of the packages that I'll be using in the code which you may need to path if your IDE doesn't import them automatically. Thanks IntelliJ :)
CSVReader
The first class
CSVReader contains a single method
parseCSV that is used to import the data into an array of arrays. An ArrayList<ArrayList> to be exact. Not to be confused with regular arrays, ArrayLists are more flexible than standard Arrays since they are mutable. This data structure will be the basis of the program and mimicks a 2D matrix.
Java imports data via a BufferedReader and we define that as br and start reading in each line. Each row of the data contains a vector of numbers that are the numerical values for the variables (X1 ... Xn). The last entry is either a 0 or 1 and is the class label. This is the category that this vector belongs to.
Our CSV file consists of vectors on each line that include a prediction in the final column. The length of the vector determines the number of columns in our matrix and this is defined as Integer len. We cast the string as a double since we can't be certain that our data consists purely of integers. We then loop through this with a for loop and enter this into an ArrayList<Double> before dumping these into our final matrix.
where we will separate out the data we intend to use for our test and training sets. Our main method
takes the matrix we created and a split ratio between 0 and 1.
A random number is generated based on the size of our data and we use this as an index to retrieve the random vector for inclusion in our training set. We remove this vector from the test set and add it to our training set. Finally both the sets are returned via an array .