Skip to content

Commit 7f079cf

Browse files
authored
codegen: Permit autogenerating types aliases for endpoints (#4213)
1 parent c15fb3d commit 7f079cf

File tree

14 files changed

+293
-102
lines changed

14 files changed

+293
-102
lines changed

openapi-codegen/cli/src/main/scala/sttp/tapir/codegen/GenScala.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ object GenScala {
6565
private val streamingImplementationOpt: Opts[Option[String]] =
6666
Opts.option[String]("streamingImplementation", "Capability to use for binary streams", "s").orNone
6767

68+
private val generateEndpointTypesOpt: Opts[Boolean] =
69+
Opts.flag("generateEndpointTypes", "Whether to emit explicit type aliases for endpoint declarations", "e").orFalse
70+
6871
private val destDirOpt: Opts[File] =
6972
Opts
7073
.option[String]("destdir", "Destination directory", "d")
@@ -88,7 +91,8 @@ object GenScala {
8891
jsonLibOpt,
8992
validateNonDiscriminatedOneOfsOpt,
9093
maxSchemasPerFileOpt,
91-
streamingImplementationOpt
94+
streamingImplementationOpt,
95+
generateEndpointTypesOpt
9296
)
9397
.mapN {
9498
case (
@@ -101,7 +105,8 @@ object GenScala {
101105
jsonLib,
102106
validateNonDiscriminatedOneOfs,
103107
maxSchemasPerFile,
104-
streamingImplementation
108+
streamingImplementation,
109+
generateEndpointTypes
105110
) =>
106111
val objectName = maybeObjectName.getOrElse(DefaultObjectName)
107112

@@ -116,7 +121,8 @@ object GenScala {
116121
jsonLib.getOrElse("circe"),
117122
streamingImplementation.getOrElse("fs2"),
118123
validateNonDiscriminatedOneOfs,
119-
maxSchemasPerFile.getOrElse(400)
124+
maxSchemasPerFile.getOrElse(400),
125+
generateEndpointTypes
120126
)
121127
)
122128
destFiles <- contents.toVector.traverse { case (fileName, content) => writeGeneratedFile(destDir, fileName, content) }

openapi-codegen/core/src/main/scala/sttp/tapir/codegen/BasicGenerator.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ object StreamingImplementation extends Enumeration {
2525
val Akka, FS2, Pekko, Zio = Value
2626
type StreamingImplementation = Value
2727
}
28+
object EndpointCapabilites extends Enumeration {
29+
val Akka, FS2, Nothing, Pekko, Zio = Value
30+
type EndpointCapabilites = Value
31+
}
2832

2933
object BasicGenerator {
3034

@@ -40,7 +44,8 @@ object BasicGenerator {
4044
jsonSerdeLib: String,
4145
streamingImplementation: String,
4246
validateNonDiscriminatedOneOfs: Boolean,
43-
maxSchemasPerFile: Int
47+
maxSchemasPerFile: Int,
48+
generateEndpointTypes: Boolean
4449
): Map[String, String] = {
4550
val normalisedJsonLib = jsonSerdeLib.toLowerCase match {
4651
case "circe" => JsonSerdeLib.Circe
@@ -65,7 +70,14 @@ object BasicGenerator {
6570
}
6671

6772
val EndpointDefs(endpointsByTag, queryOrPathParamRefs, jsonParamRefs, enumsDefinedOnEndpointParams) =
68-
endpointGenerator.endpointDefs(doc, useHeadTagForObjectNames, targetScala3, normalisedJsonLib, normalisedStreamingImplementation)
73+
endpointGenerator.endpointDefs(
74+
doc,
75+
useHeadTagForObjectNames,
76+
targetScala3,
77+
normalisedJsonLib,
78+
normalisedStreamingImplementation,
79+
generateEndpointTypes
80+
)
6981
val GeneratedClassDefinitions(classDefns, jsonSerdes, schemas) =
7082
classGenerator
7183
.classDefs(

openapi-codegen/core/src/main/scala/sttp/tapir/codegen/EndpointGenerator.scala

Lines changed: 173 additions & 80 deletions
Large diffs are not rendered by default.

openapi-codegen/core/src/test/scala/sttp/tapir/codegen/BasicGeneratorSpec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ class BasicGeneratorSpec extends CompileCheckTestBase {
1818
jsonSerdeLib = jsonSerdeLib,
1919
validateNonDiscriminatedOneOfs = true,
2020
maxSchemasPerFile = 400,
21-
streamingImplementation = "fs2"
21+
streamingImplementation = "fs2",
22+
generateEndpointTypes = false
2223
)
2324
}
2425
def gen(

openapi-codegen/core/src/test/scala/sttp/tapir/codegen/ClassDefinitionGeneratorSpec.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,8 @@ class ClassDefinitionGeneratorSpec extends CompileCheckTestBase {
396396
useHeadTagForObjectNames = false,
397397
targetScala3 = false,
398398
jsonSerdeLib = JsonSerdeLib.Circe,
399-
streamingImplementation = StreamingImplementation.FS2
399+
streamingImplementation = StreamingImplementation.FS2,
400+
generateEndpointTypes = false
400401
)
401402
.endpointDecls(None)
402403
}

openapi-codegen/core/src/test/scala/sttp/tapir/codegen/EndpointGeneratorSpec.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
6868
useHeadTagForObjectNames = false,
6969
targetScala3 = false,
7070
jsonSerdeLib = JsonSerdeLib.Circe,
71-
streamingImplementation = StreamingImplementation.FS2
71+
streamingImplementation = StreamingImplementation.FS2,
72+
generateEndpointTypes = false
7273
)
7374
.endpointDecls(None)
7475
generatedCode should include("val getTestAsdId =")
@@ -153,7 +154,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
153154
useHeadTagForObjectNames = false,
154155
targetScala3 = false,
155156
jsonSerdeLib = JsonSerdeLib.Circe,
156-
streamingImplementation = StreamingImplementation.FS2
157+
streamingImplementation = StreamingImplementation.FS2,
158+
generateEndpointTypes = false
157159
)
158160
.endpointDecls(None) shouldCompile ()
159161
}
@@ -205,7 +207,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
205207
useHeadTagForObjectNames = false,
206208
targetScala3 = false,
207209
jsonSerdeLib = JsonSerdeLib.Circe,
208-
streamingImplementation = StreamingImplementation.FS2
210+
streamingImplementation = StreamingImplementation.FS2,
211+
generateEndpointTypes = false
209212
)
210213
.endpointDecls(None)
211214
generatedCode should include(
@@ -272,7 +275,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
272275
jsonSerdeLib = "circe",
273276
validateNonDiscriminatedOneOfs = true,
274277
maxSchemasPerFile = 400,
275-
streamingImplementation = "fs2"
278+
streamingImplementation = "fs2",
279+
generateEndpointTypes = false
276280
)("TapirGeneratedEndpoints")
277281
generatedCode should include(
278282
"""file: sttp.model.Part[java.io.File]"""
@@ -294,7 +298,8 @@ class EndpointGeneratorSpec extends CompileCheckTestBase {
294298
jsonSerdeLib = "circe",
295299
validateNonDiscriminatedOneOfs = true,
296300
maxSchemasPerFile = 400,
297-
streamingImplementation = "fs2"
301+
streamingImplementation = "fs2",
302+
generateEndpointTypes = false
298303
)("TapirGeneratedEndpoints")
299304
generatedCode shouldCompile ()
300305
val expectedAttrDecls = Seq(

openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenKeys.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ case class OpenApiConfiguration(
1111
streamingImplementation: String,
1212
validateNonDiscriminatedOneOfs: Boolean,
1313
maxSchemasPerFile: Int,
14+
generateEndpointTypes: Boolean,
1415
additionalPackages: List[(String, File)]
1516
)
1617

@@ -27,6 +28,7 @@ trait OpenapiCodegenKeys {
2728
lazy val openapiMaxSchemasPerFile = settingKey[Int]("Maximum number of schemas to generate for a single file")
2829
lazy val openapiAdditionalPackages = settingKey[List[(String, File)]]("Addition package -> spec mappings to generate.")
2930
lazy val openapiStreamingImplementation = settingKey[String]("Implementation for streamTextBody. Supports: akka, fs2, pekko, zio.")
31+
lazy val openapiGenerateEndpointTypes = settingKey[Boolean]("Whether to emit explicit types for endpoint denfs")
3032
lazy val openapiOpenApiConfiguration =
3133
settingKey[OpenApiConfiguration]("Aggregation of other settings. Manually set value will be disregarded.")
3234

openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenPlugin.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ object OpenapiCodegenPlugin extends AutoPlugin {
3232
openapiStreamingImplementation.value,
3333
openapiValidateNonDiscriminatedOneOfs.value,
3434
openapiMaxSchemasPerFile.value,
35+
openapiGenerateEndpointTypes.value,
3536
openapiAdditionalPackages.value
3637
)
3738
def openapiCodegenDefaultSettings: Seq[Setting[_]] = Seq(
@@ -44,6 +45,7 @@ object OpenapiCodegenPlugin extends AutoPlugin {
4445
openapiMaxSchemasPerFile := 400,
4546
openapiAdditionalPackages := Nil,
4647
openapiStreamingImplementation := "fs2",
48+
openapiGenerateEndpointTypes := false,
4749
standardParamSetting
4850
)
4951

@@ -73,6 +75,7 @@ object OpenapiCodegenPlugin extends AutoPlugin {
7375
c.streamingImplementation,
7476
c.validateNonDiscriminatedOneOfs,
7577
c.maxSchemasPerFile,
78+
c.generateEndpointTypes,
7679
srcDir,
7780
taskStreams.cacheDirectory,
7881
sv.startsWith("3"),

openapi-codegen/sbt-plugin/src/main/scala/sttp/tapir/sbt/OpenapiCodegenTask.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ case class OpenapiCodegenTask(
1414
streamingImplementation: String,
1515
validateNonDiscriminatedOneOfs: Boolean,
1616
maxSchemasPerFile: Int,
17+
generateEndpointTypes: Boolean,
1718
dir: File,
1819
cacheDir: File,
1920
targetScala3: Boolean,
@@ -59,7 +60,8 @@ case class OpenapiCodegenTask(
5960
jsonSerdeLib,
6061
streamingImplementation,
6162
validateNonDiscriminatedOneOfs,
62-
maxSchemasPerFile
63+
maxSchemasPerFile,
64+
generateEndpointTypes
6365
)
6466
.map { case (objectName, fileBody) =>
6567
val file = directory / s"$objectName.scala"

openapi-codegen/sbt-plugin/src/sbt-test/sbt-openapi-codegen/oneOf-json-roundtrip/Expected.scala.txt

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ object TapirGeneratedEndpoints {
4444
support.mapDecode(l => DecodeResult.Value(ExplodedValues(l)))(_.values)
4545
}
4646

47-
4847
case class EnumExtraParamSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]) extends ExtraParamSupport[T] {
4948
// Case-insensitive mapping
5049
def decode(s: String): sttp.tapir.DecodeResult[T] =
@@ -63,9 +62,16 @@ object TapirGeneratedEndpoints {
6362
}
6463
def extraCodecSupport[T <: enumeratum.EnumEntry](enumName: String, T: enumeratum.Enum[T]): ExtraParamSupport[T] =
6564
EnumExtraParamSupport(enumName, T)
65+
sealed trait Error
6666
sealed trait ADTWithoutDiscriminator
6767
sealed trait ADTWithDiscriminator
6868
sealed trait ADTWithDiscriminatorNoMapping
69+
case class SimpleError (
70+
message: String
71+
) extends Error
72+
case class NotFoundError (
73+
reason: String
74+
) extends Error
6975
case class SubtypeWithoutD1 (
7076
s: String,
7177
i: Option[Int] = None,
@@ -119,34 +125,39 @@ object TapirGeneratedEndpoints {
119125

120126

121127

122-
lazy val getBinaryTest =
128+
type GetBinaryTestEndpoint = Endpoint[Unit, Unit, Unit, sttp.capabilities.pekko.PekkoStreams.BinaryStream, sttp.capabilities.pekko.PekkoStreams]
129+
lazy val getBinaryTest: GetBinaryTestEndpoint =
123130
endpoint
124131
.get
125132
.in(("binary" / "test"))
126133
.out(streamBody(sttp.capabilities.pekko.PekkoStreams)(Schema.binary[Array[Byte]], CodecFormat.OctetStream()).description("Response CSV body"))
127134

128-
lazy val postBinaryTest =
135+
type PostBinaryTestEndpoint = Endpoint[Unit, sttp.capabilities.pekko.PekkoStreams.BinaryStream, Unit, String, sttp.capabilities.pekko.PekkoStreams]
136+
lazy val postBinaryTest: PostBinaryTestEndpoint =
129137
endpoint
130138
.post
131139
.in(("binary" / "test"))
132140
.in(streamBody(sttp.capabilities.pekko.PekkoStreams)(Schema.binary[Array[Byte]], CodecFormat.OctetStream()))
133141
.out(jsonBody[String].description("successful operation"))
134142

135-
lazy val putAdtTest =
143+
type PutAdtTestEndpoint = Endpoint[Unit, ADTWithoutDiscriminator, Unit, ADTWithoutDiscriminator, Any]
144+
lazy val putAdtTest: PutAdtTestEndpoint =
136145
endpoint
137146
.put
138147
.in(("adt" / "test"))
139148
.in(jsonBody[ADTWithoutDiscriminator])
140149
.out(jsonBody[ADTWithoutDiscriminator].description("successful operation"))
141150

142-
lazy val postAdtTest =
151+
type PostAdtTestEndpoint = Endpoint[Unit, ADTWithDiscriminatorNoMapping, Unit, ADTWithDiscriminator, Any]
152+
lazy val postAdtTest: PostAdtTestEndpoint =
143153
endpoint
144154
.post
145155
.in(("adt" / "test"))
146156
.in(jsonBody[ADTWithDiscriminatorNoMapping])
147157
.out(jsonBody[ADTWithDiscriminator].description("successful operation"))
148158

149-
lazy val postInlineEnumTest =
159+
type PostInlineEnumTestEndpoint = Endpoint[Unit, (PostInlineEnumTestQueryEnum, Option[PostInlineEnumTestQueryOptEnum], List[PostInlineEnumTestQuerySeqEnum], Option[List[PostInlineEnumTestQueryOptSeqEnum]], ObjectWithInlineEnum), Unit, Unit, Any]
160+
lazy val postInlineEnumTest: PostInlineEnumTestEndpoint =
150161
endpoint
151162
.post
152163
.in(("inline" / "enum" / "test"))
@@ -197,7 +208,14 @@ object TapirGeneratedEndpoints {
197208
extraCodecSupport[PostInlineEnumTestQueryOptSeqEnum]("PostInlineEnumTestQueryOptSeqEnum", PostInlineEnumTestQueryOptSeqEnum)
198209
}
199210

211+
type GetOneofErrorTestEndpoint = Endpoint[Unit, Unit, Error, Unit, Any]
212+
lazy val getOneofErrorTest: GetOneofErrorTestEndpoint =
213+
endpoint
214+
.get
215+
.in(("oneof" / "error" / "test"))
216+
.errorOut(oneOf[Error](oneOfVariant(sttp.model.StatusCode(404), jsonBody[NotFoundError].description("Not found")), oneOfVariant(sttp.model.StatusCode(400), jsonBody[SimpleError].description("Not found"))))
217+
.out(statusCode(sttp.model.StatusCode(204)).description("No response"))
200218

201-
lazy val generatedEndpoints = List(getBinaryTest, postBinaryTest, putAdtTest, postAdtTest, postInlineEnumTest)
219+
lazy val generatedEndpoints = List(getBinaryTest, postBinaryTest, putAdtTest, postAdtTest, postInlineEnumTest, getOneofErrorTest)
202220

203221
}

0 commit comments

Comments
 (0)