1+ import warnings
12from typing import Any , Dict , List , Union
23
4+ from pydantic import field_validator
5+
36from llmcompressor .core import Event , EventType , ModelParameterizedLayer , State
47from llmcompressor .modifiers import Modifier
58from llmcompressor .modifiers .pruning .helpers import (
@@ -25,7 +28,7 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
2528 update_scheduler : str = "cubic"
2629 scheduler_args : Dict [str , Any ] = {}
2730 mask_structure : str = "unstructured"
28- leave_enabled : bool = True
31+ leave_enabled : bool = False
2932 apply_globally : bool = False
3033
3134 parameterized_layers_ : Dict [str , ModelParameterizedLayer ] = None
@@ -35,6 +38,14 @@ class MagnitudePruningModifier(Modifier, LayerParamMasking):
3538 mask_creator_function_ : MaskCreatorType = None
3639 current_sparsity_ : float = None
3740
41+ @field_validator ("leave_enabled" )
42+ def validate_leave_enabled (value : bool ) -> bool :
43+ warnings .warn (
44+ "MagnitudePruningModifier.leave_enable has been deprecated" ,
45+ DeprecationWarning ,
46+ )
47+ return False
48+
3849 def on_initialize (self , state : State , ** kwargs ) -> bool :
3950 if self .apply_globally :
4051 raise NotImplementedError ("global pruning not implemented yet for PyTorch" )
@@ -75,9 +86,8 @@ def on_initialize(self, state: State, **kwargs) -> bool:
7586 return True
7687
7788 def on_finalize (self , state : State , ** kwargs ) -> bool :
78- if not self .leave_enabled :
79- for layer_param_name , _ in self .parameterized_layers_ .items ():
80- self .remove_mask (layer_param_name )
89+ for layer_param_name , _ in self .parameterized_layers_ .items ():
90+ self .remove_mask (layer_param_name )
8191
8292 return True
8393
@@ -119,12 +129,7 @@ def on_update(self, state: State, event: Event, **kwargs):
119129 self ._update_masks (event )
120130
121131 def on_end (self , state : State , event : Event , ** kwargs ):
122- if not self .leave_enabled :
123- self .disable_masks ()
124-
125- def on_event (self , state : State , event : Event , ** kwargs ):
126- if event .current_index >= self .end and self .leave_enabled :
127- self ._update_masks (event )
132+ self .disable_masks ()
128133
129134 def _update_masks (self , event : Event ):
130135 if event .type_ == EventType .OPTIM_PRE_STEP and not self ._use_hooks :
0 commit comments