Skip to content

Conversation

@cocotdf
Copy link
Contributor

@cocotdf cocotdf commented Sep 19, 2025

Description

Free cuFFT plans and associated GPU workspaces when the CuFFTPlanCache is destroyed. A destructor (~CuFFTPlanCache()) now calls Clear(). Clear() destroys cuFFT plans and frees workspace memory (calls cufftDestroy), then clears the internal plan map.

Motivation and Context

When creating and destroying ONNX Runtime sessions that use the CUDA execution provider and cuFFT-based nodes, cuFFT plans and their GPU workspaces could remain allocated across session lifetimes. This can produce an increasing GPU memory footprint when sessions are repeatedly opened and closed. The change ensures internal cuFFT resources are released during cache cleanup, preventing GPU memory leaks in multi-session or repeated create/destroy scenarios.

How to reproduce (minimal repro)

The following Python script builds a minimal ONNX model in memory (RFFT -> IRFFT round-trip), repeatedly creates and destroys ONNX Runtime CUDA sessions, and prints GPU memory after each session close. Use this to observe a memory increase before the fix and stable memory after the fix.

Dependencies

  • Python 3.8+
  • onnx
  • onnxruntime-gpu
  • cupy matching your CUDA (example package names: cupy-cuda12x, cupy-cuda11x depending on CUDA)
  • numpy
# leak_repro_fft.py
# Minimal repro: build an ONNX model (Rfft -> Irfft round-trip), run many sessions
# and print GPU memory used after each session close.

import gc
import numpy as np
import onnx
import onnx.helper as oh
import onnxruntime as ort

try:
    import cupy as cp
except Exception as e:
    raise RuntimeError("CuPy is required to measure GPU memory. Install cupy for your CUDA version.") from e

# ---------- helpers to create MS Rfft / Irfft nodes ----------
def make_ms_rfft_node(inp, out, signal_ndim=1):
    return oh.make_node(
        "Rfft", [inp], [out],
        domain="com.microsoft",
        onesided=1, normalized=0, signal_ndim=signal_ndim
    )

def make_ms_irfft_node(inp, out, signal_ndim=1):
    return oh.make_node(
        "Irfft", [inp], [out],
        domain="com.microsoft",
        onesided=1, normalized=0, signal_ndim=signal_ndim
    )

def build_model_fft_ifft_complex():
    """
    Input: X_ab [2, N] (float32)
    Graph: RFFT -> IRFFT (round-trip)
    Output: Y_ab [2, N] (float32)
    """
    X = oh.make_tensor_value_info("X_ab", onnx.TensorProto.FLOAT, [2, None])
    Y = oh.make_tensor_value_info("Y_ab", onnx.TensorProto.FLOAT, [2, None])

    nodes = []
    nodes.append(make_ms_rfft_node("X_ab", "R_ab", signal_ndim=1))   # [2, N//2+1, 2]
    nodes.append(make_ms_irfft_node("R_ab", "Y_ab", signal_ndim=1))  # [2, N]
    graph = oh.make_graph(nodes, "complex_fft_ifft", [X], [Y])
    model = oh.make_model(
        graph,
        opset_imports=[
            oh.make_operatorsetid("", 20),
            oh.make_operatorsetid("com.microsoft", 1),
        ],
        ir_version=10,
        producer_name="leak_repro_complex_fft_ifft"
    )
    return model

# ---------- utility to probe GPU memory ----------
def gpu_used_bytes():
    free, total = cp.cuda.runtime.memGetInfo()
    return int(total - free), int(total)

# ---------- main loop: create/close sessions ----------
def run_repro(iters=20, N=2**22, provider="CUDAExecutionProvider"):
    # prepare input (avoid host reallocation between iterations)
    rng = np.random.default_rng(1234)
    a = rng.standard_normal(N).astype(np.float32)
    b = rng.standard_normal(N).astype(np.float32)
    x_ab = np.stack((a, b), axis=0)  # shape [2, N]

    # check provider availability
    providers = ort.get_available_providers()
    if provider not in providers:
        raise RuntimeError(f"{provider} not available (providers: {providers})")

    model = build_model_fft_ifft_complex()
    model_bytes = model.SerializeToString()

    # baseline
    cp.cuda.Device().synchronize()
    used0, total0 = gpu_used_bytes()
    print(f"Baseline GPU used: {used0/1024**2:8.2f} MB / {total0/1024**2:8.2f} MB total")

    for i in range(1, iters + 1):
        # create session from bytes
        sess = ort.InferenceSession(model_bytes, sess_options=ort.SessionOptions(), providers=[provider])

        # run once
        _ = sess.run(None, {"X_ab": x_ab})

        # ensure device completed
        cp.cuda.Device().synchronize()

        # delete session and force GC
        del sess
        gc.collect()
        cp.cuda.Device().synchronize()

        used, _ = gpu_used_bytes()
        print(f"Iter {i:02d}: GPU used {used/1024**2:8.2f} MB")

    # final baseline
    cp.cuda.Device().synchronize()
    usedf, _ = gpu_used_bytes()
    print(f"Final GPU used: {usedf/1024**2:8.2f} MB")
    print("Done.")

if __name__ == "__main__":
    # tweak iter and N to show leak on your machine
    run_repro(iters=5, N=2**22)
Example ouptut (before fix)
Baseline GPU used:  3105.56 MB /  8191.56 MB total
Iter 01: GPU used  3173.56 MB
Iter 02: GPU used  3241.56 MB
Iter 03: GPU used  3309.56 MB
Iter 04: GPU used  3377.56 MB
Iter 05: GPU used  3445.56 MB
Final GPU used:  3445.56 MB
Done.

@yuslepukhin yuslepukhin requested a review from Copilot October 2, 2025 23:07
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds proper memory management to the CuFFTPlanCache class by implementing a destructor and Clear() method to release cuFFT plans and GPU workspace memory. This prevents GPU memory leaks when ONNX Runtime sessions are repeatedly created and destroyed.

  • Adds destructor that calls Clear() to ensure cleanup when cache is destroyed
  • Implements Clear() method to properly destroy cuFFT plans and free GPU workspace memory
  • Ensures thread-safe cleanup by using mutex protection

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

@yuslepukhin yuslepukhin requested a review from tianleiwu October 2, 2025 23:07
@yuslepukhin
Copy link
Member

LGTM

@tianleiwu
Copy link
Contributor

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@tianleiwu
Copy link
Contributor

@cocotdf, please merge latest main branch to pass CI checks.

@tianleiwu
Copy link
Contributor

/azp run Linux QNN CI Pipeline,Win_TRT_Minimal_CUDA_Test_CI,Windows ARM64 QNN CI Pipeline,Windows GPU Doc Gen CI Pipeline

@azure-pipelines
Copy link

Azure Pipelines successfully started running 4 pipeline(s).

@tianleiwu tianleiwu merged commit 50317ff into microsoft:main Oct 21, 2025
89 of 90 checks passed
JonathanC-ARM pushed a commit to JonathanC-ARM/onnxruntime that referenced this pull request Oct 24, 2025
…oft#26098)

### Description
Free cuFFT plans and associated GPU workspaces when the CuFFTPlanCache
is destroyed. A destructor (`~CuFFTPlanCache()`) now calls `Clear()`.
`Clear()` destroys cuFFT plans and frees workspace memory (calls
`cufftDestroy`), then clears the internal plan map.

### Motivation and Context
When creating and destroying ONNX Runtime sessions that use the CUDA
execution provider and cuFFT-based nodes, cuFFT plans and their GPU
workspaces could remain allocated across session lifetimes. This can
produce an increasing GPU memory footprint when sessions are repeatedly
opened and closed. The change ensures internal cuFFT resources are
released during cache cleanup, preventing GPU memory leaks in
multi-session or repeated create/destroy scenarios.

### How to reproduce (minimal repro)
The following Python script builds a minimal ONNX model in memory (RFFT
-> IRFFT round-trip), repeatedly creates and destroys ONNX Runtime CUDA
sessions, and prints GPU memory after each session close. Use this to
observe a memory increase before the fix and stable memory after the
fix.

**Dependencies**
- Python 3.8+
- `onnx`
- `onnxruntime-gpu` 
- `cupy` matching your CUDA (example package names: `cupy-cuda12x`,
`cupy-cuda11x` depending on CUDA)
- `numpy`

```python
# leak_repro_fft.py
# Minimal repro: build an ONNX model (Rfft -> Irfft round-trip), run many sessions
# and print GPU memory used after each session close.

import gc
import numpy as np
import onnx
import onnx.helper as oh
import onnxruntime as ort

try:
    import cupy as cp
except Exception as e:
    raise RuntimeError("CuPy is required to measure GPU memory. Install cupy for your CUDA version.") from e

# ---------- helpers to create MS Rfft / Irfft nodes ----------
def make_ms_rfft_node(inp, out, signal_ndim=1):
    return oh.make_node(
        "Rfft", [inp], [out],
        domain="com.microsoft",
        onesided=1, normalized=0, signal_ndim=signal_ndim
    )

def make_ms_irfft_node(inp, out, signal_ndim=1):
    return oh.make_node(
        "Irfft", [inp], [out],
        domain="com.microsoft",
        onesided=1, normalized=0, signal_ndim=signal_ndim
    )

def build_model_fft_ifft_complex():
    """
    Input: X_ab [2, N] (float32)
    Graph: RFFT -> IRFFT (round-trip)
    Output: Y_ab [2, N] (float32)
    """
    X = oh.make_tensor_value_info("X_ab", onnx.TensorProto.FLOAT, [2, None])
    Y = oh.make_tensor_value_info("Y_ab", onnx.TensorProto.FLOAT, [2, None])

    nodes = []
    nodes.append(make_ms_rfft_node("X_ab", "R_ab", signal_ndim=1))   # [2, N//2+1, 2]
    nodes.append(make_ms_irfft_node("R_ab", "Y_ab", signal_ndim=1))  # [2, N]
    graph = oh.make_graph(nodes, "complex_fft_ifft", [X], [Y])
    model = oh.make_model(
        graph,
        opset_imports=[
            oh.make_operatorsetid("", 20),
            oh.make_operatorsetid("com.microsoft", 1),
        ],
        ir_version=10,
        producer_name="leak_repro_complex_fft_ifft"
    )
    return model

# ---------- utility to probe GPU memory ----------
def gpu_used_bytes():
    free, total = cp.cuda.runtime.memGetInfo()
    return int(total - free), int(total)

# ---------- main loop: create/close sessions ----------
def run_repro(iters=20, N=2**22, provider="CUDAExecutionProvider"):
    # prepare input (avoid host reallocation between iterations)
    rng = np.random.default_rng(1234)
    a = rng.standard_normal(N).astype(np.float32)
    b = rng.standard_normal(N).astype(np.float32)
    x_ab = np.stack((a, b), axis=0)  # shape [2, N]

    # check provider availability
    providers = ort.get_available_providers()
    if provider not in providers:
        raise RuntimeError(f"{provider} not available (providers: {providers})")

    model = build_model_fft_ifft_complex()
    model_bytes = model.SerializeToString()

    # baseline
    cp.cuda.Device().synchronize()
    used0, total0 = gpu_used_bytes()
    print(f"Baseline GPU used: {used0/1024**2:8.2f} MB / {total0/1024**2:8.2f} MB total")

    for i in range(1, iters + 1):
        # create session from bytes
        sess = ort.InferenceSession(model_bytes, sess_options=ort.SessionOptions(), providers=[provider])

        # run once
        _ = sess.run(None, {"X_ab": x_ab})

        # ensure device completed
        cp.cuda.Device().synchronize()

        # delete session and force GC
        del sess
        gc.collect()
        cp.cuda.Device().synchronize()

        used, _ = gpu_used_bytes()
        print(f"Iter {i:02d}: GPU used {used/1024**2:8.2f} MB")

    # final baseline
    cp.cuda.Device().synchronize()
    usedf, _ = gpu_used_bytes()
    print(f"Final GPU used: {usedf/1024**2:8.2f} MB")
    print("Done.")

if __name__ == "__main__":
    # tweak iter and N to show leak on your machine
    run_repro(iters=5, N=2**22)
```

```text
Example ouptut (before fix)
Baseline GPU used:  3105.56 MB /  8191.56 MB total
Iter 01: GPU used  3173.56 MB
Iter 02: GPU used  3241.56 MB
Iter 03: GPU used  3309.56 MB
Iter 04: GPU used  3377.56 MB
Iter 05: GPU used  3445.56 MB
Final GPU used:  3445.56 MB
Done.
fs-eire pushed a commit that referenced this pull request Oct 24, 2025
### Description
Free cuFFT plans and associated GPU workspaces when the CuFFTPlanCache
is destroyed. A destructor (`~CuFFTPlanCache()`) now calls `Clear()`.
`Clear()` destroys cuFFT plans and frees workspace memory (calls
`cufftDestroy`), then clears the internal plan map.

### Motivation and Context
When creating and destroying ONNX Runtime sessions that use the CUDA
execution provider and cuFFT-based nodes, cuFFT plans and their GPU
workspaces could remain allocated across session lifetimes. This can
produce an increasing GPU memory footprint when sessions are repeatedly
opened and closed. The change ensures internal cuFFT resources are
released during cache cleanup, preventing GPU memory leaks in
multi-session or repeated create/destroy scenarios.

### How to reproduce (minimal repro)
The following Python script builds a minimal ONNX model in memory (RFFT
-> IRFFT round-trip), repeatedly creates and destroys ONNX Runtime CUDA
sessions, and prints GPU memory after each session close. Use this to
observe a memory increase before the fix and stable memory after the
fix.

**Dependencies**
- Python 3.8+
- `onnx`
- `onnxruntime-gpu` 
- `cupy` matching your CUDA (example package names: `cupy-cuda12x`,
`cupy-cuda11x` depending on CUDA)
- `numpy`

```python
# leak_repro_fft.py
# Minimal repro: build an ONNX model (Rfft -> Irfft round-trip), run many sessions
# and print GPU memory used after each session close.

import gc
import numpy as np
import onnx
import onnx.helper as oh
import onnxruntime as ort

try:
    import cupy as cp
except Exception as e:
    raise RuntimeError("CuPy is required to measure GPU memory. Install cupy for your CUDA version.") from e

# ---------- helpers to create MS Rfft / Irfft nodes ----------
def make_ms_rfft_node(inp, out, signal_ndim=1):
    return oh.make_node(
        "Rfft", [inp], [out],
        domain="com.microsoft",
        onesided=1, normalized=0, signal_ndim=signal_ndim
    )

def make_ms_irfft_node(inp, out, signal_ndim=1):
    return oh.make_node(
        "Irfft", [inp], [out],
        domain="com.microsoft",
        onesided=1, normalized=0, signal_ndim=signal_ndim
    )

def build_model_fft_ifft_complex():
    """
    Input: X_ab [2, N] (float32)
    Graph: RFFT -> IRFFT (round-trip)
    Output: Y_ab [2, N] (float32)
    """
    X = oh.make_tensor_value_info("X_ab", onnx.TensorProto.FLOAT, [2, None])
    Y = oh.make_tensor_value_info("Y_ab", onnx.TensorProto.FLOAT, [2, None])

    nodes = []
    nodes.append(make_ms_rfft_node("X_ab", "R_ab", signal_ndim=1))   # [2, N//2+1, 2]
    nodes.append(make_ms_irfft_node("R_ab", "Y_ab", signal_ndim=1))  # [2, N]
    graph = oh.make_graph(nodes, "complex_fft_ifft", [X], [Y])
    model = oh.make_model(
        graph,
        opset_imports=[
            oh.make_operatorsetid("", 20),
            oh.make_operatorsetid("com.microsoft", 1),
        ],
        ir_version=10,
        producer_name="leak_repro_complex_fft_ifft"
    )
    return model

# ---------- utility to probe GPU memory ----------
def gpu_used_bytes():
    free, total = cp.cuda.runtime.memGetInfo()
    return int(total - free), int(total)

# ---------- main loop: create/close sessions ----------
def run_repro(iters=20, N=2**22, provider="CUDAExecutionProvider"):
    # prepare input (avoid host reallocation between iterations)
    rng = np.random.default_rng(1234)
    a = rng.standard_normal(N).astype(np.float32)
    b = rng.standard_normal(N).astype(np.float32)
    x_ab = np.stack((a, b), axis=0)  # shape [2, N]

    # check provider availability
    providers = ort.get_available_providers()
    if provider not in providers:
        raise RuntimeError(f"{provider} not available (providers: {providers})")

    model = build_model_fft_ifft_complex()
    model_bytes = model.SerializeToString()

    # baseline
    cp.cuda.Device().synchronize()
    used0, total0 = gpu_used_bytes()
    print(f"Baseline GPU used: {used0/1024**2:8.2f} MB / {total0/1024**2:8.2f} MB total")

    for i in range(1, iters + 1):
        # create session from bytes
        sess = ort.InferenceSession(model_bytes, sess_options=ort.SessionOptions(), providers=[provider])

        # run once
        _ = sess.run(None, {"X_ab": x_ab})

        # ensure device completed
        cp.cuda.Device().synchronize()

        # delete session and force GC
        del sess
        gc.collect()
        cp.cuda.Device().synchronize()

        used, _ = gpu_used_bytes()
        print(f"Iter {i:02d}: GPU used {used/1024**2:8.2f} MB")

    # final baseline
    cp.cuda.Device().synchronize()
    usedf, _ = gpu_used_bytes()
    print(f"Final GPU used: {usedf/1024**2:8.2f} MB")
    print("Done.")

if __name__ == "__main__":
    # tweak iter and N to show leak on your machine
    run_repro(iters=5, N=2**22)
```

```text
Example ouptut (before fix)
Baseline GPU used:  3105.56 MB /  8191.56 MB total
Iter 01: GPU used  3173.56 MB
Iter 02: GPU used  3241.56 MB
Iter 03: GPU used  3309.56 MB
Iter 04: GPU used  3377.56 MB
Iter 05: GPU used  3445.56 MB
Final GPU used:  3445.56 MB
Done.
quic-tirupath pushed a commit to CodeLinaro/onnxruntime that referenced this pull request Oct 27, 2025
…oft#26098)

### Description
Free cuFFT plans and associated GPU workspaces when the CuFFTPlanCache
is destroyed. A destructor (`~CuFFTPlanCache()`) now calls `Clear()`.
`Clear()` destroys cuFFT plans and frees workspace memory (calls
`cufftDestroy`), then clears the internal plan map.

### Motivation and Context
When creating and destroying ONNX Runtime sessions that use the CUDA
execution provider and cuFFT-based nodes, cuFFT plans and their GPU
workspaces could remain allocated across session lifetimes. This can
produce an increasing GPU memory footprint when sessions are repeatedly
opened and closed. The change ensures internal cuFFT resources are
released during cache cleanup, preventing GPU memory leaks in
multi-session or repeated create/destroy scenarios.

### How to reproduce (minimal repro)
The following Python script builds a minimal ONNX model in memory (RFFT
-> IRFFT round-trip), repeatedly creates and destroys ONNX Runtime CUDA
sessions, and prints GPU memory after each session close. Use this to
observe a memory increase before the fix and stable memory after the
fix.

**Dependencies**
- Python 3.8+
- `onnx`
- `onnxruntime-gpu` 
- `cupy` matching your CUDA (example package names: `cupy-cuda12x`,
`cupy-cuda11x` depending on CUDA)
- `numpy`

```python
# leak_repro_fft.py
# Minimal repro: build an ONNX model (Rfft -> Irfft round-trip), run many sessions
# and print GPU memory used after each session close.

import gc
import numpy as np
import onnx
import onnx.helper as oh
import onnxruntime as ort

try:
    import cupy as cp
except Exception as e:
    raise RuntimeError("CuPy is required to measure GPU memory. Install cupy for your CUDA version.") from e

# ---------- helpers to create MS Rfft / Irfft nodes ----------
def make_ms_rfft_node(inp, out, signal_ndim=1):
    return oh.make_node(
        "Rfft", [inp], [out],
        domain="com.microsoft",
        onesided=1, normalized=0, signal_ndim=signal_ndim
    )

def make_ms_irfft_node(inp, out, signal_ndim=1):
    return oh.make_node(
        "Irfft", [inp], [out],
        domain="com.microsoft",
        onesided=1, normalized=0, signal_ndim=signal_ndim
    )

def build_model_fft_ifft_complex():
    """
    Input: X_ab [2, N] (float32)
    Graph: RFFT -> IRFFT (round-trip)
    Output: Y_ab [2, N] (float32)
    """
    X = oh.make_tensor_value_info("X_ab", onnx.TensorProto.FLOAT, [2, None])
    Y = oh.make_tensor_value_info("Y_ab", onnx.TensorProto.FLOAT, [2, None])

    nodes = []
    nodes.append(make_ms_rfft_node("X_ab", "R_ab", signal_ndim=1))   # [2, N//2+1, 2]
    nodes.append(make_ms_irfft_node("R_ab", "Y_ab", signal_ndim=1))  # [2, N]
    graph = oh.make_graph(nodes, "complex_fft_ifft", [X], [Y])
    model = oh.make_model(
        graph,
        opset_imports=[
            oh.make_operatorsetid("", 20),
            oh.make_operatorsetid("com.microsoft", 1),
        ],
        ir_version=10,
        producer_name="leak_repro_complex_fft_ifft"
    )
    return model

# ---------- utility to probe GPU memory ----------
def gpu_used_bytes():
    free, total = cp.cuda.runtime.memGetInfo()
    return int(total - free), int(total)

# ---------- main loop: create/close sessions ----------
def run_repro(iters=20, N=2**22, provider="CUDAExecutionProvider"):
    # prepare input (avoid host reallocation between iterations)
    rng = np.random.default_rng(1234)
    a = rng.standard_normal(N).astype(np.float32)
    b = rng.standard_normal(N).astype(np.float32)
    x_ab = np.stack((a, b), axis=0)  # shape [2, N]

    # check provider availability
    providers = ort.get_available_providers()
    if provider not in providers:
        raise RuntimeError(f"{provider} not available (providers: {providers})")

    model = build_model_fft_ifft_complex()
    model_bytes = model.SerializeToString()

    # baseline
    cp.cuda.Device().synchronize()
    used0, total0 = gpu_used_bytes()
    print(f"Baseline GPU used: {used0/1024**2:8.2f} MB / {total0/1024**2:8.2f} MB total")

    for i in range(1, iters + 1):
        # create session from bytes
        sess = ort.InferenceSession(model_bytes, sess_options=ort.SessionOptions(), providers=[provider])

        # run once
        _ = sess.run(None, {"X_ab": x_ab})

        # ensure device completed
        cp.cuda.Device().synchronize()

        # delete session and force GC
        del sess
        gc.collect()
        cp.cuda.Device().synchronize()

        used, _ = gpu_used_bytes()
        print(f"Iter {i:02d}: GPU used {used/1024**2:8.2f} MB")

    # final baseline
    cp.cuda.Device().synchronize()
    usedf, _ = gpu_used_bytes()
    print(f"Final GPU used: {usedf/1024**2:8.2f} MB")
    print("Done.")

if __name__ == "__main__":
    # tweak iter and N to show leak on your machine
    run_repro(iters=5, N=2**22)
```

```text
Example ouptut (before fix)
Baseline GPU used:  3105.56 MB /  8191.56 MB total
Iter 01: GPU used  3173.56 MB
Iter 02: GPU used  3241.56 MB
Iter 03: GPU used  3309.56 MB
Iter 04: GPU used  3377.56 MB
Iter 05: GPU used  3445.56 MB
Final GPU used:  3445.56 MB
Done.
@cocotdf cocotdf deleted the fix_fft_memory_leak branch October 28, 2025 10:37
naomiOvad pushed a commit to naomiOvad/onnxruntime that referenced this pull request Nov 2, 2025
…oft#26098)

### Description
Free cuFFT plans and associated GPU workspaces when the CuFFTPlanCache
is destroyed. A destructor (`~CuFFTPlanCache()`) now calls `Clear()`.
`Clear()` destroys cuFFT plans and frees workspace memory (calls
`cufftDestroy`), then clears the internal plan map.

### Motivation and Context
When creating and destroying ONNX Runtime sessions that use the CUDA
execution provider and cuFFT-based nodes, cuFFT plans and their GPU
workspaces could remain allocated across session lifetimes. This can
produce an increasing GPU memory footprint when sessions are repeatedly
opened and closed. The change ensures internal cuFFT resources are
released during cache cleanup, preventing GPU memory leaks in
multi-session or repeated create/destroy scenarios.

### How to reproduce (minimal repro)
The following Python script builds a minimal ONNX model in memory (RFFT
-> IRFFT round-trip), repeatedly creates and destroys ONNX Runtime CUDA
sessions, and prints GPU memory after each session close. Use this to
observe a memory increase before the fix and stable memory after the
fix.

**Dependencies**
- Python 3.8+
- `onnx`
- `onnxruntime-gpu` 
- `cupy` matching your CUDA (example package names: `cupy-cuda12x`,
`cupy-cuda11x` depending on CUDA)
- `numpy`

```python
# leak_repro_fft.py
# Minimal repro: build an ONNX model (Rfft -> Irfft round-trip), run many sessions
# and print GPU memory used after each session close.

import gc
import numpy as np
import onnx
import onnx.helper as oh
import onnxruntime as ort

try:
    import cupy as cp
except Exception as e:
    raise RuntimeError("CuPy is required to measure GPU memory. Install cupy for your CUDA version.") from e

# ---------- helpers to create MS Rfft / Irfft nodes ----------
def make_ms_rfft_node(inp, out, signal_ndim=1):
    return oh.make_node(
        "Rfft", [inp], [out],
        domain="com.microsoft",
        onesided=1, normalized=0, signal_ndim=signal_ndim
    )

def make_ms_irfft_node(inp, out, signal_ndim=1):
    return oh.make_node(
        "Irfft", [inp], [out],
        domain="com.microsoft",
        onesided=1, normalized=0, signal_ndim=signal_ndim
    )

def build_model_fft_ifft_complex():
    """
    Input: X_ab [2, N] (float32)
    Graph: RFFT -> IRFFT (round-trip)
    Output: Y_ab [2, N] (float32)
    """
    X = oh.make_tensor_value_info("X_ab", onnx.TensorProto.FLOAT, [2, None])
    Y = oh.make_tensor_value_info("Y_ab", onnx.TensorProto.FLOAT, [2, None])

    nodes = []
    nodes.append(make_ms_rfft_node("X_ab", "R_ab", signal_ndim=1))   # [2, N//2+1, 2]
    nodes.append(make_ms_irfft_node("R_ab", "Y_ab", signal_ndim=1))  # [2, N]
    graph = oh.make_graph(nodes, "complex_fft_ifft", [X], [Y])
    model = oh.make_model(
        graph,
        opset_imports=[
            oh.make_operatorsetid("", 20),
            oh.make_operatorsetid("com.microsoft", 1),
        ],
        ir_version=10,
        producer_name="leak_repro_complex_fft_ifft"
    )
    return model

# ---------- utility to probe GPU memory ----------
def gpu_used_bytes():
    free, total = cp.cuda.runtime.memGetInfo()
    return int(total - free), int(total)

# ---------- main loop: create/close sessions ----------
def run_repro(iters=20, N=2**22, provider="CUDAExecutionProvider"):
    # prepare input (avoid host reallocation between iterations)
    rng = np.random.default_rng(1234)
    a = rng.standard_normal(N).astype(np.float32)
    b = rng.standard_normal(N).astype(np.float32)
    x_ab = np.stack((a, b), axis=0)  # shape [2, N]

    # check provider availability
    providers = ort.get_available_providers()
    if provider not in providers:
        raise RuntimeError(f"{provider} not available (providers: {providers})")

    model = build_model_fft_ifft_complex()
    model_bytes = model.SerializeToString()

    # baseline
    cp.cuda.Device().synchronize()
    used0, total0 = gpu_used_bytes()
    print(f"Baseline GPU used: {used0/1024**2:8.2f} MB / {total0/1024**2:8.2f} MB total")

    for i in range(1, iters + 1):
        # create session from bytes
        sess = ort.InferenceSession(model_bytes, sess_options=ort.SessionOptions(), providers=[provider])

        # run once
        _ = sess.run(None, {"X_ab": x_ab})

        # ensure device completed
        cp.cuda.Device().synchronize()

        # delete session and force GC
        del sess
        gc.collect()
        cp.cuda.Device().synchronize()

        used, _ = gpu_used_bytes()
        print(f"Iter {i:02d}: GPU used {used/1024**2:8.2f} MB")

    # final baseline
    cp.cuda.Device().synchronize()
    usedf, _ = gpu_used_bytes()
    print(f"Final GPU used: {usedf/1024**2:8.2f} MB")
    print("Done.")

if __name__ == "__main__":
    # tweak iter and N to show leak on your machine
    run_repro(iters=5, N=2**22)
```

```text
Example ouptut (before fix)
Baseline GPU used:  3105.56 MB /  8191.56 MB total
Iter 01: GPU used  3173.56 MB
Iter 02: GPU used  3241.56 MB
Iter 03: GPU used  3309.56 MB
Iter 04: GPU used  3377.56 MB
Iter 05: GPU used  3445.56 MB
Final GPU used:  3445.56 MB
Done.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants