diff --git a/graphene/relay/connection.py b/graphene/relay/connection.py index b8f492078..20f3cceb9 100644 --- a/graphene/relay/connection.py +++ b/graphene/relay/connection.py @@ -118,29 +118,34 @@ def type(self): ).format(str(self), connection_type) return connection_type + @classmethod + def resolve_connection(cls, connection_type, args, resolved): + if isinstance(resolved, connection_type): + return resolved + + assert isinstance(resolved, Iterable), ( + 'Resolved value from the connection field have to be iterable or instance of {}. ' + 'Received "{}"' + ).format(connection_type, resolved) + connection = connection_from_list( + resolved, + args, + connection_type=connection_type, + edge_type=connection_type.Edge, + pageinfo_type=PageInfo + ) + connection.iterable = resolved + return connection + @classmethod def connection_resolver(cls, resolver, connection_type, root, args, context, info): - p = Promise.resolve(resolver(root, args, context, info)) - - def resolve_connection(resolved): - if isinstance(resolved, connection_type): - return resolved - - assert isinstance(resolved, Iterable), ( - 'Resolved value from the connection field have to be iterable or instance of {}. ' - 'Received "{}"' - ).format(connection_type, resolved) - connection = connection_from_list( - resolved, - args, - connection_type=connection_type, - edge_type=connection_type.Edge, - pageinfo_type=PageInfo - ) - connection.iterable = resolved - return connection - - return p.then(resolve_connection) + resolved = resolver(root, args, context, info) + + on_resolve = partial(cls.resolve_connection, connection_type, args) + if isinstance(resolved, Promise): + return resolved.then(on_resolve) + + return on_resolve(resolved) def get_resolver(self, parent_resolver): resolver = super(IterableConnectionField, self).get_resolver(parent_resolver)