-
Notifications
You must be signed in to change notification settings - Fork 480
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scan and apply_layers #7901
base: master
Are you sure you want to change the base?
scan and apply_layers #7901
Conversation
you can't just import it, you need to setup import dir correctly. Take a look at https://github.com/pytorch/xla/blob/master/test/dynamo/test_dynamo_dynamic_shape.py#L1-L6 |
@JackCaoG ty. i followed your example and got it working. |
import json | ||
hlo_json = json.loads(ctx.hlo_json()) | ||
num_parameters = len(hlo_json["hostProgramShape"]["parameters"]) | ||
self.assertEqual(len(mapping), num_parameters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you expect both value to be 10?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately not. It looks like some integer values (e.g. values <= 2) are shared when you put multiple copies into the HLO, but values above 2 are not shared. So we don't necessarily get 10. In any case, the precise number of parameters seems to be an implementation detail that we can't reliably test.
@@ -1077,7 +1076,9 @@ class PyLoweringContext { | |||
at::ScalarType dtype = | |||
MaybeUpcastToHostTorchType(literal.shape().element_type()); | |||
at::Tensor input = MakeTensorFromXlaLiteral(literal, dtype); | |||
results[param_ids[i]] = input; | |||
std::optional param_id = lowering_ctx.GetParameterId(device_data[i]); | |||
XLA_CHECK(param_id.has_value()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
when would it not has value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When GetParameterId
receives a BackendData
that is not a parameter in this lowering context, it will return std::nullopt
. However, this loop is only iterating over parameters (line 1071, const std::vector<torch::lazy::BackendDataPtr>& device_data = lowering_ctx.GetParametersData();
), so we will expect all BackendData
there to have an ID. Seems good to enforce this invariant.
return input_data | ||
|
||
# Extract and stack the parameters into a pytree. | ||
params = [_extract_weights_dict(layer) for layer in layers] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if it is a dropout layer that parameters are more than just tensors?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is a dropout layer that references tensors other than model parameters (for example, the dropout probability), then those tensors will be captured as an additional HLO parameter to the XlaComputation
object. As implemented now, apply_layers
and scan
will trace the first layer, and then use the same captured tensor for subsequent layers. This will be a problem if the user passes different dropout probabilities for say a sequence of dropout layers -- we'll instead incorrectly just keep using the first dropout's probability. I'll have to dig deeper and find a solution for this.
If there's a layer that references things other than tensors, then either that thing (e.g. a bool
field) will impact the traced HLO computation, in which case I need to add a verification that all layers trace to equivalent computations. Or that thing won't impact the traced computation, in which case it won't matter to us.
example_layer = deepcopy(next(iter(layers))) | ||
|
||
# Hollow out the weights and biases in the example layer. | ||
example_layer = example_layer.to_empty(device=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this not going to impact the cloned arg?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify this question -- I thought to_empty
is going to destroy the value inside example_layer
, so I deepcopy
it before to backup.
fn_output_carry_pytree, fn_output_y_pytree = flat_fn(*(fake_carry + fake_x)) | ||
|
||
# Later we'll use `fn_output_carry_spec` etc to turn flattened outputs back to a PyTree. | ||
fn_output_carry, fn_output_carry_spec = tree_flatten(fn_output_carry_pytree) | ||
assert fn_output_carry_spec == carry_spec | ||
fn_output_y, fn_output_y_spec = tree_flatten(fn_output_y_pytree) | ||
flat_y_len = len(fn_output_y) | ||
fn_outputs = fn_output_carry + fn_output_y |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if there are in place updates to the tensor but it is not being return from the function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this and we give the wrong answer: https://github.com/tengyifei/playground/blob/master/scan_with_in_place_updates.ipynb
In the notebook, I wrote an approach to detect and prevent in place updates like that. TLDR is we'll have to trace every forward of each layer and verify that they're the same.
|
||
def step_fn(grad_carry, pytree: Tuple[torch.Tensor, torch.Tensor, | ||
torch.Tensor]): | ||
grad_y, carry, x = pytree |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this a typo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so -- pytree
is a tuple of the output grad at current step (grad_y
), carry at the current step (carry
), and input at current step (x
)
carry, carry_history, ys = _scan_impl(fn, init, xs) | ||
flat_carry_history, carry_spec = tree_flatten(carry_history) | ||
flat_xs, xs_spec = tree_flatten(xs) | ||
ctx.save_for_backward(*flat_carry_history, *flat_xs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are flat_carry_history
flat_xs
always everything we need to save for the backward?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. If the user fn
references other tensors, they'll be captured as additional inputs to the HLO computation. I'll update the documentation to mention that we'll explicitly checkpoint fn
in this iteration, and add the right barriers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the barriers won't be needed. This is the recommendation from JAX, which says you don't want to wrap inputs into a barrier if the checkpointed function is to be used in a scan
.
outputs = fn(*detached_inputs) | ||
output_carry, output_y = outputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait... you are retracing the fwd here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct. I'm not sure if the CSE pass can combine this with the same fn in the fwd pass. If it can't, then the bwd of scan will be slower.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ehh, what you do here is pretty much force the gradident accumulation(through most likely cancel by the CSE since there is no optimization barrier), this sounds like a bad idea
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the barriers won't be needed. This is the recommendation from JAX, which says you don't want to wrap inputs into a barrier if the checkpointed function is to be used in a scan
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See prevent_cse
documentation in https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html
4c4d127
to
67cff9b
Compare
@tengyifei is this PR a 2.5 candidate? |
@miladm yes, I'd like to backport this to 2.5 after addressing the comments etc. |
ddd01a4
to
ea640ab
Compare
2f868fd
to
99d8e5a
Compare
Add the lowering of scan to HLO While op. Introduce apply_layers which can sequentially apply a bunch of layers using scan underneath. Beef up unit tests including linear layers and decoders. add regression test for parameter_id_tensor_mapping add test_apply_layers.py to test shell scripts correctly import decoder model from examples
99d8e5a
to
276cff6
Compare
Add the lowering of scan to HLO While op.
Introduce apply_layers which can sequentially apply a bunch of layers
using scan underneath.
Beef up unit tests including linear layers and decoders.