From 424c7709ae7048b464f7df60dbc378cbf50028f9 Mon Sep 17 00:00:00 2001 From: eugenp Date: Wed, 22 Apr 2015 13:28:04 +0300 Subject: [PATCH] cleanup work --- .../reddit/classifier/RedditClassifier.java | 42 ++++++++----------- .../classifier/RedditClassifierTest.java | 2 +- 2 files changed, 19 insertions(+), 25 deletions(-) diff --git a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java index 4c58ff67ca..076ac0e65d 100644 --- a/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java +++ b/spring-security-oauth/src/main/java/org/baeldung/reddit/classifier/RedditClassifier.java @@ -22,32 +22,24 @@ import com.google.common.base.Splitter; import com.google.common.io.Files; public class RedditClassifier { - public static int GOOD = 0; public static int BAD = 1; - public static int MIN_SCORE = 10; + public static int MIN_SCORE = 7; + + private final int[] trainCount = { 0, 0 }; + private final int[] evalCount = { 0, 0 }; + private final int[] correctCount = { 0, 0 }; private final AdaptiveLogisticRegression classifier; private final FeatureVectorEncoder titleEncoder; private final FeatureVectorEncoder domainEncoder; - private CrossFoldLearner learner; private final int noOfFeatures; + + private CrossFoldLearner learner; private double accuracy; - private final int[] trainCount = { 0, 0 }; - - private final int[] evalCount = { 0, 0 }; - - private final int[] correctCount = { 0, 0 }; - public RedditClassifier() { - noOfFeatures = 1000; - classifier = new AdaptiveLogisticRegression(2, 1000, new L2()); - classifier.setPoolSize(150); - titleEncoder = new AdaptiveWordValueEncoder("title"); - titleEncoder.setProbes(2); - domainEncoder = new StaticWordValueEncoder("domain"); - domainEncoder.setProbes(1); + this(150, 1000); } public RedditClassifier(final int poolSize, final int noOfFeatures) { @@ -60,6 +52,8 @@ public class RedditClassifier { domainEncoder.setProbes(1); } + // API + public void trainClassifier(final String fileName) throws IOException { final List vectors = extractVectors(readDataFile(fileName)); final int size = vectors.size(); @@ -151,25 +145,25 @@ public class RedditClassifier { final String title = items[3]; final String theRootDomain = items[4]; - final String category = extractCategory(Integer.parseInt(numberOfVotes)); - - final NamedVector vector = new NamedVector(new RandomAccessSparseVector(noOfFeatures), category); + final RandomAccessSparseVector internalVector = new RandomAccessSparseVector(noOfFeatures); final Calendar cal = Calendar.getInstance(TimeZone.getTimeZone("GMT")); cal.setTimeInMillis(Long.parseLong(time) * 1000); - vector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day + internalVector.set(0, cal.get(Calendar.HOUR_OF_DAY)); // hour of day - vector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title + internalVector.set(1, Integer.parseInt(numberOfWordInTitle)); // number of words in the title - domainEncoder.addToVector(theRootDomain, vector); + domainEncoder.addToVector(theRootDomain, internalVector); final String[] words = title.split(" "); // titleEncoder.setProbes(words.length); // TODO: use a Java 8 stream with filter and remove the 1 and 2 character words; example: "a", "of", "to" for (final String word : words) { - titleEncoder.addToVector(word, vector); + titleEncoder.addToVector(word, internalVector); } - return vector; + + final String category = extractCategory(Integer.parseInt(numberOfVotes)); + return new NamedVector(internalVector, category); } private String extractCategory(final int score) { diff --git a/spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java b/spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java index d18d683dc7..1bdc843599 100644 --- a/spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java +++ b/spring-security-oauth/src/test/java/org/baeldung/classifier/RedditClassifierTest.java @@ -33,7 +33,7 @@ public class RedditClassifierTest { @Test public void givenLargerPoolSizeAndFeatures_whenUsingCustomClassifier_thenAccurate() throws IOException { - final RedditClassifier classifier = new RedditClassifier(200, 2000); + final RedditClassifier classifier = new RedditClassifier(250, 2500); classifier.trainClassifier(RedditDataCollector.DATA_FILE); final double result = classifier.getAccuracy(); System.out.println("==== Custom Classifier (large) Accuracy = " + result);