diff --git a/LICENSE b/LICENSE index 1ef1f86fd704..6b169b1447f1 100644 --- a/LICENSE +++ b/LICENSE @@ -216,6 +216,7 @@ core/src/main/resources/org/apache/spark/ui/static/bootstrap* core/src/main/resources/org/apache/spark/ui/static/jsonFormatter* core/src/main/resources/org/apache/spark/ui/static/vis* docs/js/vendor/bootstrap.js +external/spark-ganglia-lgpl/src/main/java/com/codahale/metrics/ganglia/GangliaReporter.java Python Software Foundation License diff --git a/LICENSE-binary b/LICENSE-binary index 7865d9df6314..6858193515a8 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -243,10 +243,10 @@ com.vlkan:flatbuffers com.ning:compress-lzf io.airlift:aircompressor io.dropwizard.metrics:metrics-core -io.dropwizard.metrics:metrics-ganglia io.dropwizard.metrics:metrics-graphite io.dropwizard.metrics:metrics-json io.dropwizard.metrics:metrics-jvm +io.dropwizard.metrics:metrics-jmx org.iq80.snappy:snappy com.clearspring.analytics:stream com.jamesmurty.utils:java-xmlbuilder diff --git a/NOTICE b/NOTICE index fefe08b38afc..d5ea8dedb311 100644 --- a/NOTICE +++ b/NOTICE @@ -26,3 +26,16 @@ The following provides more details on the included cryptographic software: This software uses Apache Commons Crypto (https://commons.apache.org/proper/commons-crypto/) to support authentication, and encryption and decryption of data sent across the network between services. + + +Metrics +Copyright 2010-2013 Coda Hale and Yammer, Inc. + +This product includes software developed by Coda Hale and Yammer, Inc. + +This product includes code derived from the JSR-166 project (ThreadLocalRandom, Striped64, +LongAdder), which was released with the following comments: + + Written by Doug Lea with assistance from members of JCP JSR-166 + Expert Group and released to the public domain, as explained at + http://creativecommons.org/publicdomain/zero/1.0/ \ No newline at end of file diff --git a/NOTICE-binary b/NOTICE-binary index d99c2d1c64c2..4ce8bf2f86b2 100644 --- a/NOTICE-binary +++ b/NOTICE-binary @@ -1515,3 +1515,16 @@ Copyright 2014-2017 The Apache Software Foundation This product includes software developed at The Apache Software Foundation (http://www.apache.org/). + + +Metrics +Copyright 2010-2013 Coda Hale and Yammer, Inc. + +This product includes software developed by Coda Hale and Yammer, Inc. + +This product includes code derived from the JSR-166 project (ThreadLocalRandom, Striped64, +LongAdder), which was released with the following comments: + + Written by Doug Lea with assistance from members of JCP JSR-166 + Expert Group and released to the public domain, as explained at + http://creativecommons.org/publicdomain/zero/1.0/ \ No newline at end of file diff --git a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java index 498dc51cdc81..916c14062167 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RpcIntegrationSuite.java @@ -260,14 +260,14 @@ public void onFailure(Throwable e) { @Test public void singleRPC() throws Exception { RpcResult res = sendRPC("hello/Aaron"); - assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!")); + assertEquals(Sets.newHashSet("Hello, Aaron!"), res.successMessages); assertTrue(res.errorMessages.isEmpty()); } @Test public void doubleRPC() throws Exception { RpcResult res = sendRPC("hello/Aaron", "hello/Reynold"); - assertEquals(res.successMessages, Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!")); + assertEquals(Sets.newHashSet("Hello, Aaron!", "Hello, Reynold!"), res.successMessages); assertTrue(res.errorMessages.isEmpty()); } @@ -295,7 +295,7 @@ public void doubleTrouble() throws Exception { @Test public void sendSuccessAndFailure() throws Exception { RpcResult res = sendRPC("hello/Bob", "throw error/the", "hello/Builder", "return error/!"); - assertEquals(res.successMessages, Sets.newHashSet("Hello, Bob!", "Hello, Builder!")); + assertEquals(Sets.newHashSet("Hello, Bob!", "Hello, Builder!"), res.successMessages); assertErrorsContain(res.errorMessages, Sets.newHashSet("Thrown: the", "Returned: !")); } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 184ddac9a71a..5e8b33455075 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -18,8 +18,11 @@ package org.apache.spark.unsafe.types; import java.io.Serializable; -import java.util.regex.Matcher; -import java.util.regex.Pattern; +import java.math.BigDecimal; +import java.time.Duration; +import java.time.Period; +import java.time.temporal.ChronoUnit; +import java.util.Objects; /** * The internal representation of interval type. @@ -32,249 +35,51 @@ public final class CalendarInterval implements Serializable { public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24; public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7; - private static Pattern yearMonthPattern = Pattern.compile( - "^([+|-])?(\\d+)-(\\d+)$"); - - private static Pattern dayTimePattern = Pattern.compile( - "^([+|-])?((\\d+) )?((\\d+):)?(\\d+):(\\d+)(\\.(\\d+))?$"); - - public static long toLongWithRange(String fieldName, - String s, long minValue, long maxValue) throws IllegalArgumentException { - long result = 0; - if (s != null) { - result = Long.parseLong(s); - if (result < minValue || result > maxValue) { - throw new IllegalArgumentException(String.format("%s %d outside range [%d, %d]", - fieldName, result, minValue, maxValue)); - } - } - return result; - } - - /** - * Parse YearMonth string in form: [-]YYYY-MM - * - * adapted from HiveIntervalYearMonth.valueOf - */ - public static CalendarInterval fromYearMonthString(String s) throws IllegalArgumentException { - CalendarInterval result = null; - if (s == null) { - throw new IllegalArgumentException("Interval year-month string was null"); - } - s = s.trim(); - Matcher m = yearMonthPattern.matcher(s); - if (!m.matches()) { - throw new IllegalArgumentException( - "Interval string does not match year-month format of 'y-m': " + s); - } else { - try { - int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; - int years = (int) toLongWithRange("year", m.group(2), 0, Integer.MAX_VALUE); - int months = (int) toLongWithRange("month", m.group(3), 0, 11); - result = new CalendarInterval(sign * (years * 12 + months), 0); - } catch (Exception e) { - throw new IllegalArgumentException( - "Error parsing interval year-month string: " + e.getMessage(), e); - } - } - return result; - } - - /** - * Parse dayTime string in form: [-]d HH:mm:ss.nnnnnnnnn and [-]HH:mm:ss.nnnnnnnnn - * - * adapted from HiveIntervalDayTime.valueOf - */ - public static CalendarInterval fromDayTimeString(String s) throws IllegalArgumentException { - return fromDayTimeString(s, "day", "second"); - } - - /** - * Parse dayTime string in form: [-]d HH:mm:ss.nnnnnnnnn and [-]HH:mm:ss.nnnnnnnnn - * - * adapted from HiveIntervalDayTime.valueOf. - * Below interval conversion patterns are supported: - * - DAY TO (HOUR|MINUTE|SECOND) - * - HOUR TO (MINUTE|SECOND) - * - MINUTE TO SECOND - */ - public static CalendarInterval fromDayTimeString(String s, String from, String to) - throws IllegalArgumentException { - CalendarInterval result = null; - if (s == null) { - throw new IllegalArgumentException("Interval day-time string was null"); - } - s = s.trim(); - Matcher m = dayTimePattern.matcher(s); - if (!m.matches()) { - throw new IllegalArgumentException( - "Interval string does not match day-time format of 'd h:m:s.n': " + s); - } else { - try { - int sign = m.group(1) != null && m.group(1).equals("-") ? -1 : 1; - long days = m.group(2) == null ? 0 : toLongWithRange("day", m.group(3), - 0, Integer.MAX_VALUE); - long hours = 0; - long minutes; - long seconds = 0; - if (m.group(5) != null || from.equals("minute")) { // 'HH:mm:ss' or 'mm:ss minute' - hours = toLongWithRange("hour", m.group(5), 0, 23); - minutes = toLongWithRange("minute", m.group(6), 0, 59); - seconds = toLongWithRange("second", m.group(7), 0, 59); - } else if (m.group(8) != null){ // 'mm:ss.nn' - minutes = toLongWithRange("minute", m.group(6), 0, 59); - seconds = toLongWithRange("second", m.group(7), 0, 59); - } else { // 'HH:mm' - hours = toLongWithRange("hour", m.group(6), 0, 23); - minutes = toLongWithRange("second", m.group(7), 0, 59); - } - // Hive allow nanosecond precision interval - String nanoStr = m.group(9) == null ? null : (m.group(9) + "000000000").substring(0, 9); - long nanos = toLongWithRange("nanosecond", nanoStr, 0L, 999999999L); - switch (to) { - case "hour": - minutes = 0; - seconds = 0; - nanos = 0; - break; - case "minute": - seconds = 0; - nanos = 0; - break; - case "second": - // No-op - break; - default: - throw new IllegalArgumentException( - String.format("Cannot support (interval '%s' %s to %s) expression", s, from, to)); - } - result = new CalendarInterval(0, sign * ( - days * MICROS_PER_DAY + hours * MICROS_PER_HOUR + minutes * MICROS_PER_MINUTE + - seconds * MICROS_PER_SECOND + nanos / 1000L)); - } catch (Exception e) { - throw new IllegalArgumentException( - "Error parsing interval day-time string: " + e.getMessage(), e); - } - } - return result; - } - - public static CalendarInterval fromUnitStrings(String[] units, String[] values) - throws IllegalArgumentException { - assert units.length == values.length; - int months = 0; - long microseconds = 0; - - for (int i = 0; i < units.length; i++) { - try { - switch (units[i]) { - case "year": - months = Math.addExact(months, Math.multiplyExact(Integer.parseInt(values[i]), 12)); - break; - case "month": - months = Math.addExact(months, Integer.parseInt(values[i])); - break; - case "week": - microseconds = Math.addExact( - microseconds, - Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_WEEK)); - break; - case "day": - microseconds = Math.addExact( - microseconds, - Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_DAY)); - break; - case "hour": - microseconds = Math.addExact( - microseconds, - Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_HOUR)); - break; - case "minute": - microseconds = Math.addExact( - microseconds, - Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_MINUTE)); - break; - case "second": { - microseconds = Math.addExact(microseconds, parseSecondNano(values[i])); - break; - } - case "millisecond": - microseconds = Math.addExact( - microseconds, - Math.multiplyExact(Long.parseLong(values[i]), MICROS_PER_MILLI)); - break; - case "microsecond": - microseconds = Math.addExact(microseconds, Long.parseLong(values[i])); - break; - } - } catch (Exception e) { - throw new IllegalArgumentException("Error parsing interval string: " + e.getMessage(), e); - } - } - return new CalendarInterval(months, microseconds); - } - - /** - * Parse second_nano string in ss.nnnnnnnnn format to microseconds - */ - public static long parseSecondNano(String secondNano) throws IllegalArgumentException { - String[] parts = secondNano.split("\\."); - if (parts.length == 1) { - return toLongWithRange("second", parts[0], Long.MIN_VALUE / MICROS_PER_SECOND, - Long.MAX_VALUE / MICROS_PER_SECOND) * MICROS_PER_SECOND; - - } else if (parts.length == 2) { - long seconds = parts[0].equals("") ? 0L : toLongWithRange("second", parts[0], - Long.MIN_VALUE / MICROS_PER_SECOND, Long.MAX_VALUE / MICROS_PER_SECOND); - long nanos = toLongWithRange("nanosecond", parts[1], 0L, 999999999L); - return seconds * MICROS_PER_SECOND + nanos / 1000L; - - } else { - throw new IllegalArgumentException( - "Interval string does not match second-nano format of ss.nnnnnnnnn"); - } - } - public final int months; + public final int days; public final long microseconds; public long milliseconds() { return this.microseconds / MICROS_PER_MILLI; } - public CalendarInterval(int months, long microseconds) { + public CalendarInterval(int months, int days, long microseconds) { this.months = months; + this.days = days; this.microseconds = microseconds; } public CalendarInterval add(CalendarInterval that) { int months = this.months + that.months; + int days = this.days + that.days; long microseconds = this.microseconds + that.microseconds; - return new CalendarInterval(months, microseconds); + return new CalendarInterval(months, days, microseconds); } public CalendarInterval subtract(CalendarInterval that) { int months = this.months - that.months; + int days = this.days - that.days; long microseconds = this.microseconds - that.microseconds; - return new CalendarInterval(months, microseconds); + return new CalendarInterval(months, days, microseconds); } public CalendarInterval negate() { - return new CalendarInterval(-this.months, -this.microseconds); + return new CalendarInterval(-this.months, -this.days, -this.microseconds); } @Override - public boolean equals(Object other) { - if (this == other) return true; - if (other == null || !(other instanceof CalendarInterval)) return false; - - CalendarInterval o = (CalendarInterval) other; - return this.months == o.months && this.microseconds == o.microseconds; + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + CalendarInterval that = (CalendarInterval) o; + return months == that.months && + days == that.days && + microseconds == that.microseconds; } @Override public int hashCode() { - return 31 * months + (int) microseconds; + return Objects.hash(months, days, microseconds); } @Override @@ -286,22 +91,19 @@ public String toString() { appendUnit(sb, months % 12, "month"); } + appendUnit(sb, days, "day"); + if (microseconds != 0) { long rest = microseconds; - appendUnit(sb, rest / MICROS_PER_WEEK, "week"); - rest %= MICROS_PER_WEEK; - appendUnit(sb, rest / MICROS_PER_DAY, "day"); - rest %= MICROS_PER_DAY; appendUnit(sb, rest / MICROS_PER_HOUR, "hour"); rest %= MICROS_PER_HOUR; appendUnit(sb, rest / MICROS_PER_MINUTE, "minute"); rest %= MICROS_PER_MINUTE; - appendUnit(sb, rest / MICROS_PER_SECOND, "second"); - rest %= MICROS_PER_SECOND; - appendUnit(sb, rest / MICROS_PER_MILLI, "millisecond"); - rest %= MICROS_PER_MILLI; - appendUnit(sb, rest, "microsecond"); - } else if (months == 0) { + if (rest != 0) { + String s = BigDecimal.valueOf(rest, 6).stripTrailingZeros().toPlainString(); + sb.append(' ').append(s).append(" seconds"); + } + } else if (months == 0 && days == 0) { sb.append(" 0 microseconds"); } @@ -313,4 +115,19 @@ private void appendUnit(StringBuilder sb, long value, String unit) { sb.append(' ').append(value).append(' ').append(unit).append('s'); } } + + /** + * Extracts the date part of the interval. + * @return an instance of {@code java.time.Period} based on the months and days fields + * of the given interval, not null. + */ + public Period extractAsPeriod() { return Period.of(0, months, days); } + + /** + * Extracts the time part of the interval. + * @return an instance of {@code java.time.Duration} based on the microseconds field + * of the given interval, not null. + * @throws ArithmeticException if a numeric overflow occurs + */ + public Duration extractAsDuration() { return Duration.of(microseconds, ChronoUnit.MICROS); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 3ad9ac7b4de9..19e4182b38a4 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -114,25 +114,25 @@ public void memoryDebugFillEnabledInTest() { Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED); MemoryBlock onheap = MemoryAllocator.HEAP.allocate(1); Assert.assertEquals( - Platform.getByte(onheap.getBaseObject(), onheap.getBaseOffset()), - MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE, + Platform.getByte(onheap.getBaseObject(), onheap.getBaseOffset())); MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024); Object onheap1BaseObject = onheap1.getBaseObject(); long onheap1BaseOffset = onheap1.getBaseOffset(); MemoryAllocator.HEAP.free(onheap1); Assert.assertEquals( - Platform.getByte(onheap1BaseObject, onheap1BaseOffset), - MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); + MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE, + Platform.getByte(onheap1BaseObject, onheap1BaseOffset)); MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024); Assert.assertEquals( - Platform.getByte(onheap2.getBaseObject(), onheap2.getBaseOffset()), - MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE, + Platform.getByte(onheap2.getBaseObject(), onheap2.getBaseOffset())); MemoryBlock offheap = MemoryAllocator.UNSAFE.allocate(1); Assert.assertEquals( - Platform.getByte(offheap.getBaseObject(), offheap.getBaseOffset()), - MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE, + Platform.getByte(offheap.getBaseObject(), offheap.getBaseOffset())); MemoryAllocator.UNSAFE.free(offheap); } @@ -150,11 +150,11 @@ public void heapMemoryReuse() { // The size is greater than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`, // reuse the previous memory which has released. MemoryBlock onheap3 = heapMem.allocate(1024 * 1024 + 1); - Assert.assertEquals(onheap3.size(), 1024 * 1024 + 1); + Assert.assertEquals(1024 * 1024 + 1, onheap3.size()); Object obj3 = onheap3.getBaseObject(); heapMem.free(onheap3); MemoryBlock onheap4 = heapMem.allocate(1024 * 1024 + 7); - Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); + Assert.assertEquals(1024 * 1024 + 7, onheap4.size()); Assert.assertEquals(obj3, onheap4.getBaseObject()); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java index 9f3262bf2aaa..7f607e65eaa0 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java @@ -19,6 +19,9 @@ import org.junit.Test; +import java.time.Duration; +import java.time.Period; + import static org.junit.Assert.*; import static org.apache.spark.unsafe.types.CalendarInterval.*; @@ -26,125 +29,79 @@ public class CalendarIntervalSuite { @Test public void equalsTest() { - CalendarInterval i1 = new CalendarInterval(3, 123); - CalendarInterval i2 = new CalendarInterval(3, 321); - CalendarInterval i3 = new CalendarInterval(1, 123); - CalendarInterval i4 = new CalendarInterval(3, 123); + CalendarInterval i1 = new CalendarInterval(3, 2, 123); + CalendarInterval i2 = new CalendarInterval(3, 2,321); + CalendarInterval i3 = new CalendarInterval(3, 4,123); + CalendarInterval i4 = new CalendarInterval(1, 2, 123); + CalendarInterval i5 = new CalendarInterval(1, 4, 321); + CalendarInterval i6 = new CalendarInterval(3, 2, 123); assertNotSame(i1, i2); assertNotSame(i1, i3); + assertNotSame(i1, i4); assertNotSame(i2, i3); - assertEquals(i1, i4); + assertNotSame(i2, i4); + assertNotSame(i3, i4); + assertNotSame(i1, i5); + assertEquals(i1, i6); } @Test public void toStringTest() { CalendarInterval i; - i = new CalendarInterval(0, 0); + i = new CalendarInterval(0, 0, 0); assertEquals("interval 0 microseconds", i.toString()); - i = new CalendarInterval(34, 0); + i = new CalendarInterval(34, 0, 0); assertEquals("interval 2 years 10 months", i.toString()); - i = new CalendarInterval(-34, 0); + i = new CalendarInterval(-34, 0, 0); assertEquals("interval -2 years -10 months", i.toString()); - i = new CalendarInterval(0, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); - assertEquals("interval 3 weeks 13 hours 123 microseconds", i.toString()); + i = new CalendarInterval(0, 31, 0); + assertEquals("interval 31 days", i.toString()); - i = new CalendarInterval(0, -3 * MICROS_PER_WEEK - 13 * MICROS_PER_HOUR - 123); - assertEquals("interval -3 weeks -13 hours -123 microseconds", i.toString()); + i = new CalendarInterval(0, -31, 0); + assertEquals("interval -31 days", i.toString()); - i = new CalendarInterval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); - assertEquals("interval 2 years 10 months 3 weeks 13 hours 123 microseconds", i.toString()); - } + i = new CalendarInterval(0, 0, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123); + assertEquals("interval 3 hours 13 minutes 0.000123 seconds", i.toString()); - @Test - public void fromYearMonthStringTest() { - String input; - CalendarInterval i; + i = new CalendarInterval(0, 0, -3 * MICROS_PER_HOUR - 13 * MICROS_PER_MINUTE - 123); + assertEquals("interval -3 hours -13 minutes -0.000123 seconds", i.toString()); - input = "99-10"; - i = new CalendarInterval(99 * 12 + 10, 0L); - assertEquals(fromYearMonthString(input), i); - - input = "-8-10"; - i = new CalendarInterval(-8 * 12 - 10, 0L); - assertEquals(fromYearMonthString(input), i); - - try { - input = "99-15"; - fromYearMonthString(input); - fail("Expected to throw an exception for the invalid input"); - } catch (IllegalArgumentException e) { - assertTrue(e.getMessage().contains("month 15 outside range")); - } + i = new CalendarInterval(34, 31, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123); + assertEquals("interval 2 years 10 months 31 days 3 hours 13 minutes 0.000123 seconds", + i.toString()); } @Test - public void fromDayTimeStringTest() { - String input; - CalendarInterval i; + public void addTest() { + CalendarInterval input1 = new CalendarInterval(3, 1, 1 * MICROS_PER_HOUR); + CalendarInterval input2 = new CalendarInterval(2, 4, 100 * MICROS_PER_HOUR); + assertEquals(new CalendarInterval(5, 5, 101 * MICROS_PER_HOUR), input1.add(input2)); - input = "5 12:40:30.999999999"; - i = new CalendarInterval(0, 5 * MICROS_PER_DAY + 12 * MICROS_PER_HOUR + - 40 * MICROS_PER_MINUTE + 30 * MICROS_PER_SECOND + 999999L); - assertEquals(fromDayTimeString(input), i); - - input = "10 0:12:0.888"; - i = new CalendarInterval(0, 10 * MICROS_PER_DAY + 12 * MICROS_PER_MINUTE + - 888 * MICROS_PER_MILLI); - assertEquals(fromDayTimeString(input), i); - - input = "-3 0:0:0"; - i = new CalendarInterval(0, -3 * MICROS_PER_DAY); - assertEquals(fromDayTimeString(input), i); - - try { - input = "5 30:12:20"; - fromDayTimeString(input); - fail("Expected to throw an exception for the invalid input"); - } catch (IllegalArgumentException e) { - assertTrue(e.getMessage().contains("hour 30 outside range")); - } - - try { - input = "5 30-12"; - fromDayTimeString(input); - fail("Expected to throw an exception for the invalid input"); - } catch (IllegalArgumentException e) { - assertTrue(e.getMessage().contains("not match day-time format")); - } - - try { - input = "5 1:12:20"; - fromDayTimeString(input, "hour", "microsecond"); - fail("Expected to throw an exception for the invalid convention type"); - } catch (IllegalArgumentException e) { - assertTrue(e.getMessage().contains("Cannot support (interval")); - } + input1 = new CalendarInterval(-10, -30, -81 * MICROS_PER_HOUR); + input2 = new CalendarInterval(75, 150, 200 * MICROS_PER_HOUR); + assertEquals(new CalendarInterval(65, 120, 119 * MICROS_PER_HOUR), input1.add(input2)); } @Test - public void addTest() { - CalendarInterval input1 = new CalendarInterval(3, 1 * MICROS_PER_HOUR); - CalendarInterval input2 = new CalendarInterval(2, 100 * MICROS_PER_HOUR); - assertEquals(input1.add(input2), new CalendarInterval(5, 101 * MICROS_PER_HOUR)); + public void subtractTest() { + CalendarInterval input1 = new CalendarInterval(3, 1, 1 * MICROS_PER_HOUR); + CalendarInterval input2 = new CalendarInterval(2, 4, 100 * MICROS_PER_HOUR); + assertEquals(new CalendarInterval(1, -3, -99 * MICROS_PER_HOUR), input1.subtract(input2)); - input1 = new CalendarInterval(-10, -81 * MICROS_PER_HOUR); - input2 = new CalendarInterval(75, 200 * MICROS_PER_HOUR); - assertEquals(input1.add(input2), new CalendarInterval(65, 119 * MICROS_PER_HOUR)); + input1 = new CalendarInterval(-10, -30, -81 * MICROS_PER_HOUR); + input2 = new CalendarInterval(75, 150, 200 * MICROS_PER_HOUR); + assertEquals(new CalendarInterval(-85, -180, -281 * MICROS_PER_HOUR), input1.subtract(input2)); } @Test - public void subtractTest() { - CalendarInterval input1 = new CalendarInterval(3, 1 * MICROS_PER_HOUR); - CalendarInterval input2 = new CalendarInterval(2, 100 * MICROS_PER_HOUR); - assertEquals(input1.subtract(input2), new CalendarInterval(1, -99 * MICROS_PER_HOUR)); - - input1 = new CalendarInterval(-10, -81 * MICROS_PER_HOUR); - input2 = new CalendarInterval(75, 200 * MICROS_PER_HOUR); - assertEquals(input1.subtract(input2), new CalendarInterval(-85, -281 * MICROS_PER_HOUR)); + public void periodAndDurationTest() { + CalendarInterval interval = new CalendarInterval(120, -40, 123456); + assertEquals(Period.of(0, 120, -40), interval.extractAsPeriod()); + assertEquals(Duration.ofNanos(123456000), interval.extractAsDuration()); } } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index cd253c0cbc90..dbede9bc7f12 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -38,11 +38,11 @@ public class UTF8StringSuite { private static void checkBasic(String str, int len) { UTF8String s1 = fromString(str); UTF8String s2 = fromBytes(str.getBytes(StandardCharsets.UTF_8)); - assertEquals(s1.numChars(), len); - assertEquals(s2.numChars(), len); + assertEquals(len, s1.numChars()); + assertEquals(len, s2.numChars()); - assertEquals(s1.toString(), str); - assertEquals(s2.toString(), str); + assertEquals(str, s1.toString()); + assertEquals(str, s2.toString()); assertEquals(s1, s2); assertEquals(s1.hashCode(), s2.hashCode()); @@ -375,20 +375,20 @@ public void pad() { @Test public void substringSQL() { UTF8String e = fromString("example"); - assertEquals(e.substringSQL(0, 2), fromString("ex")); - assertEquals(e.substringSQL(1, 2), fromString("ex")); - assertEquals(e.substringSQL(0, 7), fromString("example")); - assertEquals(e.substringSQL(1, 2), fromString("ex")); - assertEquals(e.substringSQL(0, 100), fromString("example")); - assertEquals(e.substringSQL(1, 100), fromString("example")); - assertEquals(e.substringSQL(2, 2), fromString("xa")); - assertEquals(e.substringSQL(1, 6), fromString("exampl")); - assertEquals(e.substringSQL(2, 100), fromString("xample")); - assertEquals(e.substringSQL(0, 0), fromString("")); - assertEquals(e.substringSQL(100, 4), EMPTY_UTF8); - assertEquals(e.substringSQL(0, Integer.MAX_VALUE), fromString("example")); - assertEquals(e.substringSQL(1, Integer.MAX_VALUE), fromString("example")); - assertEquals(e.substringSQL(2, Integer.MAX_VALUE), fromString("xample")); + assertEquals(fromString("ex"), e.substringSQL(0, 2)); + assertEquals(fromString("ex"), e.substringSQL(1, 2)); + assertEquals(fromString("example"), e.substringSQL(0, 7)); + assertEquals(fromString("ex"), e.substringSQL(1, 2)); + assertEquals(fromString("example"), e.substringSQL(0, 100)); + assertEquals(fromString("example"), e.substringSQL(1, 100)); + assertEquals(fromString("xa"), e.substringSQL(2, 2)); + assertEquals(fromString("exampl"), e.substringSQL(1, 6)); + assertEquals(fromString("xample"), e.substringSQL(2, 100)); + assertEquals(fromString(""), e.substringSQL(0, 0)); + assertEquals(EMPTY_UTF8, e.substringSQL(100, 4)); + assertEquals(fromString("example"), e.substringSQL(0, Integer.MAX_VALUE)); + assertEquals(fromString("example"), e.substringSQL(1, Integer.MAX_VALUE)); + assertEquals(fromString("xample"), e.substringSQL(2, Integer.MAX_VALUE)); } @Test @@ -506,50 +506,50 @@ public void findInSet() { @Test public void soundex() { - assertEquals(fromString("Robert").soundex(), fromString("R163")); - assertEquals(fromString("Rupert").soundex(), fromString("R163")); - assertEquals(fromString("Rubin").soundex(), fromString("R150")); - assertEquals(fromString("Ashcraft").soundex(), fromString("A261")); - assertEquals(fromString("Ashcroft").soundex(), fromString("A261")); - assertEquals(fromString("Burroughs").soundex(), fromString("B620")); - assertEquals(fromString("Burrows").soundex(), fromString("B620")); - assertEquals(fromString("Ekzampul").soundex(), fromString("E251")); - assertEquals(fromString("Example").soundex(), fromString("E251")); - assertEquals(fromString("Ellery").soundex(), fromString("E460")); - assertEquals(fromString("Euler").soundex(), fromString("E460")); - assertEquals(fromString("Ghosh").soundex(), fromString("G200")); - assertEquals(fromString("Gauss").soundex(), fromString("G200")); - assertEquals(fromString("Gutierrez").soundex(), fromString("G362")); - assertEquals(fromString("Heilbronn").soundex(), fromString("H416")); - assertEquals(fromString("Hilbert").soundex(), fromString("H416")); - assertEquals(fromString("Jackson").soundex(), fromString("J250")); - assertEquals(fromString("Kant").soundex(), fromString("K530")); - assertEquals(fromString("Knuth").soundex(), fromString("K530")); - assertEquals(fromString("Lee").soundex(), fromString("L000")); - assertEquals(fromString("Lukasiewicz").soundex(), fromString("L222")); - assertEquals(fromString("Lissajous").soundex(), fromString("L222")); - assertEquals(fromString("Ladd").soundex(), fromString("L300")); - assertEquals(fromString("Lloyd").soundex(), fromString("L300")); - assertEquals(fromString("Moses").soundex(), fromString("M220")); - assertEquals(fromString("O'Hara").soundex(), fromString("O600")); - assertEquals(fromString("Pfister").soundex(), fromString("P236")); - assertEquals(fromString("Rubin").soundex(), fromString("R150")); - assertEquals(fromString("Robert").soundex(), fromString("R163")); - assertEquals(fromString("Rupert").soundex(), fromString("R163")); - assertEquals(fromString("Soundex").soundex(), fromString("S532")); - assertEquals(fromString("Sownteks").soundex(), fromString("S532")); - assertEquals(fromString("Tymczak").soundex(), fromString("T522")); - assertEquals(fromString("VanDeusen").soundex(), fromString("V532")); - assertEquals(fromString("Washington").soundex(), fromString("W252")); - assertEquals(fromString("Wheaton").soundex(), fromString("W350")); - - assertEquals(fromString("a").soundex(), fromString("A000")); - assertEquals(fromString("ab").soundex(), fromString("A100")); - assertEquals(fromString("abc").soundex(), fromString("A120")); - assertEquals(fromString("abcd").soundex(), fromString("A123")); - assertEquals(fromString("").soundex(), fromString("")); - assertEquals(fromString("123").soundex(), fromString("123")); - assertEquals(fromString("世界千世").soundex(), fromString("世界千世")); + assertEquals(fromString("R163"), fromString("Robert").soundex()); + assertEquals(fromString("R163"), fromString("Rupert").soundex()); + assertEquals(fromString("R150"), fromString("Rubin").soundex()); + assertEquals(fromString("A261"), fromString("Ashcraft").soundex()); + assertEquals(fromString("A261"), fromString("Ashcroft").soundex()); + assertEquals(fromString("B620"), fromString("Burroughs").soundex()); + assertEquals(fromString("B620"), fromString("Burrows").soundex()); + assertEquals(fromString("E251"), fromString("Ekzampul").soundex()); + assertEquals(fromString("E251"), fromString("Example").soundex()); + assertEquals(fromString("E460"), fromString("Ellery").soundex()); + assertEquals(fromString("E460"), fromString("Euler").soundex()); + assertEquals(fromString("G200"), fromString("Ghosh").soundex()); + assertEquals(fromString("G200"), fromString("Gauss").soundex()); + assertEquals(fromString("G362"), fromString("Gutierrez").soundex()); + assertEquals(fromString("H416"), fromString("Heilbronn").soundex()); + assertEquals(fromString("H416"), fromString("Hilbert").soundex()); + assertEquals(fromString("J250"), fromString("Jackson").soundex()); + assertEquals(fromString("K530"), fromString("Kant").soundex()); + assertEquals(fromString("K530"), fromString("Knuth").soundex()); + assertEquals(fromString("L000"), fromString("Lee").soundex()); + assertEquals(fromString("L222"), fromString("Lukasiewicz").soundex()); + assertEquals(fromString("L222"), fromString("Lissajous").soundex()); + assertEquals(fromString("L300"), fromString("Ladd").soundex()); + assertEquals(fromString("L300"), fromString("Lloyd").soundex()); + assertEquals(fromString("M220"), fromString("Moses").soundex()); + assertEquals(fromString("O600"), fromString("O'Hara").soundex()); + assertEquals(fromString("P236"), fromString("Pfister").soundex()); + assertEquals(fromString("R150"), fromString("Rubin").soundex()); + assertEquals(fromString("R163"), fromString("Robert").soundex()); + assertEquals(fromString("R163"), fromString("Rupert").soundex()); + assertEquals(fromString("S532"), fromString("Soundex").soundex()); + assertEquals(fromString("S532"), fromString("Sownteks").soundex()); + assertEquals(fromString("T522"), fromString("Tymczak").soundex()); + assertEquals(fromString("V532"), fromString("VanDeusen").soundex()); + assertEquals(fromString("W252"), fromString("Washington").soundex()); + assertEquals(fromString("W350"), fromString("Wheaton").soundex()); + + assertEquals(fromString("A000"), fromString("a").soundex()); + assertEquals(fromString("A100"), fromString("ab").soundex()); + assertEquals(fromString("A120"), fromString("abc").soundex()); + assertEquals(fromString("A123"), fromString("abcd").soundex()); + assertEquals(fromString(""), fromString("").soundex()); + assertEquals(fromString("123"), fromString("123").soundex()); + assertEquals(fromString("世界千世"), fromString("世界千世").soundex()); } @Test @@ -849,7 +849,7 @@ public void skipWrongFirstByte() { for (int i = 0; i < wrongFirstBytes.length; ++i) { c[0] = (byte)wrongFirstBytes[i]; - assertEquals(fromBytes(c).numChars(), 1); + assertEquals(1, fromBytes(c).numChars()); } } } diff --git a/core/pom.xml b/core/pom.xml index 38eb8adac500..3eedc69c9593 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -292,6 +292,16 @@ io.dropwizard.metrics metrics-graphite + + + com.rabbitmq + amqp-client + + + + + io.dropwizard.metrics + metrics-jmx com.fasterxml.jackson.core diff --git a/core/src/main/java/org/apache/spark/api/plugin/DriverPlugin.java b/core/src/main/java/org/apache/spark/api/plugin/DriverPlugin.java new file mode 100644 index 000000000000..0c0d0df8ae68 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/plugin/DriverPlugin.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.plugin; + +import java.util.Collections; +import java.util.Map; + +import org.apache.spark.SparkContext; +import org.apache.spark.annotation.DeveloperApi; + +/** + * :: DeveloperApi :: + * Driver component of a {@link SparkPlugin}. + * + * @since 3.0.0 + */ +@DeveloperApi +public interface DriverPlugin { + + /** + * Initialize the plugin. + *

+ * This method is called early in the initialization of the Spark driver. Explicitly, it is + * called before the Spark driver's task scheduler is initialized. This means that a lot + * of other Spark subsystems may yet not have been initialized. This call also blocks driver + * initialization. + *

+ * It's recommended that plugins be careful about what operations are performed in this call, + * preferrably performing expensive operations in a separate thread, or postponing them until + * the application has fully started. + * + * @param sc The SparkContext loading the plugin. + * @param pluginContext Additional plugin-specific about the Spark application where the plugin + * is running. + * @return A map that will be provided to the {@link ExecutorPlugin#init(PluginContext,Map)} + * method. + */ + default Map init(SparkContext sc, PluginContext pluginContext) { + return Collections.emptyMap(); + } + + /** + * Register metrics published by the plugin with Spark's metrics system. + *

+ * This method is called later in the initialization of the Spark application, after most + * subsystems are up and the application ID is known. If there are metrics registered in + * the registry ({@link PluginContext#metricRegistry()}), then a metrics source with the + * plugin name will be created. + *

+ * Note that even though the metric registry is still accessible after this method is called, + * registering new metrics after this method is called may result in the metrics not being + * available. + * + * @param appId The application ID from the cluster manager. + * @param pluginContext Additional plugin-specific about the Spark application where the plugin + * is running. + */ + default void registerMetrics(String appId, PluginContext pluginContext) {} + + /** + * RPC message handler. + *

+ * Plugins can use Spark's RPC system to send messages from executors to the driver (but not + * the other way around, currently). Messages sent by the executor component of the plugin will + * be delivered to this method, and the returned value will be sent back to the executor as + * the reply, if the executor has requested one. + *

+ * Any exception thrown will be sent back to the executor as an error, in case it is expecting + * a reply. In case a reply is not expected, a log message will be written to the driver log. + *

+ * The implementation of this handler should be thread-safe. + *

+ * Note all plugins share RPC dispatch threads, and this method is called synchronously. So + * performing expensive operations in this handler may affect the operation of other active + * plugins. Internal Spark endpoints are not directly affected, though, since they use different + * threads. + *

+ * Spark guarantees that the driver component will be ready to receive messages through this + * handler when executors are started. + * + * @param message The incoming message. + * @return Value to be returned to the caller. Ignored if the caller does not expect a reply. + */ + default Object receive(Object message) throws Exception { + throw new UnsupportedOperationException(); + } + + /** + * Informs the plugin that the Spark application is shutting down. + *

+ * This method is called during the driver shutdown phase. It is recommended that plugins + * not use any Spark functions (e.g. send RPC messages) during this call. + */ + default void shutdown() {} + +} diff --git a/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java b/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java new file mode 100644 index 000000000000..496130803516 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/plugin/ExecutorPlugin.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.plugin; + +import java.util.Map; + +import org.apache.spark.annotation.DeveloperApi; + +/** + * :: DeveloperApi :: + * Executor component of a {@link SparkPlugin}. + * + * @since 3.0.0 + */ +@DeveloperApi +public interface ExecutorPlugin { + + /** + * Initialize the executor plugin. + *

+ * When a Spark plugin provides an executor plugin, this method will be called during the + * initialization of the executor process. It will block executor initialization until it + * returns. + *

+ * Executor plugins that publish metrics should register all metrics with the context's + * registry ({@link PluginContext#metricRegistry()}) when this method is called. Metrics + * registered afterwards are not guaranteed to show up. + * + * @param ctx Context information for the executor where the plugin is running. + * @param extraConf Extra configuration provided by the driver component during its + * initialization. + */ + default void init(PluginContext ctx, Map extraConf) {} + + /** + * Clean up and terminate this plugin. + *

+ * This method is called during the executor shutdown phase, and blocks executor shutdown. + */ + default void shutdown() {} + +} diff --git a/core/src/main/java/org/apache/spark/api/plugin/PluginContext.java b/core/src/main/java/org/apache/spark/api/plugin/PluginContext.java new file mode 100644 index 000000000000..b9413cf828aa --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/plugin/PluginContext.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.plugin; + +import java.io.IOException; + +import com.codahale.metrics.MetricRegistry; + +import org.apache.spark.SparkConf; +import org.apache.spark.annotation.DeveloperApi; + +/** + * :: DeveloperApi :: + * Context information and operations for plugins loaded by Spark. + *

+ * An instance of this class is provided to plugins in their initialization method. It is safe + * for plugins to keep a reference to the instance for later use (for example, to send messages + * to the plugin's driver component). + *

+ * Context instances are plugin-specific, so metrics and messages are tied each plugin. It is + * not possible for a plugin to directly interact with other plugins. + * + * @since 3.0.0 + */ +@DeveloperApi +public interface PluginContext { + + /** + * Registry where to register metrics published by the plugin associated with this context. + */ + MetricRegistry metricRegistry(); + + /** Configuration of the Spark application. */ + SparkConf conf(); + + /** Executor ID of the process. On the driver, this will identify the driver. */ + String executorID(); + + /** The host name which is being used by the Spark process for communication. */ + String hostname(); + + /** + * Send a message to the plugin's driver-side component. + *

+ * This method sends a message to the driver-side component of the plugin, without expecting + * a reply. It returns as soon as the message is enqueued for sending. + *

+ * The message must be serializable. + * + * @param message Message to be sent. + */ + void send(Object message) throws IOException; + + /** + * Send an RPC to the plugin's driver-side component. + *

+ * This method sends a message to the driver-side component of the plugin, and blocks until a + * reply arrives, or the configured RPC ask timeout (spark.rpc.askTimeout) elapses. + *

+ * If the driver replies with an error, an exception with the corresponding error will be thrown. + *

+ * The message must be serializable. + * + * @param message Message to be sent. + * @return The reply from the driver-side component. + */ + Object ask(Object message) throws Exception; + +} diff --git a/core/src/main/java/org/apache/spark/api/plugin/SparkPlugin.java b/core/src/main/java/org/apache/spark/api/plugin/SparkPlugin.java new file mode 100644 index 000000000000..a500f5d2188f --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/plugin/SparkPlugin.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.plugin; + +import org.apache.spark.annotation.DeveloperApi; + +/** + * :: DeveloperApi :: + * A plugin that can be dynamically loaded into a Spark application. + *

+ * Plugins can be loaded by adding the plugin's class name to the appropriate Spark configuration. + * Check the Spark configuration documentation for details. + *

+ * Plugins have two optional components: a driver-side component, of which a single instance is + * created per application, inside the Spark driver. And an executor-side component, of which one + * instance is created in each executor that is started by Spark. Details of each component can be + * found in the documentation for {@link DriverPlugin} and {@link ExecutorPlugin}. + * + * @since 3.0.0 + */ +@DeveloperApi +public interface SparkPlugin { + + /** + * Return the plugin's driver-side component. + * + * @return The driver-side component, or null if one is not needed. + */ + DriverPlugin driverPlugin(); + + /** + * Return the plugin's executor-side component. + * + * @return The executor-side component, or null if one is not needed. + */ + ExecutorPlugin executorPlugin(); + +} diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.js b/core/src/main/resources/org/apache/spark/ui/static/webui.js index cf04db28804c..fac464e1353c 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.js +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.js @@ -87,6 +87,9 @@ $(function() { collapseTablePageLoad('collapse-aggregated-runningExecutions','aggregated-runningExecutions'); collapseTablePageLoad('collapse-aggregated-completedExecutions','aggregated-completedExecutions'); collapseTablePageLoad('collapse-aggregated-failedExecutions','aggregated-failedExecutions'); + collapseTablePageLoad('collapse-aggregated-sessionstat','aggregated-sessionstat'); + collapseTablePageLoad('collapse-aggregated-sqlstat','aggregated-sqlstat'); + collapseTablePageLoad('collapse-aggregated-sqlsessionstat','aggregated-sqlsessionstat'); }); $(function() { diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 2db880976c3a..cad88ad8aec6 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -48,6 +48,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Tests._ import org.apache.spark.internal.config.UI._ +import org.apache.spark.internal.plugin.PluginContainer import org.apache.spark.io.CompressionCodec import org.apache.spark.metrics.source.JVMCPUSource import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} @@ -220,6 +221,7 @@ class SparkContext(config: SparkConf) extends Logging { private var _heartbeater: Heartbeater = _ private var _resources: scala.collection.immutable.Map[String, ResourceInformation] = _ private var _shuffleDriverComponents: ShuffleDriverComponents = _ + private var _plugins: Option[PluginContainer] = None /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -539,6 +541,9 @@ class SparkContext(config: SparkConf) extends Logging { _heartbeatReceiver = env.rpcEnv.setupEndpoint( HeartbeatReceiver.ENDPOINT_NAME, new HeartbeatReceiver(this)) + // Initialize any plugins before the task scheduler is initialized. + _plugins = PluginContainer(this) + // Create and start the scheduler val (sched, ts) = SparkContext.createTaskScheduler(this, master, deployMode) _schedulerBackend = sched @@ -621,6 +626,7 @@ class SparkContext(config: SparkConf) extends Logging { _env.metricsSystem.registerSource(e.executorAllocationManagerSource) } appStatusSource.foreach(_env.metricsSystem.registerSource(_)) + _plugins.foreach(_.registerMetrics(applicationId)) // Make sure the context is stopped if the user forgets about it. This avoids leaving // unfinished event logs around after the JVM exits cleanly. It doesn't help if the JVM // is killed, though. @@ -1976,6 +1982,9 @@ class SparkContext(config: SparkConf) extends Logging { _listenerBusStarted = false } } + Utils.tryLogNonFatalError { + _plugins.foreach(_.shutdown()) + } Utils.tryLogNonFatalError { _eventLogger.foreach(_.stop()) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index ce6d0322bafd..0f595d095a22 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -37,6 +37,7 @@ import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.internal.plugin.PluginContainer import org.apache.spark.memory.{SparkOutOfMemoryError, TaskMemoryManager} import org.apache.spark.metrics.source.JVMCPUSource import org.apache.spark.rpc.RpcTimeout @@ -165,6 +166,11 @@ private[spark] class Executor( } } + // Plugins need to load using a class loader that includes the executor's user classpath + private val plugins: Option[PluginContainer] = Utils.withContextClassLoader(replClassLoader) { + PluginContainer(env) + } + // Max size of direct result. If task result is bigger than this, we use the block manager // to send the result back. private val maxDirectResultSize = Math.min( @@ -297,6 +303,7 @@ private[spark] class Executor( logWarning("Plugin " + plugin.getClass().getCanonicalName() + " shutdown failed", e) } } + plugins.foreach(_.shutdown()) } if (!isLocal) { env.stop() diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 444a1544777a..295fe28e8b9a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1159,6 +1159,17 @@ package object config { s"The value must be in allowed range [1,048,576, ${MAX_BUFFER_SIZE_BYTES}].") .createWithDefault(1024 * 1024) + private[spark] val DEFAULT_PLUGINS_LIST = "spark.plugins.defaultList" + + private[spark] val PLUGINS = + ConfigBuilder("spark.plugins") + .withPrepended(DEFAULT_PLUGINS_LIST, separator = ",") + .doc("Comma-separated list of class names implementing " + + "org.apache.spark.api.plugin.SparkPlugin to load into the application.") + .stringConf + .toSequence + .createWithDefault(Nil) + private[spark] val EXECUTOR_PLUGINS = ConfigBuilder("spark.executor.plugins") .doc("Comma-separated list of class names for \"plugins\" implementing " + diff --git a/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala new file mode 100644 index 000000000000..fc7a9d85957c --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContainer.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.plugin + +import scala.collection.JavaConverters._ +import scala.util.{Either, Left, Right} + +import org.apache.spark.{SparkContext, SparkEnv} +import org.apache.spark.api.plugin._ +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils + +sealed abstract class PluginContainer { + + def shutdown(): Unit + def registerMetrics(appId: String): Unit + +} + +private class DriverPluginContainer(sc: SparkContext, plugins: Seq[SparkPlugin]) + extends PluginContainer with Logging { + + private val driverPlugins: Seq[(String, DriverPlugin, PluginContextImpl)] = plugins.flatMap { p => + val driverPlugin = p.driverPlugin() + if (driverPlugin != null) { + val name = p.getClass().getName() + val ctx = new PluginContextImpl(name, sc.env.rpcEnv, sc.env.metricsSystem, sc.conf, + sc.env.executorId) + + val extraConf = driverPlugin.init(sc, ctx) + if (extraConf != null) { + extraConf.asScala.foreach { case (k, v) => + sc.conf.set(s"${PluginContainer.EXTRA_CONF_PREFIX}$name.$k", v) + } + } + logInfo(s"Initialized driver component for plugin $name.") + Some((p.getClass().getName(), driverPlugin, ctx)) + } else { + None + } + } + + if (driverPlugins.nonEmpty) { + val pluginsByName = driverPlugins.map { case (name, plugin, _) => (name, plugin) }.toMap + sc.env.rpcEnv.setupEndpoint(classOf[PluginEndpoint].getName(), + new PluginEndpoint(pluginsByName, sc.env.rpcEnv)) + } + + override def registerMetrics(appId: String): Unit = { + driverPlugins.foreach { case (_, plugin, ctx) => + plugin.registerMetrics(appId, ctx) + ctx.registerMetrics() + } + } + + override def shutdown(): Unit = { + driverPlugins.foreach { case (name, plugin, _) => + try { + logDebug(s"Stopping plugin $name.") + plugin.shutdown() + } catch { + case t: Throwable => + logInfo(s"Exception while shutting down plugin $name.", t) + } + } + } + +} + +private class ExecutorPluginContainer(env: SparkEnv, plugins: Seq[SparkPlugin]) + extends PluginContainer with Logging { + + private val executorPlugins: Seq[(String, ExecutorPlugin)] = { + val allExtraConf = env.conf.getAllWithPrefix(PluginContainer.EXTRA_CONF_PREFIX) + + plugins.flatMap { p => + val executorPlugin = p.executorPlugin() + if (executorPlugin != null) { + val name = p.getClass().getName() + val prefix = name + "." + val extraConf = allExtraConf + .filter { case (k, v) => k.startsWith(prefix) } + .map { case (k, v) => k.substring(prefix.length()) -> v } + .toMap + .asJava + val ctx = new PluginContextImpl(name, env.rpcEnv, env.metricsSystem, env.conf, + env.executorId) + executorPlugin.init(ctx, extraConf) + ctx.registerMetrics() + + logInfo(s"Initialized executor component for plugin $name.") + Some(p.getClass().getName() -> executorPlugin) + } else { + None + } + } + } + + override def registerMetrics(appId: String): Unit = { + throw new IllegalStateException("Should not be called for the executor container.") + } + + override def shutdown(): Unit = { + executorPlugins.foreach { case (name, plugin) => + try { + logDebug(s"Stopping plugin $name.") + plugin.shutdown() + } catch { + case t: Throwable => + logInfo(s"Exception while shutting down plugin $name.", t) + } + } + } +} + +object PluginContainer { + + val EXTRA_CONF_PREFIX = "spark.plugins.internal.conf." + + def apply(sc: SparkContext): Option[PluginContainer] = PluginContainer(Left(sc)) + + def apply(env: SparkEnv): Option[PluginContainer] = PluginContainer(Right(env)) + + private def apply(ctx: Either[SparkContext, SparkEnv]): Option[PluginContainer] = { + val conf = ctx.fold(_.conf, _.conf) + val plugins = Utils.loadExtensions(classOf[SparkPlugin], conf.get(PLUGINS).distinct, conf) + if (plugins.nonEmpty) { + ctx match { + case Left(sc) => Some(new DriverPluginContainer(sc, plugins)) + case Right(env) => Some(new ExecutorPluginContainer(env, plugins)) + } + } else { + None + } + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/plugin/PluginContextImpl.scala b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContextImpl.scala new file mode 100644 index 000000000000..279f3d388fb2 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/plugin/PluginContextImpl.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.plugin + +import com.codahale.metrics.MetricRegistry + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.api.plugin.PluginContext +import org.apache.spark.internal.Logging +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.RpcUtils + +private class PluginContextImpl( + pluginName: String, + rpcEnv: RpcEnv, + metricsSystem: MetricsSystem, + override val conf: SparkConf, + override val executorID: String) + extends PluginContext with Logging { + + override def hostname(): String = rpcEnv.address.hostPort.split(":")(0) + + private val registry = new MetricRegistry() + + private lazy val driverEndpoint = try { + RpcUtils.makeDriverRef(classOf[PluginEndpoint].getName(), conf, rpcEnv) + } catch { + case e: Exception => + logWarning(s"Failed to create driver plugin endpoint ref.", e) + null + } + + override def metricRegistry(): MetricRegistry = registry + + override def send(message: AnyRef): Unit = { + if (driverEndpoint == null) { + throw new IllegalStateException("Driver endpoint is not known.") + } + driverEndpoint.send(PluginMessage(pluginName, message)) + } + + override def ask(message: AnyRef): AnyRef = { + try { + if (driverEndpoint != null) { + driverEndpoint.askSync[AnyRef](PluginMessage(pluginName, message)) + } else { + throw new IllegalStateException("Driver endpoint is not known.") + } + } catch { + case e: SparkException if e.getCause() != null => + throw e.getCause() + } + } + + def registerMetrics(): Unit = { + if (!registry.getMetrics().isEmpty()) { + val src = new PluginMetricsSource(s"plugin.$pluginName", registry) + metricsSystem.registerSource(src) + } + } + + class PluginMetricsSource( + override val sourceName: String, + override val metricRegistry: MetricRegistry) + extends Source + +} diff --git a/core/src/main/scala/org/apache/spark/internal/plugin/PluginEndpoint.scala b/core/src/main/scala/org/apache/spark/internal/plugin/PluginEndpoint.scala new file mode 100644 index 000000000000..9a59b6bf678f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/plugin/PluginEndpoint.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.plugin + +import org.apache.spark.api.plugin.DriverPlugin +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEnv} + +case class PluginMessage(pluginName: String, message: AnyRef) + +private class PluginEndpoint( + plugins: Map[String, DriverPlugin], + override val rpcEnv: RpcEnv) + extends IsolatedRpcEndpoint with Logging { + + override def receive: PartialFunction[Any, Unit] = { + case PluginMessage(pluginName, message) => + plugins.get(pluginName) match { + case Some(plugin) => + try { + val reply = plugin.receive(message) + if (reply != null) { + logInfo( + s"Plugin $pluginName returned reply for one-way message of type " + + s"${message.getClass().getName()}.") + } + } catch { + case e: Exception => + logWarning(s"Error in plugin $pluginName when handling message of type " + + s"${message.getClass().getName()}.", e) + } + + case None => + throw new IllegalArgumentException(s"Received message for unknown plugin $pluginName.") + } + } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case PluginMessage(pluginName, message) => + plugins.get(pluginName) match { + case Some(plugin) => + context.reply(plugin.receive(message)) + + case None => + throw new IllegalArgumentException(s"Received message for unknown plugin $pluginName.") + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index 9e94a868ccc3..a7b7b5573cfe 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -19,7 +19,8 @@ package org.apache.spark.metrics.sink import java.util.Properties -import com.codahale.metrics.{JmxReporter, MetricRegistry} +import com.codahale.metrics.MetricRegistry +import com.codahale.metrics.jmx.JmxReporter import org.apache.spark.SecurityManager diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala index 7da0a9d2285b..a5850fc2ac4b 100644 --- a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -355,6 +355,8 @@ private[spark] class AppStatusListener( val lastStageInfo = event.stageInfos.sortBy(_.stageId).lastOption val jobName = lastStageInfo.map(_.name).getOrElse("") + val description = Option(event.properties) + .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) } val jobGroup = Option(event.properties) .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) } val sqlExecutionId = Option(event.properties) @@ -363,6 +365,7 @@ private[spark] class AppStatusListener( val job = new LiveJob( event.jobId, jobName, + description, if (event.time > 0) Some(new Date(event.time)) else None, event.stageIds, jobGroup, diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala index 00c991b49920..a0ef8da0a4b6 100644 --- a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -62,6 +62,7 @@ private[spark] abstract class LiveEntity { private class LiveJob( val jobId: Int, name: String, + description: Option[String], val submissionTime: Option[Date], val stageIds: Seq[Int], jobGroup: Option[String], @@ -92,7 +93,7 @@ private class LiveJob( val info = new v1.JobData( jobId, name, - None, // description is always None? + description, submissionTime, completionTime, stageIds, diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index f2113947f6bf..ee43b76e1701 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -20,6 +20,8 @@ package org.apache.spark.storage import java.io.{File, IOException} import java.util.UUID +import scala.util.control.NonFatal + import org.apache.spark.SparkConf import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.internal.{config, Logging} @@ -117,20 +119,38 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea /** Produces a unique block id and File suitable for storing local intermediate results. */ def createTempLocalBlock(): (TempLocalBlockId, File) = { - var blockId = new TempLocalBlockId(UUID.randomUUID()) - while (getFile(blockId).exists()) { - blockId = new TempLocalBlockId(UUID.randomUUID()) + var blockId = TempLocalBlockId(UUID.randomUUID()) + var tempLocalFile = getFile(blockId) + var count = 0 + while (!canCreateFile(tempLocalFile) && count < Utils.MAX_DIR_CREATION_ATTEMPTS) { + blockId = TempLocalBlockId(UUID.randomUUID()) + tempLocalFile = getFile(blockId) + count += 1 } - (blockId, getFile(blockId)) + (blockId, tempLocalFile) } /** Produces a unique block id and File suitable for storing shuffled intermediate results. */ def createTempShuffleBlock(): (TempShuffleBlockId, File) = { - var blockId = new TempShuffleBlockId(UUID.randomUUID()) - while (getFile(blockId).exists()) { - blockId = new TempShuffleBlockId(UUID.randomUUID()) + var blockId = TempShuffleBlockId(UUID.randomUUID()) + var tempShuffleFile = getFile(blockId) + var count = 0 + while (!canCreateFile(tempShuffleFile) && count < Utils.MAX_DIR_CREATION_ATTEMPTS) { + blockId = TempShuffleBlockId(UUID.randomUUID()) + tempShuffleFile = getFile(blockId) + count += 1 + } + (blockId, tempShuffleFile) + } + + private def canCreateFile(file: File): Boolean = { + try { + file.createNewFile() + } catch { + case NonFatal(_) => + logError("Failed to create temporary block file: " + file.getAbsoluteFile) + false } - (blockId, getFile(blockId)) } /** diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala index 2488197814ff..fb43af357f7b 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala @@ -25,6 +25,7 @@ import scala.xml.Node import org.apache.spark.status.{AppStatusStore, StreamBlockData} import org.apache.spark.status.api.v1 import org.apache.spark.ui._ +import org.apache.spark.ui.storage.ToolTips._ import org.apache.spark.util.Utils /** Page showing list of RDD's currently stored in the cluster */ @@ -56,7 +57,8 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends rddHeader, rddRow(request, _: v1.RDDStorageInfo), rdds, - id = Some("storage-by-rdd-table"))} + id = Some("storage-by-rdd-table"), + tooltipHeaders = tooltips)} } @@ -72,6 +74,16 @@ private[ui] class StoragePage(parent: SparkUITab, store: AppStatusStore) extends "Size in Memory", "Size on Disk") + /** Tooltips for header fields of the RDD table */ + val tooltips = Seq( + None, + Some(RDD_NAME), + Some(STORAGE_LEVEL), + Some(CACHED_PARTITIONS), + Some(FRACTION_CACHED), + Some(SIZE_IN_MEMORY), + Some(SIZE_ON_DISK)) + /** Render an HTML row representing an RDD */ private def rddRow(request: HttpServletRequest, rdd: v1.RDDStorageInfo): Seq[Node] = { // scalastyle:off diff --git a/core/src/main/scala/org/apache/spark/ui/storage/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/storage/ToolTips.scala new file mode 100644 index 000000000000..4677eba63c83 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/storage/ToolTips.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui.storage + +private[ui] object ToolTips { + + val RDD_NAME = + "Name of the persisted RDD" + + val STORAGE_LEVEL = + "StorageLevel displays where the persisted RDD is stored, " + + "format of the persisted RDD (serialized or de-serialized) and" + + "replication factor of the persisted RDD" + + val CACHED_PARTITIONS = + "Number of partitions cached" + + val FRACTION_CACHED = + "Fraction of total partitions cached" + + val SIZE_IN_MEMORY = + "Total size of partitions in memory" + + val SIZE_ON_DISK = + "Total size of partitions on the disk" +} + diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index f853ec836836..723fbdf73f8d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -95,7 +95,7 @@ private[spark] object Utils extends Logging { */ val DEFAULT_DRIVER_MEM_MB = JavaUtils.DEFAULT_DRIVER_MEM_MB.toInt - private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 + val MAX_DIR_CREATION_ATTEMPTS: Int = 10 @volatile private var localRootDirs: Array[String] = null /** Scheme used for files that are locally available on worker nodes in the cluster. */ diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 773c390175b6..fb8523856da6 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -323,7 +323,7 @@ public static class InProcessTestApp { public static void main(String[] args) throws Exception { assertNotEquals(0, args.length); - assertEquals(args[0], "hello"); + assertEquals("hello", args[0]); new SparkContext().stop(); synchronized (LOCK) { @@ -340,7 +340,7 @@ public static class ErrorInProcessTestApp { public static void main(String[] args) { assertNotEquals(0, args.length); - assertEquals(args[0], "hello"); + assertEquals("hello", args[0]); throw DUMMY_EXCEPTION; } } diff --git a/core/src/test/java/org/apache/spark/util/SerializableConfigurationSuite.java b/core/src/test/java/org/apache/spark/util/SerializableConfigurationSuite.java index 0944d681599a..28d038a524c8 100644 --- a/core/src/test/java/org/apache/spark/util/SerializableConfigurationSuite.java +++ b/core/src/test/java/org/apache/spark/util/SerializableConfigurationSuite.java @@ -50,6 +50,6 @@ public void testSerializableConfiguration() { hadoopConfiguration.set("test.property", "value"); SerializableConfiguration scs = new SerializableConfiguration(hadoopConfiguration); SerializableConfiguration actual = rdd.map(val -> scs).collect().get(0); - assertEquals(actual.value().get("test.property"), "value"); + assertEquals("value", actual.value().get("test.property")); } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index d5b1a1c5f547..43977717f6c9 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -222,7 +222,7 @@ public void testSortingEmptyArrays() throws Exception { public void testSortTimeMetric() throws Exception { final UnsafeExternalSorter sorter = newSorter(); long prevSortTime = sorter.getSortTimeNanos(); - assertEquals(prevSortTime, 0); + assertEquals(0, prevSortTime); sorter.insertRecord(null, 0, 0, 0, false); sorter.spill(); @@ -230,7 +230,7 @@ public void testSortTimeMetric() throws Exception { prevSortTime = sorter.getSortTimeNanos(); sorter.spill(); // no sort needed - assertEquals(sorter.getSortTimeNanos(), prevSortTime); + assertEquals(prevSortTime, sorter.getSortTimeNanos()); sorter.insertRecord(null, 0, 0, 0, false); UnsafeSorterIterator iter = sorter.getSortedIterator(); diff --git a/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala new file mode 100644 index 000000000000..24fa01736365 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/internal/plugin/PluginContainerSuite.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.plugin + +import java.io.File +import java.nio.charset.StandardCharsets +import java.util.{Map => JMap} + +import scala.collection.JavaConverters._ +import scala.concurrent.duration._ + +import com.codahale.metrics.Gauge +import com.google.common.io.Files +import org.mockito.ArgumentMatchers.{any, eq => meq} +import org.mockito.Mockito.{mock, spy, verify, when} +import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually.{eventually, interval, timeout} + +import org.apache.spark.{ExecutorPlugin => _, _} +import org.apache.spark.api.plugin._ +import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher +import org.apache.spark.util.Utils + +class PluginContainerSuite extends SparkFunSuite with BeforeAndAfterEach with LocalSparkContext { + + override def afterEach(): Unit = { + TestSparkPlugin.reset() + super.afterEach() + } + + test("plugin initialization and communication") { + val conf = new SparkConf() + .setAppName(getClass().getName()) + .set(SparkLauncher.SPARK_MASTER, "local[1]") + .set(PLUGINS, Seq(classOf[TestSparkPlugin].getName())) + + TestSparkPlugin.extraConf = Map("foo" -> "bar", "bar" -> "baz").asJava + + sc = new SparkContext(conf) + + assert(TestSparkPlugin.driverPlugin != null) + verify(TestSparkPlugin.driverPlugin).init(meq(sc), any()) + + assert(TestSparkPlugin.executorPlugin != null) + verify(TestSparkPlugin.executorPlugin).init(any(), meq(TestSparkPlugin.extraConf)) + + assert(TestSparkPlugin.executorContext != null) + + // One way messages don't block, so need to loop checking whether it arrives. + TestSparkPlugin.executorContext.send("oneway") + eventually(timeout(10.seconds), interval(10.millis)) { + verify(TestSparkPlugin.driverPlugin).receive("oneway") + } + + assert(TestSparkPlugin.executorContext.ask("ask") === "reply") + + val err = intercept[Exception] { + TestSparkPlugin.executorContext.ask("unknown message") + } + assert(err.getMessage().contains("unknown message")) + + // It should be possible for the driver plugin to send a message to itself, even if that doesn't + // make a whole lot of sense. It at least allows the same context class to be used on both + // sides. + assert(TestSparkPlugin.driverContext != null) + assert(TestSparkPlugin.driverContext.ask("ask") === "reply") + + val metricSources = sc.env.metricsSystem + .getSourcesByName(s"plugin.${classOf[TestSparkPlugin].getName()}") + assert(metricSources.size === 2) + + def findMetric(name: String): Int = { + val allFound = metricSources.filter(_.metricRegistry.getGauges().containsKey(name)) + assert(allFound.size === 1) + allFound.head.metricRegistry.getGauges().get(name).asInstanceOf[Gauge[Int]].getValue() + } + + assert(findMetric("driverMetric") === 42) + assert(findMetric("executorMetric") === 84) + + sc.stop() + sc = null + + verify(TestSparkPlugin.driverPlugin).shutdown() + verify(TestSparkPlugin.executorPlugin).shutdown() + } + + test("do nothing if plugins are not configured") { + val conf = new SparkConf() + val env = mock(classOf[SparkEnv]) + when(env.conf).thenReturn(conf) + assert(PluginContainer(env) === None) + } + + test("merging of config options") { + val conf = new SparkConf() + .setAppName(getClass().getName()) + .set(SparkLauncher.SPARK_MASTER, "local[1]") + .set(PLUGINS, Seq(classOf[TestSparkPlugin].getName())) + .set(DEFAULT_PLUGINS_LIST, classOf[TestSparkPlugin].getName()) + + assert(conf.get(PLUGINS).size === 2) + + sc = new SparkContext(conf) + // Just check plugin is loaded. The plugin code below checks whether a single copy was loaded. + assert(TestSparkPlugin.driverPlugin != null) + } + + test("plugin initialization in non-local mode") { + val path = Utils.createTempDir() + + val conf = new SparkConf() + .setAppName(getClass().getName()) + .set(SparkLauncher.SPARK_MASTER, "local-cluster[2,1,1024]") + .set(PLUGINS, Seq(classOf[NonLocalModeSparkPlugin].getName())) + .set(NonLocalModeSparkPlugin.TEST_PATH_CONF, path.getAbsolutePath()) + + sc = new SparkContext(conf) + TestUtils.waitUntilExecutorsUp(sc, 2, 10000) + + eventually(timeout(10.seconds), interval(100.millis)) { + val children = path.listFiles() + assert(children != null) + assert(children.length >= 3) + } + } +} + +class NonLocalModeSparkPlugin extends SparkPlugin { + + override def driverPlugin(): DriverPlugin = { + new DriverPlugin() { + override def init(sc: SparkContext, ctx: PluginContext): JMap[String, String] = { + NonLocalModeSparkPlugin.writeFile(ctx.conf(), ctx.executorID()) + Map.empty.asJava + } + } + } + + override def executorPlugin(): ExecutorPlugin = { + new ExecutorPlugin() { + override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = { + NonLocalModeSparkPlugin.writeFile(ctx.conf(), ctx.executorID()) + } + } + } +} + +object NonLocalModeSparkPlugin { + val TEST_PATH_CONF = "spark.nonLocalPlugin.path" + + def writeFile(conf: SparkConf, id: String): Unit = { + val path = conf.get(TEST_PATH_CONF) + Files.write(id, new File(path, id), StandardCharsets.UTF_8) + } +} + +class TestSparkPlugin extends SparkPlugin { + + override def driverPlugin(): DriverPlugin = { + val p = new TestDriverPlugin() + require(TestSparkPlugin.driverPlugin == null, "Driver plugin already initialized.") + TestSparkPlugin.driverPlugin = spy(p) + TestSparkPlugin.driverPlugin + } + + override def executorPlugin(): ExecutorPlugin = { + val p = new TestExecutorPlugin() + require(TestSparkPlugin.executorPlugin == null, "Executor plugin already initialized.") + TestSparkPlugin.executorPlugin = spy(p) + TestSparkPlugin.executorPlugin + } + +} + +private class TestDriverPlugin extends DriverPlugin { + + override def init(sc: SparkContext, ctx: PluginContext): JMap[String, String] = { + TestSparkPlugin.driverContext = ctx + TestSparkPlugin.extraConf + } + + override def registerMetrics(appId: String, ctx: PluginContext): Unit = { + ctx.metricRegistry().register("driverMetric", new Gauge[Int] { + override def getValue(): Int = 42 + }) + } + + override def receive(msg: AnyRef): AnyRef = msg match { + case "oneway" => null + case "ask" => "reply" + case other => throw new IllegalArgumentException(s"unknown: $other") + } + +} + +private class TestExecutorPlugin extends ExecutorPlugin { + + override def init(ctx: PluginContext, extraConf: JMap[String, String]): Unit = { + ctx.metricRegistry().register("executorMetric", new Gauge[Int] { + override def getValue(): Int = 84 + }) + TestSparkPlugin.executorContext = ctx + } + +} + +private object TestSparkPlugin { + var driverPlugin: TestDriverPlugin = _ + var driverContext: PluginContext = _ + + var executorPlugin: TestExecutorPlugin = _ + var executorContext: PluginContext = _ + + var extraConf: JMap[String, String] = _ + + def reset(): Unit = { + driverPlugin = null + driverContext = null + executorPlugin = null + executorContext = null + extraConf = null + } +} diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala index 6bf163506e0c..a289dddbdc9e 100644 --- a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -155,6 +155,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { new StageInfo(2, 0, "stage2", 4, Nil, Seq(1), "details2")) val jobProps = new Properties() + jobProps.setProperty(SparkContext.SPARK_JOB_DESCRIPTION, "jobDescription") jobProps.setProperty(SparkContext.SPARK_JOB_GROUP_ID, "jobGroup") jobProps.setProperty(SparkContext.SPARK_SCHEDULER_POOL, "schedPool") @@ -163,7 +164,7 @@ class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { check[JobDataWrapper](1) { job => assert(job.info.jobId === 1) assert(job.info.name === stages.last.name) - assert(job.info.description === None) + assert(job.info.description === Some("jobDescription")) assert(job.info.status === JobExecutionStatus.RUNNING) assert(job.info.submissionTime === Some(new Date(time))) assert(job.info.jobGroup === Some("jobGroup")) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index c757dee43808..ccc525e85483 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -51,7 +51,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B override def beforeEach(): Unit = { super.beforeEach() val conf = testConf.clone - conf.set("spark.local.dir", rootDirs) + conf.set("spark.local.dir", rootDirs).set("spark.diskStore.subDirectories", "1") diskBlockManager = new DiskBlockManager(conf, deleteFilesOnStop = true) } @@ -90,4 +90,45 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B for (i <- 0 until numBytes) writer.write(i) writer.close() } + + test("temporary shuffle/local file should be able to handle disk failures") { + try { + // the following two lines pre-create subdirectories under each root dir of block manager + diskBlockManager.getFile("1") + diskBlockManager.getFile("2") + + val tempShuffleFile1 = diskBlockManager.createTempShuffleBlock()._2 + val tempLocalFile1 = diskBlockManager.createTempLocalBlock()._2 + assert(tempShuffleFile1.exists(), "There are no bad disks, so temp shuffle file exists") + assert(tempLocalFile1.exists(), "There are no bad disks, so temp local file exists") + + // partial disks damaged + rootDir0.setExecutable(false) + val tempShuffleFile2 = diskBlockManager.createTempShuffleBlock()._2 + val tempLocalFile2 = diskBlockManager.createTempLocalBlock()._2 + // It's possible that after 10 retries we still not able to find the healthy disk. we need to + // remove the flakiness of these two asserts + if (tempShuffleFile2.getParentFile.getParentFile.getParent === rootDir1.getAbsolutePath) { + assert(tempShuffleFile2.exists(), + "There is only one bad disk, so temp shuffle file should be created") + } + if (tempLocalFile2.getParentFile.getParentFile.getParent === rootDir1.getAbsolutePath) { + assert(tempLocalFile2.exists(), + "There is only one bad disk, so temp local file should be created") + } + + // all disks damaged + rootDir1.setExecutable(false) + val tempShuffleFile3 = diskBlockManager.createTempShuffleBlock()._2 + val tempLocalFile3 = diskBlockManager.createTempLocalBlock()._2 + assert(!tempShuffleFile3.exists(), + "All disks are broken, so there should be no temp shuffle file created") + assert(!tempLocalFile3.exists(), + "All disks are broken, so there should be no temp local file created") + } finally { + rootDir0.setExecutable(true) + rootDir1.setExecutable(true) + } + + } } diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala index 06f01a60868f..f93ecd3b006b 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StoragePageSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.ui.storage import javax.servlet.http.HttpServletRequest import org.mockito.Mockito._ +import scala.xml.{Node, Text} import org.apache.spark.SparkFunSuite import org.apache.spark.status.StreamBlockData @@ -74,7 +75,21 @@ class StoragePageSuite extends SparkFunSuite { "Fraction Cached", "Size in Memory", "Size on Disk") - assert((xmlNodes \\ "th").map(_.text) === headers) + + val headerRow: Seq[Node] = { + headers.view.zipWithIndex.map { x => + storagePage.tooltips(x._2) match { + case Some(tooltip) => + + + {Text(x._1)} + + + case None => {Text(x._1)} + } + }.toList + } + assert((xmlNodes \\ "th").map(_.text) === headerRow.map(_.text)) assert((xmlNodes \\ "tr").size === 3) assert(((xmlNodes \\ "tr")(0) \\ "td").map(_.text.trim) === diff --git a/dev/.rat-excludes b/dev/.rat-excludes index e12dc994b084..73f461255de4 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -118,3 +118,4 @@ announce.tmpl vote.tmpl SessionManager.java SessionHandler.java +GangliaReporter.java diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 945686de4996..804a178a5fe2 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -30,6 +30,8 @@ + python/pyspark/version.py # Get maven home set by MVN @@ -414,13 +414,13 @@ if [[ "$1" == "publish-release" ]]; then # TODO: revisit for Scala 2.13 support - if ! is_dry_run && [[ $PUBLISH_SCALA_2_11 = 1 ]]; then + if [[ $PUBLISH_SCALA_2_11 = 1 ]]; then ./dev/change-scala-version.sh 2.11 $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests \ $SCALA_2_11_PROFILES $PUBLISH_PROFILES clean install fi - if ! is_dry_run && [[ $PUBLISH_SCALA_2_12 = 1 ]]; then + if [[ $PUBLISH_SCALA_2_12 = 1 ]]; then ./dev/change-scala-version.sh 2.12 $MVN -DzincPort=$((ZINC_PORT + 2)) -Dmaven.repo.local=$tmp_repo -DskipTests \ $SCALA_2_11_PROFILES $PUBLISH_PROFILES clean install diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index f21e76bf4331..e6d29d04acbf 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -125,7 +125,7 @@ jetty-6.1.26.jar jetty-sslengine-6.1.26.jar jetty-util-6.1.26.jar jline-2.14.6.jar -joda-time-2.9.3.jar +joda-time-2.10.5.jar jodd-core-3.5.2.jar jpam-1.1.jar json4s-ast_2.12-3.6.6.jar @@ -149,10 +149,11 @@ lz4-java-1.6.0.jar machinist_2.12-0.6.8.jar macro-compat_2.12-1.1.1.jar mesos-1.4.0-shaded-protobuf.jar -metrics-core-3.2.6.jar -metrics-graphite-3.2.6.jar -metrics-json-3.2.6.jar -metrics-jvm-3.2.6.jar +metrics-core-4.1.1.jar +metrics-graphite-4.1.1.jar +metrics-jmx-4.1.1.jar +metrics-json-4.1.1.jar +metrics-jvm-4.1.1.jar minlog-1.3.0.jar netty-all-4.1.42.Final.jar objenesis-2.5.1.jar @@ -160,9 +161,9 @@ okapi-shade-0.4.2.jar okhttp-3.12.0.jar okio-1.15.0.jar opencsv-2.3.jar -orc-core-1.5.6-nohive.jar -orc-mapreduce-1.5.6-nohive.jar -orc-shims-1.5.6.jar +orc-core-1.5.7-nohive.jar +orc-mapreduce-1.5.7-nohive.jar +orc-shims-1.5.7.jar oro-2.0.8.jar osgi-resource-locator-1.0.3.jar paranamer-2.8.jar @@ -200,7 +201,7 @@ stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-2.8.3.jar validation-api-2.0.1.Final.jar -xbean-asm7-shaded-4.14.jar +xbean-asm7-shaded-4.15.jar xercesImpl-2.9.1.jar xmlenc-0.52.jar xz-1.5.jar diff --git a/dev/deps/spark-deps-hadoop-3.2 b/dev/deps/spark-deps-hadoop-3.2 index 3ecc3c2b0d35..8f1e7fe125b9 100644 --- a/dev/deps/spark-deps-hadoop-3.2 +++ b/dev/deps/spark-deps-hadoop-3.2 @@ -139,7 +139,7 @@ jersey-server-2.29.jar jetty-webapp-9.4.18.v20190429.jar jetty-xml-9.4.18.v20190429.jar jline-2.14.6.jar -joda-time-2.9.3.jar +joda-time-2.10.5.jar jodd-core-3.5.2.jar jpam-1.1.jar json-1.8.jar @@ -179,10 +179,11 @@ lz4-java-1.6.0.jar machinist_2.12-0.6.8.jar macro-compat_2.12-1.1.1.jar mesos-1.4.0-shaded-protobuf.jar -metrics-core-3.2.6.jar -metrics-graphite-3.2.6.jar -metrics-json-3.2.6.jar -metrics-jvm-3.2.6.jar +metrics-core-4.1.1.jar +metrics-graphite-4.1.1.jar +metrics-jmx-4.1.1.jar +metrics-json-4.1.1.jar +metrics-jvm-4.1.1.jar minlog-1.3.0.jar mssql-jdbc-6.2.1.jre7.jar netty-all-4.1.42.Final.jar @@ -193,9 +194,9 @@ okhttp-2.7.5.jar okhttp-3.12.0.jar okio-1.15.0.jar opencsv-2.3.jar -orc-core-1.5.6.jar -orc-mapreduce-1.5.6.jar -orc-shims-1.5.6.jar +orc-core-1.5.7.jar +orc-mapreduce-1.5.7.jar +orc-shims-1.5.7.jar oro-2.0.8.jar osgi-resource-locator-1.0.3.jar paranamer-2.8.jar @@ -235,7 +236,7 @@ univocity-parsers-2.8.3.jar validation-api-2.0.1.Final.jar velocity-1.5.jar woodstox-core-5.0.3.jar -xbean-asm7-shaded-4.14.jar +xbean-asm7-shaded-4.15.jar xz-1.5.jar zjsonpatch-0.3.0.jar zookeeper-3.4.14.jar diff --git a/dev/github_jira_sync.py b/dev/github_jira_sync.py index fa1736163d4c..b444b74d4027 100755 --- a/dev/github_jira_sync.py +++ b/dev/github_jira_sync.py @@ -116,7 +116,8 @@ def build_pr_component_dic(jira_prs): dic = {} for issue, pr in jira_prs: print(issue) - jira_components = [c.name.upper() for c in jira_client.issue(issue).fields.components] + page = get_json(get_url(JIRA_API_BASE + "/rest/api/2/issue/" + issue)) + jira_components = [c['name'].upper() for c in page['fields']['components']] if pr['number'] in dic: dic[pr['number']][1].update(jira_components) else: @@ -163,7 +164,8 @@ def reset_pr_labels(pr_num, jira_components): url = pr['html_url'] title = "[Github] Pull Request #%s (%s)" % (pr['number'], pr['user']['login']) try: - existing_links = map(lambda l: l.raw['object']['url'], jira_client.remote_links(issue)) + page = get_json(get_url(JIRA_API_BASE + "/rest/api/2/issue/" + issue + "/remotelink")) + existing_links = map(lambda l: l['object']['url'], page) except: print("Failure reading JIRA %s (does it exist?)" % issue) print(sys.exc_info()[0]) diff --git a/docs/index.md b/docs/index.md index 9e8af0d5f8e2..5dd8d7816bdd 100644 --- a/docs/index.md +++ b/docs/index.md @@ -48,7 +48,7 @@ or the `JAVA_HOME` environment variable pointing to a Java installation. Spark runs on Java 8/11, Scala 2.12, Python 2.7+/3.4+ and R 3.1+. Java 8 prior to version 8u92 support is deprecated as of Spark 3.0.0. -Python 2 support is deprecated as of Spark 3.0.0. +Python 2 and Python 3 prior to version 3.6 support is deprecated as of Spark 3.0.0. R prior to version 3.4 support is deprecated as of Spark 3.0.0. For the Scala API, Spark {{site.SPARK_VERSION}} uses Scala {{site.SCALA_BINARY_VERSION}}. You will need to use a compatible Scala version diff --git a/docs/monitoring.md b/docs/monitoring.md index 8cb237df0ba7..4062e16a25d3 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -991,6 +991,11 @@ This is the component with the largest amount of instrumented metrics - namespace=JVMCPU - jvmCpuTime +- namespace=plugin.\ + - Optional namespace(s). Metrics in this namespace are defined by user-supplied code, and + configured using the Spark plugin API. See "Advanced Instrumentation" below for how to load + custom plugins into Spark. + ### Component instance = Executor These metrics are exposed by Spark executors. Note, currently they are not available when running in local mode. @@ -1060,10 +1065,10 @@ when running in local mode. - hiveClientCalls.count - sourceCodeSize (histogram) -- namespace= - - Optional namespace(s). Metrics in this namespace are defined by user-supplied code, and - configured using the Spark executor plugin infrastructure. - See also the configuration parameter `spark.executor.plugins` +- namespace=plugin.\ + - Optional namespace(s). Metrics in this namespace are defined by user-supplied code, and + configured using the Spark plugin API. See "Advanced Instrumentation" below for how to load + custom plugins into Spark. ### Source = JVM Source Notes: @@ -1141,3 +1146,21 @@ can provide fine-grained profiling on individual nodes. * JVM utilities such as `jstack` for providing stack traces, `jmap` for creating heap-dumps, `jstat` for reporting time-series statistics and `jconsole` for visually exploring various JVM properties are useful for those comfortable with JVM internals. + +Spark also provides a plugin API so that custom instrumentation code can be added to Spark +applications. There are two configuration keys available for loading plugins into Spark: + +- spark.plugins +- spark.plugins.defaultList + +Both take a comma-separated list of class names that implement the +org.apache.spark.api.plugin.SparkPlugin interface. The two names exist so that it's +possible for one list to be placed in the Spark default config file, allowing users to +easily add other plugins from the command line without overwriting the config file's list. Duplicate +plugins are ignored. + +Distribution of the jar files containing the plugin code is currently not done by Spark. The user +or admin should make sure that the jar files are available to Spark applications, for example, by +including the plugin jar with the Spark distribution. The exception to this rule is the YARN +backend, where the --jars command line option (or equivalent config entry) can be +used to make the plugin code available to both executors and cluster-mode drivers. diff --git a/docs/pyspark-migration-guide.md b/docs/pyspark-migration-guide.md index 889941c37bf4..1b8d1fc1c577 100644 --- a/docs/pyspark-migration-guide.md +++ b/docs/pyspark-migration-guide.md @@ -84,6 +84,9 @@ Please refer [Migration Guide: SQL, Datasets and DataFrame](sql-migration-guide. - Since Spark 3.0, `createDataFrame(..., verifySchema=True)` validates `LongType` as well in PySpark. Previously, `LongType` was not verified and resulted in `None` in case the value overflows. To restore this behavior, `verifySchema` can be set to `False` to disable the validation. + - Since Spark 3.0, `Column.getItem` is fixed such that it does not call `Column.apply`. Consequently, if `Column` is used as an argument to `getItem`, the indexing operator should be used. + For example, `map_col.getItem(col('id'))` should be replaced with `map_col[col('id')]`. + ## Upgrading from PySpark 2.3 to 2.4 - In PySpark, when Arrow optimization is enabled, previously `toPandas` just failed when Arrow optimization is unable to be used whereas `createDataFrame` from Pandas DataFrame allowed the fallback to non-optimization. Now, both `toPandas` and `createDataFrame` from Pandas DataFrame allow the fallback by default, which can be switched off by `spark.sql.execution.arrow.fallback.enabled`. diff --git a/docs/sql-data-sources-hive-tables.md b/docs/sql-data-sources-hive-tables.md index e4ce3e938b75..f99b06494934 100644 --- a/docs/sql-data-sources-hive-tables.md +++ b/docs/sql-data-sources-hive-tables.md @@ -88,17 +88,17 @@ creating table, you can create a table using storage handler at Hive side, and u inputFormat, outputFormat - These 2 options specify the name of a corresponding `InputFormat` and `OutputFormat` class as a string literal, - e.g. `org.apache.hadoop.hive.ql.io.orc.OrcInputFormat`. These 2 options must be appeared in a pair, and you can not - specify them if you already specified the `fileFormat` option. + These 2 options specify the name of a corresponding InputFormat and OutputFormat class as a string literal, + e.g. org.apache.hadoop.hive.ql.io.orc.OrcInputFormat. These 2 options must be appeared in a pair, and you can not + specify them if you already specified the fileFormat option. serde - This option specifies the name of a serde class. When the `fileFormat` option is specified, do not specify this option - if the given `fileFormat` already include the information of serde. Currently "sequencefile", "textfile" and "rcfile" + This option specifies the name of a serde class. When the fileFormat option is specified, do not specify this option + if the given fileFormat already include the information of serde. Currently "sequencefile", "textfile" and "rcfile" don't include the serde information and you can use this option with these 3 fileFormats. diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index c3502cbdea8e..b0d37b11c711 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -60,7 +60,7 @@ the following case-insensitive options: The JDBC table that should be read from or written into. Note that when using it in the read path anything that is valid in a FROM clause of a SQL query can be used. For example, instead of a full table you could also use a subquery in parentheses. It is not - allowed to specify `dbtable` and `query` options at the same time. + allowed to specify dbtable and query options at the same time. @@ -72,10 +72,10 @@ the following case-insensitive options: SELECT <columns> FROM (<user_specified_query>) spark_gen_alias

Below are a couple of restrictions while using this option.

    -
  1. It is not allowed to specify `dbtable` and `query` options at the same time.
  2. -
  3. It is not allowed to specify `query` and `partitionColumn` options at the same time. When specifying - `partitionColumn` option is required, the subquery can be specified using `dbtable` option instead and - partition columns can be qualified using the subquery alias provided as part of `dbtable`.
    +
  4. It is not allowed to specify dbtable and query options at the same time.
  5. +
  6. It is not allowed to specify query and partitionColumn options at the same time. When specifying + partitionColumn option is required, the subquery can be specified using dbtable option instead and + partition columns can be qualified using the subquery alias provided as part of dbtable.
    Example:
    spark.read.format("jdbc")
    diff --git a/docs/sql-data-sources-parquet.md b/docs/sql-data-sources-parquet.md index b5309870f485..53a1111cd828 100644 --- a/docs/sql-data-sources-parquet.md +++ b/docs/sql-data-sources-parquet.md @@ -280,12 +280,12 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession spark.sql.parquet.compression.codec snappy - Sets the compression codec used when writing Parquet files. If either `compression` or - `parquet.compression` is specified in the table-specific options/properties, the precedence would be - `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: + Sets the compression codec used when writing Parquet files. If either compression or + parquet.compression is specified in the table-specific options/properties, the precedence would be + compression, parquet.compression, spark.sql.parquet.compression.codec. Acceptable values include: none, uncompressed, snappy, gzip, lzo, brotli, lz4, zstd. - Note that `zstd` requires `ZStandardCodec` to be installed before Hadoop 2.9.0, `brotli` requires - `BrotliCodec` to be installed. + Note that zstd requires ZStandardCodec to be installed before Hadoop 2.9.0, brotli requires + BrotliCodec to be installed. diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index d03ca663e8e3..a97a4b04ded6 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -9,9 +9,9 @@ license: | The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -218,6 +218,8 @@ license: | - Since Spark 3.0, the `size` function returns `NULL` for the `NULL` input. In Spark version 2.4 and earlier, this function gives `-1` for the same input. To restore the behavior before Spark 3.0, you can set `spark.sql.legacy.sizeOfNull` to `true`. + - Since Spark 3.0, the interval literal syntax does not allow multiple from-to units anymore. For example, `SELECT INTERVAL '1-1' YEAR TO MONTH '2-2' YEAR TO MONTH'` throws parser exception. + ## Upgrading from Spark SQL 2.4 to 2.4.1 - The value of `spark.executor.heartbeatInterval`, when specified without units like "30" rather than "30s", was diff --git a/docs/sql-pyspark-pandas-with-arrow.md b/docs/sql-pyspark-pandas-with-arrow.md index 7f01483d4058..d638278b4235 100644 --- a/docs/sql-pyspark-pandas-with-arrow.md +++ b/docs/sql-pyspark-pandas-with-arrow.md @@ -178,6 +178,41 @@ For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/p [`pyspark.sql.DataFrame.mapsInPandas`](api/python/pyspark.sql.html#pyspark.sql.DataFrame.mapInPandas). +### Cogrouped Map + +Cogrouped map Pandas UDFs allow two DataFrames to be cogrouped by a common key and then a python function applied to +each cogroup. They are used with `groupBy().cogroup().apply()` which consists of the following steps: + +* Shuffle the data such that the groups of each dataframe which share a key are cogrouped together. +* Apply a function to each cogroup. The input of the function is two `pandas.DataFrame` (with an optional Tuple +representing the key). The output of the function is a `pandas.DataFrame`. +* Combine the pandas.DataFrames from all groups into a new `DataFrame`. + +To use `groupBy().cogroup().apply()`, the user needs to define the following: +* A Python function that defines the computation for each cogroup. +* A `StructType` object or a string that defines the schema of the output `DataFrame`. + +The column labels of the returned `pandas.DataFrame` must either match the field names in the +defined output schema if specified as strings, or match the field data types by position if not +strings, e.g. integer indices. See [pandas.DataFrame](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html#pandas.DataFrame) +on how to label columns when constructing a `pandas.DataFrame`. + +Note that all data for a cogroup will be loaded into memory before the function is applied. This can lead to out of +memory exceptions, especially if the group sizes are skewed. The configuration for [maxRecordsPerBatch](#setting-arrow-batch-size) +is not applied and it is up to the user to ensure that the cogrouped data will fit into the available memory. + +The following example shows how to use `groupby().cogroup().apply()` to perform an asof join between two datasets. + +
    +
    +{% include_example cogrouped_map_pandas_udf python/sql/arrow.py %} +
    +
    + +For detailed usage, please see [`pyspark.sql.functions.pandas_udf`](api/python/pyspark.sql.html#pyspark.sql.functions.pandas_udf) and +[`pyspark.sql.CoGroupedData.apply`](api/python/pyspark.sql.html#pyspark.sql.CoGroupedData.apply). + + ## Usage Notes ### Supported SQL Types diff --git a/docs/sql-ref-syntax-aux-show-table.md b/docs/sql-ref-syntax-aux-show-table.md index ad549b6b11ec..1d881a73c811 100644 --- a/docs/sql-ref-syntax-aux-show-table.md +++ b/docs/sql-ref-syntax-aux-show-table.md @@ -1,7 +1,7 @@ --- layout: global -title: SHOW TABLE -displayTitle: SHOW TABLE +title: SHOW TABLE EXTENDED +displayTitle: SHOW TABLE EXTENDED license: | Licensed to the Apache Software Foundation (ASF) under one or more contributor license agreements. See the NOTICE file distributed with @@ -18,5 +18,161 @@ license: | See the License for the specific language governing permissions and limitations under the License. --- +### Description -**This page is under construction** +`SHOW TABLE EXTENDED` will show information for all tables matching the given regular expression. +Output includes basic table information and file system information like `Last Access`, +`Created By`, `Type`, `Provider`, `Table Properties`, `Location`, `Serde Library`, `InputFormat`, +`OutputFormat`, `Storage Properties`, `Partition Provider`, `Partition Columns` and `Schema`. + +If a partition specification is present, it outputs the given partition's file-system-specific +information such as `Partition Parameters` and `Partition Statistics`. Note that a table regex +cannot be used with a partition specification. + +### Syntax +{% highlight sql %} +SHOW TABLE EXTENDED [IN|FROM database_name] LIKE 'identifier_with_wildcards' [PARTITION(partition_spec)]; +{% endhighlight %} + +### Parameters +
    +
    IN|FROM database_name
    +
    + Specifies database name. If not provided, will use the current database. +
    +
    LIKE string_pattern
    +
    + Specifies the regular expression pattern that is used to filter out unwanted tables. +
      +
    • Except for `*` and `|` character, the pattern works like a regex.
    • +
    • `*` alone matches 0 or more characters and `|` is used to separate multiple different regexes, + any of which can match.
    • +
    • The leading and trailing blanks are trimmed in the input pattern before processing.
    • +
    +
    +
    PARTITION(partition_spec)
    +
    + Specifies partition column and its value which is exists in the table. Note that a table regex + cannot be used with a partition specification.. +
    +
    +### Examples +{% highlight sql %} +-- Assumes `employee` table created with partitioned by column `grade` +-- +-------+--------+--+ +-- | name | grade | +-- +-------+--------+--+ +-- | sam | 1 | +-- | suj | 2 | +-- +-------+--------+--+ + + -- Show the details of the table +SHOW TABLE EXTENDED LIKE `employee`; ++--------+---------+-----------+--------------------------------------------------------------- +|database|tableName|isTemporary| information ++--------+---------+-----------+--------------------------------------------------------------- +|default |employee |false |Database: default + Table: employee + Owner: root + Created Time: Fri Aug 30 15:10:21 IST 2019 + Last Access: Thu Jan 01 05:30:00 IST 1970 + Created By: Spark 3.0.0-SNAPSHOT + Type: MANAGED + Provider: hive + Table Properties: [transient_lastDdlTime=1567158021] + Location: file:/opt/spark1/spark/spark-warehouse/employee + Serde Library: org.apache.hadoop.hive.serde2.lazy + .LazySimpleSerDe + InputFormat: org.apache.hadoop.mapred.TextInputFormat + OutputFormat: org.apache.hadoop.hive.ql.io + .HiveIgnoreKeyTextOutputFormat + Storage Properties: [serialization.format=1] + Partition Provider: Catalog + Partition Columns: [`grade`] + Schema: root + |-- name: string (nullable = true) + |-- grade: integer (nullable = true) + ++--------+---------+-----------+--------------------------------------------------------------- + +-- showing the multiple table details with pattern matching +SHOW TABLE EXTENDED LIKE `employe*`; ++--------+---------+-----------+--------------------------------------------------------------- +|database|tableName|isTemporary| information ++--------+---------+-----------+--------------------------------------------------------------- +|default |employee |false |Database: default + Table: employee + Owner: root + Created Time: Fri Aug 30 15:10:21 IST 2019 + Last Access: Thu Jan 01 05:30:00 IST 1970 + Created By: Spark 3.0.0-SNAPSHOT + Type: MANAGED + Provider: hive + Table Properties: [transient_lastDdlTime=1567158021] + Location: file:/opt/spark1/spark/spark-warehouse/employee + Serde Library: org.apache.hadoop.hive.serde2.lazy + .LazySimpleSerDe + InputFormat: org.apache.hadoop.mapred.TextInputFormat + OutputFormat: org.apache.hadoop.hive.ql.io + .HiveIgnoreKeyTextOutputFormat + Storage Properties: [serialization.format=1] + Partition Provider: Catalog + Partition Columns: [`grade`] + Schema: root + |-- name: string (nullable = true) + |-- grade: integer (nullable = true) + +|default |employee1|false |Database: default + Table: employee1 + Owner: root + Created Time: Fri Aug 30 15:22:33 IST 2019 + Last Access: Thu Jan 01 05:30:00 IST 1970 + Created By: Spark 3.0.0-SNAPSHOT + Type: MANAGED + Provider: hive + Table Properties: [transient_lastDdlTime=1567158753] + Location: file:/opt/spark1/spark/spark-warehouse/employee1 + Serde Library: org.apache.hadoop.hive.serde2.lazy + .LazySimpleSerDe + InputFormat: org.apache.hadoop.mapred.TextInputFormat + OutputFormat: org.apache.hadoop.hive.ql.io + .HiveIgnoreKeyTextOutputFormat + Storage Properties: [serialization.format=1] + Partition Provider: Catalog + Schema: root + |-- name: string (nullable = true) + ++--------+---------+----------+---------------------------------------------------------------- + +-- show partition file system details +SHOW TABLE EXTENDED IN `default` LIKE `employee` PARTITION (`grade=1`); ++--------+---------+-----------+--------------------------------------------------------------- +|database|tableName|isTemporary| information ++--------+---------+-----------+--------------------------------------------------------------- +|default |employee |false | Partition Values: [grade=1] + Location: file:/opt/spark1/spark/spark-warehouse/employee + /grade=1 + Serde Library: org.apache.hadoop.hive.serde2.lazy + .LazySimpleSerDe + InputFormat: org.apache.hadoop.mapred.TextInputFormat + OutputFormat: org.apache.hadoop.hive.ql.io + .HiveIgnoreKeyTextOutputFormat + Storage Properties: [serialization.format=1] + Partition Parameters: {rawDataSize=-1, numFiles=1, + transient_lastDdlTime=1567158221, totalSize=4, + COLUMN_STATS_ACCURATE=false, numRows=-1} + Created Time: Fri Aug 30 15:13:41 IST 2019 + Last Access: Thu Jan 01 05:30:00 IST 1970 + Partition Statistics: 4 bytes + | ++--------+---------+-----------+--------------------------------------------------------------- + +-- show partition file system details with regex fails as shown below +SHOW TABLE EXTENDED IN `default` LIKE `empl*` PARTITION (`grade=1`); +Error: Error running query: org.apache.spark.sql.catalyst.analysis.NoSuchTableException: + Table or view 'emplo*' not found in database 'default'; (state=,code=0) + +{% endhighlight %} +### Related Statements +- [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) +- [DESCRIBE TABLE](sql-ref-syntax-aux-describe-table.html) diff --git a/docs/sql-ref-syntax-ddl-drop-table.md b/docs/sql-ref-syntax-ddl-drop-table.md index a036e66c3906..f9129d5114fa 100644 --- a/docs/sql-ref-syntax-ddl-drop-table.md +++ b/docs/sql-ref-syntax-ddl-drop-table.md @@ -19,4 +19,69 @@ license: | limitations under the License. --- -**This page is under construction** +### Description + +`DROP TABLE` deletes the table and removes the directory associated with the table from the file system +if the table is not `EXTERNAL` table. If the table is not present it throws an exception. + +In case of an external table, only the associated metadata information is removed from the metastore database. + +### Syntax +{% highlight sql %} +DROP TABLE [IF EXISTS] [database_name.]table_name +{% endhighlight %} + +### Parameter +
    +
    IF EXISTS
    +
    + If specified, no exception is thrown when the table does not exists. +
    +
    database_name
    +
    + Specifies the database name where table is present. +
    +
    table_name
    +
    + Specifies the table name to be dropped. +
    +
    + +### Example +{% highlight sql %} +-- Assumes a table named `employeetable` exists. +DROP TABLE employeetable; ++---------+--+ +| Result | ++---------+--+ ++---------+--+ + +-- Assumes a table named `employeetable` exists in the `userdb` database +DROP TABLE userdb.employeetable; ++---------+--+ +| Result | ++---------+--+ ++---------+--+ + +-- Assumes a table named `employeetable` does not exists. +-- Throws exception +DROP TABLE employeetable; +Error: org.apache.spark.sql.AnalysisException: Table or view not found: employeetable; +(state=,code=0) + +-- Assumes a table named `employeetable` does not exists,Try with IF EXISTS +-- this time it will not throw exception +DROP TABLE IF EXISTS employeetable; ++---------+--+ +| Result | ++---------+--+ ++---------+--+ + +{% endhighlight %} + +### Related Statements +- [CREATE TABLE](sql-ref-syntax-ddl-create-table.html) +- [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) +- [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) + + diff --git a/docs/sql-ref-syntax-ddl-drop-view.md b/docs/sql-ref-syntax-ddl-drop-view.md index 9ad22500fd9e..f095a3456772 100644 --- a/docs/sql-ref-syntax-ddl-drop-view.md +++ b/docs/sql-ref-syntax-ddl-drop-view.md @@ -19,4 +19,63 @@ license: | limitations under the License. --- -**This page is under construction** +### Description +`DROP VIEW` removes the metadata associated with a specified view from the catalog. + +### Syntax +{% highlight sql %} +DROP VIEW [IF EXISTS] [database_name.]view_name +{% endhighlight %} + +### Parameter +
    +
    IF EXISTS
    +
    + If specified, no exception is thrown when the view does not exists. +
    +
    database_name
    +
    + Specifies the database name where view is present. +
    +
    view_name
    +
    + Specifies the view name to be dropped. +
    +
    + +### Example +{% highlight sql %} +-- Assumes a view named `employeeView` exists. +DROP VIEW employeeView; ++---------+--+ +| Result | ++---------+--+ ++---------+--+ + +-- Assumes a view named `employeeView` exists in the `userdb` database +DROP VIEW userdb.employeeView; ++---------+--+ +| Result | ++---------+--+ ++---------+--+ + +-- Assumes a view named `employeeView` does not exists. +-- Throws exception +DROP VIEW employeeView; +Error: org.apache.spark.sql.AnalysisException: Table or view not found: employeeView; +(state=,code=0) + +-- Assumes a view named `employeeView` does not exists,Try with IF EXISTS +-- this time it will not throw exception +DROP VIEW IF EXISTS employeeView; ++---------+--+ +| Result | ++---------+--+ ++---------+--+ + +{% endhighlight %} + +### Related Statements +- [CREATE VIEW](sql-ref-syntax-ddl-create-view.html) +- [CREATE DATABASE](sql-ref-syntax-ddl-create-database.html) +- [DROP DATABASE](sql-ref-syntax-ddl-drop-database.html) diff --git a/docs/structured-streaming-kafka-integration.md b/docs/structured-streaming-kafka-integration.md index badf0429545f..8c17de92f348 100644 --- a/docs/structured-streaming-kafka-integration.md +++ b/docs/structured-streaming-kafka-integration.md @@ -473,8 +473,8 @@ The following configurations are optional: Desired minimum number of partitions to read from Kafka. By default, Spark has a 1-1 mapping of topicPartitions to Spark partitions consuming from Kafka. If you set this option to a value greater than your topicPartitions, Spark will divvy up large - Kafka partitions to smaller pieces. Please note that this configuration is like a `hint`: the - number of Spark tasks will be **approximately** `minPartitions`. It can be less or more depending on + Kafka partitions to smaller pieces. Please note that this configuration is like a hint: the + number of Spark tasks will be approximately minPartitions. It can be less or more depending on rounding errors or Kafka partitions that didn't receive any new data. @@ -482,7 +482,7 @@ The following configurations are optional: string spark-kafka-source streaming and batch - Prefix of consumer group identifiers (`group.id`) that are generated by structured streaming + Prefix of consumer group identifiers (group.id) that are generated by structured streaming queries. If "kafka.group.id" is set, this option will be ignored. diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 2a405f36fd5f..01679e5defe1 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -546,6 +546,13 @@ Here are the details of all the sources in Spark. "s3://a/dataset.txt"
    "s3n://a/b/dataset.txt"
    "s3a://a/b/c/dataset.txt"
    + cleanSource: option to clean up completed files after processing.
    + Available options are "archive", "delete", "off". If the option is not provided, the default value is "off".
    + When "archive" is provided, additional option sourceArchiveDir must be provided as well. The value of "sourceArchiveDir" must have 2 subdirectories (so depth of directory is greater than 2). e.g. /archived/here. This will ensure archived files are never included as new source files.
    + Spark will move source files respecting their own path. For example, if the path of source file is /a/b/dataset.txt and the path of archive directory is /archived/here, file will be moved to /archived/here/a/b/dataset.txt.
    + NOTE: Both archiving (via moving) or deleting completed files will introduce overhead (slow down) in each micro-batch, so you need to understand the cost for each operation in your file system before enabling this option. On the other hand, enabling this option will reduce the cost to list source files which can be an expensive operation.
    + NOTE 2: The source path should not be used from multiple sources or queries when enabling this option.
    + NOTE 3: Both delete and move actions are best effort. Failing to delete or move files will not fail the streaming query.

    For file-format-specific options, see the related methods in DataStreamReader (Scala/Java/Python/Append, Update, Complete Append mode uses watermark to drop old aggregation state. But the output of a - windowed aggregation is delayed the late threshold specified in `withWatermark()` as by + windowed aggregation is delayed the late threshold specified in withWatermark() as by the modes semantics, rows can be added to the Result Table only once after they are finalized (i.e. after watermark is crossed). See the Late Data section for more details. @@ -2324,7 +2331,7 @@ Here are the different kinds of triggers that are supported. One-time micro-batch - The query will execute *only one* micro-batch to process all the available data and then + The query will execute only one micro-batch to process all the available data and then stop on its own. This is useful in scenarios you want to periodically spin up a cluster, process everything that is available since the last period, and then shutdown the cluster. In some case, this may lead to significant cost savings. diff --git a/examples/src/main/python/sql/arrow.py b/examples/src/main/python/sql/arrow.py index de8d4f755de6..d5a3173ff9c0 100644 --- a/examples/src/main/python/sql/arrow.py +++ b/examples/src/main/python/sql/arrow.py @@ -258,6 +258,36 @@ def filter_func(batch_iter): # $example off:map_iter_pandas_udf$ +def cogrouped_map_pandas_udf_example(spark): + # $example on:cogrouped_map_pandas_udf$ + import pandas as pd + + from pyspark.sql.functions import pandas_udf, PandasUDFType + + df1 = spark.createDataFrame( + [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], + ("time", "id", "v1")) + + df2 = spark.createDataFrame( + [(20000101, 1, "x"), (20000101, 2, "y")], + ("time", "id", "v2")) + + @pandas_udf("time int, id int, v1 double, v2 string", PandasUDFType.COGROUPED_MAP) + def asof_join(l, r): + return pd.merge_asof(l, r, on="time", by="id") + + df1.groupby("id").cogroup(df2.groupby("id")).apply(asof_join).show() + # +--------+---+---+---+ + # | time| id| v1| v2| + # +--------+---+---+---+ + # |20000101| 1|1.0| x| + # |20000102| 1|3.0| x| + # |20000101| 2|2.0| y| + # |20000102| 2|4.0| y| + # +--------+---+---+---+ + # $example off:cogrouped_map_pandas_udf$ + + if __name__ == "__main__": spark = SparkSession \ .builder \ @@ -276,5 +306,7 @@ def filter_func(batch_iter): grouped_agg_pandas_udf_example(spark) print("Running pandas_udf map iterator example") map_iter_pandas_udf_example(spark) + print("Running pandas_udf cogrouped map example") + cogrouped_map_pandas_udf_example(spark) spark.stop() diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala index 5bdc1b5fe9f3..8b907065af1d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -25,9 +25,9 @@ import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecor import org.apache.kafka.common.header.Header import org.apache.kafka.common.header.internals.RecordHeader -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} -import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, UnsafeProjection} +import org.apache.spark.sql.types.BinaryType /** * Writes out data in a single Spark task, without any concerns about how @@ -116,66 +116,13 @@ private[kafka010] abstract class KafkaRowWriter( } private def createProjection = { - val topicExpression = topic.map(Literal(_)).orElse { - inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) - }.getOrElse { - throw new IllegalStateException(s"topic option required when no " + - s"'${KafkaWriter.TOPIC_ATTRIBUTE_NAME}' attribute is present") - } - topicExpression.dataType match { - case StringType => // good - case t => - throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + - s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + - s"must be a ${StringType.catalogString}") - } - val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) - .getOrElse(Literal(null, BinaryType)) - keyExpression.dataType match { - case StringType | BinaryType => // good - case t => - throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " + - s"attribute unsupported type ${t.catalogString}") - } - val valueExpression = inputSchema - .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( - throw new IllegalStateException("Required attribute " + - s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found") - ) - valueExpression.dataType match { - case StringType | BinaryType => // good - case t => - throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + - s"attribute unsupported type ${t.catalogString}") - } - val headersExpression = inputSchema - .find(_.name == KafkaWriter.HEADERS_ATTRIBUTE_NAME).getOrElse( - Literal(CatalystTypeConverters.convertToCatalyst(null), - KafkaRecordToRowConverter.headersType) - ) - headersExpression.dataType match { - case KafkaRecordToRowConverter.headersType => // good - case t => - throw new IllegalStateException(s"${KafkaWriter.HEADERS_ATTRIBUTE_NAME} " + - s"attribute unsupported type ${t.catalogString}") - } - val partitionExpression = - inputSchema.find(_.name == KafkaWriter.PARTITION_ATTRIBUTE_NAME) - .getOrElse(Literal(null, IntegerType)) - partitionExpression.dataType match { - case IntegerType => // good - case t => - throw new IllegalStateException(s"${KafkaWriter.PARTITION_ATTRIBUTE_NAME} " + - s"attribute unsupported type $t. ${KafkaWriter.PARTITION_ATTRIBUTE_NAME} " + - s"must be a ${IntegerType.catalogString}") - } UnsafeProjection.create( Seq( - topicExpression, - Cast(keyExpression, BinaryType), - Cast(valueExpression, BinaryType), - headersExpression, - partitionExpression + KafkaWriter.topicExpression(inputSchema, topic), + Cast(KafkaWriter.keyExpression(inputSchema), BinaryType), + Cast(KafkaWriter.valueExpression(inputSchema), BinaryType), + KafkaWriter.headersExpression(inputSchema), + KafkaWriter.partitionExpression(inputSchema) ), inputSchema ) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index 9b0d11f137ce..5ef4b3a1c19d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.QueryExecution -import org.apache.spark.sql.types.{BinaryType, IntegerType, MapType, StringType} +import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StringType} import org.apache.spark.util.Utils /** @@ -49,51 +49,14 @@ private[kafka010] object KafkaWriter extends Logging { schema: Seq[Attribute], kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( - if (topic.isEmpty) { - throw new AnalysisException(s"topic option required when no " + - s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " + - s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.") - } else { - Literal.create(topic.get, StringType) - } - ).dataType match { - case StringType => // good - case _ => - throw new AnalysisException(s"Topic type must be a ${StringType.catalogString}") - } - schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse( - Literal(null, StringType) - ).dataType match { - case StringType | BinaryType => // good - case _ => - throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " + - s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}") - } - schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse( - throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found") - ).dataType match { - case StringType | BinaryType => // good - case _ => - throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + - s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}") - } - schema.find(_.name == HEADERS_ATTRIBUTE_NAME).getOrElse( - Literal(CatalystTypeConverters.convertToCatalyst(null), - KafkaRecordToRowConverter.headersType) - ).dataType match { - case KafkaRecordToRowConverter.headersType => // good - case _ => - throw new AnalysisException(s"$HEADERS_ATTRIBUTE_NAME attribute type " + - s"must be a ${KafkaRecordToRowConverter.headersType.catalogString}") - } - schema.find(_.name == PARTITION_ATTRIBUTE_NAME).getOrElse( - Literal(null, IntegerType) - ).dataType match { - case IntegerType => // good - case _ => - throw new AnalysisException(s"$PARTITION_ATTRIBUTE_NAME attribute type " + - s"must be an ${IntegerType.catalogString}") + try { + topicExpression(schema, topic) + keyExpression(schema) + valueExpression(schema) + headersExpression(schema) + partitionExpression(schema) + } catch { + case e: IllegalStateException => throw new AnalysisException(e.getMessage) } } @@ -110,4 +73,53 @@ private[kafka010] object KafkaWriter extends Logging { finallyBlock = writeTask.close()) } } + + def topicExpression(schema: Seq[Attribute], topic: Option[String] = None): Expression = { + topic.map(Literal(_)).getOrElse( + expression(schema, TOPIC_ATTRIBUTE_NAME, Seq(StringType)) { + throw new IllegalStateException(s"topic option required when no " + + s"'${TOPIC_ATTRIBUTE_NAME}' attribute is present. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.") + } + ) + } + + def keyExpression(schema: Seq[Attribute]): Expression = { + expression(schema, KEY_ATTRIBUTE_NAME, Seq(StringType, BinaryType)) { + Literal(null, BinaryType) + } + } + + def valueExpression(schema: Seq[Attribute]): Expression = { + expression(schema, VALUE_ATTRIBUTE_NAME, Seq(StringType, BinaryType)) { + throw new IllegalStateException(s"Required attribute '${VALUE_ATTRIBUTE_NAME}' not found") + } + } + + def headersExpression(schema: Seq[Attribute]): Expression = { + expression(schema, HEADERS_ATTRIBUTE_NAME, Seq(KafkaRecordToRowConverter.headersType)) { + Literal(CatalystTypeConverters.convertToCatalyst(null), + KafkaRecordToRowConverter.headersType) + } + } + + def partitionExpression(schema: Seq[Attribute]): Expression = { + expression(schema, PARTITION_ATTRIBUTE_NAME, Seq(IntegerType)) { + Literal(null, IntegerType) + } + } + + private def expression( + schema: Seq[Attribute], + attrName: String, + desired: Seq[DataType])( + default: => Expression): Expression = { + val expr = schema.find(_.name == attrName).getOrElse(default) + if (!desired.exists(_.sameType(expr.dataType))) { + throw new IllegalStateException(s"$attrName attribute unsupported type " + + s"${expr.dataType.catalogString}. $attrName must be a(n) " + + s"${desired.map(_.catalogString).mkString(" or ")}") + } + expr + } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index cbf4952406c0..031f609cb92b 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.kafka010 import java.util.Locale +import scala.reflect.ClassTag + import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.ByteArraySerializer import org.scalatest.time.SpanSugar._ import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{BinaryType, DataType} import org.apache.spark.util.Utils @@ -192,24 +195,9 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val topic = newTopic() testUtils.createTopic(topic) - /* No topic field or topic option */ - var writer: StreamingQuery = null - var ex: Exception = null - try { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = "CAST(null as STRING) as topic", "value" - ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - } finally { - writer.stop() + runAndVerifyException[StreamingQueryException](inputTopic, "null topic present in the data.") { + createKafkaWriter(input.toDF())(withSelectExpr = "CAST(null as STRING) as topic", "value") } - assert(ex.getCause.getCause.getMessage - .toLowerCase(Locale.ROOT) - .contains("null topic present in the data.")) } test("streaming - write data with bad schema") { @@ -226,24 +214,10 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val topic = newTopic() testUtils.createTopic(topic) - val ex = intercept[AnalysisException] { - /* No topic field or topic option */ - createKafkaWriter(input.toDF())( - withSelectExpr = "value as key", "value" - ) - } - assert(ex.getMessage - .toLowerCase(Locale.ROOT) - .contains("topic option required when no 'topic' attribute is present")) - - val ex2 = intercept[AnalysisException] { - /* No value field */ - createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "value as key" - ) - } - assert(ex2.getMessage.toLowerCase(Locale.ROOT).contains( - "required attribute 'value' not found")) + assertWrongSchema(topic, input, Seq("value as key", "value"), + "topic option required when no 'topic' attribute is present") + assertWrongSchema(topic, input, Seq(s"'$topic' as topic", "value as key"), + "required attribute 'value' not found") } test("streaming - write data with valid schema but wrong types") { @@ -258,43 +232,18 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { .option("startingOffsets", "earliest") .load() .selectExpr("CAST(value as STRING) value") + .toDF() val topic = newTopic() testUtils.createTopic(topic) - val ex = intercept[AnalysisException] { - /* topic field wrong type */ - createKafkaWriter(input.toDF())( - withSelectExpr = s"CAST('1' as INT) as topic", "value" - ) - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) - - val ex2 = intercept[AnalysisException] { - /* value field wrong type */ - createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" - ) - } - assert(ex2.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binary")) - - val ex3 = intercept[AnalysisException] { - /* key field wrong type */ - createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" - ) - } - assert(ex3.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binary")) - - val ex4 = intercept[AnalysisException] { - /* partition field wrong type */ - createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "value as partition", "value" - ) - } - assert(ex4.getMessage.toLowerCase(Locale.ROOT).contains( - "partition attribute type must be an int")) + assertWrongSchema(topic, input, Seq("CAST('1' as INT) as topic", "value"), + "topic must be a(n) string") + assertWrongSchema(topic, input, Seq(s"'$topic' as topic", "CAST(value as INT) as value"), + "value must be a(n) string or binary") + assertWrongSchema(topic, input, Seq(s"'$topic' as topic", "CAST(value as INT) as key", "value"), + "key must be a(n) string or binary") + assertWrongSchema(topic, input, Seq(s"'$topic' as topic", "value as partition", "value"), + "partition must be a(n) int") } test("streaming - write to non-existing topic") { @@ -310,21 +259,9 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { .load() val topic = newTopic() - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - } - throw writer.exception.get - } - } finally { - writer.stop() + runAndVerifyException[StreamingQueryException](inputTopic, "job aborted") { + createKafkaWriter(input.toDF(), withTopic = Some(topic))() } - assert(ex.getCause.getCause.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) } test("streaming - exception on config serializer") { @@ -339,21 +276,10 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { .option("subscribe", inputTopic) .load() - val ex = intercept[IllegalArgumentException] { - createKafkaWriter( - input.toDF(), - withOptions = Map("kafka.key.serializer" -> "foo"))() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'key.serializer' is not supported")) - - val ex2 = intercept[IllegalArgumentException] { - createKafkaWriter( - input.toDF(), - withOptions = Map("kafka.value.serializer" -> "foo"))() - } - assert(ex2.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'value.serializer' is not supported")) + assertWrongOption(inputTopic, input.toDF(), Map("kafka.key.serializer" -> "foo"), + "kafka option 'key.serializer' is not supported") + assertWrongOption(inputTopic, input.toDF(), Map("kafka.value.serializer" -> "foo"), + "kafka option 'value.serializer' is not supported") } test("generic - write big data with small producer buffer") { @@ -422,4 +348,48 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { withOptions.foreach(opt => stream.option(opt._1, opt._2)) stream.start() } + + private def runAndVerifyException[T <: Exception : ClassTag]( + inputTopic: String, + expectErrorMsg: String)( + writerFn: => StreamingQuery): Unit = { + var writer: StreamingQuery = null + val ex: Exception = try { + intercept[T] { + writer = writerFn + testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) + eventually(timeout(streamingTimeout)) { + assert(writer.exception.isDefined) + } + throw writer.exception.get + } + } finally { + if (writer != null) writer.stop() + } + val rootException = ex match { + case e: StreamingQueryException => e.getCause.getCause + case e => e + } + assert(rootException.getMessage.toLowerCase(Locale.ROOT).contains(expectErrorMsg)) + } + + private def assertWrongSchema( + inputTopic: String, + input: DataFrame, + selectExpr: Seq[String], + expectErrorMsg: String): Unit = { + runAndVerifyException[AnalysisException](inputTopic, expectErrorMsg) { + createKafkaWriter(input)(withSelectExpr = selectExpr: _*) + } + } + + private def assertWrongOption( + inputTopic: String, + input: DataFrame, + options: Map[String, String], + expectErrorMsg: String): Unit = { + runAndVerifyException[IllegalArgumentException](inputTopic, expectErrorMsg) { + createKafkaWriter(input, withOptions = options)() + } + } } diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index aacb10f5197b..1705d76de758 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -211,38 +211,10 @@ class KafkaSinkStreamingSuite extends KafkaSinkSuiteBase with StreamTest { val topic = newTopic() testUtils.createTopic(topic) - /* No topic field or topic option */ - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = "value as key", "value" - ) - input.addData("1", "2", "3", "4", "5") - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage - .toLowerCase(Locale.ROOT) - .contains("topic option required when no 'topic' attribute is present")) - - try { - /* No value field */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "value as key" - ) - input.addData("1", "2", "3", "4", "5") - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "required attribute 'value' not found")) + assertWrongSchema(input, Seq("value as key", "value"), + "topic option required when no 'topic' attribute is present") + assertWrongSchema(input, Seq(s"'$topic' as topic", "value as key"), + "required attribute 'value' not found") } test("streaming - write data with valid schema but wrong types") { @@ -250,109 +222,31 @@ class KafkaSinkStreamingSuite extends KafkaSinkSuiteBase with StreamTest { val topic = newTopic() testUtils.createTopic(topic) - var writer: StreamingQuery = null - var ex: Exception = null - try { - /* topic field wrong type */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"CAST('1' as INT) as topic", "value" - ) - input.addData("1", "2", "3", "4", "5") - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) - - try { - /* value field wrong type */ - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" - ) - input.addData("1", "2", "3", "4", "5") - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "value attribute type must be a string or binary")) - - try { - ex = intercept[StreamingQueryException] { - /* key field wrong type */ - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" - ) - input.addData("1", "2", "3", "4", "5") - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "key attribute type must be a string or binary")) - - try { - ex = intercept[StreamingQueryException] { - /* partition field wrong type */ - writer = createKafkaWriter(input.toDF())( - withSelectExpr = s"'$topic' as topic", "value", "value as partition" - ) - input.addData("1", "2", "3", "4", "5") - writer.processAllAvailable() - } - } finally { - writer.stop() - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "partition attribute type must be an int")) + assertWrongSchema(input, Seq("CAST('1' as INT) as topic", "value"), + "topic must be a(n) string") + assertWrongSchema(input, Seq(s"'$topic' as topic", "CAST(value as INT) as value"), + "value must be a(n) string or binary") + assertWrongSchema(input, Seq(s"'$topic' as topic", "CAST(value as INT) as key", "value"), + "key must be a(n) string or binary") + assertWrongSchema(input, Seq(s"'$topic' as topic", "value", "value as partition"), + "partition must be a(n) int") } test("streaming - write to non-existing topic") { val input = MemoryStream[String] - val topic = newTopic() - var writer: StreamingQuery = null - var ex: Exception = null - try { - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() - input.addData("1", "2", "3", "4", "5") - writer.processAllAvailable() - } - } finally { - writer.stop() + runAndVerifyStreamingQueryException(input, "job aborted") { + createKafkaWriter(input.toDF(), withTopic = Some(newTopic()))() } - assert(ex.getCause.getCause.getMessage.toLowerCase(Locale.ROOT).contains("job aborted")) } test("streaming - exception on config serializer") { val input = MemoryStream[String] - var writer: StreamingQuery = null - var ex: Exception = null - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter( - input.toDF(), - withOptions = Map("kafka.key.serializer" -> "foo"))() - input.addData("1") - writer.processAllAvailable() - } - assert(ex.getCause.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'key.serializer' is not supported")) - - ex = intercept[StreamingQueryException] { - writer = createKafkaWriter( - input.toDF(), - withOptions = Map("kafka.value.serializer" -> "foo"))() - input.addData("1") - writer.processAllAvailable() - } - assert(ex.getCause.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'value.serializer' is not supported")) + + assertWrongOption(input, Map("kafka.key.serializer" -> "foo"), + "kafka option 'key.serializer' is not supported") + assertWrongOption(input, Map("kafka.value.serializer" -> "foo"), + "kafka option 'value.serializer' is not supported") } private def createKafkaWriter( @@ -379,6 +273,41 @@ class KafkaSinkStreamingSuite extends KafkaSinkSuiteBase with StreamTest { } stream.start() } + + private def runAndVerifyStreamingQueryException( + input: MemoryStream[String], + expectErrorMsg: String)( + writerFn: => StreamingQuery): Unit = { + var writer: StreamingQuery = null + val ex: Exception = try { + intercept[StreamingQueryException] { + writer = writerFn + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + if (writer != null) writer.stop() + } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains(expectErrorMsg)) + } + + private def assertWrongSchema( + input: MemoryStream[String], + selectExpr: Seq[String], + expectErrorMsg: String): Unit = { + runAndVerifyStreamingQueryException(input, expectErrorMsg) { + createKafkaWriter(input.toDF())(withSelectExpr = selectExpr: _*) + } + } + + private def assertWrongOption( + input: MemoryStream[String], + options: Map[String, String], + expectErrorMsg: String): Unit = { + runAndVerifyStreamingQueryException(input, expectErrorMsg) { + createKafkaWriter(input.toDF(), withOptions = options)() + } + } } abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase { diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala index 88d6d0eea536..a449a8bb7213 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala @@ -237,9 +237,10 @@ private[spark] class DirectKafkaInputDStream[K, V]( val description = offsetRanges.filter { offsetRange => // Don't display empty ranges. offsetRange.fromOffset != offsetRange.untilOffset - }.map { offsetRange => + }.toSeq.sortBy(-_.count()).map { offsetRange => s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" + - s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}" + s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}\t" + + s"count: ${offsetRange.count()}" }.mkString("\n") // Copy offsetRanges to immutable.List to prevent from being modified by the user val metadata = Map( diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index a23d255f9187..db64b201abc2 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -39,10 +39,10 @@ spark-core_${scala.binary.version} ${project.version} - - io.dropwizard.metrics - metrics-ganglia + info.ganglia.gmetric4j + gmetric4j + 1.0.10 diff --git a/external/spark-ganglia-lgpl/src/main/java/com/codahale/metrics/ganglia/GangliaReporter.java b/external/spark-ganglia-lgpl/src/main/java/com/codahale/metrics/ganglia/GangliaReporter.java new file mode 100644 index 000000000000..019ee08e0918 --- /dev/null +++ b/external/spark-ganglia-lgpl/src/main/java/com/codahale/metrics/ganglia/GangliaReporter.java @@ -0,0 +1,426 @@ +// Copied from +// https://raw.githubusercontent.com/dropwizard/metrics/v3.2.6/metrics-ganglia/ +// src/main/java/com/codahale/metrics/ganglia/GangliaReporter.java + +package com.codahale.metrics.ganglia; + +import com.codahale.metrics.*; +import com.codahale.metrics.MetricAttribute; +import info.ganglia.gmetric4j.gmetric.GMetric; +import info.ganglia.gmetric4j.gmetric.GMetricSlope; +import info.ganglia.gmetric4j.gmetric.GMetricType; +import info.ganglia.gmetric4j.gmetric.GangliaException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.SortedMap; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.regex.Pattern; + +import static com.codahale.metrics.MetricRegistry.name; +import static com.codahale.metrics.MetricAttribute.*; + +/** + * A reporter which announces metric values to a Ganglia cluster. + * + * @see Ganglia Monitoring System + */ +public class GangliaReporter extends ScheduledReporter { + + private static final Pattern SLASHES = Pattern.compile("\\\\"); + + /** + * Returns a new {@link Builder} for {@link GangliaReporter}. + * + * @param registry the registry to report + * @return a {@link Builder} instance for a {@link GangliaReporter} + */ + public static Builder forRegistry(MetricRegistry registry) { + return new Builder(registry); + } + + /** + * A builder for {@link GangliaReporter} instances. Defaults to using a {@code tmax} of {@code 60}, + * a {@code dmax} of {@code 0}, converting rates to events/second, converting durations to + * milliseconds, and not filtering metrics. + */ + public static class Builder { + private final MetricRegistry registry; + private String prefix; + private int tMax; + private int dMax; + private TimeUnit rateUnit; + private TimeUnit durationUnit; + private MetricFilter filter; + private ScheduledExecutorService executor; + private boolean shutdownExecutorOnStop; + private Set disabledMetricAttributes = Collections.emptySet(); + + private Builder(MetricRegistry registry) { + this.registry = registry; + this.tMax = 60; + this.dMax = 0; + this.rateUnit = TimeUnit.SECONDS; + this.durationUnit = TimeUnit.MILLISECONDS; + this.filter = MetricFilter.ALL; + this.executor = null; + this.shutdownExecutorOnStop = true; + } + + /** + * Specifies whether or not, the executor (used for reporting) will be stopped with same time with reporter. + * Default value is true. + * Setting this parameter to false, has the sense in combining with providing external managed executor via {@link #scheduleOn(ScheduledExecutorService)}. + * + * @param shutdownExecutorOnStop if true, then executor will be stopped in same time with this reporter + * @return {@code this} + */ + public Builder shutdownExecutorOnStop(boolean shutdownExecutorOnStop) { + this.shutdownExecutorOnStop = shutdownExecutorOnStop; + return this; + } + + /** + * Specifies the executor to use while scheduling reporting of metrics. + * Default value is null. + * Null value leads to executor will be auto created on start. + * + * @param executor the executor to use while scheduling reporting of metrics. + * @return {@code this} + */ + public Builder scheduleOn(ScheduledExecutorService executor) { + this.executor = executor; + return this; + } + + /** + * Use the given {@code tmax} value when announcing metrics. + * + * @param tMax the desired gmond {@code tmax} value + * @return {@code this} + */ + public Builder withTMax(int tMax) { + this.tMax = tMax; + return this; + } + + /** + * Prefix all metric names with the given string. + * + * @param prefix the prefix for all metric names + * @return {@code this} + */ + public Builder prefixedWith(String prefix) { + this.prefix = prefix; + return this; + } + + /** + * Use the given {@code dmax} value when announcing metrics. + * + * @param dMax the desired gmond {@code dmax} value + * @return {@code this} + */ + public Builder withDMax(int dMax) { + this.dMax = dMax; + return this; + } + + /** + * Convert rates to the given time unit. + * + * @param rateUnit a unit of time + * @return {@code this} + */ + public Builder convertRatesTo(TimeUnit rateUnit) { + this.rateUnit = rateUnit; + return this; + } + + /** + * Convert durations to the given time unit. + * + * @param durationUnit a unit of time + * @return {@code this} + */ + public Builder convertDurationsTo(TimeUnit durationUnit) { + this.durationUnit = durationUnit; + return this; + } + + /** + * Only report metrics which match the given filter. + * + * @param filter a {@link MetricFilter} + * @return {@code this} + */ + public Builder filter(MetricFilter filter) { + this.filter = filter; + return this; + } + + /** + * Don't report the passed metric attributes for all metrics (e.g. "p999", "stddev" or "m15"). + * See {@link MetricAttribute}. + * + * @param disabledMetricAttributes a {@link MetricFilter} + * @return {@code this} + */ + public Builder disabledMetricAttributes(Set disabledMetricAttributes) { + this.disabledMetricAttributes = disabledMetricAttributes; + return this; + } + + /** + * Builds a {@link GangliaReporter} with the given properties, announcing metrics to the + * given {@link GMetric} client. + * + * @param gmetric the client to use for announcing metrics + * @return a {@link GangliaReporter} + */ + public GangliaReporter build(GMetric gmetric) { + return new GangliaReporter(registry, gmetric, null, prefix, tMax, dMax, rateUnit, durationUnit, filter, + executor, shutdownExecutorOnStop, disabledMetricAttributes); + } + + /** + * Builds a {@link GangliaReporter} with the given properties, announcing metrics to the + * given {@link GMetric} client. + * + * @param gmetrics the clients to use for announcing metrics + * @return a {@link GangliaReporter} + */ + public GangliaReporter build(GMetric... gmetrics) { + return new GangliaReporter(registry, null, gmetrics, prefix, tMax, dMax, rateUnit, durationUnit, + filter, executor, shutdownExecutorOnStop , disabledMetricAttributes); + } + } + + private static final Logger LOGGER = LoggerFactory.getLogger(GangliaReporter.class); + + private final GMetric gmetric; + private final GMetric[] gmetrics; + private final String prefix; + private final int tMax; + private final int dMax; + + private GangliaReporter(MetricRegistry registry, + GMetric gmetric, + GMetric[] gmetrics, + String prefix, + int tMax, + int dMax, + TimeUnit rateUnit, + TimeUnit durationUnit, + MetricFilter filter, + ScheduledExecutorService executor, + boolean shutdownExecutorOnStop, + Set disabledMetricAttributes) { + super(registry, "ganglia-reporter", filter, rateUnit, durationUnit, executor, shutdownExecutorOnStop, + disabledMetricAttributes); + this.gmetric = gmetric; + this.gmetrics = gmetrics; + this.prefix = prefix; + this.tMax = tMax; + this.dMax = dMax; + } + + @Override + public void report(SortedMap gauges, + SortedMap counters, + SortedMap histograms, + SortedMap meters, + SortedMap timers) { + for (Map.Entry entry : gauges.entrySet()) { + reportGauge(entry.getKey(), entry.getValue()); + } + + for (Map.Entry entry : counters.entrySet()) { + reportCounter(entry.getKey(), entry.getValue()); + } + + for (Map.Entry entry : histograms.entrySet()) { + reportHistogram(entry.getKey(), entry.getValue()); + } + + for (Map.Entry entry : meters.entrySet()) { + reportMeter(entry.getKey(), entry.getValue()); + } + + for (Map.Entry entry : timers.entrySet()) { + reportTimer(entry.getKey(), entry.getValue()); + } + } + + private void reportTimer(String name, Timer timer) { + final String sanitizedName = escapeSlashes(name); + final String group = group(name); + try { + final Snapshot snapshot = timer.getSnapshot(); + + announceIfEnabled(MAX, sanitizedName, group, convertDuration(snapshot.getMax()), getDurationUnit()); + announceIfEnabled(MEAN, sanitizedName, group, convertDuration(snapshot.getMean()), getDurationUnit()); + announceIfEnabled(MIN, sanitizedName, group, convertDuration(snapshot.getMin()), getDurationUnit()); + announceIfEnabled(STDDEV, sanitizedName, group, convertDuration(snapshot.getStdDev()), getDurationUnit()); + + announceIfEnabled(P50, sanitizedName, group, convertDuration(snapshot.getMedian()), getDurationUnit()); + announceIfEnabled(P75, sanitizedName, + group, + convertDuration(snapshot.get75thPercentile()), + getDurationUnit()); + announceIfEnabled(P95, sanitizedName, + group, + convertDuration(snapshot.get95thPercentile()), + getDurationUnit()); + announceIfEnabled(P98, sanitizedName, + group, + convertDuration(snapshot.get98thPercentile()), + getDurationUnit()); + announceIfEnabled(P99, sanitizedName, + group, + convertDuration(snapshot.get99thPercentile()), + getDurationUnit()); + announceIfEnabled(P999, sanitizedName, + group, + convertDuration(snapshot.get999thPercentile()), + getDurationUnit()); + + reportMetered(sanitizedName, timer, group, "calls"); + } catch (GangliaException e) { + LOGGER.warn("Unable to report timer {}", sanitizedName, e); + } + } + + private void reportMeter(String name, Meter meter) { + final String sanitizedName = escapeSlashes(name); + final String group = group(name); + try { + reportMetered(sanitizedName, meter, group, "events"); + } catch (GangliaException e) { + LOGGER.warn("Unable to report meter {}", name, e); + } + } + + private void reportMetered(String name, Metered meter, String group, String eventName) throws GangliaException { + final String unit = eventName + '/' + getRateUnit(); + announceIfEnabled(COUNT, name, group, meter.getCount(), eventName); + announceIfEnabled(M1_RATE, name, group, convertRate(meter.getOneMinuteRate()), unit); + announceIfEnabled(M5_RATE, name, group, convertRate(meter.getFiveMinuteRate()), unit); + announceIfEnabled(M15_RATE, name, group, convertRate(meter.getFifteenMinuteRate()), unit); + announceIfEnabled(MEAN_RATE, name, group, convertRate(meter.getMeanRate()), unit); + } + + private void reportHistogram(String name, Histogram histogram) { + final String sanitizedName = escapeSlashes(name); + final String group = group(name); + try { + final Snapshot snapshot = histogram.getSnapshot(); + + announceIfEnabled(COUNT, sanitizedName, group, histogram.getCount(), ""); + announceIfEnabled(MAX, sanitizedName, group, snapshot.getMax(), ""); + announceIfEnabled(MEAN, sanitizedName, group, snapshot.getMean(), ""); + announceIfEnabled(MIN, sanitizedName, group, snapshot.getMin(), ""); + announceIfEnabled(STDDEV, sanitizedName, group, snapshot.getStdDev(), ""); + announceIfEnabled(P50, sanitizedName, group, snapshot.getMedian(), ""); + announceIfEnabled(P75, sanitizedName, group, snapshot.get75thPercentile(), ""); + announceIfEnabled(P95, sanitizedName, group, snapshot.get95thPercentile(), ""); + announceIfEnabled(P98, sanitizedName, group, snapshot.get98thPercentile(), ""); + announceIfEnabled(P99, sanitizedName, group, snapshot.get99thPercentile(), ""); + announceIfEnabled(P999, sanitizedName, group, snapshot.get999thPercentile(), ""); + } catch (GangliaException e) { + LOGGER.warn("Unable to report histogram {}", sanitizedName, e); + } + } + + private void reportCounter(String name, Counter counter) { + final String sanitizedName = escapeSlashes(name); + final String group = group(name); + try { + announce(prefix(sanitizedName, COUNT.getCode()), group, Long.toString(counter.getCount()), GMetricType.DOUBLE, ""); + } catch (GangliaException e) { + LOGGER.warn("Unable to report counter {}", name, e); + } + } + + private void reportGauge(String name, Gauge gauge) { + final String sanitizedName = escapeSlashes(name); + final String group = group(name); + final Object obj = gauge.getValue(); + final String value = String.valueOf(obj); + final GMetricType type = detectType(obj); + try { + announce(name(prefix, sanitizedName), group, value, type, ""); + } catch (GangliaException e) { + LOGGER.warn("Unable to report gauge {}", name, e); + } + } + + private static final double MIN_VAL = 1E-300; + + private void announceIfEnabled(MetricAttribute metricAttribute, String metricName, String group, double value, String units) + throws GangliaException { + if (getDisabledMetricAttributes().contains(metricAttribute)) { + return; + } + final String string = Math.abs(value) < MIN_VAL ? "0" : Double.toString(value); + announce(prefix(metricName, metricAttribute.getCode()), group, string, GMetricType.DOUBLE, units); + } + + private void announceIfEnabled(MetricAttribute metricAttribute, String metricName, String group, long value, String units) + throws GangliaException { + if (getDisabledMetricAttributes().contains(metricAttribute)) { + return; + } + announce(prefix(metricName, metricAttribute.getCode()), group, Long.toString(value), GMetricType.DOUBLE, units); + } + + private void announce(String name, String group, String value, GMetricType type, String units) + throws GangliaException { + if (gmetric != null) { + gmetric.announce(name, value, type, units, GMetricSlope.BOTH, tMax, dMax, group); + } else { + for (GMetric gmetric : gmetrics) { + gmetric.announce(name, value, type, units, GMetricSlope.BOTH, tMax, dMax, group); + } + } + } + + private GMetricType detectType(Object o) { + if (o instanceof Float) { + return GMetricType.FLOAT; + } else if (o instanceof Double) { + return GMetricType.DOUBLE; + } else if (o instanceof Byte) { + return GMetricType.INT8; + } else if (o instanceof Short) { + return GMetricType.INT16; + } else if (o instanceof Integer) { + return GMetricType.INT32; + } else if (o instanceof Long) { + return GMetricType.DOUBLE; + } + return GMetricType.STRING; + } + + private String group(String name) { + final int i = name.lastIndexOf('.'); + if (i < 0) { + return ""; + } + return name.substring(0, i); + } + + private String prefix(String name, String n) { + return name(prefix, name, n); + } + + // ganglia metric names can't contain slashes. + private String escapeSlashes(String name) { + return SLASHES.matcher(name).replaceAll("_"); + } +} diff --git a/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java b/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java index fe44efd2e46a..d1b350fd9f48 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/ChildProcAppHandleSuite.java @@ -77,11 +77,11 @@ public void testRedirectsSimple() throws Exception { SparkLauncher launcher = new SparkLauncher(); launcher.redirectError(ProcessBuilder.Redirect.PIPE); assertNotNull(launcher.errorStream); - assertEquals(launcher.errorStream.type(), ProcessBuilder.Redirect.Type.PIPE); + assertEquals(ProcessBuilder.Redirect.Type.PIPE, launcher.errorStream.type()); launcher.redirectOutput(ProcessBuilder.Redirect.PIPE); assertNotNull(launcher.outputStream); - assertEquals(launcher.outputStream.type(), ProcessBuilder.Redirect.Type.PIPE); + assertEquals(ProcessBuilder.Redirect.Type.PIPE, launcher.outputStream.type()); } @Test @@ -89,11 +89,11 @@ public void testRedirectLastWins() throws Exception { SparkLauncher launcher = new SparkLauncher(); launcher.redirectError(ProcessBuilder.Redirect.PIPE) .redirectError(ProcessBuilder.Redirect.INHERIT); - assertEquals(launcher.errorStream.type(), ProcessBuilder.Redirect.Type.INHERIT); + assertEquals(ProcessBuilder.Redirect.Type.INHERIT, launcher.errorStream.type()); launcher.redirectOutput(ProcessBuilder.Redirect.PIPE) .redirectOutput(ProcessBuilder.Redirect.INHERIT); - assertEquals(launcher.outputStream.type(), ProcessBuilder.Redirect.Type.INHERIT); + assertEquals(ProcessBuilder.Redirect.Type.INHERIT, launcher.outputStream.type()); } @Test diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index e467228b4cc1..27cf2988aae8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -37,6 +37,7 @@ import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.storage.StorageLevel /** Params for linear SVM Classifier. */ private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam @@ -159,7 +160,10 @@ class LinearSVC @Since("2.2.0") ( override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra) override protected def train(dataset: Dataset[_]): LinearSVCModel = instrumented { instr => + val handlePersistence = dataset.storageLevel == StorageLevel.NONE + val instances = extractInstances(dataset) + if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) instr.logPipelineStage(this) instr.logDataset(dataset) @@ -268,6 +272,8 @@ class LinearSVC @Since("2.2.0") ( (Vectors.dense(coefficientArray), intercept, scaledObjectiveHistory.result()) } + if (handlePersistence) instances.unpersist() + copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 99c0a0df5367..fbccfb1041d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasInputCols, HasOutputCols} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -32,7 +32,8 @@ import org.apache.spark.sql.types._ /** * Params for [[Imputer]] and [[ImputerModel]]. */ -private[feature] trait ImputerParams extends Params with HasInputCols with HasOutputCols { +private[feature] trait ImputerParams extends Params with HasInputCol with HasInputCols + with HasOutputCol with HasOutputCols with HasRelativeError { /** * The imputation strategy. Currently only "mean" and "median" are supported. @@ -63,15 +64,26 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu /** @group getParam */ def getMissingValue: Double = $(missingValue) + /** Returns the input and output column names corresponding in pair. */ + private[feature] def getInOutCols(): (Array[String], Array[String]) = { + if (isSet(inputCol)) { + (Array($(inputCol)), Array($(outputCol))) + } else { + ($(inputCols), $(outputCols)) + } + } + /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - require($(inputCols).length == $(inputCols).distinct.length, s"inputCols contains" + - s" duplicates: (${$(inputCols).mkString(", ")})") - require($(outputCols).length == $(outputCols).distinct.length, s"outputCols contains" + - s" duplicates: (${$(outputCols).mkString(", ")})") - require($(inputCols).length == $(outputCols).length, s"inputCols(${$(inputCols).length})" + - s" and outputCols(${$(outputCols).length}) should have the same length") - val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) => + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), Seq(outputCols)) + val (inputColNames, outputColNames) = getInOutCols() + require(inputColNames.length == inputColNames.distinct.length, s"inputCols contains" + + s" duplicates: (${inputColNames.mkString(", ")})") + require(outputColNames.length == outputColNames.distinct.length, s"outputCols contains" + + s" duplicates: (${outputColNames.mkString(", ")})") + require(inputColNames.length == outputColNames.length, s"inputCols(${inputColNames.length})" + + s" and outputCols(${outputColNames.length}) should have the same length") + val outputFields = inputColNames.zip(outputColNames).map { case (inputCol, outputCol) => val inputField = schema(inputCol) SchemaUtils.checkNumericType(schema, inputCol) StructField(outputCol, inputField.dataType, inputField.nullable) @@ -101,6 +113,14 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) @Since("2.2.0") def this() = this(Identifiable.randomUID("imputer")) + /** @group setParam */ + @Since("3.0.0") + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + @Since("3.0.0") + def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ @Since("2.2.0") def setInputCols(value: Array[String]): this.type = set(inputCols, value) @@ -120,13 +140,19 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) @Since("2.2.0") def setMissingValue(value: Double): this.type = set(missingValue, value) + /** @group expertSetParam */ + @Since("3.0.0") + def setRelativeError(value: Double): this.type = set(relativeError, value) + setDefault(strategy -> Imputer.mean, missingValue -> Double.NaN) override def fit(dataset: Dataset[_]): ImputerModel = { transformSchema(dataset.schema, logging = true) val spark = dataset.sparkSession - val cols = $(inputCols).map { inputCol => + val (inputColumns, _) = getInOutCols() + + val cols = inputColumns.map { inputCol => when(col(inputCol).equalTo($(missingValue)), null) .when(col(inputCol).isNaN, null) .otherwise(col(inputCol)) @@ -139,7 +165,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) // Function avg will ignore null automatically. // For a column only containing null, avg will return null. val row = dataset.select(cols.map(avg): _*).head() - Array.range(0, $(inputCols).length).map { i => + Array.range(0, inputColumns.length).map { i => if (row.isNullAt(i)) { Double.NaN } else { @@ -150,7 +176,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) case Imputer.median => // Function approxQuantile will ignore null automatically. // For a column only containing null, approxQuantile will return an empty array. - dataset.select(cols: _*).stat.approxQuantile($(inputCols), Array(0.5), 0.001) + dataset.select(cols: _*).stat.approxQuantile(inputColumns, Array(0.5), $(relativeError)) .map { array => if (array.isEmpty) { Double.NaN @@ -160,7 +186,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) } } - val emptyCols = $(inputCols).zip(results).filter(_._2.isNaN).map(_._1) + val emptyCols = inputColumns.zip(results).filter(_._2.isNaN).map(_._1) if (emptyCols.nonEmpty) { throw new SparkException(s"surrogate cannot be computed. " + s"All the values in ${emptyCols.mkString(",")} are Null, Nan or " + @@ -168,7 +194,7 @@ class Imputer @Since("2.2.0") (@Since("2.2.0") override val uid: String) } val rows = spark.sparkContext.parallelize(Seq(Row.fromSeq(results))) - val schema = StructType($(inputCols).map(col => StructField(col, DoubleType, nullable = false))) + val schema = StructType(inputColumns.map(col => StructField(col, DoubleType, nullable = false))) val surrogateDF = spark.createDataFrame(rows, schema) copyValues(new ImputerModel(uid, surrogateDF).setParent(this)) } @@ -205,6 +231,14 @@ class ImputerModel private[ml] ( import ImputerModel._ + /** @group setParam */ + @Since("3.0.0") + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + @Since("3.0.0") + def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ def setInputCols(value: Array[String]): this.type = set(inputCols, value) @@ -213,9 +247,11 @@ class ImputerModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq + val (inputColumns, outputColumns) = getInOutCols + val surrogates = surrogateDF.select(inputColumns.map(col): _*).head().toSeq + - val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map { + val newCols = inputColumns.zip(outputColumns).zip(surrogates).map { case ((inputCol, outputCol), surrogate) => val inputType = dataset.schema(inputCol).dataType val ic = col(inputCol).cast(DoubleType) @@ -224,7 +260,7 @@ class ImputerModel private[ml] ( .otherwise(ic) .cast(inputType) } - dataset.withColumns($(outputCols), newCols).toDF() + dataset.withColumns(outputColumns, newCols).toDF() } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index ec9792cbbda8..459994c352da 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -25,7 +25,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCols, HasOutputCols} +import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.UserDefinedFunction @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** Private trait for params and common methods for OneHotEncoder and OneHotEncoderModel */ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid - with HasInputCols with HasOutputCols { + with HasInputCol with HasInputCols with HasOutputCol with HasOutputCols { /** * Param for how to handle invalid data during transform(). @@ -68,12 +68,21 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid @Since("2.3.0") def getDropLast: Boolean = $(dropLast) + /** Returns the input and output column names corresponding in pair. */ + private[feature] def getInOutCols(): (Array[String], Array[String]) = { + if (isSet(inputCol)) { + (Array($(inputCol)), Array($(outputCol))) + } else { + ($(inputCols), $(outputCols)) + } + } + protected def validateAndTransformSchema( schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = { - val inputColNames = $(inputCols) - val outputColNames = $(outputCols) + ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), Seq(outputCols)) + val (inputColNames, outputColNames) = getInOutCols() require(inputColNames.length == outputColNames.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -83,7 +92,7 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid inputColNames.foreach(SchemaUtils.checkNumericType(schema, _)) // Prepares output columns with proper attributes by examining input columns. - val inputFields = $(inputCols).map(schema(_)) + val inputFields = inputColNames.map(schema(_)) val outputFields = inputFields.zip(outputColNames).map { case (inputField, outputColName) => OneHotEncoderCommon.transformOutputColumnSchema( @@ -123,6 +132,14 @@ class OneHotEncoder @Since("3.0.0") (@Since("3.0.0") override val uid: String) @Since("3.0.0") def this() = this(Identifiable.randomUID("oneHotEncoder")) + /** @group setParam */ + @Since("3.0.0") + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + @Since("3.0.0") + def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ @Since("3.0.0") def setInputCols(values: Array[String]): this.type = set(inputCols, values) @@ -150,13 +167,14 @@ class OneHotEncoder @Since("3.0.0") (@Since("3.0.0") override val uid: String) override def fit(dataset: Dataset[_]): OneHotEncoderModel = { transformSchema(dataset.schema) + val (inputColumns, outputColumns) = getInOutCols() // Compute the plain number of categories without `handleInvalid` and // `dropLast` taken into account. val transformedSchema = validateAndTransformSchema(dataset.schema, dropLast = false, keepInvalid = false) - val categorySizes = new Array[Int]($(outputCols).length) + val categorySizes = new Array[Int](outputColumns.length) - val columnToScanIndices = $(outputCols).zipWithIndex.flatMap { case (outputColName, idx) => + val columnToScanIndices = outputColumns.zipWithIndex.flatMap { case (outputColName, idx) => val numOfAttrs = AttributeGroup.fromStructField( transformedSchema(outputColName)).size if (numOfAttrs < 0) { @@ -170,8 +188,8 @@ class OneHotEncoder @Since("3.0.0") (@Since("3.0.0") override val uid: String) // Some input columns don't have attributes or their attributes don't have necessary info. // We need to scan the data to get the number of values for each column. if (columnToScanIndices.length > 0) { - val inputColNames = columnToScanIndices.map($(inputCols)(_)) - val outputColNames = columnToScanIndices.map($(outputCols)(_)) + val inputColNames = columnToScanIndices.map(inputColumns(_)) + val outputColNames = columnToScanIndices.map(outputColumns(_)) // When fitting data, we want the plain number of categories without `handleInvalid` and // `dropLast` taken into account. @@ -287,7 +305,7 @@ class OneHotEncoderModel private[ml] ( @Since("3.0.0") override def transformSchema(schema: StructType): StructType = { - val inputColNames = $(inputCols) + val (inputColNames, _) = getInOutCols() require(inputColNames.length == categorySizes.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -306,8 +324,9 @@ class OneHotEncoderModel private[ml] ( */ private def verifyNumOfValues(schema: StructType): StructType = { val configedSizes = getConfigedCategorySizes - $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => - val inputColName = $(inputCols)(idx) + val (inputColNames, outputColNames) = getInOutCols() + outputColNames.zipWithIndex.foreach { case (outputColName, idx) => + val inputColName = inputColNames(idx) val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) // If the input metadata specifies number of category for output column, @@ -327,10 +346,11 @@ class OneHotEncoderModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema, logging = true) val keepInvalid = $(handleInvalid) == OneHotEncoder.KEEP_INVALID + val (inputColNames, outputColNames) = getInOutCols() - val encodedColumns = $(inputCols).indices.map { idx => - val inputColName = $(inputCols)(idx) - val outputColName = $(outputCols)(idx) + val encodedColumns = inputColNames.indices.map { idx => + val inputColName = inputColNames(idx) + val outputColName = outputColNames(idx) val outputAttrGroupFromSchema = AttributeGroup.fromStructField(transformedSchema(outputColName)) @@ -345,7 +365,7 @@ class OneHotEncoderModel private[ml] ( encoder(col(inputColName).cast(DoubleType), lit(idx)) .as(outputColName, metadata) } - dataset.withColumns($(outputCols), encodedColumns) + dataset.withColumns(outputColNames, encodedColumns) } @Since("3.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index eb78d8224fc3..216d99d01f2f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -22,7 +22,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, HasInputCols, HasOutputCol, HasOutputCols} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -31,7 +31,8 @@ import org.apache.spark.sql.types.StructType * Params for [[QuantileDiscretizer]]. */ private[feature] trait QuantileDiscretizerBase extends Params - with HasHandleInvalid with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols { + with HasHandleInvalid with HasInputCol with HasOutputCol with HasInputCols with HasOutputCols + with HasRelativeError { /** * Number of buckets (quantiles, or categories) into which data points are grouped. Must @@ -67,22 +68,6 @@ private[feature] trait QuantileDiscretizerBase extends Params /** @group getParam */ def getNumBucketsArray: Array[Int] = $(numBucketsArray) - /** - * Relative error (see documentation for - * `org.apache.spark.sql.DataFrameStatFunctions.approxQuantile` for description) - * Must be in the range [0, 1]. - * Note that in multiple columns case, relative error is applied to all columns. - * default: 0.001 - * @group param - */ - val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " + - "for the approximate quantile algorithm used to generate buckets. " + - "Must be in the range [0, 1].", ParamValidators.inRange(0.0, 1.0)) - setDefault(relativeError -> 0.001) - - /** @group getParam */ - def getRelativeError: Double = getOrDefault(relativeError) - /** * Param for how to handle invalid entries. Options are 'skip' (filter out rows with * invalid values), 'error' (throw an error), or 'keep' (keep invalid values in a special @@ -98,7 +83,6 @@ private[feature] trait QuantileDiscretizerBase extends Params "error (throw an error), or keep (keep invalid values in a special additional bucket).", ParamValidators.inArray(Bucketizer.supportedHandleInvalids)) setDefault(handleInvalid, Bucketizer.ERROR_INVALID) - } /** @@ -110,7 +94,8 @@ private[feature] trait QuantileDiscretizerBase extends Params * parameter. If both of the `inputCol` and `inputCols` parameters are set, an Exception will be * thrown. To specify the number of buckets for each column, the `numBucketsArray` parameter can * be set, or if the number of buckets should be the same across columns, `numBuckets` can be - * set as a convenience. + * set as a convenience. Note that in multiple columns case, relative error is applied to all + * columns. * * NaN handling: * null and NaN values will be ignored from the column during `QuantileDiscretizer` fitting. This @@ -134,7 +119,7 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") def this() = this(Identifiable.randomUID("quantileDiscretizer")) - /** @group setParam */ + /** @group expertSetParam */ @Since("2.0.0") def setRelativeError(value: Double): this.type = set(relativeError, value) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala index 9dae39756d31..1d609ef3190d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RobustScaler.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg._ import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasRelativeError} import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql._ @@ -34,7 +34,8 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * Params for [[RobustScaler]] and [[RobustScalerModel]]. */ -private[feature] trait RobustScalerParams extends Params with HasInputCol with HasOutputCol { +private[feature] trait RobustScalerParams extends Params with HasInputCol with HasOutputCol + with HasRelativeError { /** * Lower quantile to calculate quantile range, shared by all features @@ -141,8 +142,12 @@ class RobustScaler (override val uid: String) /** @group setParam */ def setWithScaling(value: Boolean): this.type = set(withScaling, value) + /** @group expertSetParam */ + def setRelativeError(value: Double): this.type = set(relativeError, value) + override def fit(dataset: Dataset[_]): RobustScalerModel = { transformSchema(dataset.schema, logging = true) + val localRelativeError = $(relativeError) val summaries = dataset.select($(inputCol)).rdd.map { case Row(vec: Vector) => vec @@ -152,7 +157,7 @@ class RobustScaler (override val uid: String) val vec = iter.next() if (agg == null) { agg = Array.fill(vec.size)( - new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, 0.001)) + new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, localRelativeError)) } require(vec.size == agg.length, s"Number of dimensions must be ${agg.length} but got ${vec.size}") diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 7ac9a288d285..7ac680ec1183 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -83,6 +83,9 @@ private[shared] object SharedParamsCodeGen { isValid = "ParamValidators.inRange(0, 1)"), ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms (>= 0)", isValid = "ParamValidators.gtEq(0)"), + ParamDesc[Double]("relativeError", "the relative target precision for the approximate " + + "quantile algorithm. Must be in the range [0, 1]", + Some("0.001"), isValid = "ParamValidators.inRange(0, 1)", isExpertParam = true), ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization (>" + " 0)", isValid = "ParamValidators.gt(0)", finalFields = false), ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 6eeeb57e08fb..44c993eeafdd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -418,6 +418,25 @@ trait HasTol extends Params { final def getTol: Double = $(tol) } +/** + * Trait for shared param relativeError (default: 0.001). This trait may be changed or + * removed between minor versions. + */ +@DeveloperApi +trait HasRelativeError extends Params { + + /** + * Param for the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]. + * @group expertParam + */ + final val relativeError: DoubleParam = new DoubleParam(this, "relativeError", "the relative target precision for the approximate quantile algorithm. Must be in the range [0, 1]", ParamValidators.inRange(0, 1)) + + setDefault(relativeError, 0.001) + + /** @group expertGetParam */ + final def getRelativeError: Double = $(relativeError) +} + /** * Trait for shared param stepSize. This trait may be changed or * removed between minor versions. diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 004102103d52..49ac49339415 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -50,7 +50,7 @@ public void setUp() throws IOException { @Test public void logisticRegressionDefaultParams() { LogisticRegression lr = new LogisticRegression(); - Assert.assertEquals(lr.getLabelCol(), "label"); + Assert.assertEquals("label", lr.getLabelCol()); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).createOrReplaceTempView("prediction"); Dataset predictions = spark.sql("SELECT label, probability, prediction FROM prediction"); @@ -119,8 +119,8 @@ public void logisticRegressionPredictorClassifierMethods() { for (Row row : trans1.collectAsList()) { Vector raw = (Vector) row.get(0); Vector prob = (Vector) row.get(1); - Assert.assertEquals(raw.size(), 2); - Assert.assertEquals(prob.size(), 2); + Assert.assertEquals(2, raw.size()); + Assert.assertEquals(2, prob.size()); double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1))); Assert.assertEquals(0, Math.abs(prob.apply(1) - probFromRaw1), eps); Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index 6194167bda35..62888b85a075 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -62,12 +62,12 @@ public void setUp() throws IOException { public void oneVsRestDefaultParams() { OneVsRest ova = new OneVsRest(); ova.setClassifier(new LogisticRegression()); - Assert.assertEquals(ova.getLabelCol(), "label"); - Assert.assertEquals(ova.getPredictionCol(), "prediction"); + Assert.assertEquals("label", ova.getLabelCol()); + Assert.assertEquals("prediction", ova.getPredictionCol()); OneVsRestModel ovaModel = ova.fit(dataset); Dataset predictions = ovaModel.transform(dataset).select("label", "prediction"); predictions.collectAsList(); - Assert.assertEquals(ovaModel.getLabelCol(), "label"); - Assert.assertEquals(ovaModel.getPredictionCol(), "prediction"); + Assert.assertEquals("label", ovaModel.getLabelCol()); + Assert.assertEquals("prediction", ovaModel.getPredictionCol()); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 57696d0150a8..71c644553c4a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -64,7 +64,7 @@ public void hashingTF() { Dataset rescaledData = idfModel.transform(featurizedData); for (Row r : rescaledData.select("features", "label").takeAsList(3)) { Vector features = r.getAs(0); - Assert.assertEquals(features.size(), numFeatures); + Assert.assertEquals(numFeatures, features.size()); } } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index ca8fae3a48b9..cf5308bac3c3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -47,9 +47,9 @@ public void vectorIndexerAPI() { .setOutputCol("indexed") .setMaxCategories(2); VectorIndexerModel model = indexer.fit(data); - Assert.assertEquals(model.numFeatures(), 2); + Assert.assertEquals(2, model.numFeatures()); Map> categoryMaps = model.javaCategoryMaps(); - Assert.assertEquals(categoryMaps.size(), 1); + Assert.assertEquals(1, categoryMaps.size()); Dataset indexedData = model.transform(data); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index 3dc2e1f89614..b9bca9d5a3be 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -63,7 +63,7 @@ public void vectorSlice() { for (Row r : output.select("userFeatures", "features").takeAsList(2)) { Vector features = r.getAs(1); - Assert.assertEquals(features.size(), 2); + Assert.assertEquals(2, features.size()); } } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java index d0a849fd11c7..f6041e052871 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -53,7 +53,7 @@ public void testJavaWord2Vec() { for (Row r : result.select("result").collectAsList()) { double[] polyFeatures = ((Vector) r.get(0)).toArray(); - Assert.assertEquals(polyFeatures.length, 3); + Assert.assertEquals(3, polyFeatures.length); } } } diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java index 1077e103a3b8..5dae65c6e50a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaParamsSuite.java @@ -30,10 +30,10 @@ public class JavaParamsSuite { @Test public void testParams() { JavaTestParams testParams = new JavaTestParams(); - Assert.assertEquals(testParams.getMyIntParam(), 1); + Assert.assertEquals(1, testParams.getMyIntParam()); testParams.setMyIntParam(2).setMyDoubleParam(0.4).setMyStringParam("a"); - Assert.assertEquals(testParams.getMyDoubleParam(), 0.4, 0.0); - Assert.assertEquals(testParams.getMyStringParam(), "a"); + Assert.assertEquals(0.4, testParams.getMyDoubleParam(), 0.0); + Assert.assertEquals("a", testParams.getMyStringParam()); Assert.assertArrayEquals(testParams.getMyDoubleArrayParam(), new double[]{1.0, 2.0}, 0.0); } diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java index bf7671993777..51313f4fb581 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaGaussianMixtureSuite.java @@ -42,7 +42,7 @@ public void runGaussianMixture() { JavaRDD data = jsc.parallelize(points, 2); GaussianMixtureModel model = new GaussianMixture().setK(2).setMaxIterations(1).setSeed(1234) .run(data); - assertEquals(model.gaussians().length, 2); + assertEquals(2, model.gaussians().length); JavaRDD predictions = model.predict(data); predictions.first(); } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index 02ef261a6c06..dfee2b4029c8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkException +import org.apache.spark.ml.Pipeline +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} @@ -36,7 +38,31 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { val imputer = new Imputer() .setInputCols(Array("value1", "value2")) .setOutputCols(Array("out1", "out2")) - ImputerSuite.iterateStrategyTest(imputer, df) + ImputerSuite.iterateStrategyTest(true, imputer, df) + } + + test("Single Column: Imputer for Double with default missing Value NaN") { + val df1 = spark.createDataFrame( Seq( + (0, 1.0, 1.0, 1.0), + (1, 11.0, 11.0, 11.0), + (2, 3.0, 3.0, 3.0), + (3, Double.NaN, 5.0, 3.0) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer1 = new Imputer() + .setInputCol("value") + .setOutputCol("out") + ImputerSuite.iterateStrategyTest(false, imputer1, df1) + + val df2 = spark.createDataFrame( Seq( + (0, 4.0, 4.0, 4.0), + (1, 12.0, 12.0, 12.0), + (2, Double.NaN, 10.0, 12.0), + (3, 14.0, 14.0, 14.0) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer2 = new Imputer() + .setInputCol("value") + .setOutputCol("out") + ImputerSuite.iterateStrategyTest(false, imputer2, df2) } test("Imputer should handle NaNs when computing surrogate value, if missingValue is not NaN") { @@ -48,7 +74,20 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { )).toDF("id", "value", "expected_mean_value", "expected_median_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1.0) - ImputerSuite.iterateStrategyTest(imputer, df) + ImputerSuite.iterateStrategyTest(true, imputer, df) + } + + test("Single Column: Imputer should handle NaNs when computing surrogate value," + + " if missingValue is not NaN") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 1.0, 1.0), + (1, 3.0, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN), + (3, -1.0, 2.0, 1.0) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer = new Imputer().setInputCol("value").setOutputCol("out") + .setMissingValue(-1.0) + ImputerSuite.iterateStrategyTest(false, imputer, df) } test("Imputer for Float with missing Value -1.0") { @@ -61,7 +100,20 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { )).toDF("id", "value", "expected_mean_value", "expected_median_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1) - ImputerSuite.iterateStrategyTest(imputer, df) + ImputerSuite.iterateStrategyTest(true, imputer, df) + } + + test("Single Column: Imputer for Float with missing Value -1.0") { + val df = spark.createDataFrame( Seq( + (0, 1.0F, 1.0F, 1.0F), + (1, 3.0F, 3.0F, 3.0F), + (2, 10.0F, 10.0F, 10.0F), + (3, 10.0F, 10.0F, 10.0F), + (4, -1.0F, 6.0F, 3.0F) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + val imputer = new Imputer().setInputCol("value").setOutputCol("out") + .setMissingValue(-1) + ImputerSuite.iterateStrategyTest(false, imputer, df) } test("Imputer should impute null as well as 'missingValue'") { @@ -74,7 +126,20 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) - ImputerSuite.iterateStrategyTest(imputer, df) + ImputerSuite.iterateStrategyTest(true, imputer, df) + } + + test("Single Column: Imputer should impute null as well as 'missingValue'") { + val rawDf = spark.createDataFrame( Seq( + (0, 4.0, 4.0, 4.0), + (1, 10.0, 10.0, 10.0), + (2, 10.0, 10.0, 10.0), + (3, Double.NaN, 8.0, 10.0), + (4, -1.0, 8.0, 10.0) + )).toDF("id", "rawValue", "expected_mean_value", "expected_median_value") + val df = rawDf.selectExpr("*", "IF(rawValue=-1.0, null, rawValue) as value") + val imputer = new Imputer().setInputCol("value").setOutputCol("out") + ImputerSuite.iterateStrategyTest(false, imputer, df) } test("Imputer should work with Structured Streaming") { @@ -99,6 +164,28 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } } + test("Single Column: Imputer should work with Structured Streaming") { + val localSpark = spark + import localSpark.implicits._ + val df = Seq[(java.lang.Double, Double)]( + (4.0, 4.0), + (10.0, 10.0), + (10.0, 10.0), + (Double.NaN, 8.0), + (null, 8.0) + ).toDF("value", "expected_mean_value") + val imputer = new Imputer() + .setInputCol("value") + .setOutputCol("out") + .setStrategy("mean") + val model = imputer.fit(df) + testTransformer[(java.lang.Double, Double)](df, model, "expected_mean_value", "out") { + case Row(exp: java.lang.Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") + } + } + test("Imputer throws exception when surrogate cannot be computed") { val df = spark.createDataFrame( Seq( (0, Double.NaN, 1.0, 1.0), @@ -117,6 +204,24 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } } + test("Single Column: Imputer throws exception when surrogate cannot be computed") { + val df = spark.createDataFrame( Seq( + (0, Double.NaN, 1.0, 1.0), + (1, Double.NaN, 3.0, 3.0), + (2, Double.NaN, Double.NaN, Double.NaN) + )).toDF("id", "value", "expected_mean_value", "expected_median_value") + Seq("mean", "median").foreach { strategy => + val imputer = new Imputer().setInputCol("value").setOutputCol("out") + .setStrategy(strategy) + withClue("Imputer should fail all the values are invalid") { + val e: SparkException = intercept[SparkException] { + val model = imputer.fit(df) + } + assert(e.getMessage.contains("surrogate cannot be computed")) + } + } + } + test("Imputer input & output column validation") { val df = spark.createDataFrame( Seq( (0, 1.0, 1.0, 1.0), @@ -164,6 +269,14 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { testDefaultReadWrite(t) } + test("Single Column: Imputer read/write") { + val t = new Imputer() + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + .setMissingValue(-1.0) + testDefaultReadWrite(t) + } + test("ImputerModel read/write") { val spark = this.spark import spark.implicits._ @@ -178,6 +291,20 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect()) } + test("Single Column: ImputerModel read/write") { + val spark = this.spark + import spark.implicits._ + val surrogateDF = Seq(1.234).toDF("myInputCol") + + val instance = new ImputerModel( + "myImputer", surrogateDF) + .setInputCol("myInputCol") + .setOutputCol("myOutputCol") + val newInstance = testDefaultReadWrite(instance) + assert(newInstance.surrogateDF.columns === instance.surrogateDF.columns) + assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect()) + } + test("Imputer for IntegerType with default missing value null") { val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( @@ -195,7 +322,27 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { for (mType <- types) { // cast all columns to desired data type for testing val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) - ImputerSuite.iterateStrategyTest(imputer, df2) + ImputerSuite.iterateStrategyTest(true, imputer, df2) + } + } + + test("Single Column Imputer for IntegerType with default missing value null") { + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( + (1, 1, 1), + (11, 11, 11), + (3, 3, 3), + (null, 5, 3) + )).toDF("value", "expected_mean_value", "expected_median_value") + + val imputer = new Imputer() + .setInputCol("value") + .setOutputCol("out") + + val types = Seq(IntegerType, LongType) + for (mType <- types) { + // cast all columns to desired data type for testing + val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) + ImputerSuite.iterateStrategyTest(false, imputer, df2) } } @@ -217,7 +364,85 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { for (mType <- types) { // cast all columns to desired data type for testing val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) - ImputerSuite.iterateStrategyTest(imputer, df2) + ImputerSuite.iterateStrategyTest(true, imputer, df2) + } + } + + test("Single Column: Imputer for IntegerType with missing value -1") { + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( + (1, 1, 1), + (11, 11, 11), + (3, 3, 3), + (-1, 5, 3) + )).toDF("value", "expected_mean_value", "expected_median_value") + + val imputer = new Imputer() + .setInputCol("value") + .setOutputCol("out") + .setMissingValue(-1.0) + + val types = Seq(IntegerType, LongType) + for (mType <- types) { + // cast all columns to desired data type for testing + val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) + ImputerSuite.iterateStrategyTest(false, imputer, df2) + } + } + + test("assert exception is thrown if both multi-column and single-column params are set") { + import testImplicits._ + val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") + ParamsSuite.testExclusiveParams(new Imputer, df, ("inputCol", "feature1"), + ("inputCols", Array("feature1", "feature2"))) + ParamsSuite.testExclusiveParams(new Imputer, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("outputCols", Array("result1", "result2"))) + + // this should fail because at least one of inputCol and inputCols must be set + ParamsSuite.testExclusiveParams(new Imputer, df, ("outputCol", "feature1")) + } + + test("Compare single/multiple column(s) Imputer in pipeline") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 4.0), + (1, 11.0, 12.0), + (2, 3.0, Double.NaN), + (3, Double.NaN, 14.0) + )).toDF("id", "value1", "value2") + Seq("mean", "median").foreach { strategy => + val multiColsImputer = new Imputer() + .setInputCols(Array("value1", "value2")) + .setOutputCols(Array("result1", "result2")) + .setStrategy(strategy) + + val plForMultiCols = new Pipeline() + .setStages(Array(multiColsImputer)) + .fit(df) + + val imputerForCol1 = new Imputer() + .setInputCol("value1") + .setOutputCol("result1") + .setStrategy(strategy) + val imputerForCol2 = new Imputer() + .setInputCol("value2") + .setOutputCol("result2") + .setStrategy(strategy) + + val plForSingleCol = new Pipeline() + .setStages(Array(imputerForCol1, imputerForCol2)) + .fit(df) + + val resultForSingleCol = plForSingleCol.transform(df) + .select("result1", "result2") + .collect() + val resultForMultiCols = plForMultiCols.transform(df) + .select("result1", "result2") + .collect() + + resultForSingleCol.zip(resultForMultiCols).foreach { + case (rowForSingle, rowForMultiCols) => + assert(rowForSingle.getDouble(0) == rowForMultiCols.getDouble(0) && + rowForSingle.getDouble(1) == rowForMultiCols.getDouble(1)) + } } } } @@ -228,34 +453,45 @@ object ImputerSuite { * Imputation strategy. Available options are ["mean", "median"]. * @param df DataFrame with columns "id", "value", "expected_mean", "expected_median" */ - def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = { + def iterateStrategyTest(isMultiCol: Boolean, imputer: Imputer, df: DataFrame): Unit = { Seq("mean", "median").foreach { strategy => imputer.setStrategy(strategy) val model = imputer.fit(df) val resultDF = model.transform(df) - imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => - - // check dataType is consistent between input and output - val inputType = resultDF.schema(inputCol).dataType - val outputType = resultDF.schema(outputCol).dataType - assert(inputType == outputType, "Output type is not the same as input type.") - - // check value - resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach { - case Row(exp: Float, out: Float) => - assert((exp.isNaN && out.isNaN) || (exp == out), - s"Imputed values differ. Expected: $exp, actual: $out") - case Row(exp: Double, out: Double) => - assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), - s"Imputed values differ. Expected: $exp, actual: $out") - case Row(exp: Integer, out: Integer) => - assert(exp == out, - s"Imputed values differ. Expected: $exp, actual: $out") - case Row(exp: Long, out: Long) => - assert(exp == out, - s"Imputed values differ. Expected: $exp, actual: $out") + if (isMultiCol) { + imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => + verifyTransformResult(strategy, inputCol, outputCol, resultDF) } + } else { + verifyTransformResult(strategy, imputer.getInputCol, imputer.getOutputCol, resultDF) } } } + + def verifyTransformResult( + strategy: String, + inputCol: String, + outputCol: String, + resultDF: DataFrame): Unit = { + // check dataType is consistent between input and output + val inputType = resultDF.schema(inputCol).dataType + val outputType = resultDF.schema(outputCol).dataType + assert(inputType == outputType, "Output type is not the same as input type.") + + // check value + resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach { + case Row(exp: Float, out: Float) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), + s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Integer, out: Integer) => + assert(exp == out, + s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Long, out: Long) => + assert(exp == out, + s"Imputed values differ. Expected: $exp, actual: $out") + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala index 70f8c029a257..897251d9815c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/OneHotEncoderSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.ml.feature +import org.apache.spark.ml.Pipeline import org.apache.spark.ml.attribute.{AttributeGroup, BinaryAttribute, NominalAttribute} import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.ml.param.ParamsSuite @@ -62,6 +63,34 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest { } } + test("Single Column: OneHotEncoder dropLast = false") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder = new OneHotEncoder() + .setInputCol("input") + .setOutputCol("output") + assert(encoder.getDropLast) + encoder.setDropLast(false) + assert(encoder.getDropLast === false) + val model = encoder.fit(df) + testTransformer[(Double, Vector)](df, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + } + } + test("OneHotEncoder dropLast = true") { val data = Seq( Row(0.0, Vectors.sparse(2, Seq((0, 1.0)))), @@ -104,6 +133,22 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest { } } + test("Single Column: input column with ML attribute") { + val attr = NominalAttribute.defaultAttr.withValues("small", "medium", "large") + val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("size") + .select(col("size").as("size", attr.toMetadata())) + val encoder = new OneHotEncoder() + .setInputCol("size") + .setOutputCol("encoded") + val model = encoder.fit(df) + testTransformerByGlobalCheckFunc[(Double)](df, model, "encoded") { rows => + val group = AttributeGroup.fromStructField(rows.head.schema("encoded")) + assert(group.size === 2) + assert(group.getAttr(0) === BinaryAttribute.defaultAttr.withName("small").withIndex(0)) + assert(group.getAttr(1) === BinaryAttribute.defaultAttr.withName("medium").withIndex(1)) + } + } + test("input column without ML attribute") { val df = Seq(0.0, 1.0, 2.0, 1.0).map(Tuple1.apply).toDF("index") val encoder = new OneHotEncoder() @@ -125,6 +170,13 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest { testDefaultReadWrite(encoder) } + test("Single Column: read/write") { + val encoder = new OneHotEncoder() + .setInputCol("index") + .setOutputCol("encoded") + testDefaultReadWrite(encoder) + } + test("OneHotEncoderModel read/write") { val instance = new OneHotEncoderModel("myOneHotEncoderModel", Array(1, 2, 3)) val newInstance = testDefaultReadWrite(instance) @@ -173,6 +225,48 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest { } } + test("Single Column: OneHotEncoder with varying types") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input", DoubleType), + StructField("expected", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + class NumericTypeWithEncoder[A](val numericType: NumericType) + (implicit val encoder: Encoder[(A, Vector)]) + + val types = Seq( + new NumericTypeWithEncoder[Short](ShortType), + new NumericTypeWithEncoder[Long](LongType), + new NumericTypeWithEncoder[Int](IntegerType), + new NumericTypeWithEncoder[Float](FloatType), + new NumericTypeWithEncoder[Byte](ByteType), + new NumericTypeWithEncoder[Double](DoubleType), + new NumericTypeWithEncoder[Decimal](DecimalType(10, 0))(ExpressionEncoder())) + + for (t <- types) { + val dfWithTypes = df.select(col("input").cast(t.numericType), col("expected")) + val estimator = new OneHotEncoder() + .setInputCol("input") + .setOutputCol("output") + .setDropLast(false) + + val model = estimator.fit(dfWithTypes) + testTransformer(dfWithTypes, model, "output", "expected") { + case Row(output: Vector, expected: Vector) => + assert(output === expected) + }(t.encoder) + } + } + test("OneHotEncoder: encoding multiple columns and dropLast = false") { val data = Seq( Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))), @@ -211,6 +305,58 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest { } } + test("Single Column: OneHotEncoder: encoding multiple columns and dropLast = false") { + val data = Seq( + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0)))), + Row(1.0, Vectors.sparse(3, Seq((1, 1.0))), 3.0, Vectors.sparse(4, Seq((3, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 1.0, Vectors.sparse(4, Seq((1, 1.0)))), + Row(0.0, Vectors.sparse(3, Seq((0, 1.0))), 0.0, Vectors.sparse(4, Seq((0, 1.0)))), + Row(2.0, Vectors.sparse(3, Seq((2, 1.0))), 2.0, Vectors.sparse(4, Seq((2, 1.0))))) + + val schema = StructType(Array( + StructField("input1", DoubleType), + StructField("expected1", new VectorUDT), + StructField("input2", DoubleType), + StructField("expected2", new VectorUDT))) + + val df = spark.createDataFrame(sc.parallelize(data), schema) + + val encoder1 = new OneHotEncoder() + .setInputCol("input1") + .setOutputCol("output1") + assert(encoder1.getDropLast) + encoder1.setDropLast(false) + assert(encoder1.getDropLast === false) + + val model1 = encoder1.fit(df) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model1, + "output1", + "expected1") { + case Row(output1: Vector, expected1: Vector) => + assert(output1 === expected1) + } + + val encoder2 = new OneHotEncoder() + .setInputCol("input2") + .setOutputCol("output2") + assert(encoder2.getDropLast) + encoder2.setDropLast(false) + assert(encoder2.getDropLast === false) + + val model2 = encoder2.fit(df) + testTransformer[(Double, Vector, Double, Vector)]( + df, + model2, + "output2", + "expected2") { + case Row(output2: Vector, expected2: Vector) => + assert(output2 === expected2) + } + } + test("OneHotEncoder: encoding multiple columns and dropLast = true") { val data = Seq( Row(0.0, Vectors.sparse(2, Seq((0, 1.0))), 2.0, Vectors.sparse(3, Seq((2, 1.0)))), @@ -419,4 +565,52 @@ class OneHotEncoderSuite extends MLTest with DefaultReadWriteTest { expectedMessagePart = "OneHotEncoderModel expected 2 categorical values", firstResultCol = "encoded") } + + test("assert exception is thrown if both multi-column and single-column params are set") { + import testImplicits._ + val df = Seq((0.5, 0.3), (0.5, -0.4)).toDF("feature1", "feature2") + ParamsSuite.testExclusiveParams(new OneHotEncoder, df, ("inputCol", "feature1"), + ("inputCols", Array("feature1", "feature2"))) + ParamsSuite.testExclusiveParams(new OneHotEncoder, df, ("inputCol", "feature1"), + ("outputCol", "result1"), ("outputCols", Array("result1", "result2"))) + + // this should fail because at least one of inputCol and inputCols must be set + ParamsSuite.testExclusiveParams(new OneHotEncoder, df, ("outputCol", "feature1")) + } + + test("Compare single/multiple column(s) OneHotEncoder in pipeline") { + val df = Seq((0.0, 2.0), (1.0, 3.0), (2.0, 0.0), (0.0, 1.0), (0.0, 0.0), (2.0, 2.0)) + .toDF("input1", "input2") + + val multiColsEncoder = new OneHotEncoder() + .setInputCols(Array("input1", "input2")) + .setOutputCols(Array("output1", "output2")) + + val plForMultiCols = new Pipeline() + .setStages(Array(multiColsEncoder)) + .fit(df) + + val encoderForCol1 = new OneHotEncoder() + .setInputCol("input1") + .setOutputCol("output1") + val encoderForCol2 = new OneHotEncoder() + .setInputCol("input2") + .setOutputCol("output2") + + val plForSingleCol = new Pipeline() + .setStages(Array(encoderForCol1, encoderForCol2)) + .fit(df) + + val resultForSingleCol = plForSingleCol.transform(df) + .select("output1", "output2") + .collect() + val resultForMultiCols = plForMultiCols.transform(df) + .select("output1", "output2") + .collect() + + resultForSingleCol.zip(resultForMultiCols).foreach { + case (rowForSingle, rowForMultiCols) => + assert(rowForSingle === rowForMultiCols) + } + } } diff --git a/pom.xml b/pom.xml index f1a7cb3d106f..8f86ae3b6dfb 100644 --- a/pom.xml +++ b/pom.xml @@ -139,7 +139,7 @@ 2.3.1 10.12.1.1 1.10.1 - 1.5.6 + 1.5.7 nohive com.twitter 1.6.0 @@ -148,10 +148,10 @@ 0.9.3 2.4.0 2.0.8 - 3.2.6 + 4.1.1 1.8.2 hadoop2 - 1.8.10 + 1.12.0 1.11.271 @@ -184,7 +184,7 @@ 3.2.10 3.0.15 2.29 - 2.9.3 + 2.10.5 3.5.2 3.0.0 0.12.0 @@ -333,7 +333,7 @@ org.apache.xbean xbean-asm7-shaded - 4.14 + 4.15