1
1
import inspect
2
2
3
- from graphql import GraphQLSchema , graphql , is_type
3
+ from graphql import GraphQLSchema , graphql , is_type , GraphQLObjectType
4
4
from graphql .type .directives import (GraphQLDirective , GraphQLIncludeDirective ,
5
5
GraphQLSkipDirective )
6
6
from graphql .type .introspection import IntrospectionSchema
12
12
from .typemap import TypeMap , is_graphene_type
13
13
14
14
15
+ def assert_valid_root_type (_type ):
16
+ if _type is None :
17
+ return
18
+ is_graphene_objecttype = inspect .isclass (
19
+ _type ) and issubclass (_type , ObjectType )
20
+ is_graphql_objecttype = isinstance (_type , GraphQLObjectType )
21
+ assert is_graphene_objecttype or is_graphql_objecttype , (
22
+ "Type {} is not a valid ObjectType."
23
+ ).format (_type )
24
+
25
+
15
26
class Schema (GraphQLSchema ):
16
27
'''
17
28
Schema Definition
@@ -20,21 +31,23 @@ class Schema(GraphQLSchema):
20
31
query and mutation (optional).
21
32
'''
22
33
23
- def __init__ (self , query = None , mutation = None , subscription = None ,
24
- directives = None , types = None , auto_camelcase = True ):
25
- assert inspect .isclass (query ) and issubclass (query , ObjectType ), (
26
- 'Schema query must be Object Type but got: {}.'
27
- ).format (query )
34
+ def __init__ (self ,
35
+ query = None ,
36
+ mutation = None ,
37
+ subscription = None ,
38
+ directives = None ,
39
+ types = None ,
40
+ auto_camelcase = True ):
41
+ assert_valid_root_type (query )
42
+ assert_valid_root_type (mutation )
43
+ assert_valid_root_type (subscription )
28
44
self ._query = query
29
45
self ._mutation = mutation
30
46
self ._subscription = subscription
31
47
self .types = types
32
48
self .auto_camelcase = auto_camelcase
33
49
if directives is None :
34
- directives = [
35
- GraphQLIncludeDirective ,
36
- GraphQLSkipDirective
37
- ]
50
+ directives = [GraphQLIncludeDirective , GraphQLSkipDirective ]
38
51
39
52
assert all (isinstance (d , GraphQLDirective ) for d in directives ), \
40
53
'Schema directives must be List[GraphQLDirective] if provided but got: {}.' .format (
@@ -61,7 +74,8 @@ def __getattr__(self, type_name):
61
74
'''
62
75
_type = super (Schema , self ).get_type (type_name )
63
76
if _type is None :
64
- raise AttributeError ('Type "{}" not found in the Schema' .format (type_name ))
77
+ raise AttributeError (
78
+ 'Type "{}" not found in the Schema' .format (type_name ))
65
79
if isinstance (_type , GrapheneGraphQLType ):
66
80
return _type .graphene_type
67
81
return _type
@@ -73,7 +87,8 @@ def get_graphql_type(self, _type):
73
87
return _type
74
88
if is_graphene_type (_type ):
75
89
graphql_type = self .get_type (_type ._meta .name )
76
- assert graphql_type , "Type {} not found in this schema." .format (_type ._meta .name )
90
+ assert graphql_type , "Type {} not found in this schema." .format (
91
+ _type ._meta .name )
77
92
assert graphql_type .graphene_type == _type
78
93
return graphql_type
79
94
raise Exception ("{} is not a valid GraphQL type." .format (_type ))
@@ -102,4 +117,8 @@ def build_typemap(self):
102
117
]
103
118
if self .types :
104
119
initial_types += self .types
105
- self ._type_map = TypeMap (initial_types , auto_camelcase = self .auto_camelcase , schema = self )
120
+ self ._type_map = TypeMap (
121
+ initial_types ,
122
+ auto_camelcase = self .auto_camelcase ,
123
+ schema = self
124
+ )
0 commit comments