diff --git a/spectrum.py b/spectrum.py index 62fe2a4..da83d02 100644 --- a/spectrum.py +++ b/spectrum.py @@ -234,6 +234,8 @@ def main(): parser = argparse.ArgumentParser(description="Process SNR data for layers.") parser.add_argument('--model-name', type=str, required=True, help='Model name or path to the model') parser.add_argument('--top-percent', type=int, default=None, help='Top percentage of layers to select, overriding the default') + parser.add_argument('--all-layer-types', action='store_true', help='Whether to include all layers in selection') + parser.add_argument('--batch-size', type=int, default=None, help='Batch size to use in SnR calculation') args = parser.parse_args() # Check for existing SNR results file @@ -244,12 +246,29 @@ def main(): print(f"Found existing SNR results file for {args.model_name}") modifier = ModelModifier(top_percent=args.top_percent) modifier.generate_unfrozen_params_yaml(snr_file_path, args.top_percent) + + weight_types = modifier.get_weight_types() + print(weight_types) else: print(f"No existing SNR results file found for {args.model_name}. Proceeding with SNR calculation.") - batch_size = input_dialog(title="Batch Size", text="Enter the batch size:").run() + + if args.batch_size is None: + batch_size = input_dialog(title="Batch Size", text="Enter the batch size:").run() + else: + batch_size = args.batch_size + batch_size = int(batch_size) if batch_size else 1 - modifier = ModelModifier(model_name=args.model_name, batch_size=batch_size) - selected_weight_types = modifier.interactive_select_weights() + + modifier = ModelModifier(model_name=args.model_name, batch_size=batch_size, top_percent=args.top_percent) + + if args.all_layer_types: + weight_types = modifier.get_weight_types() + print(f"selecting all weight types: {weight_types}") + selected_weight_types = modifier.sort_weight_types(weight_types) + modifier.layer_types = selected_weight_types + else: + selected_weight_types = modifier.interactive_select_weights() + if selected_weight_types: modifier.assess_layers_snr(selected_weight_types) modifier.save_snr_to_json() @@ -258,4 +277,4 @@ def main(): print("No weight types selected.") if __name__ == "__main__": - main() + main() \ No newline at end of file