Skip to content

Commit e686389

Browse files
committed
Review
1 parent ef41edc commit e686389

File tree

9 files changed

+58
-123
lines changed

9 files changed

+58
-123
lines changed

src/main/scala/com/hivemind/llmsdsl/JsonSchema.scala

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,14 @@ enum JsonSchema derives CanEqual:
1515
case Null
1616

1717
object JsonSchema:
18-
import PromptValidator.unsafe.acceptPrompt
1918

2019
def apply[A](using derivation: SchemaEncoder[A]): JsonSchema = derivation.jsonSchema
2120

22-
given Conversion[JsonSchema, Prompt] with
23-
def apply(value: JsonSchema): Prompt = prompt"""${value.asJson.noSpaces.unsafeLiteral}"""
21+
given PromptEncoder[JsonSchema] = PromptEncoder.from(_.asJson.noSpaces.unsafeLiteral)
2422

25-
given promptEncoder: PromptEncoder[JsonSchema] =
26-
PromptEncoder.from(schema => prompt"""${schema.asJson.noSpaces.unsafeLiteral}""")
23+
given PromptValidator[JsonSchema] = PromptValidator.accept[JsonSchema]
2724

28-
given schemaValidator: PromptValidator[JsonSchema] =
29-
PromptValidator.from(schema => cats.effect.IO.pure(true))
30-
31-
given circeEncoder: CirceEncoder[JsonSchema] = CirceEncoder.instance {
25+
given CirceEncoder[JsonSchema] = CirceEncoder.instance {
3226
case Str => Json.fromString("string")
3327
case Num => Json.fromString("number")
3428
case Bool => Json.fromString("boolean")

src/main/scala/com/hivemind/llmsdsl/Prompt.scala

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,30 +20,23 @@ object Prompt:
2020
// todo: consider rename to "renderUnsafe"
2121
// todo: consider adding a "def render:IO[String]" after rename which does validated and render in one go
2222
def render: String = x match {
23-
case Combine(left, right) =>
24-
s"${left.render}${right.render}"
25-
case Literal(text) =>
26-
text
27-
case _: NewLine.type =>
28-
"\n"
29-
case _: Empty.type =>
30-
""
31-
case Value(value, _, encoder) =>
32-
encoder.encode(value).render
23+
case Literal(text) => text
24+
case _: NewLine.type => "\n"
25+
case _: Empty.type => ""
26+
case Value(value, _, encoder) => encoder.encode(value).render
27+
case Combine(left, right) => s"${left.render}${right.render}"
3328
}
3429

3530
def validate(using strValidator: PromptValidator[String]): IO[Boolean] = x match {
31+
case Literal(_) => IO.pure(true) // no validation needed for literals, we control creation of literals
32+
case _: NewLine.type => IO.pure(true) // nothing to validate
33+
case _: Empty.type => IO.pure(true) // nothing to validate
34+
case Value(value, validator, encoder) => validator.validate(value) // validate the value
3635
case Combine(left, right) =>
3736
for {
38-
l <- left.validate
39-
r <- right.validate
40-
valid <- strValidator.validate(Combine(left, right).render)
41-
} yield l && r && valid
42-
case Literal(_) =>
43-
IO.pure(true) // no validation needed for literals, we control creation of literals
44-
case _: NewLine.type =>
45-
IO.pure(true) // nothing to validate
46-
case _: Empty.type =>
47-
IO.pure(true) // nothing to validate
48-
case Value(value, validator, encoder) => validator.validate(value) // validate the value
37+
leftValid <- left.validate
38+
rightValid <- right.validate
39+
combined = Combine(left, right).render
40+
combinedValid <- strValidator.validate(combined)
41+
} yield leftValid && rightValid && combinedValid
4942
}

src/main/scala/com/hivemind/llmsdsl/PromptEncoder.scala

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,27 +2,22 @@ package com.hivemind.llmsdsl
22

33
import scala.annotation.implicitNotFound
44

5-
@implicitNotFound("No PromptEncoder available for type ${A}. Please provide an encoder or import PromptEncoder.unsafe._")
5+
@implicitNotFound("No PromptEncoder available for type ${A}. Please provide an encoder or import PromptEncoder.unsafe.*")
66
opaque type PromptEncoder[A] = A => Prompt
77

88
object PromptEncoder:
9-
def apply[A](using encoder: PromptEncoder[A]): PromptEncoder[A] = encoder
10-
11-
// Extension method to encode values
129
extension [A](encoder: PromptEncoder[A]) def encode(a: A): Prompt = encoder(a)
1310

14-
// Smart constructor for creating encoders
1511
inline def from[A](f: A => Prompt): PromptEncoder[A] = f
1612

1713
given PromptEncoder[Prompt] = identity
1814

19-
given contravariantSemigroupal: cats.ContravariantSemigroupal[PromptEncoder] with
15+
given cats.ContravariantSemigroupal[PromptEncoder] with
2016
def contramap[A, B](fa: PromptEncoder[A])(f: B => A): PromptEncoder[B] =
2117
(b: B) => fa(f(b))
2218

2319
def product[A, B](fa: PromptEncoder[A], fb: PromptEncoder[B]): PromptEncoder[(A, B)] =
24-
(pair: (A, B)) =>
25-
val (a, b) = pair
20+
case (a, b) =>
2621
import cats.syntax.semigroup.*
2722
fa(a) |+| fb(b)
2823

@@ -34,28 +29,3 @@ object PromptEncoder:
3429

3530
object unsafe:
3631
given string: PromptEncoder[String] = from(x => Literal(x))
37-
38-
// todo: decide which route to prefer...
39-
// call unsafeLiteral or use the import and toPrompt?
40-
// val option1 = prompt"${section.title.toUpperCase().unsafeLiteral}: ${section.value.unsafeLiteral}"
41-
42-
// val option2 = {
43-
// import PromptEncoder.unsafe.string
44-
// prompt"${section.title.toUpperCase().toPrompt(PromptValidator.accept)}: ${section.value.toPrompt(PromptValidator.accept)}"
45-
// }
46-
47-
// val option2b = {
48-
// import PromptEncoder.unsafe.string
49-
// given unsafeStringValidator: PromptValidator[String] = PromptValidator.accept // todo: provide a importable validator for this
50-
51-
// prompt"${section.title.toUpperCase()}: ${section.value}"
52-
// }
53-
54-
// val option3 = {
55-
// import cats.syntax.all.*
56-
// given option3Encoder: PromptEncoder[Section] = (unsafe.string, unsafe.string).contramapN(s => (s.title, s.value))
57-
// given option3Validator: PromptValidator[Section] = PromptValidator.accept // similar think can be done with PromptValidator
58-
// prompt"""$section""" // implicitly converts section to prompt, reason for the import of implicit conversions
59-
// }
60-
61-
// option2

src/main/scala/com/hivemind/llmsdsl/PromptValidator.scala

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,17 @@ package com.hivemind.llmsdsl
33
import cats.effect.IO
44
import scala.annotation.implicitNotFound
55

6-
@implicitNotFound("No PromptValidator available for type ${A}. Please provide a validator or import PromptValidator.unsafe._")
6+
@implicitNotFound("No PromptValidator available for type ${A}. Please provide a validator or import PromptValidator.unsafe.*")
77
opaque type PromptValidator[A] = A => IO[Boolean]
88

99
object PromptValidator:
10-
def apply[A](using validator: PromptValidator[A]): PromptValidator[A] = validator
11-
12-
// Extension method to validate values
1310
extension [A](validator: PromptValidator[A]) def validate(a: A): IO[Boolean] = validator(a)
1411

15-
// Smart constructor for creating validators
1612
inline def from[A](f: A => IO[Boolean]): PromptValidator[A] = f
13+
inline def pure[A](f: A => Boolean): PromptValidator[A] = from((a: A) => IO.pure(f(a)))
1714

18-
def accept[A]: PromptValidator[A] = from((a: A) => IO.pure(true))
19-
def reject[A]: PromptValidator[A] = from((a: A) => IO.pure(false))
15+
def accept[A]: PromptValidator[A] = pure(_ => true)
16+
def reject[A]: PromptValidator[A] = pure(_ => false)
2017

2118
object unsafe:
2219
given acceptString: PromptValidator[String] = accept[String]
@@ -27,8 +24,7 @@ object PromptValidator:
2724
(b: B) => fa(f(b))
2825

2926
def product[A, B](fa: PromptValidator[A], fb: PromptValidator[B]): PromptValidator[(A, B)] =
30-
(pair: (A, B)) =>
31-
val (a, b) = pair
27+
case (a, b) =>
3228
for {
3329
resultA <- fa(a)
3430
resultB <- fb(b)

src/main/scala/com/hivemind/llmsdsl/Section.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ object Section:
1414
def safeEncoder(
1515
combinator: (String, String) => String,
1616
)(using resultEncoder: PromptEncoder[String], resultValidator: PromptValidator[String]): PromptEncoder[Section] =
17-
PromptEncoder.from((section: Section) => resultEncoder.encode(combinator(section.title, section.value)))
17+
PromptEncoder.from(section => resultEncoder.encode(combinator(section.title, section.value)))
1818

1919
object xml:
2020
val combinator: (String, String) => String =

src/test/scala/com/hivemind/llmsdsl/OpaqueEncoderErgonomicsSpec.scala

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@ package com.hivemind.llmsdsl
22

33
import org.scalatest.flatspec.AnyFlatSpec
44
import org.scalatest.matchers.should.Matchers
5+
import scala.compiletime.testing.typeCheckErrors
56

6-
/** Test file to investigate ergonomics of opaque function-based PromptEncoder vs the original trait-based approach.
7-
*/
87
class OpaqueEncoderErgonomicsSpec extends AnyFlatSpec with Matchers {
98

109
case class TestType(value: String)
@@ -64,26 +63,6 @@ class OpaqueEncoderErgonomicsSpec extends AnyFlatSpec with Matchers {
6463
pairEncoder.encode(pair).render should include("42")
6564
}
6665

67-
"Error message quality" should "be tested manually by uncommenting the following" in {
68-
// Uncomment these lines one at a time to test error messages:
69-
70-
// Test 1: Missing encoder
71-
// val testNoEncoder = TestType("test")
72-
// val prompt1 = testNoEncoder.toPrompt(PromptValidator.accept)
73-
74-
// Test 2: Missing validator
75-
// given testEncoder: PromptEncoder[TestType] = PromptEncoder.from(_.value)
76-
// val testNoValidator = TestType("test")
77-
// val prompt2 = prompt"${testNoValidator}"
78-
79-
// Test 3: Complex type without encoder
80-
// case class ComplexType(a: String, b: Int, c: List[Double])
81-
// val complex = ComplexType("test", 1, List(1.0, 2.0))
82-
// val prompt3 = complex.toPrompt(PromptValidator.accept)
83-
84-
succeed // This test always passes - it's for manual error checking
85-
}
86-
8766
"Performance characteristics" should "demonstrate zero-cost abstraction" in {
8867
// This test verifies that the opaque type compiles to direct function calls
8968
given testEncoder: PromptEncoder[TestType] =
@@ -103,12 +82,12 @@ class OpaqueEncoderErgonomicsSpec extends AnyFlatSpec with Matchers {
10382
}
10483

10584
"@implicitNotFound annotation" should "work with opaque types" in {
106-
// This is a meta-test to verify our annotation works
107-
// If the annotation is working, uncommenting the next line should show our custom error message:
108-
109-
// val test = TestType("annotation-test")
110-
// val prompt = test.toPrompt(PromptValidator.accept)
111-
112-
succeed // Test passes when commented - uncomment to test error message
85+
val errors = typeCheckErrors("""
86+
val test = TestType("annotation-test")
87+
val prompt = test.toPrompt(PromptValidator.accept)
88+
""")
89+
errors should not be empty
90+
errors.head.message should include("No PromptEncoder available for type OpaqueEncoderErgonomicsSpec.this.TestType")
91+
errors.head.message should include("Please provide an encoder or import PromptEncoder.unsafe.*")
11392
}
11493
}

src/test/scala/com/hivemind/llmsdsl/OpaqueValidatorErgonomicsSpec.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ import org.scalatest.matchers.should.Matchers
55
import cats.effect.IO
66
import cats.effect.unsafe.implicits.global
77

8-
/** Test file to investigate ergonomics of opaque function-based PromptValidator vs the original trait-based approach.
9-
*/
108
class OpaqueValidatorErgonomicsSpec extends AnyFlatSpec with Matchers {
119

1210
case class TestType(value: String)
@@ -24,7 +22,7 @@ class OpaqueValidatorErgonomicsSpec extends AnyFlatSpec with Matchers {
2422

2523
it should "work with explicit validator creation" in {
2624
given testValidator: PromptValidator[TestType] =
27-
PromptValidator.from(test => IO.pure(test.value.nonEmpty))
25+
PromptValidator.pure(_.value.nonEmpty)
2826

2927
val test = TestType("hello")
3028
val isValid = testValidator.validate(test).unsafeRunSync()
@@ -40,7 +38,7 @@ class OpaqueValidatorErgonomicsSpec extends AnyFlatSpec with Matchers {
4038
import PromptValidator.unsafe.acceptString
4139

4240
given testValidator: PromptValidator[TestType] =
43-
PromptValidator.from(test => IO.pure(test.value.length > 3))
41+
PromptValidator.pure(_.value.length > 3)
4442
given testEncoder: PromptEncoder[TestType] = PromptEncoder.unsafe.string.contramap(_.value)
4543

4644
val test = TestType("hello")
@@ -52,7 +50,7 @@ class OpaqueValidatorErgonomicsSpec extends AnyFlatSpec with Matchers {
5250
import cats.syntax.contravariant.*
5351

5452
given stringValidator: PromptValidator[String] =
55-
PromptValidator.from(s => IO.pure(s.nonEmpty))
53+
PromptValidator.pure(_.nonEmpty)
5654

5755
val testValidator: PromptValidator[TestType] =
5856
stringValidator.contramap(_.value)
@@ -69,9 +67,10 @@ class OpaqueValidatorErgonomicsSpec extends AnyFlatSpec with Matchers {
6967
import cats.implicits.*
7068

7169
given stringValidator: PromptValidator[String] =
72-
PromptValidator.from(s => IO.pure(s.nonEmpty))
73-
given intValidator: PromptValidator[Int] =
74-
PromptValidator.from(i => IO.pure(i > 0))
70+
PromptValidator.pure(_.nonEmpty)
71+
72+
given intValidator: PromptValidator[Int] =
73+
PromptValidator.pure(_ > 0)
7574

7675
val pairValidator: PromptValidator[(String, Int)] =
7776
(stringValidator, intValidator).tupled
@@ -118,7 +117,7 @@ class OpaqueValidatorErgonomicsSpec extends AnyFlatSpec with Matchers {
118117
"Performance characteristics" should "demonstrate zero-cost abstraction" in {
119118
// This test verifies that the opaque type compiles to direct function calls
120119
given testValidator: PromptValidator[TestType] =
121-
PromptValidator.from(test => IO.pure(test.value.nonEmpty))
120+
PromptValidator.pure(_.value.nonEmpty)
122121

123122
val test = TestType("performance")
124123

src/test/scala/com/hivemind/llmsdsl/PromptValidationSpec.scala

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,18 @@ class PromptValidationSpec extends AnyWordSpec with Matchers {
1212
case class Person(name: String, age: Int)
1313
object Person {
1414
given validator: PromptValidator[Person] =
15-
PromptValidator.from((person: Person) => IO.apply(!person.name.trim().isEmpty() && person.age >= 18))
16-
17-
import PromptValidator.unsafe.acceptPrompt
18-
import PromptValidator.unsafe.acceptString
15+
PromptValidator.pure(person => !person.name.trim().isEmpty() && person.age >= 18)
1916

2017
given promptEncoder: PromptEncoder[Person] =
21-
PromptEncoder.from((person: Person) => prompt"""Person name: ${person.name.unsafeLiteral} and age: ${person.age.toString.unsafeLiteral}""")
18+
given stringValidator: PromptValidator[String] = PromptValidator.unsafe.acceptString
19+
given intValidator: PromptValidator[Int] = PromptValidator.pure(_ => true)
20+
given stringEncoder: PromptEncoder[String] = PromptEncoder.unsafe.string
21+
given intEncoder: PromptEncoder[Int] = PromptEncoder.from(_.toString.unsafeLiteral)
22+
// TODO: rethink this
23+
// why do we need validators to create encoders?
24+
// we should capture the input arg types and expect the validators when we render the prompt later
25+
// feels a lot like ZIO like abstraction where we capture the input requirements and provide them later
26+
PromptEncoder.from(person => prompt"""Person name: ${person.name} and age: ${person.age}""")
2227
}
2328

2429
lazy val grownUpPerson: Person = Person("John", 30) // should pass
@@ -28,7 +33,7 @@ class PromptValidationSpec extends AnyWordSpec with Matchers {
2833
"Prompt.validate method" should {
2934
"pass when it is a constant (literal) string" in {
3035
given stringValidator: PromptValidator[String] =
31-
PromptValidator.from((str: String) => IO.apply(str.length > 0))
36+
PromptValidator.pure(_.length > 0)
3237

3338
val p: Prompt = prompt"""You are a helpful assistant."""
3439

@@ -37,7 +42,7 @@ class PromptValidationSpec extends AnyWordSpec with Matchers {
3742

3843
"fail when the string validator fails with a constant string" in {
3944
given stringValidator: PromptValidator[String] =
40-
PromptValidator.from((str: String) => IO.apply(!str.toLowerCase().contains("bomb")))
45+
PromptValidator.pure(!_.toLowerCase().contains("bomb"))
4146

4247
val p: Prompt = prompt"""You are a bomb assistant."""
4348

@@ -46,7 +51,7 @@ class PromptValidationSpec extends AnyWordSpec with Matchers {
4651

4752
"fail when the string validator fails when the parts are ok but not the sum of them" in {
4853
given stringValidator: PromptValidator[String] =
49-
PromptValidator.from((str: String) => IO.apply(!str.toLowerCase().contains("bomb")))
54+
PromptValidator.pure(!_.toLowerCase().contains("bomb"))
5055

5156
val p: Prompt = prompt"""|
5257
|You are an expert engineer. I have all the materials for the task in question.
@@ -99,7 +104,7 @@ class PromptValidationSpec extends AnyWordSpec with Matchers {
99104
import PromptValidator.unsafe.acceptPrompt
100105

101106
given promptEncoder: PromptEncoder[WithoutValidator] =
102-
PromptEncoder.from((withoutValidator: WithoutValidator) => prompt"""This is a string value: ${withoutValidator.value.unsafeLiteral}""")
107+
PromptEncoder.from(withoutValidator => prompt"""This is a string value: ${withoutValidator.value.unsafeLiteral}""")
103108
}
104109

105110
val withoutValidatorInstance: WithoutValidator = WithoutValidator("John")
@@ -110,8 +115,7 @@ class PromptValidationSpec extends AnyWordSpec with Matchers {
110115
"fail compilation if the type has no prompt encoder" in {
111116
case class WithoutPromptEncoder(value: String)
112117
object WithoutPromptEncoder {
113-
given validator: PromptValidator[WithoutPromptEncoder] =
114-
PromptValidator.from((withoutPromptEncoder: WithoutPromptEncoder) => IO.apply(true)) // dummy validator
118+
given validator: PromptValidator[WithoutPromptEncoder] = PromptValidator.accept[WithoutPromptEncoder]
115119
}
116120

117121
val withoutPromptEncoderInstance: WithoutPromptEncoder = WithoutPromptEncoder("John")

src/test/scala/com/hivemind/llmsdsl/SectionFormatterSpec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import org.scalatest.matchers.should.Matchers
99
class SectionFormatterSpec extends AnyWordSpec with Matchers {
1010
val inputFormat = Section("Input", "The input will be a JSON object")
1111
val outputFormat = Section("Output", "The output will be a JSON object")
12-
given accept: PromptValidator[Section] = PromptValidator.from((s: Section) => IO.pure(s.title.nonEmpty && s.value.nonEmpty))
12+
given accept: PromptValidator[Section] = PromptValidator.pure(s => s.title.nonEmpty && s.value.nonEmpty)
1313

1414
"Section formatter" should {
1515
"format a section as XML" in {

0 commit comments

Comments
 (0)