Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion _quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ quartodoc:
- utils.distributed
- utils.dict
- utils.optimizers.adopt
- utils.data.pretraining
- utils.data.streaming
- utils.data.sft
- utils.quantization
- title: Schemas
Expand Down Expand Up @@ -272,6 +272,7 @@ website:
contents:
- docs/batch_vs_grad.qmd
- docs/dataset_preprocessing.qmd
- docs/streaming.qmd
- docs/multipack.qmd
- docs/mixed_precision.qmd
- docs/optimizers.qmd
Expand Down
120 changes: 120 additions & 0 deletions docs/streaming.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
---
title: Streaming Datasets
description: How to use streaming mode for large-scale datasets and memory-efficient training
order: 10
---

Streaming enables memory-efficient training with large datasets by loading data
incrementally rather than loading the entire dataset into memory at once.

Use streaming when:

- Your dataset is too large to fit in memory (e.g. when you're doing pretraining with massive text corpora)
- You want to start training immediately without preprocessing the entire dataset

Streaming works with both remote and locally stored datasets!
Comment thread
djsaunde marked this conversation as resolved.

::: {.callout-note}
Streaming currently only supports a single dataset. Multi-dataset support will be added soon.
:::


## Configuration

### Basic Streaming

Enable streaming mode by setting the `streaming` flag:

```yaml
streaming: true
```

### Pretraining with Streaming

For pretraining tasks, streaming is automatically enabled when using `pretraining_dataset`:

```yaml
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
type: pretrain
text_column: text
split: train
Comment thread
djsaunde marked this conversation as resolved.

# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```

### SFT with Streaming

For supervised fine-tuning with streaming:
Comment thread
djsaunde marked this conversation as resolved.

```yaml
streaming: true
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train

# Optionally, enable sample packing
streaming_multipack_buffer_size: 10000
sample_packing: true
```

## Configuration Options

### `streaming_multipack_buffer_size`

Controls the buffer size for multipack streaming (default: 10,000). This determines how
many samples are buffered before packing. Larger buffers can improve packing efficiency
but use more memory.

### `shuffle_merged_datasets`

When enabled, shuffles the streaming dataset using the buffer. This requires additional
memory for the shuffle buffer.

## Sample Packing with Streaming

Sample packing is supported for streaming datasets. When enabled, multiple samples are
packed into a single sequence to maximize GPU utilization:

```yaml
sample_packing: true
streaming_multipack_buffer_size: 10000

# For SFT: attention is automatically isolated between packed samples
# For pretraining: control with pretrain_multipack_attn
pretrain_multipack_attn: true # prevent cross-attention between packed samples
```
Comment thread
djsaunde marked this conversation as resolved.

For more information, see our [documentation](multipack.qmd) on multipacking.

## Important Considerations

### Memory Usage

While streaming reduces memory usage compared to loading entire datasets, you still need
to consider:

- You can control the memory usage by adjusting `streaming_multipack_buffer_size`
- Sample packing requires buffering multiple samples
- Shuffling requires additional memory for the shuffle buffer

### Performance

- Streaming may have slightly higher latency compared to preprocessed datasets, as samples are processed on-the-fly
- Network speed and disk read speed are important when streaming from remote sources or a local dataset, respectively
- Consider using `axolotl preprocess` for smaller or more frequently used datasets
Comment thread
djsaunde marked this conversation as resolved.

### Evaluation Datasets

Evaluation datasets are not streamed to ensure consistent evaluation metrics. They're
loaded normally even when training uses streaming.
Comment thread
djsaunde marked this conversation as resolved.

## Examples

See the `examples/streaming/` directory for complete configuration examples:

- `pretrain.yaml`: Pretraining with streaming dataset
- `sft.yaml`: Supervised fine-tuning with streaming
50 changes: 50 additions & 0 deletions examples/streaming/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Streaming Dataset Examples

This directory contains example configurations for using Axolotl's streaming dataset
functionality, which enables memory-efficient training with large datasets.

## Examples

Run the following examples with e.g. `axolotl train examples/streaming/sft.yaml`; no
`axolotl preprocess` required!

### Pretraining (`pretrain.yaml`)

Demonstrates streaming configuration for pretraining tasks using the fineweb-edu dataset
with SmolLM2-135M.

- Uses `pretraining_dataset` configuration for automatic streaming
- Multipack attention control to prevent cross-attention between packed sequences
- Buffer size configuration for memory management

### SFT (`sft.yaml`)

Shows how to use streaming for supervised fine-tuning with the Alpaca dataset.

- Explicit `streaming: true` flag for SFT datasets
- Memory-efficient training on instruction datasets
- Evaluation datasets are currently not streamed

## Key Configuration Options

### `streaming`
- Enables streaming mode for standard datasets
- Automatically enabled for `pretraining_dataset`

### `streaming_multipack_buffer_size`
- Controls buffer size for sample packing (default: 10,000)
- Larger values improve packing efficiency but use more memory
- Adjust based on available memory

### `shuffle_merged_datasets`
- Enables shuffling of streaming datasets
- Requires additional memory for shuffle buffer

### `sample_packing`
- Packs multiple samples into single sequences
- Minimize per-step padding tokens

## Performance Tips

- Download small / frequently-used datasets locally for better performance
- Larger buffer sizes improve packing efficiency
57 changes: 57 additions & 0 deletions examples/streaming/pretrain.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
base_model: HuggingFaceTB/SmolLM2-135M

# Streaming pretraining configuration
pretraining_dataset:
- path: HuggingFaceFW/fineweb-edu
name: sample-10BT
type: pretrain
text_column: text
split: train

# Streaming-specific settings
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true

# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-pretrain-streaming

# Sequence and packing settings
sequence_len: 1024
sample_packing: true
pretrain_multipack_attn: true # Prevent cross-attention between packed sequences
flash_attention: true

# Batch size settings
gradient_accumulation_steps: 8
micro_batch_size: 1

# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 5e-4
warmup_ratio: 0.1
weight_decay: 0.01

# Precision and performance
bf16: auto
tf32: true

# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 250
save_total_limit: 3

# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

# Special tokens
special_tokens:
pad_token: "<|endoftext|>"

# save_first_step: true # uncomment this to validate checkpoint saving works with your config
55 changes: 55 additions & 0 deletions examples/streaming/sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
base_model: HuggingFaceTB/SmolLM2-135M

# Dataset configuration
datasets:
- path: tatsu-lab/alpaca
type: alpaca
split: train

# Streaming-specific settings
streaming: true
streaming_multipack_buffer_size: 10000
shuffle_merged_datasets: true

# Training configuration
max_steps: 1000
output_dir: ./outputs/smollm2-135m-sft-streaming

# Sequence and packing settings
sequence_len: 1024
sample_packing: true
flash_attention: true

# Batch size settings
gradient_accumulation_steps: 4
micro_batch_size: 1

# Optimizer and scheduler
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 2e-4
warmup_ratio: 0.1
weight_decay: 0.0

# Precision and performance
bf16: auto
tf32: true

# Logging and checkpointing
logging_steps: 10
save_strategy: steps
save_steps: 100
save_total_limit: 3

# Weights & Biases (optional)
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:

# Special tokens
special_tokens:
pad_token: "<|endoftext|>"

# save_first_step: true # uncomment this to validate checkpoint saving works with your config
8 changes: 6 additions & 2 deletions src/axolotl/cli/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@ class PreprocessCliArgs:
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=None,
default=False,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
"help": (
"Deprecated in v0.13.0, will be removed in v0.14.0. For streaming "
"datasets, use 'axolotl train' and set 'streaming: true' in your YAML "
"config, or pass --streaming instead in the CLI."
)
},
)

Expand Down
12 changes: 11 additions & 1 deletion src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,20 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
check_accelerate_default_config()
check_user_token()

if cli_args.iterable:
LOG.error(
"The --iterable CLI argument for 'axolotl preprocess' is no longer "
"supported. For training, set 'streaming: true' in your YAML config or "
"pass '--streaming' in your 'axolotl train' command for on-the-fly "
"preprocessing."
)
return
Comment thread
djsaunde marked this conversation as resolved.

for key in ["skip_prepare_dataset", "pretraining_dataset"]:
if cfg.get(key):
LOG.error(
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
f"You have set `{key}:`. `preprocess` is not needed. Run the 'axolotl "
"train' CLI directly instead."
)
return

Expand Down
2 changes: 0 additions & 2 deletions src/axolotl/common/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,11 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = getattr(cli_args, "iterable", False)

train_dataset, eval_dataset, total_num_steps, prompters = prepare_datasets(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)

if (
Expand Down
Loading
Loading