Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class SqlFunctionProperties
private final boolean fieldNamesInJsonCastEnabled;
private final boolean legacyJsonCast;
private final Map<String, String> extraCredentials;
private final boolean warnOnCommonNanPatterns;

private SqlFunctionProperties(
boolean parseDecimalLiteralAsDouble,
Expand All @@ -48,7 +49,8 @@ private SqlFunctionProperties(
String sessionUser,
boolean fieldNamesInJsonCastEnabled,
boolean legacyJsonCast,
Map<String, String> extraCredentials)
Map<String, String> extraCredentials,
boolean warnOnCommonNanPatterns)
{
this.parseDecimalLiteralAsDouble = parseDecimalLiteralAsDouble;
this.legacyRowFieldOrdinalAccessEnabled = legacyRowFieldOrdinalAccessEnabled;
Expand All @@ -61,6 +63,7 @@ private SqlFunctionProperties(
this.fieldNamesInJsonCastEnabled = fieldNamesInJsonCastEnabled;
this.legacyJsonCast = legacyJsonCast;
this.extraCredentials = requireNonNull(extraCredentials, "extraCredentials is null");
this.warnOnCommonNanPatterns = warnOnCommonNanPatterns;
}

public boolean isParseDecimalLiteralAsDouble()
Expand Down Expand Up @@ -119,6 +122,11 @@ public boolean isLegacyJsonCast()
return legacyJsonCast;
}

public boolean shouldWarnOnCommonNanPatterns()
{
return warnOnCommonNanPatterns;
}

@Override
public boolean equals(Object o)
{
Expand Down Expand Up @@ -167,6 +175,7 @@ public static class Builder
private boolean fieldNamesInJsonCastEnabled;
private boolean legacyJsonCast;
private Map<String, String> extraCredentials = emptyMap();
private boolean warnOnCommonNanPatterns;

private Builder() {}

Expand Down Expand Up @@ -236,6 +245,12 @@ public Builder setLegacyJsonCast(boolean legacyJsonCast)
return this;
}

public Builder setWarnOnCommonNanPatterns(boolean warnOnCommonNanPatterns)
{
this.warnOnCommonNanPatterns = warnOnCommonNanPatterns;
return this;
}

public SqlFunctionProperties build()
{
return new SqlFunctionProperties(
Expand All @@ -249,7 +264,8 @@ public SqlFunctionProperties build()
sessionUser,
fieldNamesInJsonCastEnabled,
legacyJsonCast,
extraCredentials);
extraCredentials,
warnOnCommonNanPatterns);
}
}
}
2 changes: 2 additions & 0 deletions presto-main/src/main/java/com/facebook/presto/Session.java
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
import static com.facebook.presto.SystemSessionProperties.isLegacyRowFieldOrdinalAccessEnabled;
import static com.facebook.presto.SystemSessionProperties.isLegacyTimestamp;
import static com.facebook.presto.SystemSessionProperties.isParseDecimalLiteralsAsDouble;
import static com.facebook.presto.SystemSessionProperties.warnOnCommonNanPatterns;
import static com.facebook.presto.spi.ConnectorId.createInformationSchemaConnectorId;
import static com.facebook.presto.spi.ConnectorId.createSystemTablesConnectorId;
import static com.facebook.presto.spi.StandardErrorCode.NOT_FOUND;
Expand Down Expand Up @@ -525,6 +526,7 @@ public SqlFunctionProperties getSqlFunctionProperties()
.setFieldNamesInJsonCastEnabled(isFieldNameInJsonCastEnabled(this))
.setLegacyJsonCast(legacyJsonCast)
.setExtraCredentials(identity.getExtraCredentials())
.setWarnOnCommonNanPatterns(warnOnCommonNanPatterns(this))
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ public final class SystemSessionProperties
public static final String DEFAULT_VIEW_SECURITY_MODE = "default_view_security_mode";
public static final String JOIN_PREFILTER_BUILD_SIDE = "join_prefilter_build_side";
public static final String OPTIMIZER_USE_HISTOGRAMS = "optimizer_use_histograms";
public static final String WARN_ON_COMMON_NAN_PATTERNS = "warn_on_common_nan_patterns";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -1937,6 +1938,10 @@ public SystemSessionProperties(
booleanProperty(OPTIMIZER_USE_HISTOGRAMS,
"whether or not to use histograms in the CBO",
featuresConfig.isUseHistograms(),
false),
booleanProperty(WARN_ON_COMMON_NAN_PATTERNS,
"Whether to give a warning for some common issues relating to NaNs",
featuresConfig.getWarnOnCommonNanPatterns(),
false));
}

Expand Down Expand Up @@ -3229,4 +3234,9 @@ public static boolean shouldOptimizerUseHistograms(Session session)
{
return session.getSystemProperty(OPTIMIZER_USE_HISTOGRAMS, Boolean.class);
}

public static boolean warnOnCommonNanPatterns(Session session)
{
return session.getSystemProperty(WARN_ON_COMMON_NAN_PATTERNS, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
import static com.facebook.presto.sql.NodeUtils.getSortItemsFromOrderBy;
import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoAggregateWindowOrGroupingFunctions;
import static com.facebook.presto.sql.analyzer.Analyzer.verifyNoExternalFunctions;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isConstant;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isNonNullConstant;
import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.tryResolveEnumLiteralType;
import static com.facebook.presto.sql.analyzer.FunctionArgumentCheckerForAccessControlUtils.getResolvedLambdaArguments;
Expand Down Expand Up @@ -601,7 +602,16 @@ protected Type visitLogicalBinaryExpression(LogicalBinaryExpression node, Stacka
protected Type visitComparisonExpression(ComparisonExpression node, StackableAstVisitorContext<Context> context)
{
OperatorType operatorType = OperatorType.valueOf(node.getOperator().name());
return getOperator(context, node, operatorType, node.getLeft(), node.getRight());
Type outputType = getOperator(context, node, operatorType, node.getLeft(), node.getRight());
// this needs to be checked after the call to getOperator(), because that's where the argument types get analyzed
if (sqlFunctionProperties.shouldWarnOnCommonNanPatterns() &&
(TypeUtils.isApproximateNumericType(getExpressionType(node.getLeft())) || TypeUtils.isApproximateNumericType(getExpressionType(node.getRight())))) {
warningCollector.add(new PrestoWarning(
SEMANTIC_WARNING,
"Comparison operations involving DOUBLE or REAL types may include NaNs in the input. " +
"Consider filtering out NaN values from your comparison input using the is_nan() function."));
}
return outputType;
}

@Override
Expand Down Expand Up @@ -732,7 +742,17 @@ protected Type visitArithmeticUnary(ArithmeticUnaryExpression node, StackableAst
@Override
protected Type visitArithmeticBinary(ArithmeticBinaryExpression node, StackableAstVisitorContext<Context> context)
{
return getOperator(context, node, OperatorType.valueOf(node.getOperator().name()), node.getLeft(), node.getRight());
Type returnType = getOperator(context, node, OperatorType.valueOf(node.getOperator().name()), node.getLeft(), node.getRight());
if (sqlFunctionProperties.shouldWarnOnCommonNanPatterns() &&
node.getOperator() == ArithmeticBinaryExpression.Operator.DIVIDE &&
TypeUtils.isApproximateNumericType(returnType) &&
!isConstant(node.getLeft()) &&
!isConstant(node.getRight())) {
warningCollector.add(new PrestoWarning(SEMANTIC_WARNING,
"Division operations on DOUBLE/REAL types may produce NaNs or infinities if there are zeros in the denominator. " +
"Consider checking the denominator of your division operation for zeros."));
}
return returnType;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ public class FeaturesConfig
private boolean useHistograms;

private boolean useNewNanDefinition = true;
private boolean warnOnPossibleNans;

public enum PartitioningPrecisionStrategy
{
Expand Down Expand Up @@ -3132,4 +3133,17 @@ public FeaturesConfig setUseNewNanDefinition(boolean useNewNanDefinition)
this.useNewNanDefinition = useNewNanDefinition;
return this;
}

public boolean getWarnOnCommonNanPatterns()
{
return warnOnPossibleNans;
}

@Config("warn-on-common-nan-patterns")
@ConfigDescription("Give warnings for operations on DOUBLE/REAL types where NaN issues are common")
public FeaturesConfig setWarnOnCommonNanPatterns(boolean warnOnPossibleNans)
{
this.warnOnPossibleNans = warnOnPossibleNans;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ public void testDefaults()
.setLegacyJsonCast(true)
.setPrintEstimatedStatsFromCache(false)
.setUseHistograms(false)
.setUseNewNanDefinition(true));
.setUseNewNanDefinition(true)
.setWarnOnCommonNanPatterns(false));
}

@Test
Expand Down Expand Up @@ -489,6 +490,7 @@ public void testExplicitPropertyMappings()
.put("optimizer.print-estimated-stats-from-cache", "true")
.put("optimizer.use-histograms", "true")
.put("use-new-nan-definition", "false")
.put("warn-on-common-nan-patterns", "true")
.build();

FeaturesConfig expected = new FeaturesConfig()
Expand Down Expand Up @@ -702,7 +704,8 @@ public void testExplicitPropertyMappings()
.setLegacyJsonCast(false)
.setPrintEstimatedStatsFromCache(true)
.setUseHistograms(true)
.setUseNewNanDefinition(false);
.setUseNewNanDefinition(false)
.setWarnOnCommonNanPatterns(true);
assertFullMapping(properties, expected);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Set;

import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.SystemSessionProperties.WARN_ON_COMMON_NAN_PATTERNS;
import static com.facebook.presto.execution.TestQueryRunnerUtil.createQueryRunner;
import static com.facebook.presto.spi.StandardWarningCode.MULTIPLE_ORDER_BY;
import static com.facebook.presto.spi.StandardWarningCode.PARSER_WARNING;
Expand All @@ -40,6 +41,9 @@
public class TestWarnings
{
private static final int STAGE_COUNT_WARNING_THRESHOLD = 20;
private static final Session ALL_WARININGS_SESSION = Session.builder(TEST_SESSION)
.setSystemProperty(WARN_ON_COMMON_NAN_PATTERNS, "true")
.build();
private QueryRunner queryRunner;

@BeforeClass
Expand Down Expand Up @@ -202,6 +206,80 @@ public void testMapWithDecimalKeyProducesNoWarnings()
assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of());
}

/**
* The below tests check warnings for nan on DOUBLE/REAL types. Because we usually don't know whether any input values are nan or will produce nan,
* the warnings only check that the type of the input can be affected by nans.
*/
@Test
public void testDoubleDivisionNanWarning()
{
String query = "SELECT x /y FROM (VALUES (DOUBLE '1.0', DOUBLE '2.0')) t(x, y)";
assertWarnings(queryRunner, ALL_WARININGS_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode()));
}

@Test
public void testRealDivisionNanWarning()
{
String query = "SELECT x/y FROM (VALUES (REAL '1.0' , REAL '2.0')) t(x,y)";
assertWarnings(queryRunner, ALL_WARININGS_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode()));
}

@Test
public void testConstantDivisionProducesNoWarnings()
{
String query = "SELECT DOUBLE '1.0' / DOUBLE '2.0'";
assertWarnings(queryRunner, ALL_WARININGS_SESSION, query, ImmutableSet.of());
}

@Test
public void testIntegerDivisionProducesNoWarnings()
{
String query = "SELECT 4 / 2";
assertWarnings(queryRunner, ALL_WARININGS_SESSION, query, ImmutableSet.of());
}

@Test
public void testNoWarningsForDivisionWhenDisabled()
{
String query = "SELECT DOUBLE '1.0' / DOUBLE '2.0'";
assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of());
}

@Test
public void testOtherArithmeticOperationsProducesNoWarnings()
{
String query = "SELECT DOUBLE '1.0' * DOUBLE '2.0'";
assertWarnings(queryRunner, ALL_WARININGS_SESSION, query, ImmutableSet.of());
}

@Test
public void testDoubleComparisonNaNWarning()
{
String query = "SELECT DOUBLE '1.0' > DOUBLE '2.0'";
assertWarnings(queryRunner, ALL_WARININGS_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode()));
}

@Test
public void testRealComparisonNaNWarning()
{
String query = "SELECT REAL '1.0' > REAL '2.0'";
assertWarnings(queryRunner, ALL_WARININGS_SESSION, query, ImmutableSet.of(SEMANTIC_WARNING.toWarningCode()));
}

@Test
public void testIntegerComparisonProducesNoWarnings()
{
String query = "SELECT 1 > 2";
assertWarnings(queryRunner, ALL_WARININGS_SESSION, query, ImmutableSet.of());
}

@Test
public void testNoWarningsForComparisonWhenDisabled()
{
String query = "SELECT DOUBLE '1.0' > DOUBLE '2.0'";
assertWarnings(queryRunner, TEST_SESSION, query, ImmutableSet.of());
}

private static void assertWarnings(QueryRunner queryRunner, Session session, @Language("SQL") String sql, Set<WarningCode> expectedWarnings)
{
Set<WarningCode> warnings = queryRunner.execute(session, sql).getWarnings().stream()
Expand Down