Skip to content

Conversation

kasparas-k
Copy link
Contributor

In the current version, there is a race condition when there are Geometry objects in the global scope as the Python interpreter is exiting. If torch gets freed before the Geometry objects, the deallocator will fail because torch is None at that point in the program. This is not a crucial error, since all the work has already been done and the program is exiting, but it's confusing. Here is the error:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/warpconvnet/utils/cupy_alloc.py", line 22, in _torch_free
AttributeError: 'NoneType' object has no attribute 'cuda'
Exception ignored in: 'cupy.cuda.memory.PythonFunctionAllocatorMemory.__dealloc__'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/warpconvnet/utils/cupy_alloc.py", line 22, in _torch_free
AttributeError: 'NoneType' object has no attribute 'cuda'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/warpconvnet/utils/cupy_alloc.py", line 22, in _torch_free
AttributeError: 'NoneType' object has no attribute 'cuda'
Exception ignored in: 'cupy.cuda.memory.PythonFunctionAllocatorMemory.__dealloc__'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/warpconvnet/utils/cupy_alloc.py", line 22, in _torch_free
AttributeError: 'NoneType' object has no attribute 'cuda'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/warpconvnet/utils/cupy_alloc.py", line 22, in _torch_free
AttributeError: 'NoneType' object has no attribute 'cuda'
Exception ignored in: 'cupy.cuda.memory.PythonFunctionAllocatorMemory.__dealloc__'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/warpconvnet/utils/cupy_alloc.py", line 22, in _torch_free
AttributeError: 'NoneType' object has no attribute 'cuda'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/warpconvnet/utils/cupy_alloc.py", line 22, in _torch_free
AttributeError: 'NoneType' object has no attribute 'cuda'
Exception ignored in: 'cupy.cuda.memory.PythonFunctionAllocatorMemory.__dealloc__'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/warpconvnet/utils/cupy_alloc.py", line 22, in _torch_free
AttributeError: 'NoneType' object has no attribute 'cuda'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/warpconvnet/utils/cupy_alloc.py", line 22, in _torch_free
AttributeError: 'NoneType' object has no attribute 'cuda'
Exception ignored in: 'cupy.cuda.memory.PythonFunctionAllocatorMemory.__dealloc__'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/warpconvnet/utils/cupy_alloc.py", line 22, in _torch_free
AttributeError: 'NoneType' object has no attribute 'cuda'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/warpconvnet/utils/cupy_alloc.py", line 22, in _torch_free
AttributeError: 'NoneType' object has no attribute 'cuda'
Exception ignored in: 'cupy.cuda.memory.PythonFunctionAllocatorMemory.__dealloc__'
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/warpconvnet/utils/cupy_alloc.py", line 22, in _torch_free
AttributeError: 'NoneType' object has no attribute 'cuda'

or alternatively, depending on the stage of the exiting:

Error in sys.excepthook:

Original exception was:
Error in sys.excepthook:

Original exception was:
Error in sys.excepthook:

Original exception was:
Error in sys.excepthook:

Original exception was:
Error in sys.excepthook:

Original exception was:
Error in sys.excepthook:

Original exception was:

Minimal example with PointTransformerV3 from the examples (pytest import in the original example makes torch get freed later, so I re-add the example with only strictly necessary imports)

from typing import Literal, Optional, Tuple

import torch
import torch.nn as nn

from warpconvnet.geometry.base.geometry import Geometry
from warpconvnet.geometry.coords.ops.serialization import POINT_ORDERING
from warpconvnet.geometry.types.points import Points
from warpconvnet.nn.modules.activations import GELU, DropPath
from warpconvnet.nn.modules.attention import FeedForward, PatchAttention
from warpconvnet.nn.modules.base_module import BaseSpatialModel, BaseSpatialModule
from warpconvnet.nn.modules.mlp import Linear
from warpconvnet.nn.modules.normalizations import LayerNorm
from warpconvnet.nn.modules.sequential import Sequential, TupleSequential
from warpconvnet.nn.modules.sparse_conv import SparseConv3d
from warpconvnet.nn.modules.sparse_pool import SparseMaxPool, SparseUnpool


class PatchAttentionBlock(BaseSpatialModule):
    def __init__(
        self,
        in_channels: int,
        attention_channels: int,
        patch_size: int,
        num_heads: int,
        kernel_size: int = 3,
        mlp_ratio: float = 4.0,
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        drop_path: float = 0.0,
        norm_layer: type = LayerNorm,
        act_layer: type = GELU,
        attn_type: Literal["patch"] = "patch",
        order: POINT_ORDERING = POINT_ORDERING.MORTON_XYZ,
    ):
        super().__init__()
        self.order = order
        self.conv = Sequential(
            SparseConv3d(
                in_channels,
                in_channels,
                kernel_size=kernel_size,
                stride=1,
                bias=True,
            ),
            nn.Linear(in_channels, attention_channels),
            norm_layer(attention_channels),
        )
        self.conv_shortcut = (
            nn.Identity()
            if in_channels == attention_channels
            else Linear(in_channels, attention_channels)
        )

        self.norm1 = norm_layer(attention_channels)
        self.attention = PatchAttention(
            attention_channels,
            patch_size=patch_size,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=proj_drop,
            order=order,
        )
        self.norm2 = norm_layer(attention_channels)
        self.mlp = FeedForward(
            dim=attention_channels,
            hidden_dim=int(attention_channels * mlp_ratio),
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x: Geometry, order: Optional[POINT_ORDERING | str] = None) -> Geometry:
        x = self.conv(x) + self.conv_shortcut(x)

        # Attention block
        x = self.drop_path(self.attention(self.norm1(x), order)) + x

        # MLP block
        x = self.drop_path(self.mlp(self.norm2(x))) + x
        return x


class PointTransformerV3(BaseSpatialModel):
    def __init__(
        self,
        in_channels: int = 6,
        enc_depths: Tuple[int, ...] = (2, 2, 2, 6, 2),
        enc_channels: Tuple[int, ...] = (32, 64, 128, 256, 512),
        enc_num_head: Tuple[int, ...] = (2, 4, 8, 16, 32),
        enc_patch_size: Tuple[int, ...] = (1024, 1024, 1024, 1024, 1024),
        dec_depths: Tuple[int, ...] = (2, 2, 2, 2),
        dec_channels: Tuple[int, ...] = (64, 64, 128, 256),
        dec_num_head: Tuple[int, ...] = (4, 4, 8, 16),
        dec_patch_size: Tuple[int, ...] = (1024, 1024, 1024, 1024),
        mlp_ratio: float = 4,
        qkv_bias: bool = True,
        qk_scale: Optional[float] = None,
        attn_drop: float = 0.0,
        proj_drop: float = 0.0,
        drop_path: float = 0.2,
        orders: Tuple[POINT_ORDERING, ...] = tuple(POINT_ORDERING),
        shuffle_orders: bool = True,
        attn_type: Literal["patch"] = "patch",
        **kwargs,
    ):
        super().__init__()

        num_level = len(enc_depths)
        assert num_level == len(enc_channels)
        assert num_level == len(enc_num_head)
        assert num_level == len(enc_patch_size)

        assert num_level - 1 == len(dec_channels)
        assert num_level - 1 == len(dec_depths)
        assert num_level - 1 == len(dec_num_head)
        assert num_level - 1 == len(dec_patch_size)
        self.num_level = num_level
        self.shuffle_orders = shuffle_orders
        self.orders = orders

        self.conv = Sequential(
            SparseConv3d(
                in_channels,
                enc_channels[0],
                kernel_size=5,
            ),
            nn.BatchNorm1d(enc_channels[0]),
            nn.GELU(),
        )

        encs = nn.ModuleList()
        down_convs = nn.ModuleList()
        for i in range(num_level):
            level_blocks = nn.ModuleList(
                [
                    PatchAttentionBlock(
                        in_channels=enc_channels[i],
                        attention_channels=enc_channels[i],
                        patch_size=enc_patch_size[i],
                        num_heads=enc_num_head[i],
                        mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias,
                        qk_scale=qk_scale,
                        attn_drop=attn_drop,
                        proj_drop=proj_drop,
                        drop_path=drop_path,
                        order=self.orders[i % len(self.orders)],
                        attn_type=attn_type,
                    )
                    for _ in range(enc_depths[i])
                ]
            )
            encs.append(level_blocks)

            if i < num_level - 1:
                down_convs.append(
                    Sequential(
                        nn.Linear(enc_channels[i], enc_channels[i + 1]),
                        SparseMaxPool(
                            kernel_size=2,
                            stride=2,
                        ),
                        nn.BatchNorm1d(enc_channels[i + 1]),
                        nn.GELU(),
                    )
                )

        decs = nn.ModuleList()
        up_convs = nn.ModuleList()
        dec_channels_list = list(dec_channels) + [enc_channels[-1]]
        for i in reversed(range(num_level - 1)):
            up_convs.append(
                TupleSequential(
                    nn.Linear(dec_channels_list[i + 1], dec_channels_list[i]),
                    SparseUnpool(
                        kernel_size=2,
                        stride=2,
                        concat_unpooled_st=True,
                    ),
                    nn.Linear(dec_channels_list[i] + enc_channels[i], dec_channels_list[i]),
                    nn.BatchNorm1d(dec_channels_list[i]),
                    nn.GELU(),
                    tuple_layer=1,
                )
            )
            level_blocks = nn.ModuleList(
                [
                    PatchAttentionBlock(
                        in_channels=dec_channels_list[i],
                        attention_channels=dec_channels_list[i],
                        patch_size=dec_patch_size[i],
                        num_heads=dec_num_head[i],
                        mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias,
                        qk_scale=qk_scale,
                        attn_drop=attn_drop,
                        proj_drop=proj_drop,
                        drop_path=drop_path,
                        order=self.orders[i % len(self.orders)],
                        attn_type=attn_type,
                    )
                    for _ in range(dec_depths[i])
                ]
            )
            decs.append(level_blocks)

        self.encs = encs
        self.down_convs = down_convs
        self.decs = decs
        self.up_convs = up_convs

        out_channels = kwargs.get("out_channels")
        if out_channels is not None:
            self.out_channels = out_channels
            self.final = Linear(dec_channels_list[0], out_channels)
        else:
            self.final = nn.Identity()

    def _select_order(self, blk_idx: int) -> POINT_ORDERING:
        """Selects the point ordering for a block.

        Use `torch.manual_seed` to control randomness.
        """
        if self.shuffle_orders:
            idx = torch.randint(0, len(self.orders), (1,)).item()
            return self.orders[idx]
        return self.orders[blk_idx % len(self.orders)]

    def forward(self, x: Geometry) -> Geometry:
        x = self.conv(x)
        skips = []

        blk_idx = 0
        # Encoder
        for level in range(self.num_level):
            # Process each block individually in this level
            level_blocks = self.encs[level]
            for block in level_blocks.children():
                selected_order = self._select_order(blk_idx)
                x = block(x, selected_order)
                blk_idx += 1

            if level < self.num_level - 1:
                skips.append(x)
                x = self.down_convs[level](x)

        # Decoder
        for level in range(self.num_level - 1):
            x = self.up_convs[level](x, skips[-(level + 1)])

            level_blocks = self.decs[level]
            for block in level_blocks.children():
                selected_order = self._select_order(blk_idx)
                x = block(x, selected_order)
                blk_idx += 1

        return self.final(x)

torch.manual_seed(0)

x = Points([torch.randn(100, 3)], [torch.randn(100, 6)]).to_voxels(0.05).to('cuda')
model = PointTransformerV3().to('cuda')

model = model.eval()
out = model(x)

Running the above python script leads to the NoneType error.

Sanity check, add

del x
del pred

to the end of the file, there is no error, as the deallocation is guaranteed to happen before torch is freed.

In the deallocator function, adding the check whether torch is None also fixes this issue, as all memory will be deallocated as the program exits anyway.

…pe as the program exits

Signed-off-by: Kasparas Karlauskas <[email protected]>
@chrischoy chrischoy merged commit 4048280 into NVlabs:main Oct 9, 2025
@kasparas-k kasparas-k deleted the fix_deallocator_on_exit_global branch October 9, 2025 18:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants