diff --git a/build.gradle b/build.gradle index 2cc13c3..fb400e1 100755 --- a/build.gradle +++ b/build.gradle @@ -1,6 +1,16 @@ // Apply plugins -apply plugin: 'java' -apply plugin: 'application' + +plugins { + // adds './gradlew shadowjar', which creates a fat Jar (includes all dependencies) in build/libs/ + id 'com.github.johnrengelman.shadow' version '4.0.2' + // Needed for all Java projects + id 'java' + id 'application' +} + +shadowJar { + transform(com.github.jengelman.gradle.plugins.shadow.transformers.Log4j2PluginsCacheFileTransformer) +} // Basic configuration and settings for all (sub-)projects allprojects { @@ -18,6 +28,8 @@ allprojects { // Declare global dependencies dependencies { + implementation group: 'org.apache.hadoop', name: 'hadoop-common', version: '3.3.1' + implementation group: 'org.apache.hadoop', name: 'hadoop-mapreduce-client-core', version: '3.3.1' implementation group: 'org.apache.commons', name: 'commons-compress', version: '1.19' implementation 'info.picocli:picocli:4.5.2' diff --git a/src/main/java/org/netspeak/hadoop/Merge.java b/src/main/java/org/netspeak/hadoop/Merge.java new file mode 100644 index 0000000..96b74e5 --- /dev/null +++ b/src/main/java/org/netspeak/hadoop/Merge.java @@ -0,0 +1,129 @@ +package org.netspeak.hadoop; + +import java.io.IOException; +import java.util.Collection; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.reduce.LongSumReducer; +import org.netspeak.lang.Agnostic; +import org.netspeak.lang.En; +import org.netspeak.lang.MapperConfig; +import org.netspeak.lang.SingleMapProcessor; +import org.netspeak.preprocessing.PhraseMapper; + +public class Merge { + private static final String CONFIG_LOWERCASE = "preprocessing.lowercase"; + private static final String CONFIG_MAX_N_GRAM = "preprocessing.max-n-gram"; + private static final String CONFIG_LANG = "preprocessing.lang"; + + private static final String LANG_NONE = "none"; + private static final String LANG_EN = "en"; + private static final String LANG_DE = "de"; + + public static class TokenizerMapper extends Mapper { + + private final Text phrase = new Text(); + private final LongWritable frequency = new LongWritable(); + + private PhraseMapper[] mappers = new PhraseMapper[0]; + + @Override + public void setup(Context context) throws IOException, InterruptedException { + final Configuration conf = context.getConfiguration(); + + final MapperConfig config = new MapperConfig(); + config.lowercase = conf.getBoolean(CONFIG_LOWERCASE, false); + config.maxNGram = conf.getInt(CONFIG_MAX_N_GRAM, Integer.MAX_VALUE); + + SingleMapProcessor processor; + final String lang = conf.get(CONFIG_LANG, LANG_NONE).toLowerCase(); + switch (lang) { + case LANG_NONE: + processor = Agnostic.INSTANCE; + break; + case LANG_DE: + throw new IllegalArgumentException("DE is not supported for Hadoop."); + case LANG_EN: + processor = En.INSTANCE; + break; + default: + throw new IllegalArgumentException("Unknown language: " + lang); + } + + try { + mappers = processor.getMappers(config).toArray(new PhraseMapper[0]); + } catch (final Exception e) { + throw new RuntimeException(e); + } + } + + private String map(String phrase, long frequency) { + for (final PhraseMapper mapper : mappers) { + phrase = mapper.map(phrase, frequency); + if (phrase == null) { + break; + } + } + return phrase; + } + + @Override + public void map(Object key, Text value, Context context) throws IOException, InterruptedException { + // format: *( ) + final String line = value.toString().trim(); + if (line.isEmpty()) { + // ignore line + return; + } + + final int tabIndex = line.indexOf('\t'); + if (tabIndex == -1) { + throw new IOException("Invalid format: Unable to find tab character."); + } + + final long freq = Long.parseLong(line.substring(tabIndex + 1)); + final String p = map(line.substring(0, tabIndex), freq); + + if (p == null) { + return; + } + + phrase.set(p); + frequency.set(freq); + + context.write(phrase, frequency); + } + } + + public static void run(Collection input, String outputDir, String lang, MapperConfig config) + throws Exception { + final Configuration conf = new Configuration(); + conf.set(CONFIG_LANG, lang); + conf.setBoolean(CONFIG_LOWERCASE, config.lowercase); + conf.setInt(CONFIG_MAX_N_GRAM, config.maxNGram); + + final Job job = Job.getInstance(conf, "Netspeak index preprocessing (" + lang + ")"); + job.setJarByClass(Merge.class); + job.setMapperClass(TokenizerMapper.class); + job.setCombinerClass(LongSumReducer.class); + job.setReducerClass(LongSumReducer.class); + job.setOutputKeyClass(Text.class); + job.setOutputValueClass(LongWritable.class); + job.setNumReduceTasks(1000); + + FileInputFormat.setInputPaths(job, input.stream().map(Path::new).toArray(Path[]::new)); + FileOutputFormat.setOutputPath(job, new Path(outputDir)); + + if (!job.waitForCompletion(true)) { + throw new RuntimeException("Job failed."); + } + } + +} diff --git a/src/main/java/org/netspeak/lang/De.java b/src/main/java/org/netspeak/lang/De.java index 2ea197a..fca06fd 100644 --- a/src/main/java/org/netspeak/lang/De.java +++ b/src/main/java/org/netspeak/lang/De.java @@ -39,7 +39,6 @@ public void process(Config config) throws Exception { final StandardMappers stdMappers = new StandardMappers(); stdMappers.setSuperBlacklist(Util.readResourceWordList("/super-blacklist.txt")); stdMappers.setBlacklist(Util.readResourceWordList("/blacklist.txt")); - stdMappers.setBlacklistCombinations(4); stdMappers.setMaxNGram(config.maxNGram); stdMappers.setToLowerCase(config.lowercase); diff --git a/src/main/java/org/netspeak/lang/En.java b/src/main/java/org/netspeak/lang/En.java index 7aa9fc1..3aae67f 100644 --- a/src/main/java/org/netspeak/lang/En.java +++ b/src/main/java/org/netspeak/lang/En.java @@ -23,7 +23,6 @@ public Collection getMappers(MapperConfig config) throws IOExcepti final StandardMappers stdMappers = new StandardMappers(); stdMappers.setSuperBlacklist(Util.readResourceWordList("/super-blacklist.txt")); stdMappers.setBlacklist(Util.readResourceWordList("/blacklist.txt")); - stdMappers.setBlacklistCombinations(4); stdMappers.setMaxNGram(config.maxNGram); stdMappers.setToLowerCase(config.lowercase); diff --git a/src/main/java/org/netspeak/preprocessing/mappers/PhraseMappers.java b/src/main/java/org/netspeak/preprocessing/mappers/PhraseMappers.java index 7cd4561..c21bb9f 100755 --- a/src/main/java/org/netspeak/preprocessing/mappers/PhraseMappers.java +++ b/src/main/java/org/netspeak/preprocessing/mappers/PhraseMappers.java @@ -4,7 +4,6 @@ import java.util.Arrays; import java.util.Collection; import java.util.HashSet; -import java.util.List; import java.util.Set; import java.util.function.Predicate; import java.util.regex.Pattern; @@ -93,7 +92,11 @@ public static PhraseMapper normalizeApostrophe() { * @return */ public static PhraseMapper blacklist(final Collection words) { - return PhraseMapper.rename(blacklist(words, 1)); + final HashSet set = new HashSet<>(words); + set.remove(null); + set.remove(""); + + return PhraseMapper.rename(filterByWords(w -> !set.contains(w))); } /** @@ -135,50 +138,51 @@ public static PhraseMapper removeControlCharacters() { * Returns a new {@link PhraseMapper} that removes phrases which contain at * least one word that is contained in a given blacklist vocabulary. *

- * Phrases which contains a word which can be constructed by concatenating - * {@code <= repeating} many words from the blacklist will also be removed. I.e. - * if {@code "} and {@code ?} are in the blacklist and {@code repeating} is 3, - * then {@code """}, {@code "?"}, {@code "?}, and {@code ??} will all be - * removed. - *

- * Please note that the blacklist will consume {@code O(n ** repeat)} - * many bytes of memory where {@code n} is the number of blacklist entries. + * Phrases which contain a word which can be constructed by concatenating + * blacklist words will also be removed. I.e. if {@code "} and {@code ?} are in + * the blacklist, then {@code """}, {@code "?"}, {@code "?}, and {@code ??} will + * all be removed. * * @param words * @return */ - public static PhraseMapper blacklist(final Collection words, int repeat) { - HashSet tempBlacklist = new HashSet<>(words); - // just to be safe - tempBlacklist.remove(null); - tempBlacklist.remove(""); - - if (repeat > 1) { - tempBlacklist = new HashSet<>(getAllCombinations(tempBlacklist, repeat)); + public static PhraseMapper blacklistRepeated(final Collection words) { + // create a regex for the words + + // split by length + final ArrayList singleChar = new ArrayList<>(); + final ArrayList multipleChar = new ArrayList<>(); + for (final String word : words) { + if (word == null || word.isEmpty()) { + // skip + } else if (word.length() == 1) { + singleChar.add(word); + } else { + multipleChar.add(word); + } } - // thanks Java - final Set blacklist = tempBlacklist; + final StringBuilder sb = new StringBuilder(); + sb.append("["); + for (final String singleCharWord : singleChar) { + appendLiteral(sb, singleCharWord); + } + sb.append("]"); - return PhraseMapper.rename(filterByWords(w -> !blacklist.contains(w))); - } + for (final String word : multipleChar) { + sb.append("|"); + appendLiteral(sb, word); + } - private static List getAllCombinations(Collection words, int repeat) { - final ArrayList combinations = new ArrayList<>((int) Math.pow(words.size(), repeat)); - combinations.addAll(words); + final Pattern regex = Pattern.compile("(?:" + sb.toString() + ")+"); - int start = 0; - for (; repeat > 1; repeat--) { - final int size = combinations.size(); - for (int i = start; i < size; i++) { - for (final String word : words) { - combinations.add(combinations.get(i) + word); - } - } - start = size; - } + return PhraseMapper.rename(filterByWords(w -> !regex.matcher(w).matches())); + } - return combinations; + private static void appendLiteral(StringBuilder sb, String value) { + for (final char c : value.toCharArray()) { + sb.append("\\u").append(String.format("%04x", (int) c)); + } } /** diff --git a/src/main/java/org/netspeak/preprocessing/mappers/StandardMappers.java b/src/main/java/org/netspeak/preprocessing/mappers/StandardMappers.java index 4590a6e..f66fb98 100644 --- a/src/main/java/org/netspeak/preprocessing/mappers/StandardMappers.java +++ b/src/main/java/org/netspeak/preprocessing/mappers/StandardMappers.java @@ -19,11 +19,10 @@ public class StandardMappers { */ boolean toLowerCase = false; /** - * All phrases with at least one word which can be constructed from at most - * {@link #blacklistCombinations} many blacklisted word will be removed. + * All phrases with at least one word which can be constructed from blacklisted + * word will be removed. */ Collection blacklist = null; - int blacklistCombinations = 4; /** * @see PhraseMappers#superBlacklist(Iterable) */ @@ -37,10 +36,6 @@ public void setBlacklist(Collection blacklist) { this.blacklist = blacklist; } - public void setBlacklistCombinations(int blacklistCombinations) { - this.blacklistCombinations = blacklistCombinations; - } - public void setSuperBlacklist(Path superBlacklist) throws IOException { this.superBlacklist = Util.readWordList(superBlacklist); } @@ -79,7 +74,7 @@ public Collection getMappers() { mappers.add(PhraseMappers.joinWordsWithLeadingApostrophe()); if (blacklist != null) { - mappers.add(PhraseMappers.blacklist(blacklist, blacklistCombinations)); + mappers.add(PhraseMappers.blacklistRepeated(blacklist)); } if (maxNGram < Integer.MAX_VALUE) { mappers.add(PhraseMappers.maxNGram(maxNGram)); diff --git a/src/main/java/org/netspeak/usage/Cli.java b/src/main/java/org/netspeak/usage/Cli.java index a158ee0..a5ca29c 100644 --- a/src/main/java/org/netspeak/usage/Cli.java +++ b/src/main/java/org/netspeak/usage/Cli.java @@ -14,11 +14,13 @@ import java.util.Properties; import java.util.stream.Collectors; +import org.netspeak.hadoop.Merge; import org.netspeak.io.GoogleBooksCsvReader; import org.netspeak.lang.Agnostic; import org.netspeak.lang.Config; import org.netspeak.lang.De; import org.netspeak.lang.En; +import org.netspeak.lang.MapperConfig; import org.netspeak.lang.Processor; import org.netspeak.preprocessing.PhraseSource; import org.netspeak.preprocessing.SimplePhraseSourceFile; @@ -48,18 +50,19 @@ public class Cli implements Runnable { "---" }) Lang lang; - @Option(names = { "-i", "--input" }, type = Path.class, arity = "1..*", description = { + @Option(names = { "-i", "--input" }, type = String.class, arity = "1..*", description = { "A list of input directories.", "The file and directory formats will be automatically detected.", "The given files will not be modified." }) - List input; + List input; @Option(names = { "-o", "--output" }, description = { "The output path of the preprocessing step.", "The given directory has to be either empty or not exist." }) - Path output; + String output; @Option(names = { "-t", "--temp" }, description = { "A temporary path used to store temporary files.", "This path should point to a fast SSD." + " Most processing and IO-heavy operations will be done in the temp directories." - + " All temporary files will be deleted during or at the end of execution." }) - Path temp; + + " All temporary files will be deleted during or at the end of execution.", + "This option will be ignored when run with Hadoop." }) + String temp; @Option(names = { "--lowercase" }, description = { "Whether the whole data set will be lowercased." }) Boolean lowercase; @@ -74,6 +77,9 @@ public class Cli implements Runnable { @Option(names = { "--merge" }, description = { "Whether duplicate phrases in the data set will be merged.", "Defaults to true." }) Boolean merge; + @Option(names = { "--hadoop" }, description = { "Whether to do the given operation on a Hadoop cluster.", + "Defaults to false." }) + Boolean hadoop; private void readConfig() throws Throwable { if (config == null) { @@ -95,7 +101,7 @@ private void readConfig() throws Throwable { if (input == null) { p = props.getProperty("input"); if (p != null) { - input = Arrays.stream(p.split(";")).map(String::trim).filter(s -> !s.isEmpty()).map(Paths::get) + input = Arrays.stream(p.split(";")).map(String::trim).filter(s -> !s.isEmpty()) .collect(Collectors.toList()); } } @@ -103,14 +109,14 @@ private void readConfig() throws Throwable { if (output == null) { p = props.getProperty("output"); if (p != null) { - output = Paths.get(p); + output = p; } } if (temp == null) { p = props.getProperty("temp"); if (p != null) { - temp = Paths.get(p); + temp = p; } } @@ -200,6 +206,39 @@ private PhraseSource toPhraseSource(Path input) throws IOException { return PhraseSource.fromFiles(files.stream().map(SimplePhraseSourceFile::new).collect(Collectors.toList())); } + private void runLocal() throws Throwable { + final PhraseSource source = PhraseSource.combine(input.stream().map(p -> { + try { + return toPhraseSource(Paths.get(p)); + } catch (final Exception e) { + throw new RuntimeException(e); + } + }).collect(toList())); + + final Config config = new Config(source, Paths.get(output)); + config.temp = Paths.get(temp); + config.lowercase = lowercase == null ? false : lowercase; + config.maxNGram = maxNGram == null ? Integer.MAX_VALUE : maxNGram; + config.parallelDegree = parallel == null || parallel <= 0 ? Runtime.getRuntime().availableProcessors() + : parallel; + config.mergeDuplicates = merge == null ? true : merge; + + lang.processor.process(config); + } + + private void runHadoop() throws Throwable { + if (merge != null && merge == false) { + throw new IllegalArgumentException( + "When running using Hadoop, duplicates will always be merged. This conflicts with the `merge=false` given."); + } + + final MapperConfig config = new MapperConfig(); + config.lowercase = lowercase == null ? false : lowercase; + config.maxNGram = maxNGram == null ? Integer.MAX_VALUE : maxNGram; + + Merge.run(input, output, lang.name(), config); + } + private void runWithExecption() throws Throwable { readConfig(); @@ -216,23 +255,11 @@ private void runWithExecption() throws Throwable { throw new IllegalArgumentException("--lang option is not set by config file or argument."); } - final PhraseSource source = PhraseSource.combine(input.stream().map(p -> { - try { - return toPhraseSource(p); - } catch (final Exception e) { - throw new RuntimeException(e); - } - }).collect(toList())); - - final Config config = new Config(source, output); - config.temp = temp; - config.lowercase = lowercase == null ? false : lowercase; - config.maxNGram = maxNGram == null ? Integer.MAX_VALUE : maxNGram; - config.parallelDegree = parallel == null || parallel <= 0 ? Runtime.getRuntime().availableProcessors() - : parallel; - config.mergeDuplicates = merge == null ? true : merge; - - lang.processor.process(config); + if (hadoop == true) { + runHadoop(); + } else { + runLocal(); + } System.out.println("Done."); } diff --git a/src/test/java/org/netspeak/preprocessing/PhraseMappersTest.java b/src/test/java/org/netspeak/preprocessing/PhraseMappersTest.java index 050e57d..e9d43d4 100644 --- a/src/test/java/org/netspeak/preprocessing/PhraseMappersTest.java +++ b/src/test/java/org/netspeak/preprocessing/PhraseMappersTest.java @@ -15,25 +15,25 @@ public class PhraseMappersTest { private void phraseMapperTest(PhraseMapper mapper, Collection unchanged, Collection removed, Map changed) { - String name = mapper.getName(); + final String name = mapper.getName(); if (unchanged != null) { - for (String expected : unchanged) { - String actual = mapper.map(expected, 100); + for (final String expected : unchanged) { + final String actual = mapper.map(expected, 100); assertEquals("Expected unchanged for " + name, expected, actual); } } if (removed != null) { - for (String expected : removed) { - String actual = mapper.map(expected, 100); + for (final String expected : removed) { + final String actual = mapper.map(expected, 100); assertEquals("Expected removed for " + name, null, actual); } } if (changed != null) { - for (Map.Entry transform : changed.entrySet()) { - String actual = mapper.map(transform.getKey(), 100); + for (final Map.Entry transform : changed.entrySet()) { + final String actual = mapper.map(transform.getKey(), 100); assertEquals("Expected changed for " + name, transform.getValue(), actual); } } @@ -41,8 +41,8 @@ private void phraseMapperTest(PhraseMapper mapper, Collection unchanged, @Test public void blacklist() { - Set blacklistedWords = new HashSet<>(); - for (String word : ". - ( ) \" '".split(" ")) { + final Set blacklistedWords = new HashSet<>(); + for (final String word : ". - ( ) \" '".split(" ")) { blacklistedWords.add(word); } @@ -63,7 +63,7 @@ public void blacklist() { sharedRemoved.add("foo - bar"); { - final PhraseMapper mapper = PhraseMappers.blacklist(blacklistedWords, 1); + final PhraseMapper mapper = PhraseMappers.blacklist(blacklistedWords); final Collection unchanged = new ArrayList<>(sharedUnchanged); unchanged.add("()"); @@ -73,14 +73,15 @@ public void blacklist() { phraseMapperTest(mapper, unchanged, removed, null); } { - final PhraseMapper mapper = PhraseMappers.blacklist(blacklistedWords, 4); + final PhraseMapper mapper = PhraseMappers.blacklistRepeated(blacklistedWords); final Collection unchanged = new ArrayList<>(); - unchanged.add("()()-"); + unchanged.add("()()a"); final Collection removed = new ArrayList<>(); removed.add("()()"); removed.add("-.-."); + removed.add("-.-.-.-.-.-.-.-.-.-"); removed.add("-.-. foo"); removed.add("foo -.-. foo"); @@ -90,8 +91,8 @@ public void blacklist() { @Test public void superBlacklist() { - Set blacklistedWords = new HashSet<>(); - for (String word : ". - ( ) \" '".split(" ")) { + final Set blacklistedWords = new HashSet<>(); + for (final String word : ". - ( ) \" '".split(" ")) { blacklistedWords.add(word); }