Skip to content

ExplicitEnum subclass str (JSON dump compatible)#17933

Merged
sgugger merged 2 commits intohuggingface:mainfrom
BramVanroy:patch-1
Jun 29, 2022
Merged

ExplicitEnum subclass str (JSON dump compatible)#17933
sgugger merged 2 commits intohuggingface:mainfrom
BramVanroy:patch-1

Conversation

@BramVanroy
Copy link
Collaborator

@BramVanroy BramVanroy commented Jun 29, 2022

I found that when I wanted to write the parsed dataclasses that I get from HfArgumentParser.parse_args_into_dataclasses() to JSON, that I would get JSON errors. The reason being that TypeError: Object of type IntervalStrategy is not JSON serializable. While this is understandable (Enum members are not serializable), this is not ideal within transformers.

I checked all items in transformers that subclass ExplicitEnum and it seems that they are all str-only Enums. That would allow us to have them inherit from str, too, which solves the JSON issue. JSON can then make use of its str class for serialization. Below is a minimal - but full - example to show how this would work:

from enum import Enum
from json import dump, loads
from pathlib import Path


class ExplicitEnum(str, Enum):  # If you remove `str` you'll get a serialization error
    """
    Enum with more explicit error message for missing values.
    """

    @classmethod
    def _missing_(cls, value):
        raise ValueError(
            f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
        )


class IntervalStrategy(ExplicitEnum):
    NO = "no"
    STEPS = "steps"
    EPOCH = "epoch"


if __name__ == "__main__":
    strat = IntervalStrategy("no")
    print(strat)

    p = Path("strat_dump.json")
    with p.open("w", encoding="utf-8") as out:
        dump({"strategy": strat}, out, indent=4, sort_keys=True)

    loaded = loads(p.read_text(encoding="utf-8"))
    strat = IntervalStrategy(loaded["strategy"])
    print(strat)

A consequence is that now these ExplicitEnums will have a Union type, which originally lead to issues when using HfArgumentParser._parse_dataclass_field. Therefore, I added an exception to _parse_dataclass_field to allow for a Union if one of the types is str, assuming that a given string value to the argparser will be resolved correctly, because it is one of the accepted types.

Who can review?

@sgugger

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 29, 2022

The documentation is not available anymore as the PR was closed or merged.

@BramVanroy
Copy link
Collaborator Author

The following tests are failing but that seems unrelated:

tests/pipelines/test_pipelines_object_detection.py::ObjectDetectionPipelineTests::test_small_model_pt
tests/pipelines/test_pipelines_image_segmentation.py::ImageSegmentationPipelineTests::test_small_model_pt

@BramVanroy BramVanroy requested a review from sgugger June 29, 2022 13:04
@sgugger
Copy link
Collaborator

sgugger commented Jun 29, 2022

Yes, I skipped those tests on main for now. Let me play a little bit with this, it seems like a good idea but I want to make sure it doesn't break anything before merging.

@sgugger
Copy link
Collaborator

sgugger commented Jun 29, 2022

Tested and it all looks good, thanks a lot!

@sgugger sgugger merged commit bc019b0 into huggingface:main Jun 29, 2022
viclzhu pushed a commit to viclzhu/transformers that referenced this pull request Jul 18, 2022
* ExplicitEnum subclass str (JSON dump compatible)

* allow union if one of the types is str
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants