|
| 1 | +import lbann |
| 2 | +import numpy as np |
| 3 | +import test_util |
| 4 | +import pytest |
| 5 | +import os |
| 6 | +import sys |
| 7 | +import lbann.contrib.launcher |
| 8 | +import lbann.contrib.args |
| 9 | + |
| 10 | +# Bamboo utilities |
| 11 | +current_file = os.path.realpath(__file__) |
| 12 | +current_dir = os.path.dirname(current_file) |
| 13 | +sys.path.insert(0, os.path.join(os.path.dirname(current_dir), 'common_python')) |
| 14 | +import tools |
| 15 | + |
| 16 | +@pytest.mark.parametrize('num_dims', [2, 3]) |
| 17 | +@test_util.lbann_test(check_gradients=True, |
| 18 | + environment=lbann.contrib.args.get_distconv_environment(), |
| 19 | + time_limit=10) |
| 20 | +def test_simple(num_dims): |
| 21 | + try: |
| 22 | + import torch |
| 23 | + import torch.nn as nn |
| 24 | + except: |
| 25 | + pytest.skip('PyTorch is required to run this test.') |
| 26 | + |
| 27 | + torch.manual_seed(20240216) |
| 28 | + # Two samples of 4x16x16 or 4x16x16x16 tensors |
| 29 | + shape = [2, 4] + [16] * num_dims |
| 30 | + x = torch.randn(shape) |
| 31 | + if num_dims == 2: |
| 32 | + ConvClass = nn.Conv2d |
| 33 | + kerenel_size = (3, 1) |
| 34 | + padding = (1, 0) |
| 35 | + group_name = 'height_groups' |
| 36 | + else: |
| 37 | + ConvClass = nn.Conv3d |
| 38 | + kerenel_size = (5, 3, 1) |
| 39 | + padding = (2, 1, 0) |
| 40 | + group_name = 'depth_groups' |
| 41 | + |
| 42 | + conv = ConvClass(4, 8, kerenel_size, padding=padding, bias=False) |
| 43 | + with torch.no_grad(): |
| 44 | + ref = conv(x) |
| 45 | + |
| 46 | + tester = test_util.ModelTester() |
| 47 | + x = tester.inputs(x.numpy()) |
| 48 | + ref = tester.make_reference(ref.numpy()) |
| 49 | + |
| 50 | + # Test layer |
| 51 | + kernel = conv.weight.detach().numpy() |
| 52 | + kernel_weights = lbann.Weights( |
| 53 | + initializer=lbann.ValueInitializer(values=np.nditer(kernel)), |
| 54 | + name=f'kernel_{num_dims}d' |
| 55 | + ) |
| 56 | + ps = {group_name: tools.gpus_per_node(lbann)} |
| 57 | + y = lbann.Convolution( |
| 58 | + x, |
| 59 | + weights=(kernel_weights,), |
| 60 | + num_dims=num_dims, |
| 61 | + out_channels=8, |
| 62 | + kernel_size=kerenel_size, |
| 63 | + stride=1, |
| 64 | + padding=padding, |
| 65 | + dilation=1, |
| 66 | + has_bias=False, |
| 67 | + parallel_strategy=ps, |
| 68 | + name=f'conv_{num_dims}d' |
| 69 | + ) |
| 70 | + y = lbann.Identity(y) |
| 71 | + tester.set_loss(lbann.MeanSquaredError(y, ref)) |
| 72 | + tester.set_check_gradients_tensor(lbann.Square(y)) |
| 73 | + return tester |
0 commit comments