@@ -195,7 +195,6 @@ def compile(
195195 hash_str , file_path = None , None
196196 from torch ._inductor .codecache import (FxGraphCache ,
197197 compiled_fx_graph_hash )
198-
199198 if torch .__version__ .startswith ("2.5" ):
200199 original_load = FxGraphCache .load
201200 original_load_name = "torch._inductor.codecache.FxGraphCache.load"
@@ -280,6 +279,16 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
280279 patch ("torch._inductor.codecache.FxGraphCache._get_shape_env" ,
281280 _get_shape_env ))
282281
282+ from torch ._functorch ._aot_autograd .autograd_cache import (
283+ AOTAutogradCache )
284+
285+ # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
286+ if hasattr (AOTAutogradCache , "_get_shape_env" ):
287+ stack .enter_context (
288+ patch (
289+ "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env" ,
290+ _get_shape_env ))
291+
283292 # for forcing the graph to be cached
284293 stack .enter_context (
285294 patch (
@@ -325,11 +334,19 @@ def load(self,
325334 assert isinstance (handle [1 ], str )
326335 hash_str = handle [0 ]
327336
337+ from torch ._functorch ._aot_autograd .autograd_cache import (
338+ AOTAutogradCache )
328339 from torch ._inductor .codecache import FxGraphCache
329340 with ExitStack () as exit_stack :
330341 exit_stack .enter_context (
331342 patch ("torch._inductor.codecache.FxGraphCache._get_shape_env" ,
332343 lambda * args , ** kwargs : AlwaysHitShapeEnv ()))
344+ # torch 2.8+ on main uses _get_shape_env in AOTAutogradCache
345+ if hasattr (AOTAutogradCache , "_get_shape_env" ):
346+ exit_stack .enter_context (
347+ patch (
348+ "torch._functorch._aot_autograd.autograd_cache.AOTAutogradCache._get_shape_env" ,
349+ lambda * args , ** kwargs : AlwaysHitShapeEnv ()))
333350
334351 # Dynamo metrics context, see method for more details.
335352 exit_stack .enter_context (self .metrics_context ())
0 commit comments