Skip to content

Commit

Permalink
Actually fix the bug with dataset iterator processing.
Browse files Browse the repository at this point in the history
  • Loading branch information
khatchad committed Jul 2, 2024
1 parent 8ad5914 commit 02c2a6b
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,21 @@ public void testDataset35()
test("tf2_test_dataset35.py", "add", 2, 2, 2, 3);
}

/** Test a dataset that uses an iterator. */
@Test
public void testDataset36()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset36.py", "id1", 1, 1, 2);
// test("tf2_test_dataset36.py", "id2", 1, 1, 2);
}

/** Test a dataset that uses an iterator. */
@Test
public void testDataset37()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test("tf2_test_dataset37.py", "add", 2, 2, 2, 3);
}

/**
* Test enumerating a dataset (https://github.com/wala/ML/issues/140). The first element of the
* tuple returned isn't a tensor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,42 +131,7 @@ private static Set<PointsToSetVariable> getDataflowSources(
if (inst instanceof SSAAbstractInvokeInstruction) {
// We potentially have a function call that generates a tensor.
SSAAbstractInvokeInstruction ni = (SSAAbstractInvokeInstruction) inst;

// don't consider exceptions as a data source.
if (ni.getException() == vn) continue;

if (ni.getCallSite()
.getDeclaredTarget()
.getName()
.toString()
.equals(TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME)
&& ni.getException() != vn) {
sources.add(src);
logger.info("Added dataflow source from tensor generator: " + src + ".");
} else if (ni.getNumberOfUses() > 1) {
// Get the invoked function from the PA.
int target = ni.getUse(0);
PointerKey targetKey =
pointerAnalysis.getHeapModel().getPointerKeyForLocal(localPointerKeyNode, target);

for (InstanceKey ik : pointerAnalysis.getPointsToSet(targetKey)) {
if (ik instanceof ConcreteTypeKey) {
ConcreteTypeKey ctk = (ConcreteTypeKey) ik;
IClass type = ctk.getType();
TypeReference reference = type.getReference();

if (reference.equals(NEXT.getDeclaringClass())) {
// it's a call to `next()`. Look up the call to `iter()`.
int iterator = ni.getUse(1);

// Use the original instruction. NOTE: We can only do this because `iter()` is
// currently just passing-through its argument.
processInstructionInterprocedurally(
ni, iterator, localPointerKeyNode, src, sources, pointerAnalysis);
}
}
}
}
processInstruction(ni, du, localPointerKeyNode, src, vn, sources, pointerAnalysis);
} else if (inst instanceof EachElementGetInstruction) {
// We are potentially pulling a tensor out of a tensor iterable.
EachElementGetInstruction eachElementGetInstruction = (EachElementGetInstruction) inst;
Expand Down Expand Up @@ -210,15 +175,112 @@ private static Set<PointsToSetVariable> getDataflowSources(
} else if (def instanceof EachElementGetInstruction
|| def instanceof PythonPropertyRead
|| def instanceof PythonInvokeInstruction) {
processInstruction(
def, du, localPointerKeyNode, src, sources, callGraph, pointerAnalysis);
boolean added = false;
// we may be invoking `next()` on a dataset.
if (def instanceof SSAAbstractInvokeInstruction && def.getNumberOfUses() > 1) {
SSAAbstractInvokeInstruction invokeInstruction = (SSAAbstractInvokeInstruction) def;
added =
processInstruction(
invokeInstruction,
du,
localPointerKeyNode,
src,
vn,
sources,
pointerAnalysis);
}

if (!added)
processInstruction(
def, du, localPointerKeyNode, src, sources, callGraph, pointerAnalysis);
}
}
}
}
return sources;
}

/**
* Processes the given {@link SSAAbstractInvokeInstruction}, adding the given {@link PointsToSetVariable} to the given {@link Set} of {@link PointsToSetVariable}s as a dataflow source if the given {@link SSAAbstractInvokeInstruction} results in a tensor value.
*
* @param instruction The {@link SSAAbstractInvokeInstruction} to consider.
* @param du The {@link DefUse} for the given {@link SSAAbstractInvokeInstruction}.
* @param node The {@link CGNode} containing the given {@link SSAAbstractInvokeInstruction}.
* @param src The {@link PointsToSetVariable} to add to the given {@link Set} of {@link PointsToSetVariable}s if there a tensor flows from the given {@link SSAAbstractInvokeInstruction.
* @param vn The value number in the given {@link CGNode} corresponding to the given {@link PointsToSetVariable}.
* @param sources The {@link Set} of {@link PointsToSetVariable}s representing tensor dataflow sources.
* @param pointerAnalysis The {@link PointerAnalysis} for the given {@link CGNode}.
* @return True iff given the source was added to the set.
*/
private static boolean processInstruction(
SSAAbstractInvokeInstruction instruction,
DefUse du,
CGNode node,
PointsToSetVariable src,
int vn,
Set<PointsToSetVariable> sources,
PointerAnalysis<InstanceKey> pointerAnalysis) {
boolean ret = false;

// don't consider exceptions as a data source.
if (instruction.getException() != vn) {
if (instruction
.getCallSite()
.getDeclaredTarget()
.getName()
.toString()
.equals(TENSOR_GENERATOR_SYNTHETIC_FUNCTION_NAME)) {
sources.add(src);
logger.info("Added dataflow source from tensor generator: " + src + ".");
ret = true;
} else if (instruction.getNumberOfUses() > 1) {
// Get the invoked function from the PA.
int target = instruction.getUse(0);
PointerKey targetKey = pointerAnalysis.getHeapModel().getPointerKeyForLocal(node, target);

for (InstanceKey ik : pointerAnalysis.getPointsToSet(targetKey)) {
if (ik instanceof ConcreteTypeKey) {
ConcreteTypeKey ctk = (ConcreteTypeKey) ik;
IClass type = ctk.getType();
TypeReference reference = type.getReference();

if (reference.equals(NEXT.getDeclaringClass())) {
// it's a call to `next()`. Look up the iterator definition.
int iterator = instruction.getUse(1);
SSAInstruction iteratorDef = du.getDef(iterator);

// Let's see if the iterator is over a tensor dataset. First, check the iterator
// for a dataset source. NOTE: We can only do this because `iter()` is currently
// just passing-through its argument.
if (iteratorDef != null && iteratorDef.getNumberOfUses() > 1) {
boolean added =
processInstructionInterprocedurally(
iteratorDef, iteratorDef.getDef(), node, src, sources, pointerAnalysis);

ret |= added;

if (!added && iteratorDef instanceof SSAAbstractInvokeInstruction) {
// It may be a call to `iter()`. Get the argument.
int iterArg = iteratorDef.getUse(1);
ret |=
processInstructionInterprocedurally(
iteratorDef, iterArg, node, src, sources, pointerAnalysis);
}
} else
// Use the original instruction. NOTE: We can only do this because `iter()` is
// currently just passing-through its argument.
ret |=
processInstructionInterprocedurally(
instruction, iterator, node, src, sources, pointerAnalysis);
}
}
}
}
}

return ret;
}

/**
* Processes the given {@link SSAInstruction} to decide if the given {@link PointsToSetVariable}
* is added to the given {@link Set} of {@link PointsToSetVariable}s as tensor dataflow sources.
Expand Down
34 changes: 34 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset36.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import tensorflow as tf


class C:

def __init__(self, some_iter):
self.some_iter = some_iter

def __str__(self):
return str(self.some_iter)


def id1(a):
return a


def id2(a):
return a


def gen():
yield "42", tf.constant("43")


dataset = tf.data.Dataset.from_generator(gen, output_types=(tf.string, tf.string))

my_iter = iter(dataset)
c = C(my_iter)
length = 1

for _ in range(length):
x, y = next(c.some_iter)
id1(x)
id2(y)
28 changes: 28 additions & 0 deletions com.ibm.wala.cast.python.test/data/tf2_test_dataset37.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import tensorflow as tf


class C:

def __init__(self, some_iter):
self.some_iter = some_iter

def __str__(self):
return str(self.some_iter)


def add(a, b):
return a + b


def gen_iter(dataset):
my_iter = iter(dataset)
return C(my_iter)


dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
c = gen_iter(dataset)
length = len(dataset)

for _ in range(length):
element = next(c.some_iter)
add(element, element)

0 comments on commit 02c2a6b

Please sign in to comment.