Skip to content

Commit

Permalink
fix experimental_ops_override_groupnorm_test.py (#2723)
Browse files Browse the repository at this point in the history
  • Loading branch information
jianyizh authored Jun 14, 2024
1 parent 0cd4c03 commit e027c47
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
23 changes: 21 additions & 2 deletions itex/python/experimental_ops_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from intel_extension_for_tensorflow.python.ops.recurrent import gpu_lstm
from intel_extension_for_tensorflow.python.ops.recurrent import is_itex_supported_inputs
from intel_extension_for_tensorflow.python.ops.group_norm import GroupNormalization
from intel_extension_for_tensorflow.python.ops.load_ops_library import load_ops_library

format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(level=logging.INFO, format=format_str)
Expand Down Expand Up @@ -210,6 +211,7 @@ def experimental_ops_override():
from tf_keras import backend # pylint: disable=import-outside-toplevel
from tf_keras.utils import tf_utils # pylint: disable=import-outside-toplevel
import tf_keras
tf_gn_call = copy_func(tf_keras.layers.GroupNormalization.call)
tf_ln_call = copy_func(tf_keras.layers.LayerNormalization.call)
tf_lstm_call = copy_func(tf_keras.layers.LSTM.call)
tf_lstm_build = copy_func(tf_keras.layers.LSTM.build)
Expand Down Expand Up @@ -481,6 +483,23 @@ def itex_adamw_update_step(self, gradient, variable):
else:
AdamWithWeightDecayOptimizer.update_step(self, gradient, variable)

def itex_group_norm_call(self, inputs, mask=None):
input_shape = tf.shape(inputs)
if self.use_fused_group_norm and mask is None:
normalized_inputs, _, _ = load_ops_library.itex_group_norm(
inputs,
tf.cast(self.gamma, inputs.dtype),
tf.cast(self.beta, inputs.dtype),
num_groups=self.groups,
epsilon=self.epsilon,
use_scale=self.scale,
use_center=self.center,
)
return normalized_inputs
elif (not self.use_gpu) and mask is None:
return GroupNormalization.itex_group_norm_call(self, inputs)
return tf_gn_call(self, inputs)

try:
import tensorflow_addons as tfa # pylint: disable=import-outside-toplevel
tfa.layers.InstanceNormalization.call = itex_instance_norm_call
Expand Down Expand Up @@ -512,7 +531,7 @@ def itex_adamw_update_step(self, gradient, variable):
tf_keras.src.layers.rnn.lstm.LSTM.call = itex_lstm_call
tf_keras.src.layers.rnn.lstm.LSTM.build = itex_lstm_build
tf_keras.layers.GroupNormalization.build = GroupNormalization.build
tf_keras.layers.GroupNormalization.call = GroupNormalization.call
tf_keras.layers.GroupNormalization.call = itex_group_norm_call
tf_keras.optimizers.Adam.update_step = itex_adam_update_step
tf_keras.optimizers.AdamW.apply_gradients = itex_adamw_apply_gradients
tf_keras.optimizers.AdamW.update_step = itex_adamw_update_step
Expand All @@ -523,7 +542,7 @@ def itex_adamw_update_step(self, gradient, variable):
tf_keras.layers.LSTM.call = itex_lstm_call
tf_keras.layers.LSTM.build = itex_lstm_build
tf_keras.layers.GroupNormalization.build = GroupNormalization.build
tf_keras.layers.GroupNormalization.call = GroupNormalization.call
tf_keras.layers.GroupNormalization.call = itex_group_norm_call

except BaseException: # pylint: disable=broad-except
logger.warning("itex experimental ops override: Keras is not installed.") # pylint: disable=line-too-long
27 changes: 25 additions & 2 deletions itex/python/experimental_ops_override_k3.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from intel_extension_for_tensorflow.python.ops.layer_norm_k3 import _layer_norm
from intel_extension_for_tensorflow.python.ops.group_norm_k3 import GroupNormalization
from intel_extension_for_tensorflow.python.ops.optimizers_k3 import Adam
from intel_extension_for_tensorflow.python.ops.load_ops_library import load_ops_library


format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(level=logging.INFO, format=format_str)
Expand Down Expand Up @@ -95,6 +97,7 @@ def experimental_ops_override():
tf_ln_call = copy_func(keras.layers.LayerNormalization.call)
tf_bn_call = copy_func(keras.layers.BatchNormalization.call)
tf_bn_build = copy_func(keras.layers.BatchNormalization.build)
tf_gn_call = copy_func(keras.layers.GroupNormalization.call)
tf_mean = copy_func(keras.src.backend.numpy.mean)

except BaseException: # pylint: disable=broad-except
Expand Down Expand Up @@ -335,6 +338,26 @@ def _fused_batch_norm_inference():
self.moving_variance.assign(variance)
return output

def itex_group_norm_call(self,inputs):
if self.use_fused_group_norm:
shape = inputs.shape
inputs = ops.cast(inputs, self.compute_dtype)
normalized_inputs, _, _ = load_ops_library.itex_group_norm(
inputs,
ops.zeros((shape[-1],),self.compute_dtype) if self.gamma is None else self.gamma,
ops.zeros((shape[-1],),self.compute_dtype) if self.beta is None else self.beta,
num_groups=self.groups,
epsilon=self.epsilon,
use_scale=self.scale,
use_center=self.center,
)
return normalized_inputs
elif not self.use_gpu:
normalized_inputs = GroupNormalization.itex_group_norm_call(self, inputs)
return normalized_inputs
else:
return tf_gn_call(self, inputs)

def itex_mean(x, axis=None, keepdims=False):
if isinstance(x, tf.IndexedSlices):
return tf_mean(x, axis, keepdims)
Expand Down Expand Up @@ -370,7 +393,7 @@ def itex_var(x, axis=None, keepdims=False):
keras.layers.LayerNormalization.build = itex_layer_norm_build
keras.layers.BatchNormalization.call = itex_batch_norm_call
keras.layers.BatchNormalization.build = itex_batch_norm_build
keras.layers.GroupNormalization.call = GroupNormalization.call
keras.layers.GroupNormalization.call = itex_group_norm_call
keras.layers.GroupNormalization.build = GroupNormalization.build
keras.optimizers.Adam.update_step = Adam.update_step

Expand All @@ -382,7 +405,7 @@ def itex_var(x, axis=None, keepdims=False):
keras.src.layers.normalization.layer_normalization.LayerNormalization.build = itex_layer_norm_build
keras.src.layers.normalization.batch_normalization.BatchNormalization.call = itex_batch_norm_call
keras.src.layers.normalization.batch_normalization.BatchNormalization.build = itex_batch_norm_build
keras.src.layers.normalization.group_normalization.GroupNormalization.call = GroupNormalization.call
keras.src.layers.normalization.group_normalization.GroupNormalization.call = itex_group_norm_call
keras.src.layers.normalization.group_normalization.GroupNormalization.build = GroupNormalization.build
keras.src.backend.numpy.mean = itex_mean
keras.src.backend.numpy.var = itex_var
Expand Down
1 change: 1 addition & 0 deletions itex/python/ops/group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def build(self, input_shape):
self.beta = tf.zeros((dim,))

self.built = True
self._build_input_shape = input_shape

def call(self, inputs, training=False):
input_shape = tf.shape(inputs)
Expand Down

0 comments on commit e027c47

Please sign in to comment.