@@ -64,10 +64,16 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
6464 fun parseSchemaObjects (): SchemaObjects {
6565
6666 // Create GraphQL objects
67- val interfaces = interfaceDefinitions.map { createInterfaceObject(it) }
68- val objects = objectDefinitions.map { createObject(it, interfaces) }
67+ // val inputObjects = inputObjectDefinitions.map { createInputObject(it, listOf())}
68+ val inputObjects: MutableList <GraphQLInputObjectType > = mutableListOf ()
69+ inputObjectDefinitions.forEach {
70+ if (inputObjects.none { io -> io.name == it.name }) {
71+ inputObjects.add(createInputObject(it, inputObjects))
72+ }
73+ }
74+ val interfaces = interfaceDefinitions.map { createInterfaceObject(it, inputObjects) }
75+ val objects = objectDefinitions.map { createObject(it, interfaces, inputObjects) }
6976 val unions = unionDefinitions.map { createUnionObject(it, objects) }
70- val inputObjects = inputObjectDefinitions.map { createInputObject(it) }
7177 val enums = enumDefinitions.map { createEnumObject(it) }
7278
7379 // Assign type resolver to interfaces now that we know all of the object types
@@ -103,7 +109,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
103109 @Suppress(" unused" )
104110 fun getUnusedDefinitions (): Set <TypeDefinition <* >> = unusedDefinitions
105111
106- private fun createObject (objectDefinition : ObjectTypeDefinition , interfaces : List <GraphQLInterfaceType >): GraphQLObjectType {
112+ private fun createObject (objectDefinition : ObjectTypeDefinition , interfaces : List <GraphQLInterfaceType >, inputObjects : List < GraphQLInputObjectType > ): GraphQLObjectType {
107113 val name = objectDefinition.name
108114 val builder = GraphQLObjectType .newObject()
109115 .name(name)
@@ -121,7 +127,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
121127 objectDefinition.getExtendedFieldDefinitions(extensionDefinitions).forEach { fieldDefinition ->
122128 fieldDefinition.description
123129 builder.field { field ->
124- createField(field, fieldDefinition)
130+ createField(field, fieldDefinition, inputObjects )
125131 codeRegistryBuilder.dataFetcher(
126132 FieldCoordinates .coordinates(objectDefinition.name, fieldDefinition.name),
127133 fieldResolversByType[objectDefinition]?.get(fieldDefinition)?.createDataFetcher()
@@ -153,7 +159,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
153159 return output.toTypedArray()
154160 }
155161
156- private fun createInputObject (definition : InputObjectTypeDefinition ): GraphQLInputObjectType {
162+ private fun createInputObject (definition : InputObjectTypeDefinition , inputObjects : List < GraphQLInputObjectType > ): GraphQLInputObjectType {
157163 val builder = GraphQLInputObjectType .newInputObject()
158164 .name(definition.name)
159165 .definition(definition)
@@ -167,7 +173,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
167173 .definition(inputDefinition)
168174 .description(if (inputDefinition.description != null ) inputDefinition.description.content else getDocumentation(inputDefinition))
169175 .defaultValue(buildDefaultValue(inputDefinition.defaultValue))
170- .type(determineInputType(inputDefinition.type))
176+ .type(determineInputType(inputDefinition.type, inputObjects ))
171177 .withDirectives(* buildDirectives(inputDefinition.directives, setOf (), Introspection .DirectiveLocation .INPUT_FIELD_DEFINITION ))
172178 builder.field(fieldBuilder.build())
173179 }
@@ -210,7 +216,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
210216 return directiveGenerator.onEnum(builder.build(), DirectiveBehavior .Params (runtimeWiring, codeRegistryBuilder))
211217 }
212218
213- private fun createInterfaceObject (interfaceDefinition : InterfaceTypeDefinition ): GraphQLInterfaceType {
219+ private fun createInterfaceObject (interfaceDefinition : InterfaceTypeDefinition , inputObjects : List < GraphQLInputObjectType > ): GraphQLInterfaceType {
214220 val name = interfaceDefinition.name
215221 val builder = GraphQLInterfaceType .newInterface()
216222 .name(name)
@@ -220,7 +226,7 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
220226 builder.withDirectives(* buildDirectives(interfaceDefinition.directives, setOf (), Introspection .DirectiveLocation .INTERFACE ))
221227
222228 interfaceDefinition.fieldDefinitions.forEach { fieldDefinition ->
223- builder.field { field -> createField(field, fieldDefinition) }
229+ builder.field { field -> createField(field, fieldDefinition, inputObjects ) }
224230 }
225231
226232 return directiveGenerator.onInterface(builder.build(), DirectiveBehavior .Params (runtimeWiring, codeRegistryBuilder))
@@ -259,19 +265,19 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
259265 return leafObjects
260266 }
261267
262- private fun createField (field : GraphQLFieldDefinition .Builder , fieldDefinition : FieldDefinition ): GraphQLFieldDefinition .Builder {
268+ private fun createField (field : GraphQLFieldDefinition .Builder , fieldDefinition : FieldDefinition , inputObjects : List < GraphQLInputObjectType > ): GraphQLFieldDefinition .Builder {
263269 field.name(fieldDefinition.name)
264270 field.description(if (fieldDefinition.description != null ) fieldDefinition.description.content else getDocumentation(fieldDefinition))
265271 field.definition(fieldDefinition)
266272 getDeprecated(fieldDefinition.directives)?.let { field.deprecate(it) }
267- field.type(determineOutputType(fieldDefinition.type))
273+ field.type(determineOutputType(fieldDefinition.type, inputObjects ))
268274 fieldDefinition.inputValueDefinitions.forEach { argumentDefinition ->
269275 val argumentBuilder = GraphQLArgument .newArgument()
270276 .name(argumentDefinition.name)
271277 .definition(argumentDefinition)
272278 .description(if (argumentDefinition.description != null ) argumentDefinition.description.content else getDocumentation(argumentDefinition))
273279 .defaultValue(buildDefaultValue(argumentDefinition.defaultValue))
274- .type(determineInputType(argumentDefinition.type))
280+ .type(determineInputType(argumentDefinition.type, inputObjects ))
275281 .withDirectives(* buildDirectives(argumentDefinition.directives, setOf (), Introspection .DirectiveLocation .ARGUMENT_DEFINITION ))
276282 field.argument(argumentBuilder.build())
277283 }
@@ -293,16 +299,17 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
293299 }
294300 }
295301
296- private fun determineOutputType (typeDefinition : Type <* >) =
297- determineType(GraphQLOutputType ::class , typeDefinition, permittedTypesForObject) as GraphQLOutputType
302+ private fun determineOutputType (typeDefinition : Type <* >, inputObjects : List < GraphQLInputObjectType > ) =
303+ determineType(GraphQLOutputType ::class , typeDefinition, permittedTypesForObject, inputObjects ) as GraphQLOutputType
298304
299- private fun determineInputType (typeDefinition : Type <* >) =
300- determineType(GraphQLInputType ::class , typeDefinition, permittedTypesForInputObject) as GraphQLInputType
301-
302- private fun <T : Any > determineType (expectedType : KClass <T >, typeDefinition : Type <* >, allowedTypeReferences : Set <String >): GraphQLType =
305+ private fun <T : Any > determineType (expectedType : KClass <T >, typeDefinition : Type <* >, allowedTypeReferences : Set <String >, inputObjects : List <GraphQLInputObjectType >): GraphQLType =
303306 when (typeDefinition) {
304- is ListType -> GraphQLList (determineType(expectedType, typeDefinition.type, allowedTypeReferences))
305- is NonNullType -> GraphQLNonNull (determineType(expectedType, typeDefinition.type, allowedTypeReferences))
307+ is ListType -> GraphQLList (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
308+ is NonNullType -> GraphQLNonNull (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
309+ is InputObjectTypeDefinition -> {
310+ log.info(" Create input object" )
311+ createInputObject(typeDefinition, inputObjects)
312+ }
306313 is TypeName -> {
307314 val scalarType = customScalars[typeDefinition.name] ? : graphQLScalars[typeDefinition.name]
308315 if (scalarType != null ) {
@@ -318,6 +325,45 @@ class SchemaParser internal constructor(scanResult: ScannedSchemaObjects, privat
318325 else -> throw SchemaError (" Unknown type: $typeDefinition " )
319326 }
320327
328+ private fun determineInputType (typeDefinition : Type <* >, inputObjects : List <GraphQLInputObjectType >) =
329+ determineInputType(GraphQLInputType ::class , typeDefinition, permittedTypesForInputObject, inputObjects) as GraphQLInputType
330+
331+ private fun <T : Any > determineInputType (expectedType : KClass <T >, typeDefinition : Type <* >, allowedTypeReferences : Set <String >, inputObjects : List <GraphQLInputObjectType >): GraphQLType =
332+ when (typeDefinition) {
333+ is ListType -> GraphQLList (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
334+ is NonNullType -> GraphQLNonNull (determineType(expectedType, typeDefinition.type, allowedTypeReferences, inputObjects))
335+ is InputObjectTypeDefinition -> {
336+ log.info(" Create input object" )
337+ createInputObject(typeDefinition, inputObjects)
338+ }
339+ is TypeName -> {
340+ val scalarType = customScalars[typeDefinition.name] ? : graphQLScalars[typeDefinition.name]
341+ if (scalarType != null ) {
342+ scalarType
343+ } else {
344+ if (! allowedTypeReferences.contains(typeDefinition.name)) {
345+ throw SchemaError (" Expected type '${typeDefinition.name} ' to be a ${expectedType.simpleName} , but it wasn't! " +
346+ " Was a type only permitted for object types incorrectly used as an input type, or vice-versa?" )
347+ }
348+ val found = inputObjects.filter { it.name == typeDefinition.name }
349+ if (found.size == 1 ) {
350+ found[0 ]
351+ } else {
352+ val filteredDefinitions = inputObjectDefinitions.filter { it.name == typeDefinition.name }
353+ if (filteredDefinitions.isNotEmpty()) {
354+ val inputObject = createInputObject(filteredDefinitions[0 ], inputObjects)
355+ (inputObjects as MutableList ).add(inputObject)
356+ inputObject
357+ } else {
358+ // todo: handle enum type
359+ GraphQLTypeReference (typeDefinition.name)
360+ }
361+ }
362+ }
363+ }
364+ else -> throw SchemaError (" Unknown type: $typeDefinition " )
365+ }
366+
321367 /* *
322368 * Returns an optional [String] describing a deprecated field/enum.
323369 * If a deprecation directive was defined using the @deprecated directive,
0 commit comments