Cleaning up scheduling of static repeat#4325
Conversation
|
!test --diff |
|
Review updated until commit e803fa0 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
|
!test --diff |
Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
|
!test --diff |
|
!test --diff |
|
!test --diff |
| auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()}); | ||
| auto tv2 = repeat(tv1, {2, 1}); | ||
| fusion.addOutput(tv2); | ||
| auto tv3 = segment_set(tv2); |
There was a problem hiding this comment.
Just to make sure the transformation is applied by ignoring plain set ops.
There was a problem hiding this comment.
Out of curiosity, is there any specific reason we are using a segment_set, instead of a plain set?
There was a problem hiding this comment.
Is it to ensure this:
// It is especially important to recognize this pattern when it
// appears at the end of a pointwise fusion segment, where an output
// is used as the reference tensor of scheduling the segment.
If so, should the caching ops be mixed inside the repeat?
There was a problem hiding this comment.
Though I guess this was the end of the fusion anyway - so that may not make much sense.
There was a problem hiding this comment.
Just because I saw a segment ending with segment_set, preceded by a repetition. I think that's one segment of a LitGPT Llama forward.
|
!test --diff |
jjsjann123
left a comment
There was a problem hiding this comment.
LGTM in general, I have quite some questions, but those are mostly just for my own curiosity.
| auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()}); | ||
| auto tv2 = repeat(tv1, {2, 1}); | ||
| fusion.addOutput(tv2); | ||
| auto tv3 = segment_set(tv2); |
There was a problem hiding this comment.
Out of curiosity, is there any specific reason we are using a segment_set, instead of a plain set?
| const auto& [tvs_with_repeat_id, tvs_without_repeat_id] = partitionTvsById( | ||
| all_tvs, | ||
| repeat_info->factor_id, | ||
| id_model->maybeBuildGraph(IdMappingMode::BROADCAST)); |
There was a problem hiding this comment.
naive question for my own curiosity, do we need to map BROADCAST?
In the added example
1708 std::vector<int64_t> shape1{3, 1, 200};
1709
1710 auto tv0 = makeContigConcreteTensor(shape1);
1711 fusion.addInput(tv0);
1712
1713 auto tv1 = pad(tv0, {fusion.oneVal(), fusion.oneVal()});
1714 auto tv2 = expand(
1715 tv1,
1716 {IrBuilder::create<Val>(-1),
1717 IrBuilder::create<Val>(2),
1718 IrBuilder::create<Val>(-1)});
1719 auto tv3 =
1720 reshape(tv2, {IrBuilder::create<Val>(6), IrBuilder::create<Val>(-1)});
1721 fusion.addOutput(tv3);
Say for tv1 [i0, b(1), i2], after the expand, we would have tv2 [i0, b(2), i2]
The two broadcast ID in tv1 and tv2 would have different extent.
Q1. IIUC, mapping with broadcast would allow us map those two together?
Q2. Does it matter for us to group tv1 with the tvs_with_repeat_id, even though the it only contains the non-expanded factor_id?
There was a problem hiding this comment.
EXACT should work too. Previously, we schedule those tensors like tv1 and tv2 together with tv3, so using BROADCAST keeps the same behavior. I don't think there should be any actual difference in final performances.
| reshape_out = ldst->in()->as<TensorView>(); | ||
| repeat_tvs.insert(reshape_out); | ||
| TensorView* maybe_repeat_out_tv) { | ||
| // Skip set ops if any (e.g., inserted by caching). Only Set |
There was a problem hiding this comment.
I thought the deleted comment here was helpful - the bit about skipping caching ops.
| // output, it is likely there's a cache tv between expand_out and | ||
| // repeat_out, so the following pattern should also be detected. | ||
| // | ||
| // broadcast_out = broadcast(input) |
There was a problem hiding this comment.
These 4 lines were helpful.
There was a problem hiding this comment.
Not sure which 4 lines, but broadcast is no longer required.
| @@ -352,35 +397,33 @@ void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { | |||
| // detected. The repeat ID then just remains there with no | |||
There was a problem hiding this comment.
nit: move the def of repeat_info here - near the use.
There was a problem hiding this comment.
I'd keep it there as ref_tv is going to be transformed after that. It shouldn't affect the analysis, but there's no need to introduce an additional complexity.
|
!build |
|
!build |
Stacked on top of #4325 If a repeat is moved to the end of a segment, the resize scheduler will take advantage of it.
The resize scheduler automatically detects a sequence of ops to repeat a tensor at a certain ID and slightly modifies the scheduling to reduce redundant computations. This PR makes the analysis a little more flexible so that it also works with a pattern appearing in a Llama forward RoPE module.
Specifically, previously, the specific scheduling is only applied when a sequence of
BroadcastOp, ExpandOp and ViewOpare detected in this order, just because that's how a repetition of a tensor is commonly represented. However, the only op that is absolutely necessary is the final reshape. As long as it meets the patterns for repetition, it should be sufficient to apply the scheduling. In fact, in a segment of a Llama RoPE forward, there's a segment input that has a broadcast ID, which is then expanded inside the segment and merged to realize a repetition. This case is not detectable as the segment lacks a BroadcastOp in the current main but is detected with this PR.