Skip to content

Commit ac3b330

Browse files
authored
feat(stepfunctions): retrieve all reachable states from a given state in a state machine definition (#7324)
A new method - `findReachableStates` - to retrieve all reachable states subsequent to a given initial state, including itself. Motivation With this method, developers can programatically modify their state definitions, given an initial state. For example, walk through the set of states and attach some of them to a 'ResumeTo' Choice state, so that state machine executions can jump directly to one of these states. closes #7256
1 parent 51aecde commit ac3b330

File tree

3 files changed

+96
-1
lines changed

3 files changed

+96
-1
lines changed

packages/@aws-cdk/aws-stepfunctions/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,10 @@ new stepfunctions.Parallel(this, 'All jobs')
728728
.branch(new MyJob(this, 'Slow', { jobFlavor: 'slow' }).prefixStates());
729729
```
730730

731+
A few utility functions are available to parse state machine fragments.
732+
* `State.findReachableStates`: Retrieve the list of states reachable from a given state.
733+
* `State.findReachableEndStates`: Retrieve the list of end or terminal states reachable from a given state.
734+
731735
## Activity
732736

733737
**Activities** represent work that is done on some non-Lambda worker pool. The

packages/@aws-cdk/aws-stepfunctions/lib/states/state.ts

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,29 @@ export abstract class State extends cdk.Construct implements IChainable {
7373
}
7474
}
7575

76+
/**
77+
* Find the set of states reachable through transitions from the given start state.
78+
* This does not retrieve states from within sub-graphs, such as states within a Parallel state's branch.
79+
*/
80+
public static findReachableStates(start: State, options: FindStateOptions = {}): State[] {
81+
const visited = new Set<State>();
82+
const ret = new Set<State>();
83+
const queue = [start];
84+
while (queue.length > 0) {
85+
const state = queue.splice(0, 1)[0]!;
86+
if (visited.has(state)) { continue; }
87+
visited.add(state);
88+
const outgoing = state.outgoingTransitions(options);
89+
queue.push(...outgoing);
90+
ret.add(state);
91+
}
92+
return Array.from(ret);
93+
}
94+
7695
/**
7796
* Find the set of end states states reachable through transitions from the given start state
7897
*/
79-
public static findReachableEndStates(start: State, options: FindStateOptions = {}) {
98+
public static findReachableEndStates(start: State, options: FindStateOptions = {}): State[] {
8099
const visited = new Set<State>();
81100
const ret = new Set<State>();
82101
const queue = [start];

packages/@aws-cdk/aws-stepfunctions/test/states-language.test.ts

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,78 @@ describe('States Language', () => {
614614
expect(() => new stepfunctions.Parallel(stack, 'Parallel')
615615
.branch(state1.next(state2))
616616
.branch(state2)).toThrow();
617+
}),
618+
619+
describe('findReachableStates', () => {
620+
621+
test('Can retrieve possible states from initial state', () => {
622+
// GIVEN
623+
const stack = new cdk.Stack();
624+
const state1 = new stepfunctions.Pass(stack, 'State1');
625+
const state2 = new stepfunctions.Pass(stack, 'State2');
626+
const state3 = new stepfunctions.Pass(stack, 'State3');
627+
628+
const definition = state1
629+
.next(state2)
630+
.next(state3);
631+
632+
// WHEN
633+
const states = stepfunctions.State.findReachableStates(definition.startState);
634+
635+
// THEN
636+
expect(state1.id).toStrictEqual(states[0].id);
637+
expect(state2.id).toStrictEqual(states[1].id);
638+
expect(state3.id).toStrictEqual(states[2].id);
639+
});
640+
641+
test('Does not retrieve unreachable states', () => {
642+
// GIVEN
643+
const stack = new cdk.Stack();
644+
const state1 = new stepfunctions.Pass(stack, 'State1');
645+
const state2 = new stepfunctions.Pass(stack, 'State2');
646+
const state3 = new stepfunctions.Pass(stack, 'State3');
647+
648+
state1.next(state2).next(state3);
649+
650+
// WHEN
651+
const states = stepfunctions.State.findReachableStates(state2);
652+
653+
// THEN
654+
expect(state2.id).toStrictEqual(states[0].id);
655+
expect(state3.id).toStrictEqual(states[1].id);
656+
expect(states.length).toStrictEqual(2);
657+
});
658+
659+
test('Works with Choice and Parallel states', () => {
660+
// GIVEN
661+
const stack = new cdk.Stack();
662+
const state1 = new stepfunctions.Choice(stack, 'MainChoice');
663+
const stateCA = new stepfunctions.Pass(stack, 'StateA');
664+
const stateCB = new stepfunctions.Pass(stack, 'StateB');
665+
const statePA = new stepfunctions.Pass(stack, 'ParallelA');
666+
const statePB = new stepfunctions.Pass(stack, 'ParallelB');
667+
const state2 = new stepfunctions.Parallel(stack, 'RunParallel');
668+
const state3 = new stepfunctions.Pass(stack, 'FinalState');
669+
state2.branch(statePA);
670+
state2.branch(statePB);
671+
state1.when(stepfunctions.Condition.stringEquals('$.myInput', 'A' ), stateCA);
672+
state1.when(stepfunctions.Condition.stringEquals('$.myInput', 'B'), stateCB);
673+
stateCA.next(state2);
674+
state2.next(state3);
675+
676+
const definition = state1.otherwise(stateCA);
677+
678+
// WHEN
679+
const statesFromStateCB = stepfunctions.State.findReachableStates(stateCB);
680+
const statesFromState1 = stepfunctions.State.findReachableStates(definition);
681+
682+
// THEN
683+
const expectedFromState1 = [state1, stateCA, stateCB, state2, state3];
684+
for (let i = 0; i < expectedFromState1.length; i++) {
685+
expect(statesFromState1[i].id).toStrictEqual(expectedFromState1[i].id);
686+
}
687+
expect(statesFromStateCB[0].id).toStrictEqual(stateCB.id);
688+
});
617689
});
618690
});
619691

0 commit comments

Comments
 (0)