Skip to content

Commit 2f87698

Browse files
committed
Improved TypeMap and Dynamic Field to optionally include the schema
1 parent ecb1edd commit 2f87698

File tree

3 files changed

+57
-52
lines changed

3 files changed

+57
-52
lines changed

graphene/types/dynamic.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ class Dynamic(MountedType):
99
the schema. So we can have lazy fields.
1010
'''
1111

12-
def __init__(self, type, _creation_counter=None):
12+
def __init__(self, type, with_schema=False, _creation_counter=None):
1313
super(Dynamic, self).__init__(_creation_counter=_creation_counter)
1414
assert inspect.isfunction(type)
1515
self.type = type
16+
self.with_schema = with_schema
1617

17-
def get_type(self):
18+
def get_type(self, schema=None):
19+
if schema and self.with_schema:
20+
return self.type(schema=schema)
1821
return self.type()

graphene/types/schema.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,4 @@ def build_typemap(self):
9494
]
9595
if self.types:
9696
initial_types += self.types
97-
self._type_map = TypeMap(initial_types, auto_camelcase=self.auto_camelcase)
97+
self._type_map = TypeMap(initial_types, auto_camelcase=self.auto_camelcase, schema=self)

graphene/types/typemap.py

+51-49
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ def resolve_type(resolve_type_func, map, type_name, root, context, info):
5151

5252
class TypeMap(GraphQLTypeMap):
5353

54-
def __init__(self, types, auto_camelcase=True):
54+
def __init__(self, types, auto_camelcase=True, schema=None):
5555
self.auto_camelcase = auto_camelcase
56+
self.schema = schema
5657
super(TypeMap, self).__init__(types)
5758

5859
def reducer(self, map, type):
@@ -72,21 +73,25 @@ def graphene_reducer(self, map, type):
7273
if isinstance(_type, GrapheneGraphQLType):
7374
assert _type.graphene_type == type
7475
return map
76+
7577
if issubclass(type, ObjectType):
76-
return self.construct_objecttype(map, type)
78+
internal_type = self.construct_objecttype(map, type)
7779
if issubclass(type, InputObjectType):
78-
return self.construct_inputobjecttype(map, type)
80+
internal_type = self.construct_inputobjecttype(map, type)
7981
if issubclass(type, Interface):
80-
return self.construct_interface(map, type)
82+
internal_type = self.construct_interface(map, type)
8183
if issubclass(type, Scalar):
82-
return self.construct_scalar(map, type)
84+
internal_type = self.construct_scalar(map, type)
8385
if issubclass(type, Enum):
84-
return self.construct_enum(map, type)
86+
internal_type = self.construct_enum(map, type)
8587
if issubclass(type, Union):
86-
return self.construct_union(map, type)
87-
return map
88+
internal_type = self.construct_union(map, type)
89+
90+
return GraphQLTypeMap.reducer(map, internal_type)
8891

8992
def construct_scalar(self, map, type):
93+
# We have a mapping to the original GraphQL types
94+
# so there are no collisions.
9095
_scalars = {
9196
String: GraphQLString,
9297
Int: GraphQLInt,
@@ -95,18 +100,17 @@ def construct_scalar(self, map, type):
95100
ID: GraphQLID
96101
}
97102
if type in _scalars:
98-
map[type._meta.name] = _scalars[type]
99-
else:
100-
map[type._meta.name] = GrapheneScalarType(
101-
graphene_type=type,
102-
name=type._meta.name,
103-
description=type._meta.description,
104-
105-
serialize=getattr(type, 'serialize', None),
106-
parse_value=getattr(type, 'parse_value', None),
107-
parse_literal=getattr(type, 'parse_literal', None),
108-
)
109-
return map
103+
return _scalars[type]
104+
105+
return GrapheneScalarType(
106+
graphene_type=type,
107+
name=type._meta.name,
108+
description=type._meta.description,
109+
110+
serialize=getattr(type, 'serialize', None),
111+
parse_value=getattr(type, 'parse_value', None),
112+
parse_literal=getattr(type, 'parse_literal', None),
113+
)
110114

111115
def construct_enum(self, map, type):
112116
values = OrderedDict()
@@ -117,78 +121,76 @@ def construct_enum(self, map, type):
117121
description=getattr(value, 'description', None),
118122
deprecation_reason=getattr(value, 'deprecation_reason', None)
119123
)
120-
map[type._meta.name] = GrapheneEnumType(
124+
return GrapheneEnumType(
121125
graphene_type=type,
122126
values=values,
123127
name=type._meta.name,
124128
description=type._meta.description,
125129
)
126-
return map
127130

128131
def construct_objecttype(self, map, type):
129132
if type._meta.name in map:
130133
_type = map[type._meta.name]
131134
if isinstance(_type, GrapheneGraphQLType):
132135
assert _type.graphene_type == type
133-
return map
134-
map[type._meta.name] = GrapheneObjectType(
136+
return _type
137+
138+
def interfaces():
139+
interfaces = []
140+
for interface in type._meta.interfaces:
141+
i = self.construct_interface(map, interface)
142+
interfaces.append(i)
143+
return interfaces
144+
145+
return GrapheneObjectType(
135146
graphene_type=type,
136147
name=type._meta.name,
137148
description=type._meta.description,
138-
fields=None,
149+
fields=partial(self.construct_fields_for_type, map, type),
139150
is_type_of=type.is_type_of,
140-
interfaces=None
151+
interfaces=interfaces
141152
)
142-
interfaces = []
143-
for i in type._meta.interfaces:
144-
map = self.reducer(map, i)
145-
interfaces.append(map[i._meta.name])
146-
map[type._meta.name]._provided_interfaces = interfaces
147-
map[type._meta.name]._fields = self.construct_fields_for_type(map, type)
148-
# self.reducer(map, map[type._meta.name])
149-
return map
150153

151154
def construct_interface(self, map, type):
155+
if type._meta.name in map:
156+
_type = map[type._meta.name]
157+
if isinstance(_type, GrapheneInterfaceType):
158+
assert _type.graphene_type == type
159+
return _type
160+
152161
_resolve_type = None
153162
if type.resolve_type:
154163
_resolve_type = partial(resolve_type, type.resolve_type, map, type._meta.name)
155-
map[type._meta.name] = GrapheneInterfaceType(
164+
return GrapheneInterfaceType(
156165
graphene_type=type,
157166
name=type._meta.name,
158167
description=type._meta.description,
159-
fields=None,
168+
fields=partial(self.construct_fields_for_type, map, type),
160169
resolve_type=_resolve_type,
161170
)
162-
map[type._meta.name]._fields = self.construct_fields_for_type(map, type)
163-
# self.reducer(map, map[type._meta.name])
164-
return map
165171

166172
def construct_inputobjecttype(self, map, type):
167-
map[type._meta.name] = GrapheneInputObjectType(
173+
return GrapheneInputObjectType(
168174
graphene_type=type,
169175
name=type._meta.name,
170176
description=type._meta.description,
171-
fields=None,
177+
fields=partial(self.construct_fields_for_type, map, type, is_input_type=True),
172178
)
173-
map[type._meta.name]._fields = self.construct_fields_for_type(map, type, is_input_type=True)
174-
return map
175179

176180
def construct_union(self, map, type):
177181
_resolve_type = None
178182
if type.resolve_type:
179183
_resolve_type = partial(resolve_type, type.resolve_type, map, type._meta.name)
180184
types = []
181185
for i in type._meta.types:
182-
map = self.construct_objecttype(map, i)
183-
types.append(map[i._meta.name])
184-
map[type._meta.name] = GrapheneUnionType(
186+
internal_type = self.construct_objecttype(map, i)
187+
types.append(internal_type)
188+
return GrapheneUnionType(
185189
graphene_type=type,
186190
name=type._meta.name,
187191
types=types,
188192
resolve_type=_resolve_type,
189193
)
190-
map[type._meta.name].types = types
191-
return map
192194

193195
def get_name(self, name):
194196
if self.auto_camelcase:
@@ -202,7 +204,7 @@ def construct_fields_for_type(self, map, type, is_input_type=False):
202204
fields = OrderedDict()
203205
for name, field in type._meta.fields.items():
204206
if isinstance(field, Dynamic):
205-
field = get_field_as(field.get_type(), _as=Field)
207+
field = get_field_as(field.get_type(self.schema), _as=Field)
206208
if not field:
207209
continue
208210
map = self.reducer(map, field.type)

0 commit comments

Comments
 (0)