diff --git a/packages/@aws-cdk/aws-stepfunctions/README.md b/packages/@aws-cdk/aws-stepfunctions/README.md index 6f92f68646923..534ce7e1f3aa1 100644 --- a/packages/@aws-cdk/aws-stepfunctions/README.md +++ b/packages/@aws-cdk/aws-stepfunctions/README.md @@ -728,6 +728,10 @@ new stepfunctions.Parallel(this, 'All jobs') .branch(new MyJob(this, 'Slow', { jobFlavor: 'slow' }).prefixStates()); ``` +A few utility functions are available to parse state machine fragments. +* `State.findReachableStates`: Retrieve the list of states reachable from a given state. +* `State.findReachableEndStates`: Retrieve the list of end or terminal states reachable from a given state. + ## Activity **Activities** represent work that is done on some non-Lambda worker pool. The diff --git a/packages/@aws-cdk/aws-stepfunctions/lib/states/state.ts b/packages/@aws-cdk/aws-stepfunctions/lib/states/state.ts index e2df07b5d6fab..5f69d2af2d80b 100644 --- a/packages/@aws-cdk/aws-stepfunctions/lib/states/state.ts +++ b/packages/@aws-cdk/aws-stepfunctions/lib/states/state.ts @@ -73,10 +73,29 @@ export abstract class State extends cdk.Construct implements IChainable { } } + /** + * Find the set of states reachable through transitions from the given start state. + * This does not retrieve states from within sub-graphs, such as states within a Parallel state's branch. + */ + public static findReachableStates(start: State, options: FindStateOptions = {}): State[] { + const visited = new Set(); + const ret = new Set(); + const queue = [start]; + while (queue.length > 0) { + const state = queue.splice(0, 1)[0]!; + if (visited.has(state)) { continue; } + visited.add(state); + const outgoing = state.outgoingTransitions(options); + queue.push(...outgoing); + ret.add(state); + } + return Array.from(ret); + } + /** * Find the set of end states states reachable through transitions from the given start state */ - public static findReachableEndStates(start: State, options: FindStateOptions = {}) { + public static findReachableEndStates(start: State, options: FindStateOptions = {}): State[] { const visited = new Set(); const ret = new Set(); const queue = [start]; diff --git a/packages/@aws-cdk/aws-stepfunctions/test/states-language.test.ts b/packages/@aws-cdk/aws-stepfunctions/test/states-language.test.ts index 55adf75666207..de232deffd531 100644 --- a/packages/@aws-cdk/aws-stepfunctions/test/states-language.test.ts +++ b/packages/@aws-cdk/aws-stepfunctions/test/states-language.test.ts @@ -614,6 +614,78 @@ describe('States Language', () => { expect(() => new stepfunctions.Parallel(stack, 'Parallel') .branch(state1.next(state2)) .branch(state2)).toThrow(); + }), + + describe('findReachableStates', () => { + + test('Can retrieve possible states from initial state', () => { + // GIVEN + const stack = new cdk.Stack(); + const state1 = new stepfunctions.Pass(stack, 'State1'); + const state2 = new stepfunctions.Pass(stack, 'State2'); + const state3 = new stepfunctions.Pass(stack, 'State3'); + + const definition = state1 + .next(state2) + .next(state3); + + // WHEN + const states = stepfunctions.State.findReachableStates(definition.startState); + + // THEN + expect(state1.id).toStrictEqual(states[0].id); + expect(state2.id).toStrictEqual(states[1].id); + expect(state3.id).toStrictEqual(states[2].id); + }); + + test('Does not retrieve unreachable states', () => { + // GIVEN + const stack = new cdk.Stack(); + const state1 = new stepfunctions.Pass(stack, 'State1'); + const state2 = new stepfunctions.Pass(stack, 'State2'); + const state3 = new stepfunctions.Pass(stack, 'State3'); + + state1.next(state2).next(state3); + + // WHEN + const states = stepfunctions.State.findReachableStates(state2); + + // THEN + expect(state2.id).toStrictEqual(states[0].id); + expect(state3.id).toStrictEqual(states[1].id); + expect(states.length).toStrictEqual(2); + }); + + test('Works with Choice and Parallel states', () => { + // GIVEN + const stack = new cdk.Stack(); + const state1 = new stepfunctions.Choice(stack, 'MainChoice'); + const stateCA = new stepfunctions.Pass(stack, 'StateA'); + const stateCB = new stepfunctions.Pass(stack, 'StateB'); + const statePA = new stepfunctions.Pass(stack, 'ParallelA'); + const statePB = new stepfunctions.Pass(stack, 'ParallelB'); + const state2 = new stepfunctions.Parallel(stack, 'RunParallel'); + const state3 = new stepfunctions.Pass(stack, 'FinalState'); + state2.branch(statePA); + state2.branch(statePB); + state1.when(stepfunctions.Condition.stringEquals('$.myInput', 'A' ), stateCA); + state1.when(stepfunctions.Condition.stringEquals('$.myInput', 'B'), stateCB); + stateCA.next(state2); + state2.next(state3); + + const definition = state1.otherwise(stateCA); + + // WHEN + const statesFromStateCB = stepfunctions.State.findReachableStates(stateCB); + const statesFromState1 = stepfunctions.State.findReachableStates(definition); + + // THEN + const expectedFromState1 = [state1, stateCA, stateCB, state2, state3]; + for (let i = 0; i < expectedFromState1.length; i++) { + expect(statesFromState1[i].id).toStrictEqual(expectedFromState1[i].id); + } + expect(statesFromStateCB[0].id).toStrictEqual(stateCB.id); + }); }); });