Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
c9fb0c9
remove transformer layer ID from the top module
Jun 12, 2020
6254e46
updating docstring
Jun 12, 2020
3eab150
add inject
jeffra Jun 12, 2020
f6baecb
update inject PoC
jeffra Jun 17, 2020
b5acaca
fix the preln injection
Jun 18, 2020
b3f99b7
fix the preln injection
Jun 18, 2020
04a4d35
backward-test fixed
Jun 19, 2020
fca500f
backward-test fixed
Jun 19, 2020
f3ff21e
update with replace module style
jeffra Jun 25, 2020
344b016
Merge branch 'master' into jeffra/inject
jeffra Jun 25, 2020
c208bdf
add function to revert from ds kernel -> orig layer
jeffra Jul 16, 2020
c278562
add code from Elton to do ds kernel -> orig layer conversion
jeffra Jul 17, 2020
68d8c13
formatting
jeffra Jul 17, 2020
3161565
update replace to fix runtime errors
jeffra Jul 22, 2020
798e6d3
remove pillow
jeffra Jul 29, 2020
66f590d
remove transformer layer ID from the top module
Jun 12, 2020
e4b46fb
updating docstring
Jun 12, 2020
25ee5e7
add inject
jeffra Jun 12, 2020
d5d10e9
update inject PoC
jeffra Jun 17, 2020
e090049
fix the preln injection
Jun 18, 2020
3df72f8
fix the preln injection
Jun 18, 2020
41cc4e6
backward-test fixed
Jun 19, 2020
24a3d24
backward-test fixed
Jun 19, 2020
66b4e63
update with replace module style
jeffra Jun 25, 2020
ee40034
add function to revert from ds kernel -> orig layer
jeffra Jul 16, 2020
e982c65
add code from Elton to do ds kernel -> orig layer conversion
jeffra Jul 17, 2020
fd4d0bc
formatting
jeffra Jul 17, 2020
e332d61
update replace to fix runtime errors
jeffra Jul 22, 2020
5814254
rebase-complete
Dec 4, 2020
55ffb88
merging
Dec 4, 2020
f48c52a
resolve conflict
Dec 4, 2020
0df72a0
remove dup line and add local-rank parameter to replace function
Dec 4, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .runtime.config import DeepSpeedConfig
from .runtime.activation_checkpointing import checkpointing
from .ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig
from .module_inject.replace_module import replace_transformer_layer, revert_transformer_layer
from .utils import log_dist

from .pipe import PipelineModule
Expand Down
Empty file.
122 changes: 122 additions & 0 deletions deepspeed/module_inject/inject.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import copy
import torch
from deepspeed.ops.transformer import DeepSpeedTransformerLayer, DeepSpeedTransformerConfig


def module_inject(layer_obj,
model,
config,
micro_batch_size,
max_seq_length,
seed,
preln,
fp16=True):
for name, child in model.named_children():
if isinstance(child, layer_obj):
print('REPLACING BertLayer')

cuda_config = DeepSpeedTransformerConfig(
batch_size=micro_batch_size,
max_seq_length=max_seq_length,
hidden_size=config.hidden_size,
heads=config.num_attention_heads,
attn_dropout_ratio=config.attention_probs_dropout_prob,
hidden_dropout_ratio=config.hidden_dropout_prob,
num_hidden_layers=config.num_hidden_layers,
initializer_range=config.initializer_range,
seed=seed,
fp16=fp16,
pre_layer_norm=preln)

new_module = DeepSpeedTransformerLayer(cuda_config)

# copy relevant state from child -> new module
qw = child.attention.self.query.weight
qb = child.attention.self.query.bias
kw = child.attention.self.key.weight
kb = child.attention.self.key.bias
vw = child.attention.self.value.weight
vb = child.attention.self.value.bias

qkvw = torch.cat((qw, kw, vw), 0)
qkvb = torch.cat((qb, kb, vb), 0)

new_module.attn_qkvw.data = qkvw
new_module.attn_qkvb.data = qkvb
new_module.attn_ow.data = child.attention.output.dense.weight
new_module.attn_ob.data = child.attention.output.dense.bias
if preln:
attention_layerNorm = child.PostAttentionLayerNorm
else:
attention_layerNorm = child.attention.output.LayerNorm
new_module.attn_nw.data = attention_layerNorm.weight
new_module.attn_nb.data = attention_layerNorm.bias
if preln:
intermediate_FF = child.intermediate.dense_act
else:
intermediate_FF = child.intermediate.dense
new_module.inter_w.data = intermediate_FF.weight
new_module.inter_b.data = intermediate_FF.bias
new_module.output_w.data = child.output.dense.weight
new_module.output_b.data = child.output.dense.bias
if preln:
transformer_LayerNorm = child.PreAttentionLayerNorm
else:
transformer_LayerNorm = child.output.LayerNorm
new_module.norm_w.data = transformer_LayerNorm.weight
new_module.norm_b.data = transformer_LayerNorm.bias

setattr(model, name, copy.deepcopy(new_module))

else:
module_inject(layer_obj,
child,
config,
micro_batch_size,
max_seq_length,
seed,
preln,
fp16)

return model


def test_hi():
from turing.nvidia_modelingpreln import BertConfig as BertConfigPreLN
from turing.nvidia_modelingpreln import BertForQuestionAnswering as BertForQuestionAnsweringPreLN
from turing.nvidia_modelingpreln import BertLayer
bert_model_config = {
"vocab_size_or_config_json_file": 119547,
"hidden_size": 1024,
"num_hidden_layers": 1,
"num_attention_heads": 16,
"intermediate_size": 4096,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.02
}
bert_config = BertConfigPreLN(**bert_model_config)
base_model = BertForQuestionAnsweringPreLN(bert_config, args=None)

#base_model = LinearStack()

test_model = copy.deepcopy(base_model)
test_model = module_inject(BertLayer, test_model, bert_config, 4, 384, 1234)

print('BASE', base_model)
print('TEST', test_model)

#base_model.eval()
#test_model.eval()

#test_input = torch.rand(1, base_model.input_dim)

#base_output = base_model(test_input)
#test_output = test_model(test_input)
#
#assert torch.allclose(base_output, test_output, atol=3e-8)
193 changes: 193 additions & 0 deletions deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
import copy
import torch
import deepspeed


def replace_transformer_layer(orig_layer_impl,
model,
micro_batch_size,
bert_config,
seed,
max_seq_length,
preln=False,
fp16=True,
huggingface=False,
local_rank=-1):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
micro_batch_size (int): micro batch size per gpu used during training/eval
bert_config (dict): model config containing hidden size, attention heads, etc.
seed (int): random seed value
max_seq_length (int): max sequence length for training
preln (bool): does the original layer implementation do pre or post layer norm?
fp16 (bool): fp16 or fp32
huggingface (bool): huggingface implementation is unique (supports both encoder/decoder modes)

Returns:
Updated nn.module with replaced transformer layers
"""
def replace_fn(child):
transformer_config = deepspeed.DeepSpeedTransformerConfig(
batch_size=micro_batch_size,
max_seq_length=max_seq_length,
hidden_size=bert_config.hidden_size,
heads=bert_config.num_attention_heads,
attn_dropout_ratio=bert_config.attention_probs_dropout_prob,
hidden_dropout_ratio=bert_config.hidden_dropout_prob,
num_hidden_layers=bert_config.num_hidden_layers,
initializer_range=bert_config.initializer_range,
seed=seed,
fp16=fp16,
pre_layer_norm=preln,
huggingface=huggingface,
local_rank=local_rank)
new_module = deepspeed.DeepSpeedTransformerLayer(transformer_config)

# copy relevant state from child -> new module
qw = child.attention.self.query.weight
qb = child.attention.self.query.bias
kw = child.attention.self.key.weight
kb = child.attention.self.key.bias
vw = child.attention.self.value.weight
vb = child.attention.self.value.bias

qkvw = torch.cat((qw, kw, vw), 0)
qkvb = torch.cat((qb, kb, vb), 0)

#qw.data,kw.data,vw.data = torch.chunk(qkvw, 3, axis=0)
#qb.data,kb.data,vb.data = torch.chunk(qkvb, 3, axis=0)

new_module.attn_qkvw.data = qkvw
new_module.attn_qkvb.data = qkvb
new_module.attn_ow.data = child.attention.output.dense.weight
new_module.attn_ob.data = child.attention.output.dense.bias
if preln:
attention_layernorm = child.PostAttentionLayerNorm
else:
attention_layernorm = child.attention.output.LayerNorm
new_module.attn_nw.data = attention_layernorm.weight
new_module.attn_nb.data = attention_layernorm.bias
if preln:
intermediate_ff = child.intermediate.dense_act
else:
intermediate_ff = child.intermediate.dense
new_module.inter_w.data = intermediate_ff.weight
new_module.inter_b.data = intermediate_ff.bias
new_module.output_w.data = child.output.dense.weight
new_module.output_b.data = child.output.dense.bias
if preln:
transformer_layernorm = child.PreAttentionLayerNorm
else:
transformer_layernorm = child.output.LayerNorm
new_module.norm_w.data = transformer_layernorm.weight
new_module.norm_b.data = transformer_layernorm.bias
return new_module

return replace_module(model=model, orig_class=orig_layer_impl, replace_fn=replace_fn)


def revert_transformer_layer(orig_layer_impl, model, bert_config, preln=False):
""" Revert DeepSpeed's transformer layer back to original bert-style transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation that was replaced,
e.g., transformers.modeling_bert.BertLayer.
model (torch.nn.Module): user's nn.module representing their model
bert_config (dict): model config containing hidden size, attention heads, etc.

Returns:
Updated nn.module with original bert-style transformer layers
"""
def replace_fn(child):
#from turing.nvidia_modelingpreln import BertLayer
orig_module = orig_layer_impl(bert_config)

# copy relevant state from child -> original module
qkvw = child.attn_qkvw.data
qkvb = child.attn_qkvb.data

qw, kw, vw = torch.chunk(qkvw, 3, axis=0)
qb, kb, vb = torch.chunk(qkvb, 3, axis=0)

orig_module.attention.self.query.weight.data = qw
orig_module.attention.self.query.bias.data = qb
orig_module.attention.self.key.weight.data = kw
orig_module.attention.self.key.bias.data = kb
orig_module.attention.self.value.weight.data = vw
orig_module.attention.self.value.bias.data = vb

orig_module.attention.output.dense.weight.data = child.attn_ow.data
orig_module.attention.output.dense.bias.data = child.attn_ob.data

attn_ln_w = child.attn_nw.data
attn_ln_b = child.attn_nb.data
if preln:
orig_module.PostAttentionLayerNorm.weight.data = attn_ln_w
orig_module.PostAttentionLayerNorm.bias.data = attn_ln_b
else:
orig_module.attention.output.LayerNorm.weight.data = attn_ln_w
orig_module.attention.output.LayerNorm.bias.data = attn_ln_b

inter_ff_w = child.inter_w.data
inter_ff_b = child.inter_b.data
if preln:
orig_module.intermediate.dense_act.weight.data = inter_ff_w
orig_module.intermediate.dense_act.bias.data = inter_ff_b
else:
orig_module.intermediate.dense.weight.data = inter_ff_w
orig_module.intermediate.dense.bias.data = inter_ff_b

orig_module.output.dense.weight.data = child.output_w.data
orig_module.output.dense.bias.data = child.output_b.data

transformer_ln_w = child.norm_w.data
transformer_ln_b = child.norm_b.data
if preln:
orig_module.PreAttentionLayerNorm.weight.data = transformer_ln_w
orig_module.PreAttentionLayerNorm.bias.data = transformer_ln_b
else:
orig_module.output.LayerNorm.weight.data = transformer_ln_w
orig_module.output.LayerNorm.bias.data = transformer_ln_b
return orig_module

return replace_module(model=model,
orig_class=deepspeed.DeepSpeedTransformerLayer,
replace_fn=replace_fn)


def replace_module(model, orig_class, replace_fn):
""" Scan the model for instances of ``orig_clas:`` to replace using ``replace_fn``.
Arguments:
model (torch.nn.Module): the model to augment
orig_class (torch.nn.Module): the module to search for
replace_fn (method): a method to convert instances of ``orig_class`` to the
desired type and return a new instance.

Returns:
A modified ``model``.
"""
policy = {orig_class: replace_fn}
return _replace_module(model, policy)


def _replace_module(model, policies):
""" Traverse model's children recursively and apply any transformations in ``policies``.
Arguments:
model (torch.nn.Module): model to augment
policies (dict): Mapping of source class to replacement function.

Returns:
Modified ``model``.
"""
for name, child in model.named_children():
if child.__class__ in policies:
orig = repr(child)
setattr(model, name, policies[child.__class__](child))
new = getattr(model, name)
print(f'{orig} -> {new}')
else:
_replace_module(child, policies)

return model
Loading