|
2 | 2 | import os
|
3 | 3 | import sys
|
4 | 4 |
|
| 5 | +from .internal_utils import * |
| 6 | + |
| 7 | +CUDA_ALLOCATOR_ENV_WARNING_STR = """ |
| 8 | +An experimental feature for CUDA allocations is turned on for better allocation |
| 9 | +pattern resulting in better memory usage for minibatch GNN training workloads. |
| 10 | +See https://pytorch.org/docs/stable/notes/cuda.html#optimizing-memory-usage-with-pytorch-cuda-alloc-conf, |
| 11 | +and set the environment variable `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:False` |
| 12 | +if you want to disable it. |
| 13 | +""" |
| 14 | +cuda_allocator_env = os.getenv("PYTORCH_CUDA_ALLOC_CONF") |
| 15 | +if cuda_allocator_env is None: |
| 16 | + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" |
| 17 | + gb_warning(CUDA_ALLOCATOR_ENV_WARNING_STR) |
| 18 | +else: |
| 19 | + configs = { |
| 20 | + kv_pair.split(":")[0]: kv_pair.split(":")[1] |
| 21 | + for kv_pair in cuda_allocator_env.split(",") |
| 22 | + } |
| 23 | + if "expandable_segments" in configs: |
| 24 | + if configs["expandable_segments"] != "True": |
| 25 | + gb_warning( |
| 26 | + "You should consider `expandable_segments:True` in the" |
| 27 | + " environment variable `PYTORCH_CUDA_ALLOC_CONF` for lower" |
| 28 | + " memory usage. See " |
| 29 | + "https://pytorch.org/docs/stable/notes/cuda.html" |
| 30 | + "#optimizing-memory-usage-with-pytorch-cuda-alloc-conf" |
| 31 | + ) |
| 32 | + else: |
| 33 | + configs["expandable_segments"] = "True" |
| 34 | + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ",".join( |
| 35 | + [k + ":" + v for k, v in configs.items()] |
| 36 | + ) |
| 37 | + gb_warning(CUDA_ALLOCATOR_ENV_WARNING_STR) |
| 38 | + |
| 39 | + |
| 40 | +# pylint: disable=wrong-import-position, wrong-import-order |
5 | 41 | import torch
|
6 | 42 |
|
7 | 43 | ### FROM DGL @todo
|
@@ -47,7 +83,6 @@ def load_graphbolt():
|
47 | 83 | from .itemset import *
|
48 | 84 | from .item_sampler import *
|
49 | 85 | from .minibatch_transformer import *
|
50 |
| -from .internal_utils import * |
51 | 86 | from .negative_sampler import *
|
52 | 87 | from .sampled_subgraph import *
|
53 | 88 | from .subgraph_sampler import *
|
|
0 commit comments