Skip to content

Commit 56aeb1a

Browse files
committed
add in example and run pre-commit
1 parent 1e69c32 commit 56aeb1a

File tree

7 files changed

+23
-20
lines changed

7 files changed

+23
-20
lines changed

README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ bash ./examples/run_qwen3_moe_eagle3_online.sh
164164

165165
# train Qwen3-8B
166166
bash ./examples/run_qwen3_dense_eagle3_online.sh
167+
168+
# train Qwq-32B
169+
bash ./examples/run_qwq_dense_eagle3_online.sh
167170
```
168171

169172
### 💨 Offline Training

scripts/train_eagle3_offline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
import torch
66
import torch.distributed as dist
7-
import wandb
87
from accelerate.utils import set_seed
98
from datasets import load_dataset
109
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1110
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
1211
from tqdm import tqdm
1312
from transformers import AutoTokenizer
1413

14+
import wandb
1515
from specforge import AutoDraftModelConfig, AutoEagle3DraftModel, OfflineEagle3Model
1616
from specforge.data import (
1717
build_eagle3_dataset,

scripts/train_eagle3_online.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55
import torch
66
import torch.distributed as dist
7-
import wandb
87
from accelerate.utils import set_seed
98
from datasets import load_dataset
109
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1110
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy, StateDictType
1211
from tqdm import tqdm
1312
from transformers import AutoModelForCausalLM, AutoTokenizer
1413

14+
import wandb
1515
from specforge import (
1616
AutoDistributedTargetModel,
1717
AutoDraftModelConfig,

specforge/modeling/auto.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,17 @@
1010
Llama4TextConfig,
1111
LlamaConfig,
1212
PretrainedConfig,
13-
Qwen3MoeConfig,
1413
Qwen2Config,
14+
Qwen3MoeConfig,
1515
)
1616

1717
from specforge.utils import default_torch_dtype
1818

1919
from .draft.llama3_eagle import LlamaForCausalLMEagle3
20-
from .target.llama4 import Llama4ForCausalLM
21-
from .target.qwen3_moe import Qwen3MoeForCausalLM
22-
2320
from .draft.qwen2_eagle import Qwen2ForCausalLMEagle3
21+
from .target.llama4 import Llama4ForCausalLM
2422
from .target.qwen2 import Qwen2ForCausalLM
23+
from .target.qwen3_moe import Qwen3MoeForCausalLM
2524

2625

2726
class AutoEagle3DraftModel(AutoModelForCausalLMBase):

specforge/modeling/draft/qwen2_eagle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import torch.nn as nn
77
import torch.nn.functional as F
8-
from transformers import GenerationMixin, Qwen2Config, PreTrainedModel
8+
from transformers import GenerationMixin, PreTrainedModel, Qwen2Config
99
from transformers.activations import ACT2FN
1010
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
1111

specforge/modeling/target/qwen2.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from transformers.cache_utils import Cache, DynamicCache
2323
from transformers.generation import GenerationMixin
2424
from transformers.integrations import use_kernel_forward_from_hub
25-
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
25+
from transformers.masking_utils import (
26+
create_causal_mask,
27+
create_sliding_window_causal_mask,
28+
)
2629
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
2730
from transformers.modeling_layers import GradientCheckpointingLayer
2831
from transformers.modeling_outputs import (
@@ -146,7 +149,9 @@ def forward(
146149
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
147150

148151
cos, sin = position_embeddings
149-
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
152+
query_states, key_states = apply_rotary_pos_emb(
153+
query_states, key_states, cos, sin
154+
)
150155

151156
if past_key_value is not None:
152157
# sin and cos are specific to RoPE models; cache_position needed for the static cache
@@ -187,9 +192,7 @@ def __init__(self, config: Qwen2Config, layer_idx: int):
187192
self.self_attn = Qwen2Attention(config=config, layer_idx=layer_idx)
188193

189194
self.mlp = Qwen2MLP(config)
190-
self.input_layernorm = Qwen2RMSNorm(
191-
config.hidden_size, eps=config.rms_norm_eps
192-
)
195+
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
193196
self.post_attention_layernorm = Qwen2RMSNorm(
194197
config.hidden_size, eps=config.rms_norm_eps
195198
)
@@ -381,9 +384,9 @@ def forward(
381384
}
382385
# The sliding window alternating layers are not always activated depending on the config
383386
if self.has_sliding_layers:
384-
causal_mask_mapping[
385-
"sliding_attention"
386-
] = create_sliding_window_causal_mask(**mask_kwargs)
387+
causal_mask_mapping["sliding_attention"] = (
388+
create_sliding_window_causal_mask(**mask_kwargs)
389+
)
387390

388391
hidden_states = inputs_embeds
389392

@@ -843,4 +846,4 @@ def forward(
843846
"Qwen2ForSequenceClassification",
844847
"Qwen2ForTokenClassification",
845848
"Qwen2ForQuestionAnswering",
846-
]
849+
]

tests/test_qwen2_tp.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import torch.distributed as dist
77
import torch.multiprocessing as mp
88
from accelerate.utils import set_seed
9-
from transformers import Qwen2ForCausalLM, Qwen2Config
9+
from transformers import Qwen2Config, Qwen2ForCausalLM
1010

1111
from specforge.distributed import init_distributed
1212

@@ -38,9 +38,7 @@ def test_qwen2_tp(rank, world_size, temp_dir):
3838
# create the single-gpu
3939
model = Qwen2ForCausalLM(config).cuda()
4040

41-
from specforge.modeling.target.qwen2 import (
42-
Qwen2ForCausalLM as DistQwen2ForCausalLM,
43-
)
41+
from specforge.modeling.target.qwen2 import Qwen2ForCausalLM as DistQwen2ForCausalLM
4442

4543
dist_model = DistQwen2ForCausalLM(config).cuda()
4644

0 commit comments

Comments
 (0)