Skip to content

Commit

Permalink
Add expression complexity and call stack depth limits.
Browse files Browse the repository at this point in the history
git-svn-id: https://angleproject.googlecode.com/svn/branches/dx11proto@2254 736b8ea6-26fd-11df-bfd4-992fa37f6226
  • Loading branch information
[email protected] committed May 29, 2013
1 parent b0f1b48 commit da8ea02
Show file tree
Hide file tree
Showing 13 changed files with 805 additions and 152 deletions.
4 changes: 4 additions & 0 deletions build/common.gypi
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@
'LinkIncremental': '2',
},
},
'xcode_settings': {
'COPY_PHASE_STRIP': 'NO',
'GCC_OPTIMIZATION_LEVEL': '0',
},
}, # Debug
'Release': {
'inherit_from': ['Common'],
Expand Down
14 changes: 13 additions & 1 deletion include/GLSLANG/ShaderLang.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,13 @@ typedef enum {
// vec234, or mat234 type. The ShArrayIndexClampingStrategy enum,
// specified in the ShBuiltInResources when constructing the
// compiler, selects the strategy for the clamping implementation.
SH_CLAMP_INDIRECT_ARRAY_BOUNDS = 0x1000
SH_CLAMP_INDIRECT_ARRAY_BOUNDS = 0x1000,

// This flag limits the complexity of an expression.
SH_LIMIT_EXPRESSION_COMPLEXITY = 0x2000,

// This flag limits the depth of the call stack.
SH_LIMIT_CALL_STACK_DEPTH = 0x4000,
} ShCompileOptions;

// Defines alternate strategies for implementing array index clamping.
Expand Down Expand Up @@ -225,6 +231,12 @@ typedef struct
// Selects a strategy to use when implementing array index clamping.
// Default is SH_CLAMP_WITH_CLAMP_INTRINSIC.
ShArrayIndexClampingStrategy ArrayIndexClampingStrategy;

// The maximum complexity an expression can be.
int MaxExpressionComplexity;

// The maximum depth a call stack can be.
int MaxCallStackDepth;
} ShBuiltInResources;

//
Expand Down
4 changes: 2 additions & 2 deletions src/build_angle.gypi
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@
'compiler/ConstantUnion.h',
'compiler/debug.cpp',
'compiler/debug.h',
'compiler/DetectRecursion.cpp',
'compiler/DetectRecursion.h',
'compiler/DetectCallDepth.cpp',
'compiler/DetectCallDepth.h',
'compiler/Diagnostics.h',
'compiler/Diagnostics.cpp',
'compiler/DirectiveHandler.h',
Expand Down
52 changes: 43 additions & 9 deletions src/compiler/Compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//

#include "compiler/BuiltInFunctionEmulator.h"
#include "compiler/DetectRecursion.h"
#include "compiler/DetectCallDepth.h"
#include "compiler/ForLoopUnroll.h"
#include "compiler/Initialize.h"
#include "compiler/InitializeParseContext.h"
Expand Down Expand Up @@ -104,6 +104,9 @@ TShHandleBase::~TShHandleBase() {
TCompiler::TCompiler(ShShaderType type, ShShaderSpec spec)
: shaderType(type),
shaderSpec(spec),
maxUniformVectors(0),
maxExpressionComplexity(0),
maxCallStackDepth(0),
fragmentPrecisionHigh(false),
clampingStrategy(SH_CLAMP_WITH_CLAMP_INTRINSIC),
builtInFunctionEmulator(type)
Expand All @@ -122,6 +125,8 @@ bool TCompiler::Init(const ShBuiltInResources& resources)
maxUniformVectors = (shaderType == SH_VERTEX_SHADER) ?
resources.MaxVertexUniformVectors :
resources.MaxFragmentUniformVectors;
maxExpressionComplexity = resources.MaxExpressionComplexity;
maxCallStackDepth = resources.MaxCallStackDepth;
TScopedPoolAllocator scopedAlloc(&allocator, false);

// Generate built-in symbol table.
Expand Down Expand Up @@ -185,7 +190,7 @@ bool TCompiler::compile(const char* const shaderStrings[],
success = intermediate.postProcess(root);

if (success)
success = detectRecursion(root);
success = detectCallDepth(root, infoSink, (compileOptions & SH_LIMIT_CALL_STACK_DEPTH) != 0);

if (success && (compileOptions & SH_VALIDATE_LOOP_INDEXING))
success = validateLimitations(root);
Expand All @@ -208,6 +213,10 @@ bool TCompiler::compile(const char* const shaderStrings[],
if (success && (compileOptions & SH_CLAMP_INDIRECT_ARRAY_BOUNDS))
arrayBoundsClamper.MarkIndirectArrayBoundsForClamping(root);

// Disallow expressions deemed too complex.
if (success && (compileOptions & SH_LIMIT_EXPRESSION_COMPLEXITY))
success = limitExpressionComplexity(root);

// Call mapLongVariableNames() before collectAttribsUniforms() so in
// collectAttribsUniforms() we already have the mapped symbol names and
// we could composite mapped and original variable names.
Expand Down Expand Up @@ -268,24 +277,27 @@ void TCompiler::clearResults()
nameMap.clear();
}

bool TCompiler::detectRecursion(TIntermNode* root)
bool TCompiler::detectCallDepth(TIntermNode* root, TInfoSink& infoSink, bool limitCallStackDepth)
{
DetectRecursion detect;
DetectCallDepth detect(infoSink, limitCallStackDepth, maxCallStackDepth);
root->traverse(&detect);
switch (detect.detectRecursion()) {
case DetectRecursion::kErrorNone:
switch (detect.detectCallDepth()) {
case DetectCallDepth::kErrorNone:
return true;
case DetectRecursion::kErrorMissingMain:
case DetectCallDepth::kErrorMissingMain:
infoSink.info.prefix(EPrefixError);
infoSink.info << "Missing main()";
return false;
case DetectRecursion::kErrorRecursion:
case DetectCallDepth::kErrorRecursion:
infoSink.info.prefix(EPrefixError);
infoSink.info << "Function recursion detected";
return false;
case DetectCallDepth::kErrorMaxDepthExceeded:
infoSink.info.prefix(EPrefixError);
infoSink.info << "Function call stack too deep";
return false;
default:
UNREACHABLE();
return false;
}
}

Expand Down Expand Up @@ -327,6 +339,28 @@ bool TCompiler::enforceTimingRestrictions(TIntermNode* root, bool outputGraph)
}
}

bool TCompiler::limitExpressionComplexity(TIntermNode* root)
{
TIntermTraverser traverser;
root->traverse(&traverser);
TDependencyGraph graph(root);

for (TFunctionCallVector::const_iterator iter = graph.beginUserDefinedFunctionCalls();
iter != graph.endUserDefinedFunctionCalls();
++iter)
{
TGraphFunctionCall* samplerSymbol = *iter;
TDependencyGraphTraverser graphTraverser;
samplerSymbol->traverse(&graphTraverser);
}

if (traverser.getMaxDepth() > maxExpressionComplexity) {
infoSink.info << "Expression too complex.";
return false;
}
return true;
}

bool TCompiler::enforceFragmentShaderTimingRestrictions(const TDependencyGraph& graph)
{
RestrictFragmentShaderTiming restrictor(infoSink.info);
Expand Down
187 changes: 187 additions & 0 deletions src/compiler/DetectCallDepth.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
//
// Copyright (c) 2002-2011 The ANGLE Project Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//

#include "compiler/DetectCallDepth.h"
#include "compiler/InfoSink.h"

const int DetectCallDepth::FunctionNode::kInfiniteCallDepth;

DetectCallDepth::FunctionNode::FunctionNode(const TString& fname)
: name(fname),
visit(PreVisit)
{
}

const TString& DetectCallDepth::FunctionNode::getName() const
{
return name;
}

void DetectCallDepth::FunctionNode::addCallee(
DetectCallDepth::FunctionNode* callee)
{
for (size_t i = 0; i < callees.size(); ++i) {
if (callees[i] == callee)
return;
}
callees.push_back(callee);
}

int DetectCallDepth::FunctionNode::detectCallDepth(DetectCallDepth* detectCallDepth, int depth)
{
ASSERT(visit == PreVisit);
ASSERT(detectCallDepth);

int maxDepth = depth;
visit = InVisit;
for (size_t i = 0; i < callees.size(); ++i) {
switch (callees[i]->visit) {
case InVisit:
// cycle detected, i.e., recursion detected.
return kInfiniteCallDepth;
case PostVisit:
break;
case PreVisit: {
// Check before we recurse so we don't go too depth
if (detectCallDepth->checkExceedsMaxDepth(depth))
return depth;
int callDepth = callees[i]->detectCallDepth(detectCallDepth, depth + 1);
// Check after we recurse so we can exit immediately and provide info.
if (detectCallDepth->checkExceedsMaxDepth(callDepth)) {
detectCallDepth->getInfoSink().info << "<-" << callees[i]->getName();
return callDepth;
}
maxDepth = std::max(callDepth, maxDepth);
break;
}
default:
UNREACHABLE();
break;
}
}
visit = PostVisit;
return maxDepth;
}

void DetectCallDepth::FunctionNode::reset()
{
visit = PreVisit;
}

DetectCallDepth::DetectCallDepth(TInfoSink& infoSink, bool limitCallStackDepth, int maxCallStackDepth)
: TIntermTraverser(true, false, true, false),
currentFunction(NULL),
infoSink(infoSink),
maxDepth(limitCallStackDepth ? maxCallStackDepth : FunctionNode::kInfiniteCallDepth)
{
}

DetectCallDepth::~DetectCallDepth()
{
for (size_t i = 0; i < functions.size(); ++i)
delete functions[i];
}

bool DetectCallDepth::visitAggregate(Visit visit, TIntermAggregate* node)
{
switch (node->getOp())
{
case EOpPrototype:
// Function declaration.
// Don't add FunctionNode here because node->getName() is the
// unmangled function name.
break;
case EOpFunction: {
// Function definition.
if (visit == PreVisit) {
currentFunction = findFunctionByName(node->getName());
if (currentFunction == NULL) {
currentFunction = new FunctionNode(node->getName());
functions.push_back(currentFunction);
}
} else if (visit == PostVisit) {
currentFunction = NULL;
}
break;
}
case EOpFunctionCall: {
// Function call.
if (visit == PreVisit) {
FunctionNode* func = findFunctionByName(node->getName());
if (func == NULL) {
func = new FunctionNode(node->getName());
functions.push_back(func);
}
if (currentFunction)
currentFunction->addCallee(func);
}
break;
}
default:
break;
}
return true;
}

bool DetectCallDepth::checkExceedsMaxDepth(int depth)
{
return depth >= maxDepth;
}

void DetectCallDepth::resetFunctionNodes()
{
for (size_t i = 0; i < functions.size(); ++i) {
functions[i]->reset();
}
}

DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepthForFunction(FunctionNode* func)
{
currentFunction = NULL;
resetFunctionNodes();

int maxCallDepth = func->detectCallDepth(this, 1);

if (maxCallDepth == FunctionNode::kInfiniteCallDepth)
return kErrorRecursion;

if (maxCallDepth >= maxDepth)
return kErrorMaxDepthExceeded;

return kErrorNone;
}

DetectCallDepth::ErrorCode DetectCallDepth::detectCallDepth()
{
if (maxDepth != FunctionNode::kInfiniteCallDepth) {
// Check all functions because the driver may fail on them
// TODO: Before detectingRecursion, strip unused functions.
for (size_t i = 0; i < functions.size(); ++i) {
ErrorCode error = detectCallDepthForFunction(functions[i]);
if (error != kErrorNone)
return error;
}
} else {
FunctionNode* main = findFunctionByName("main(");
if (main == NULL)
return kErrorMissingMain;

return detectCallDepthForFunction(main);
}

return kErrorNone;
}

DetectCallDepth::FunctionNode* DetectCallDepth::findFunctionByName(
const TString& name)
{
for (size_t i = 0; i < functions.size(); ++i) {
if (functions[i]->getName() == name)
return functions[i];
}
return NULL;
}

Loading

0 comments on commit da8ea02

Please sign in to comment.