Skip to content

Commit 0971a05

Browse files
committed
Improved support / assertion for graphql types in Schema
1 parent a7a4ba6 commit 0971a05

File tree

1 file changed

+32
-13
lines changed

1 file changed

+32
-13
lines changed

graphene/types/schema.py

+32-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import inspect
22

3-
from graphql import GraphQLSchema, graphql, is_type
3+
from graphql import GraphQLSchema, graphql, is_type, GraphQLObjectType
44
from graphql.type.directives import (GraphQLDirective, GraphQLIncludeDirective,
55
GraphQLSkipDirective)
66
from graphql.type.introspection import IntrospectionSchema
@@ -12,6 +12,17 @@
1212
from .typemap import TypeMap, is_graphene_type
1313

1414

15+
def assert_valid_root_type(_type):
16+
if _type is None:
17+
return
18+
is_graphene_objecttype = inspect.isclass(
19+
_type) and issubclass(_type, ObjectType)
20+
is_graphql_objecttype = isinstance(_type, GraphQLObjectType)
21+
assert is_graphene_objecttype or is_graphql_objecttype, (
22+
"Type {} is not a valid ObjectType."
23+
).format(_type)
24+
25+
1526
class Schema(GraphQLSchema):
1627
'''
1728
Schema Definition
@@ -20,21 +31,23 @@ class Schema(GraphQLSchema):
2031
query and mutation (optional).
2132
'''
2233

23-
def __init__(self, query=None, mutation=None, subscription=None,
24-
directives=None, types=None, auto_camelcase=True):
25-
assert inspect.isclass(query) and issubclass(query, ObjectType), (
26-
'Schema query must be Object Type but got: {}.'
27-
).format(query)
34+
def __init__(self,
35+
query=None,
36+
mutation=None,
37+
subscription=None,
38+
directives=None,
39+
types=None,
40+
auto_camelcase=True):
41+
assert_valid_root_type(query)
42+
assert_valid_root_type(mutation)
43+
assert_valid_root_type(subscription)
2844
self._query = query
2945
self._mutation = mutation
3046
self._subscription = subscription
3147
self.types = types
3248
self.auto_camelcase = auto_camelcase
3349
if directives is None:
34-
directives = [
35-
GraphQLIncludeDirective,
36-
GraphQLSkipDirective
37-
]
50+
directives = [GraphQLIncludeDirective, GraphQLSkipDirective]
3851

3952
assert all(isinstance(d, GraphQLDirective) for d in directives), \
4053
'Schema directives must be List[GraphQLDirective] if provided but got: {}.'.format(
@@ -61,7 +74,8 @@ def __getattr__(self, type_name):
6174
'''
6275
_type = super(Schema, self).get_type(type_name)
6376
if _type is None:
64-
raise AttributeError('Type "{}" not found in the Schema'.format(type_name))
77+
raise AttributeError(
78+
'Type "{}" not found in the Schema'.format(type_name))
6579
if isinstance(_type, GrapheneGraphQLType):
6680
return _type.graphene_type
6781
return _type
@@ -73,7 +87,8 @@ def get_graphql_type(self, _type):
7387
return _type
7488
if is_graphene_type(_type):
7589
graphql_type = self.get_type(_type._meta.name)
76-
assert graphql_type, "Type {} not found in this schema.".format(_type._meta.name)
90+
assert graphql_type, "Type {} not found in this schema.".format(
91+
_type._meta.name)
7792
assert graphql_type.graphene_type == _type
7893
return graphql_type
7994
raise Exception("{} is not a valid GraphQL type.".format(_type))
@@ -102,4 +117,8 @@ def build_typemap(self):
102117
]
103118
if self.types:
104119
initial_types += self.types
105-
self._type_map = TypeMap(initial_types, auto_camelcase=self.auto_camelcase, schema=self)
120+
self._type_map = TypeMap(
121+
initial_types,
122+
auto_camelcase=self.auto_camelcase,
123+
schema=self
124+
)

0 commit comments

Comments
 (0)