@@ -301,7 +301,11 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
301301
302302
303303def update_tensor_lifetime (
304- node : torch .fx .Node , spec : TensorSpec , node_idx : int
304+ node : torch .fx .Node ,
305+ spec : TensorSpec ,
306+ node_idx : int ,
307+ max_node_idx : int ,
308+ gs : Optional [ExportGraphSignature ] = None ,
305309) -> None :
306310 r"""
307311 Update the lifetime of the tensor to cover node_idx. A tensor's lifetime
@@ -317,7 +321,12 @@ def update_tensor_lifetime(
317321 start = 0
318322 else :
319323 start = node_idx if start is None or start > node_idx else start
320- end = node_idx if end is None or end < node_idx else end
324+
325+ if node .op == "placeholder" and _is_mutable_buffer (node , gs ):
326+ # mutable buffers are never freed
327+ end = max_node_idx
328+ else :
329+ end = node_idx if end is None or end < node_idx else end
321330 spec .lifetime = [start , end ]
322331
323332
@@ -497,7 +506,7 @@ def update_all_tensors_lifetime(
497506 Set the lifetime for all the tensors encountered in the Fx graph.
498507 """
499508 specs = set ()
500-
509+ max_node_idx = len ( graph_module . graph . nodes ) - 1
501510 for node_idx , node in enumerate (graph_module .graph .nodes ):
502511 for spec in collect_specs_from_nodes (
503512 filter_nodes (itertools .chain ([node ], node .args , node .kwargs .values ())),
@@ -509,7 +518,7 @@ def update_all_tensors_lifetime(
509518 do_assertion = False ,
510519 ignore_dynamic_unbound_tensor = False ,
511520 ):
512- update_tensor_lifetime (node , spec , node_idx )
521+ update_tensor_lifetime (node , spec , node_idx , max_node_idx , graph_signature )
513522 specs .add (spec )
514523 return specs
515524
0 commit comments