diff --git a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java index 4772f8927f..c6cc99933f 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java +++ b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java @@ -25,7 +25,7 @@ public class RedditClassifier { public static int GOOD = 0; public static int BAD = 1; - public static int MIN_SCORE = 7; + public static int MIN_SCORE = 10; public static int NUM_OF_FEATURES = 1000; private final AdaptiveLogisticRegression classifier; @@ -42,9 +42,9 @@ public class RedditClassifier { public RedditClassifier() { classifier = new AdaptiveLogisticRegression(2, NUM_OF_FEATURES, new L2()); - classifier.setPoolSize(50); + classifier.setPoolSize(150); titleEncoder = new AdaptiveWordValueEncoder("title"); - titleEncoder.setProbes(1); + titleEncoder.setProbes(2); domainEncoder = new StaticWordValueEncoder("domain"); domainEncoder.setProbes(1); } @@ -65,13 +65,15 @@ public class RedditClassifier { } public Vector convertPost(String title, String domain, int hour) { - final Vector features = new RandomAccessSparseVector(4); - final int noOfWords = Splitter.onPattern("\\W").omitEmptyStrings().splitToList(title).size(); - titleEncoder.addToVector(title, features); - domainEncoder.addToVector(domain, features); - features.set(2, hour); - features.set(3, noOfWords); - return features; + final Vector vector = new RandomAccessSparseVector(NUM_OF_FEATURES); + final List words = Splitter.onPattern("\\W").omitEmptyStrings().splitToList(title); + vector.set(0, hour); + vector.set(1, words.size()); + domainEncoder.addToVector(domain, vector); + for (final String word : words) { + titleEncoder.addToVector(word, vector); + } + return vector; } public int classify(Vector features) { @@ -106,6 +108,7 @@ public class RedditClassifier { System.out.println("Eval count ========= Good = " + evalCount[0] + " ___ Bad = " + evalCount[1]); System.out.println("Test result ======== Correct prediction = " + correct + " ----- Wrong prediction = " + wrong); System.out.println("Test result ======== Correct Good = " + correctCount[0] + " ----- Correct Bad = " + correctCount[1]); + System.out.println("Test result ======== Good accuracy = " + (correctCount[0] / (evalCount[0] + 0.0)) + " ----- Bad accuracy = " + (correctCount[1] / (evalCount[1] + 0.0))); this.accuracy = correct / (wrong + correct + 0.0); }