Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train loop takes exponentially longer as number of layers increase #8824

Open
rplsbo opened this issue Mar 12, 2025 · 5 comments
Open

Train loop takes exponentially longer as number of layers increase #8824

rplsbo opened this issue Mar 12, 2025 · 5 comments

Comments

@rplsbo
Copy link

rplsbo commented Mar 12, 2025

🐛 Bug

Training time increases exponentially with increase in number of layers.

To Reproduce

The following code takes around 16 seconds per step at layers=1000, where it takes 0.38 seconds per step at layers=500.

import os

os.environ["XLA_REGISTER_INSTALLED_PLUGINS"] = "1"
os.environ["XLA_DISABLE_FUNCTIONALIZATION"] = "1"

import torch
import torch_xla.core.xla_model as xm
import time

class SimpleModel(torch.nn.Module):
    def __init__(self, layers=1000):
        super().__init__()
        self.input_layer = torch.nn.Linear(10, 100)
        self.input_activation = torch.nn.ReLU()
        
        self.hidden_layers = torch.nn.ModuleList()
        for _ in range(layers):
            self.hidden_layers.append(torch.nn.Linear(100, 100))
            self.hidden_layers.append(torch.nn.ReLU())
        
        self.output_layer = torch.nn.Linear(100, 1)
    
    def forward(self, x):
        x = self.input_activation(self.input_layer(x))
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.output_layer(x)
        return x

def train():
    device = xm.xla_device()
    for layers in [500, 1000, 1500, 2000, 2500, 3000]:
        model = SimpleModel(layers).to(device)

        optimizer = torch.optim.Adam(model.parameters())
        loss_fn = torch.nn.MSELoss()

        step_times = []

        for i in range(10):
            start_time = time.time()
            input = torch.randn(10).to(device)
            y = torch.randn(1).to(device)
            output = model(input)
            total_loss = loss_fn(output, y)
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            xm.mark_step()
            xm.wait_device_ops()
            step_time = time.time() - start_time
            step_times.append(step_time)
        
        median_time = sorted(step_times)[len(step_times)//2]
        print(f"Median step time: {median_time:.4f}s")
        print(step_times)
    

if __name__ == "__main__":
    train()

Expected behavior

Wondering if it is expected that the training step time increases exponentially (rather than linearly) with more layers.

Environment

  • Reproducible on XLA backend [CPU/TPU/CUDA]: CUDA
  • torch_xla version: 2.5.1

Additional context

@ysiraichi
Copy link
Collaborator

Thank you for filing this issue.
Could you try using nightly PyTorch/XLA? Even better if you could also compare it with PyTorch CUDA device.

@rplsbo
Copy link
Author

rplsbo commented Mar 12, 2025

Issue does NOT happen if I use pytorch CUDA device. Code below to demonstrate issue not happening if not using XLA:

import torch
import time

class SimpleModel(torch.nn.Module):
    def __init__(self, layers=1000):
        super().__init__()
        self.input_layer = torch.nn.Linear(10, 100)
        self.input_activation = torch.nn.ReLU()
        
        self.hidden_layers = torch.nn.ModuleList()
        for _ in range(layers):
            self.hidden_layers.append(torch.nn.Linear(100, 100))
            self.hidden_layers.append(torch.nn.ReLU())
        
        self.output_layer = torch.nn.Linear(100, 1)
    
    def forward(self, x):
        x = self.input_activation(self.input_layer(x))
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.output_layer(x)
        return x

def train():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    for layers in [500, 1000, 1500, 2000, 2500, 3000]:
        model = SimpleModel(layers).to(device)

        optimizer = torch.optim.Adam(model.parameters())
        loss_fn = torch.nn.MSELoss()

        step_times = []

        for i in range(10):
            start_time = time.time()
            input = torch.randn(10).to(device)
            y = torch.randn(1).to(device)
            output = model(input)
            total_loss = loss_fn(output, y)
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            torch.cuda.synchronize()
            step_time = time.time() - start_time
            step_times.append(step_time)
        
        median_time = sorted(step_times)[len(step_times)//2]
        print(f"Layers: {layers}, Median step time: {median_time:.4f}s")
        print(f"All times: {step_times}")
    

if __name__ == "__main__":
    train()

The issue is slightly different when running on a TPU v6e. With the nightly build, I see compile times getting pretty large in the initial steps, but the execution time only increases exponentially in the stable build (2.6.0), but not the nightly build (2.7.0).

That seems to suggest this is fixed in a later build. Any idea what the fix could be?

@ysiraichi
Copy link
Collaborator

If I understood it correctly, you said that the issue seems to be solved in the nightly build, using TPU v6e, correct? Is this also true for CUDA (accelerator used in the original post)?

Any idea what the fix could be?

Unfortunately, I don't know.
@miladm @tengyifei @bhavya01 @ManfeiBai Do you know what could have caused this?

@tengyifei
Copy link
Collaborator

The easiest way to tell what's going on is to take a profile by following one of our profiles guides. That can show you which operations are executing on the TPU and why it's taking so long.

@qihqi
Copy link
Collaborator

qihqi commented Mar 17, 2025

This benchmark actually includes the compile time. So if you want just time the runtime you can exclude the time of the first 2 runs in the inner forloop. (note i also moved time to move to device out of timing)

i.e.

        for i in range(10):
            input = torch.randn(10).to(device)
            y = torch.randn(1).to(device)
            xm.mark_step(True)
            xm.wait_device_ops() # make sure transfer is complete
            start_time = time.time()
            output = model(input)
            total_loss = loss_fn(output, y)
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()
            xm.mark_step()
            xm.wait_device_ops()
            step_time = time.time() - start_time
            if i >= 3:
              step_times.append(step_time)

The first time when a graph runs, it will trace the graph with LazyTensor, generate XLA HLO, then give XLA to compile it.
XLA compile time is indeed super-linear w.r.t number of nodes in the graph (I believe it's quadratic, not exponential).

If you want to time the compile time along, a good approximation is time of runing a model for the first time - time to run the same model for 10th time. If you plot out the time curve it should be quadratic.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants