1616from ..language import (
1717 DirectiveNode ,
1818 InputValueDefinitionNode ,
19+ InterfaceTypeDefinitionNode ,
1920 InterfaceTypeExtensionNode ,
2021 NamedTypeNode ,
2122 Node ,
23+ ObjectTypeDefinitionNode ,
2224 ObjectTypeExtensionNode ,
2325 OperationType ,
26+ SchemaDefinitionNode ,
2427 SchemaExtensionNode ,
28+ UnionTypeDefinitionNode ,
2529 UnionTypeExtensionNode ,
2630)
2731from .definition import (
4549)
4650from ..utilities .assert_valid_name import is_valid_name_error
4751from ..utilities .type_comparators import is_equal_type , is_type_sub_type_of
48- from .directives import is_directive , GraphQLDirective , GraphQLDeprecatedDirective
52+ from .directives import is_directive , GraphQLDeprecatedDirective
4953from .introspection import is_introspection_type
5054from .schema import GraphQLSchema , assert_schema
5155
@@ -252,7 +256,7 @@ def validate_fields(
252256 if not fields :
253257 self .report_error (
254258 f"Type { type_ .name } must define one or more fields." ,
255- get_all_nodes (type_ ) ,
259+ [ type_ . ast_node , * (type_ . extension_ast_nodes or ())] ,
256260 )
257261
258262 for field_name , field in fields .items ():
@@ -339,7 +343,11 @@ def validate_type_implements_interface(
339343 self .report_error (
340344 f"Interface field { iface .name } .{ field_name } "
341345 f" expected but { type_ .name } does not provide it." ,
342- [iface_field .ast_node , * get_all_nodes (type_ )],
346+ [
347+ iface_field .ast_node ,
348+ type_ .ast_node ,
349+ * (type_ .extension_ast_nodes or ()),
350+ ],
343351 )
344352 continue
345353
@@ -422,7 +430,7 @@ def validate_union_members(self, union: GraphQLUnionType) -> None:
422430 if not member_types :
423431 self .report_error (
424432 f"Union type { union .name } must define one or more member types." ,
425- get_all_nodes (union ) ,
433+ [ union . ast_node , * (union . extension_ast_nodes or ())] ,
426434 )
427435
428436 included_type_names : Set [str ] = set ()
@@ -449,7 +457,7 @@ def validate_enum_values(self, enum_type: GraphQLEnumType) -> None:
449457 if not enum_values :
450458 self .report_error (
451459 f"Enum type { enum_type .name } must define one or more values." ,
452- get_all_nodes (enum_type ) ,
460+ [ enum_type . ast_node , * (enum_type . extension_ast_nodes or ())] ,
453461 )
454462
455463 for value_name , enum_value in enum_values .items ():
@@ -469,7 +477,7 @@ def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
469477 self .report_error (
470478 f"Input Object type { input_obj .name } "
471479 " must define one or more fields." ,
472- get_all_nodes (input_obj ) ,
480+ [ input_obj . ast_node , * (input_obj . extension_ast_nodes or ())] ,
473481 )
474482
475483 # Ensure the arguments are valid
@@ -500,12 +508,14 @@ def validate_input_fields(self, input_obj: GraphQLInputObjectType) -> None:
500508def get_operation_type_node (
501509 schema : GraphQLSchema , operation : OperationType
502510) -> Optional [Node ]:
503- for extension_node in get_all_nodes (schema ):
504- operation_types = cast (SchemaExtensionNode , extension_node ).operation_types
505- if operation_types : # pragma: no cover else
506- for operation_type in operation_types :
507- if operation_type .operation == operation :
508- return operation_type .type
511+ ast_node : Optional [Union [SchemaDefinitionNode , SchemaExtensionNode ]]
512+ for ast_node in [schema .ast_node , * (schema .extension_ast_nodes or ())]:
513+ if ast_node :
514+ operation_types = ast_node .operation_types
515+ if operation_types : # pragma: no cover else
516+ for operation_type in operation_types :
517+ if operation_type .operation == operation :
518+ return operation_type .type
509519 return None
510520
511521
@@ -561,55 +571,44 @@ def __call__(self, input_obj: GraphQLInputObjectType) -> None:
561571 del self .field_path_index_by_type_name [name ]
562572
563573
564- SDLDefinedObject = Union [
565- GraphQLSchema ,
566- GraphQLDirective ,
567- GraphQLInterfaceType ,
568- GraphQLObjectType ,
569- GraphQLInputObjectType ,
570- GraphQLUnionType ,
571- GraphQLEnumType ,
572- ]
573-
574-
575- def get_all_nodes (obj : SDLDefinedObject ) -> List [Node ]:
576- node = obj .ast_node
577- nodes : List [Node ] = [node ] if node else []
578- extension_nodes = getattr (obj , "extension_ast_nodes" , None )
579- if extension_nodes :
580- nodes .extend (extension_nodes )
581- return nodes
582-
583-
584574def get_all_implements_interface_nodes (
585575 type_ : Union [GraphQLObjectType , GraphQLInterfaceType ], iface : GraphQLInterfaceType
586576) -> List [NamedTypeNode ]:
587577 implements_nodes : List [NamedTypeNode ] = []
588- for extension_node in get_all_nodes (type_ ):
589- iface_nodes = cast (
590- Union [ObjectTypeExtensionNode , InterfaceTypeExtensionNode ], extension_node
591- ).interfaces
592- if iface_nodes : # pragma: no cover else
593- implements_nodes .extend (
594- iface_node
595- for iface_node in iface_nodes
596- if iface_node .name .value == iface .name
597- )
578+ ast_node : Optional [
579+ Union [
580+ ObjectTypeDefinitionNode ,
581+ ObjectTypeExtensionNode ,
582+ InterfaceTypeDefinitionNode ,
583+ InterfaceTypeExtensionNode ,
584+ ]
585+ ]
586+ for ast_node in [type_ .ast_node , * (type_ .extension_ast_nodes or ())]:
587+ if ast_node :
588+ iface_nodes = ast_node .interfaces
589+ if iface_nodes : # pragma: no cover else
590+ implements_nodes .extend (
591+ iface_node
592+ for iface_node in iface_nodes
593+ if iface_node .name .value == iface .name
594+ )
598595 return implements_nodes
599596
600597
601598def get_union_member_type_nodes (
602599 union : GraphQLUnionType , type_name : str
603600) -> Optional [List [NamedTypeNode ]]:
604601 member_type_nodes : List [NamedTypeNode ] = []
605- for extension_node in get_all_nodes (union ):
606- type_nodes = cast (UnionTypeExtensionNode , extension_node ).types
607- if type_nodes : # pragma: no cover else
608- member_type_nodes .extend (
609- type_node
610- for type_node in type_nodes
611- if type_node .name .value == type_name
612- )
602+ ast_node : Optional [Union [UnionTypeDefinitionNode , UnionTypeExtensionNode ]]
603+ for ast_node in [union .ast_node , * (union .extension_ast_nodes or ())]:
604+ if ast_node :
605+ type_nodes = ast_node .types
606+ if type_nodes : # pragma: no cover else
607+ member_type_nodes .extend (
608+ type_node
609+ for type_node in type_nodes
610+ if type_node .name .value == type_name
611+ )
613612 return member_type_nodes
614613
615614
0 commit comments