|
| 1 | +# Adopt from https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html |
| 2 | +import torch |
| 3 | +import torch.nn |
| 4 | +import torch.optim |
| 5 | +import torch.profiler |
| 6 | +import torch.utils.data |
| 7 | +import torchvision.datasets |
| 8 | +import torchvision.models |
| 9 | +import torchvision.transforms as T |
| 10 | +from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights |
| 11 | + |
| 12 | +transform = T.Compose( |
| 13 | + [T.Resize(224), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
| 14 | +) |
| 15 | +train_set = torchvision.datasets.CIFAR10( |
| 16 | + root="./data", train=True, download=True, transform=transform |
| 17 | +) |
| 18 | +train_loader = torch.utils.data.DataLoader(train_set, batch_size=2, shuffle=True) |
| 19 | + |
| 20 | +device = torch.device("cuda:0") |
| 21 | +model = efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT).to(device) |
| 22 | +criterion = torch.nn.CrossEntropyLoss().cuda(device) |
| 23 | +optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) |
| 24 | +model.train() |
| 25 | + |
| 26 | + |
| 27 | +def train(data): |
| 28 | + inputs, labels = data[0].to(device=device), data[1].to(device=device) |
| 29 | + outputs = model(inputs) |
| 30 | + loss = criterion(outputs, labels) |
| 31 | + optimizer.zero_grad() |
| 32 | + loss.backward() |
| 33 | + optimizer.step() |
| 34 | + |
| 35 | + |
| 36 | +with torch.profiler.profile( |
| 37 | + schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), |
| 38 | + on_trace_ready=torch.profiler.tensorboard_trace_handler( |
| 39 | + "/home/envd/log/efficientnet" |
| 40 | + ), |
| 41 | + record_shapes=True, |
| 42 | + profile_memory=True, |
| 43 | + with_stack=True, |
| 44 | +) as prof: |
| 45 | + for step, batch_data in enumerate(train_loader): |
| 46 | + if step >= (1 + 1 + 3) * 2: |
| 47 | + break |
| 48 | + train(batch_data) |
| 49 | + prof.step() # Need to call this at the end of each step to notify profiler of steps' boundary. |
| 50 | + |
| 51 | + |
| 52 | +prof = torch.profiler.profile( |
| 53 | + schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2), |
| 54 | + on_trace_ready=torch.profiler.tensorboard_trace_handler( |
| 55 | + "/home/envd/log/efficientnet" |
| 56 | + ), |
| 57 | + record_shapes=True, |
| 58 | + with_stack=True, |
| 59 | +) |
| 60 | +prof.start() |
| 61 | +for step, batch_data in enumerate(train_loader): |
| 62 | + if step >= (1 + 1 + 3) * 2: |
| 63 | + break |
| 64 | + train(batch_data) |
| 65 | + prof.step() |
| 66 | +prof.stop() |
0 commit comments