diff --git a/graphene/types/definitions.py b/graphene/types/definitions.py index e5505fd3a..ac574bed5 100644 --- a/graphene/types/definitions.py +++ b/graphene/types/definitions.py @@ -20,6 +20,11 @@ def __init__(self, *args, **kwargs): self.graphene_type = kwargs.pop("graphene_type") super(GrapheneGraphQLType, self).__init__(*args, **kwargs) + def __copy__(self): + result = GrapheneGraphQLType(graphene_type=self.graphene_type) + result.__dict__.update(self.__dict__) + return result + class GrapheneInterfaceType(GrapheneGraphQLType, GraphQLInterfaceType): pass diff --git a/graphene/types/tests/test_definition.py b/graphene/types/tests/test_definition.py index 0d8a95dfa..898fac71b 100644 --- a/graphene/types/tests/test_definition.py +++ b/graphene/types/tests/test_definition.py @@ -1,4 +1,7 @@ +import copy + from ..argument import Argument +from ..definitions import GrapheneGraphQLType from ..enum import Enum from ..field import Field from ..inputfield import InputField @@ -312,3 +315,16 @@ class TestInputObject2(CommonFields, InputObjectType): pass assert TestInputObject1._meta.fields == TestInputObject2._meta.fields + + +def test_graphene_graphql_type_can_be_copied(): + class Query(ObjectType): + field = String() + + def resolve_field(self, info): + return "" + + schema = Schema(query=Query) + query_type_copy = copy.copy(schema.graphql_schema.query_type) + assert query_type_copy.__dict__ == schema.graphql_schema.query_type.__dict__ + assert isinstance(schema.graphql_schema.query_type, GrapheneGraphQLType)