Skip to content

Commit

Permalink
Fix PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
BulatVakhitov committed Jun 28, 2024
1 parent f3d5886 commit fbfab2b
Showing 1 changed file with 181 additions and 123 deletions.
304 changes: 181 additions & 123 deletions batchflow/models/torch/base_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,153 +223,200 @@ class OptimalBatchSizeMixin:
For stable measurements, we make `n` iterations of `train`/`predict`, until the memory consumption stabilizes.
"""

def is_cuda_out_of_memory(self, exception):
""" Check if exception is CUDA OOM """
return (
isinstance(exception, RuntimeError)
and len(exception.args) == 1
and "CUDA" in exception.args[0]
and "out of memory" in exception.args[0]
)

# not sure if it's necessary
def is_cudnn_snafu(self, exception):
return (
isinstance(exception, RuntimeError)
and len(exception.args) == 1
and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0]
)

def is_out_of_cpu_memory(self, exception):
def is_oom_error(self, exception):
""" Check whether exception is OOM error """
if not (isinstance(exception, RuntimeError) and len(exception.args) == 1):
return False
return (
isinstance(exception, RuntimeError)
and len(exception.args) == 1
and "DefaultCPUAllocator: can't allocate memory" in exception.args[0]
"cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0]
or "DefaultCPUAllocator: can't allocate memory" in exception.args[0]
or ("CUDA" in exception.args[0] and "out of memory" in exception.args[0])
)

def is_oom_error(self, exception):
return (self.is_cuda_out_of_memory(exception) or
self.is_cudnn_snafu(exception) or
self.is_out_of_cpu_memory(exception))

def garbage_collection_cuda(self):
""" Garbage collection Torch (CUDA) memory. """
gc.collect()
try:
# This is the last thing that should cause an OOM error, but seemingly it can.
torch.cuda.empty_cache()
except RuntimeError as exception:
if not self.is_oom_error(exception):
# Only handle OOM errors
raise


def _compute_optimal_batch_size(self, method='train', inputs=None, targets=None,
start_batch_size=4, max_iters=25, factor=2,
spread=0.2, use_estimation=False, n=4,
max_memory=100, pbar='n', tail_size=20,
frequency=0.05, delta_batch_size=16,
max_batch_size=1024, max_iters_estimation=2):
""" Compute memory usage for multiple batch sizes. """
def compute_optimal_batch_size(self, method='train', inputs=None, targets=None,
start_batch_size=4, max_iters=25, factor=2,
spread=0.2, estimation_method='bruteforce', n=4,
max_memory=100, pbar='n', tail_size=20,
frequency=0.05, delta_batch_size=16,
max_batch_size=1024, max_iters_estimation=2):
""" Computes optimal batch size in two steps:
1. Calculate batch size estimation using `predictive` or `bruteforce` method.
2. Calculate exact optimal batch size using binary search.
Parameters
----------
method: str
Defines in which method (`train` or `predict`) the optimal batch size will
be computed. Default: `train`.
estimation_method: str
Whether `bruteforce` or `predictive` estimation method will be used.
inputs: np.array
The inputs to the model, that will be used in optimal batch computation.
If none, then the placeholder will be created
targets: np.array
The targets to the model, that will be used in optimal batch computation.
If none, then we run model.predict() to get targets. Only used in method=`train`
start_batch_size: int
Batch size to start batch_size estimation. If your model is small, you should use larger
start_batch_size in order to get accurate optimal batch_size
max_iters: int
Maximum number of binary search iterations.
factor: int
Value by which we multiply batch_size at each iteration. Uses in bruteforce estimation.
spread: int
Defines how wide the interval for binary search will be. Uses in predictive estimation.
n: int
For stable measurements, we make `n` iterations of `train`/`predict`,
until the memory consumption stabilizes. Used only in estimation_method='predictive'.
max_memory: int
In percent. Defines which portion of memory will be occupied by the model with optimal batch size.
Default: 100
pbar: str
The same as bar in Notifier
tail_size: int
How many items of gpu memory data will be used to compute mean memory utilization.
frequency: int
How often do we collect gpu memory data.
delta_batch_size: int
Step size of the change in batch_size at each iteration.
max_batch_size: int
Maximum batch_size
max_iters_estimation: int
Maximum number of batch_size estimation iterations.
Returns
-------
optimal_batch_size: int
Batch size that perfectly fits in max_memory
"""

# first calculate optimal batch_size estimation
if use_estimation:
batch_size_estimation = self.compute_optimal_batch_size(method=method, max_memory=max_memory,
inputs=inputs, targets=targets, pbar='n',
start_batch_size=start_batch_size,
delta_batch_size=delta_batch_size, n=n,
max_iters=max_iters_estimation,
max_batch_size=max_batch_size)
if estimation_method == 'predictive':
batch_size_estimation = self._compute_optimal_batch_size_predictive(method=method, max_memory=max_memory,
inputs=inputs, targets=targets,
start_batch_size=start_batch_size,
delta_batch_size=delta_batch_size,
max_iters=max_iters_estimation,
max_batch_size=max_batch_size,
pbar='n', n=n)
batch_size_estimation = batch_size_estimation['batch_size']
low = int(batch_size_estimation * (1 - spread))
high = int(batch_size_estimation * (1 + spread))
else:
low, high = self._run_power_scaling(inputs=inputs, targets=targets, factor=factor,
start_batch_size=start_batch_size, method=method,
max_memory=max_memory, pbar=pbar, tail_size=10,
frequency=frequency)
elif estimation_method == 'bruteforce':
low = self._compute_optimal_batch_size(inputs=inputs, targets=targets, factor=factor,
start_batch_size=start_batch_size, method=method,
max_memory=max_memory, pbar=pbar, tail_size=10,
frequency=frequency, update_method='bruteforce')
high = low * factor
batch_size_estimation = (low + high) // 2
else:
raise ValueError("Wrong estimation method! It could be `predictive` or `bruteforce`.")

# then run precise method in neighbourhood of batch_size_estimation
return self._run_binary_scaling(inputs=inputs, targets=targets,
start_batch_size=batch_size_estimation,
low=low, high=high, method=method,
max_iters=max_iters, tail_size=tail_size,
max_memory=max_memory, pbar=pbar,
frequency=frequency)


def _run_power_scaling(self, inputs=None, targets=None, factor=2,
start_batch_size=4, method='train', pbar='n',
tail_size=10, max_memory=100, frequency=0.05):
""" Returns `batch_size` and `batch_size * factor`, so that `batch_size`
fits in max_memory, while `batch_size * factor` does not."""
return self._compute_optimal_batch_size(inputs=inputs, targets=targets,
start_batch_size=batch_size_estimation,
low=low, high=high, method=method,
max_iters=max_iters, tail_size=tail_size,
max_memory=max_memory, pbar=pbar,
frequency=frequency, factor=factor,
update_method='binary')


def _bruteforce_batch_size_generator(self, factor, max_memory):
""" Calculates next batch size for bruteforce estimation method. If consumed memory is lower
than max_memory, then batch_size is multiplied by factor, otherwise it is divided by factor
Yields
------
new_batch_size, exit: tuple(int, bool)
New batch size to check, and exit condition whether the optimal
batch size computation is finished
"""

batch_size = start_batch_size if isinstance(start_batch_size, list) else [start_batch_size]
notifier = Notifier(n_iters=None, bar=pbar,
monitors=[{'source': batch_size, 'name': 'batch_size'}])
while True:
try:
notifier.update()
input = inputs or self.make_placeholder_data(batch_size[-1], to_device=False)
input = list(input) if isinstance(input, (tuple, list)) else [input]
input = [item[:batch_size[-1]] for item in input]
batch_size, consumed_memory = yield

with GPUMemoryMonitor(frequency=frequency) as monitor:
if method == 'train':
target = targets or self.predict(inputs=input, outputs='predictions')
target = list(target) if isinstance(target, (tuple, list)) else [target]
target = [item[:batch_size[-1]] for item in target]
if consumed_memory > max_memory:
yield batch_size // factor, True
else:
yield batch_size * factor, False

_ = self.train(inputs=input, targets=target, microbatch_size=False)
else:
_ = self.predict(inputs=input, microbatch_size=False)

data = monitor.data
consumed_memory = np.mean(np.sort(data, axis=0)[-tail_size:])
def _binary_batch_size_generator(self, low, high, max_memory):
""" Calculates next batch size for binary search method. If consumed memory is lower
than max_memory, then lower bound is increased, otherwise the upped bound is decreased.
# exit if max_memory was exceeded
if consumed_memory > max_memory:
batch_size.append(batch_size[-1] // factor)
break
Yields
------
new_batch_size, exit: tuple(int, bool)
New batch size to check, and exit condition whether the optimal
batch size computation is finished.
"""
while True:
batch_size, consumed_memory = yield

batch_size.append(batch_size[-1] * factor)
except RuntimeError as exception:
if self.is_oom_error(exception):
batch_size.append(batch_size[-1] // factor)
break
raise # some other error not memory related
if consumed_memory > max_memory:
high = batch_size
else:
low = batch_size

self.garbage_collection_cuda()
self.garbage_collection_cuda()
return batch_size[-1], batch_size[-1] * factor
exit = high - low <= 1
yield (high + low) // 2, exit


def _run_binary_scaling(self, inputs=None, targets=None, low=2, high=None,
start_batch_size=4, max_iters=25, pbar='n',
method='train', max_memory=100, frequency=0.01,
tail_size=20):
def _compute_optimal_batch_size(self, inputs=None, targets=None, low=2, high=None,
start_batch_size=4, max_iters=15, pbar='n',
method='train', max_memory=100, frequency=0.01,
tail_size=20, update_method='bruteforce',
factor=2):
count = 0

# if None => make equal distance between low, start_batch_size and high
high = high if high is not None else 2 * start_batch_size - low

low = low if isinstance(low, list) else [low]
high = high if isinstance(high, list) else [high]
# The list is used to show current batch_size in notifier. Current value is the last one.
batch_size = start_batch_size if isinstance(start_batch_size, list) else [start_batch_size]

n_iters = int(np.ceil(np.log2(high[-1] - low[-1])))
if update_method == 'binary':
n_iters = int(np.ceil(np.log2(high - low)))
generator = self._binary_batch_size_generator(low=low, high=high, max_memory=max_memory)

elif update_method == 'bruteforce':
n_iters = None
generator = self._bruteforce_batch_size_generator(factor=factor, max_memory=max_memory)
else:
raise ValueError("Wrong update method! Could be `bruteforce` or `binary`")

notifier = Notifier(n_iters=n_iters, bar=pbar,
monitors=[{'source': low, 'name': 'low'},
{'source': high, 'name': 'high'},
{'source': batch_size, 'name': 'batch_size'}])
monitors=[{'source': batch_size, 'name': 'batch_size'}])

while True:
try:
notifier.update()
# exit when >= max_iters
if count >= max_iters:
batch_size.append(low[-1])
break
count += 1

# monitor consumed memory
Expand All @@ -391,32 +438,39 @@ def _run_binary_scaling(self, inputs=None, targets=None, low=2, high=None,
# take mean of top_k memory measures
consumed_memory = np.mean(np.sort(data, axis=0)[-tail_size:])

# update borders
if consumed_memory > max_memory:
high.append(batch_size[-1])
else:
low.append(batch_size[-1])

batch_size.append((high[-1] + low[-1]) // 2)
next(generator)
generator.send(batch_size[-1])
new_batch_size, exit = generator.send(consumed_memory)
batch_size.append(new_batch_size)
except RuntimeError as exception:
if self.is_oom_error(exception):
high.append(batch_size[-1])
batch_size.append((high[-1] + low[-1]) // 2)
next(generator)
generator.send(batch_size[-1])
new_batch_size, exit = generator.send(max_memory * 2)
batch_size.append(new_batch_size)
else:
raise # some other error not memory related

if high[-1] - low[-1] <= 1:
self.garbage_collection_cuda()
if exit or count >= max_iters:
break

self.garbage_collection_cuda()
self.garbage_collection_cuda()
return batch_size[-1]


def compute_optimal_batch_size(self, method='train', max_memory=90, inputs=None, targets=None, pbar='n',
start_batch_size=4, delta_batch_size=4, max_batch_size=128, max_iters=16,
n=20, frequency=0.05, time_threshold=3, tail_size=20, std_threshold=0.1):
""" Compute memory usage for multiple batch sizes. """
def _compute_optimal_batch_size_predictive(self, method='train', max_memory=90,
inputs=None, targets=None, pbar='n',
max_iters=16, start_batch_size=4,
delta_batch_size=4, max_batch_size=128,
n=20, frequency=0.05, time_threshold=3,
tail_size=20, std_threshold=0.1):
""" Compute optimal batch size for training/inference to maximize GPU memory usage.
Works by using `train`/`predict` with different batch sizes, and measuring how much memory is taken.
Then, we solve the system of `measured_memory = batch_size * item_size + model_size + eps` equations for both
`item_size` and `model_size`.
For stable measurements, we make `n` iterations of `train`/`predict`, until the memory consumption stabilizes.
"""
#pylint: disable=consider-iterating-dictionary
table = {}
batch_size = start_batch_size
Expand All @@ -429,7 +483,7 @@ def compute_optimal_batch_size(self, method='train', max_memory=90, inputs=None,

# Exit condition
batch_size += delta_batch_size
if info['memory'] > max_memory or batch_size > max_batch_size :
if info['memory'] > max_memory or batch_size > max_batch_size:
break

# Make and solve a system of equations for `item_size`, `model_size`
Expand Down Expand Up @@ -473,6 +527,10 @@ def _get_memory_utilization(self, method, inputs, targets, n, frequency,
elif method == 'predict':
_ = self.predict(inputs=inputs, microbatch_size=False)

if not self.config.get('benchmark'):
data = monitor.data
return {'memory': np.mean(data[-tail_size:]), 'n': n, 'monitor': monitor}

# Check if the measurement is stable. If so, return the value and confidence
data = monitor.data
time = len(data) * frequency # in seconds
Expand Down

0 comments on commit fbfab2b

Please sign in to comment.