diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 135775c7dc..7fb6702c81 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -1589,9 +1589,14 @@ def make_rotary_embedding_multi_cache(self, **kwargs): self.rope_attrs["save_caches"] = False cos_cache_small, sin_cache_small = self.make_rotary_embedding_caches(cos_cache_name=cos_cache_small_name, sin_cache_name=sin_cache_small_name) - if self.ep == "dml": - # Concat small and large cos/sin caches for DML EP - # DML EP doesn't support the If operator + # Determine which EPs don't support the If operator + self.eps_without_if_support = ["dml"] + if self.extra_options.get("enable_webgpu_graph", False): + self.eps_without_if_support.append("webgpu") + + if self.ep in self.eps_without_if_support: + # Concat small and large cos/sin caches for DML and WebGPU (when graph enabled) EPs + # These EPs don't support the If operator cos_cache = torch.cat((cos_cache_small, cos_cache_large), dim=0) sin_cache = torch.cat((sin_cache_small, sin_cache_large), dim=0) # Save cos/sin caches to disk