diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 65f1611730..3c638417e6 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -49,21 +49,6 @@ def setup_parser(parser): - class StoreTrueFalseAction(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): - if isinstance(values, bool) or values is None: - # Flag passed without any value -> set to True - setattr(namespace, self.dest, True) - else: - # Flag passed with value -> pattern match and set accordingly - value_str = values.lower() - if value_str in ("true", "1", "yes"): - setattr(namespace, self.dest, True) - elif value_str in ("false", "0", "no"): - setattr(namespace, self.dest, False) - else: - raise ValueError(f"Invalid value for {option_string}: {values}") - # Arguments management parser.add_argument("--device", "-d", type=str, choices=["hpu"], help="Device to run", default="hpu") parser.add_argument( @@ -299,7 +284,7 @@ def __call__(self, parser, namespace, values, option_string=None): nargs="?", const=True, default=False, - action=StoreTrueFalseAction, + action=SetTrueOrFalseOrNone, help="Whether to enable Habana Flash Attention, provided that the model supports it.", ) parser.add_argument( @@ -307,7 +292,7 @@ def __call__(self, parser, namespace, values, option_string=None): nargs="?", const=True, default=False, - action=StoreTrueFalseAction, + action=SetTrueOrFalseOrNone, help="Whether to enable Habana Flash Attention in recompute mode on first token generation. This gives an opportunity of splitting graph internally which helps reduce memory consumption.", ) parser.add_argument( @@ -315,7 +300,7 @@ def __call__(self, parser, namespace, values, option_string=None): nargs="?", const=True, default=False, - action=StoreTrueFalseAction, + action=SetTrueOrFalseOrNone, help="Whether to enable Habana Flash Attention in causal mode on first token generation.", ) parser.add_argument(