Skip to content

Commit 9d7a615

Browse files
Redesign the library: Unified interpolator (#8)
Changes: - Integrate unified interpolator model into the main branch - Recover tests functionality
1 parent 2a198e7 commit 9d7a615

27 files changed

+698
-650
lines changed

project/metals.sbt

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
// format: off
22
// DO NOT EDIT! This file is auto-generated.
33

4-
// This plugin enables semantic information to be produced by sbt.
5-
// It also adds support for debugging using the Debug Adapter Protocol
4+
// This file enables sbt-bloop to create bloop config files.
65

7-
addSbtPlugin("org.scalameta" % "sbt-metals" % "1.6.2")
8-
9-
// This plugin adds the BSP debug capability to sbt server.
10-
11-
addSbtPlugin("ch.epfl.scala" % "sbt-debug-adapter" % "4.2.8")
6+
addSbtPlugin("ch.epfl.scala" % "sbt-bloop" % "2.0.13")
127

138
// format: on

src/main/scala/com/hivemind/llmsdsl/CanBeEncodedAsPrompt.scala

Lines changed: 0 additions & 26 deletions
This file was deleted.

src/main/scala/com/hivemind/llmsdsl/Prompt.scala

Lines changed: 13 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -4,60 +4,21 @@ import cats.*
44
import cats.implicits.*
55
import cats.effect.*
66

7-
sealed trait Prompt
8-
case object Empty extends Prompt
9-
case object NewLine extends Prompt
10-
case class Literal private[llmsdsl] (text: String, marginChar: Option[Char]) extends Prompt
11-
case class Value[A] private[llmsdsl] (value: A, validator: PromptValidator[A], encoder: PromptEncoder[A]) extends Prompt
12-
case class Combine private[llmsdsl] (left: Prompt, right: Prompt) extends Prompt
13-
14-
object Literal:
15-
def apply(text: String): Literal = Literal(text, None)
16-
def withMargin(text: String, char: Char = '|'): Literal = Literal(text, Some(char))
7+
final case class Prompt(text: String)
178

189
object Prompt:
19-
given monoid: Monoid[Prompt] with
20-
def empty: Prompt = Empty
21-
def combine(x: Prompt, y: Prompt): Prompt = Combine(x, y)
10+
val NewLine = Prompt("\n")
11+
val Empty = Prompt("")
2212

23-
extension (x: Prompt)
24-
// todo: consider rename to "renderUnsafe"
25-
// todo: consider adding a "def render:IO[String]" after rename which does validated and render in one go
26-
def render: String = x match {
27-
case Literal(text, marginChar) =>
28-
marginChar match {
29-
case None => text // Preserve original formatting
30-
case Some(char) => text.stripMargin(char) // Apply margin stripping at render time
31-
}
32-
case _: NewLine.type => "\n"
33-
case _: Empty.type => ""
34-
case Value(value, _, encoder) => encoder.encode(value).render
35-
case Combine(left, right) => s"${left.render}${right.render}"
36-
}
13+
extension (prompt: Prompt)
14+
def validate(using validator: Validate) = validator(prompt)
15+
def modify(f: String => String): Prompt = Prompt(f(prompt.text))
16+
def trim: Prompt = modify(_.trim)
17+
def stripMargin: Prompt = stripMargin('|')
18+
def stripMargin(marginChar: Char): Prompt = modify(_.stripMargin(marginChar))
3719

38-
def validate(using strValidator: PromptValidator[String]): IO[Boolean] = x match {
39-
case Literal(text, marginChar) =>
40-
val processedText = marginChar.fold(text)(text.stripMargin)
41-
strValidator.validate(processedText)
42-
case _: NewLine.type => IO.pure(true) // nothing to validate
43-
case _: Empty.type => IO.pure(true) // nothing to validate
44-
case Value(value, validator, encoder) => validator.validate(value) // validate the value
45-
case Combine(left, right) =>
46-
for {
47-
leftValid <- left.validate
48-
rightValid <- right.validate
49-
combined = Combine(left, right).render
50-
combinedValid <- strValidator.validate(combined)
51-
} yield leftValid && rightValid && combinedValid
52-
}
20+
extension (prompts: Seq[Prompt]) def mkPrompt: Prompt = prompts.intercalate(Empty)
5321

54-
// stripMargin extension methods for Prompt
55-
def stripMargin: Prompt = x match {
56-
case Literal(text, _) => Literal(text, Some('|'))
57-
case other => other // No-op for non-literals
58-
}
59-
60-
def stripMargin(marginChar: Char): Prompt = x match {
61-
case Literal(text, _) => Literal(text, Some(marginChar))
62-
case other => other // No-op for non-literals
63-
}
22+
given Monoid[Prompt] = new Monoid[Prompt]:
23+
def empty: Prompt = Empty
24+
def combine(l: Prompt, r: Prompt): Prompt = Prompt(l.text + r.text) // only time we use .text to combine, all other times we use |+|

src/main/scala/com/hivemind/llmsdsl/PromptEncoder.scala

Lines changed: 0 additions & 31 deletions
This file was deleted.

src/main/scala/com/hivemind/llmsdsl/PromptValidator.scala

Lines changed: 0 additions & 31 deletions
This file was deleted.

src/main/scala/com/hivemind/llmsdsl/Section.scala

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,10 @@ package com.hivemind.llmsdsl
33
case class Section(title: String, value: String)
44

55
object Section:
6-
def safeValidator(both: PromptValidator[String]): PromptValidator[Section] = safeValidator(both, both)
7-
8-
def safeValidator(titleValidator: PromptValidator[String], valueValidator: PromptValidator[String]): PromptValidator[Section] =
9-
import cats.syntax.all.*
10-
import cats.instances.all.*
11-
import PromptValidator.contravariantSemigroupal
12-
((titleValidator, valueValidator)).contramapN((s: Section) => (s.title, s.value))
13-
14-
def safeEncoder(
15-
combinator: (String, String) => String,
16-
)(using resultEncoder: PromptEncoder[String], resultValidator: PromptValidator[String]): PromptEncoder[Section] =
17-
PromptEncoder.from(section => resultEncoder.encode(combinator(section.title, section.value)))
18-
196
object xml:
20-
val combinator: (String, String) => String =
21-
(t, v) => s"<$t>$v</$t>"
22-
23-
def safe(using resultEncoder: PromptEncoder[String], resultValidator: PromptValidator[String]) =
24-
Section.safeEncoder(combinator)
25-
object unsafe:
26-
given sectionEncoder: PromptEncoder[Section] = safe(using PromptEncoder.unsafe.string, PromptValidator.accept)
27-
given sectionValidator: PromptValidator[Section] = safeValidator(PromptValidator.accept[String])
7+
given template: Template[Section] =
8+
section => prompt"<${section.title}>${section.value}</${section.title}>"
289

2910
object allcaps:
30-
val combinator: (String, String) => String =
31-
(t, v) => s"${t.toUpperCase()}: $v"
32-
33-
def safe(using resultEncoder: PromptEncoder[String], resultValidator: PromptValidator[String]) =
34-
Section.safeEncoder(combinator)
35-
36-
object unsafe:
37-
given sectionEncoder: PromptEncoder[Section] = safe(using PromptEncoder.unsafe.string, PromptValidator.accept)
38-
given sectionValidator: PromptValidator[Section] = PromptValidator.accept
11+
given template: Template[Section] =
12+
section => prompt"${section.title.toUpperCase()}: ${section.value}"
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package com.hivemind.llmsdsl
2+
3+
import cats.*
4+
import cats.implicits.*
5+
import cats.effect.*
6+
7+
trait Template[A]:
8+
def apply(arg: A): Prompt
9+
10+
object Template:
11+
inline def apply[A](using ev: Template[A]): Template[A] = ev
12+
13+
given promptTemplate: Template[Prompt] = x => Prompt(x.text)
14+
given stringTemplate: Template[String] = x => Prompt(x)
15+
given intTemplate: Template[Int] = x => Prompt(x.toString)
16+
given booleanTemplate: Template[Boolean] = x => Prompt(x.toString)
17+
18+
given optional[A](using template: Template[A]): Template[Option[A]] =
19+
(option: Option[A]) => option.map(template.apply).getOrElse(Prompt.Empty)
20+
21+
given list[A](using template: Template[A]): Template[List[A]] =
22+
(list: List[A]) => list.map(template.apply).mkPrompt
23+
24+
given ContravariantMonoidal[Template] with
25+
def unit: Template[Unit] = _ => Prompt.Empty
26+
27+
def contramap[A, B](fa: Template[A])(f: B => A): Template[B] =
28+
(b: B) => fa(f(b))
29+
30+
def product[A, B](fa: Template[A], fb: Template[B]): Template[(A, B)] =
31+
(a, b) => fa(a) |+| fb(b)
32+
33+
def from[A](f: A => String): Template[A] = (a: A) => Prompt(f(a))
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package com.hivemind.llmsdsl
2+
3+
import cats.*
4+
import cats.implicits.*
5+
import cats.effect.*
6+
7+
trait Validate:
8+
def apply(p: Prompt): IO[Boolean]
9+
10+
object Validate:
11+
def accept: Validate = prompt => IO.pure(true)
12+
def reject: Validate = prompt => IO.pure(false)
13+
def from(validator: String => IO[Boolean]): Validate = prompt => validator(prompt.text)
14+
def pure(validator: String => Boolean): Validate = prompt => IO.pure(validator(prompt.text))
15+
16+
given Monoid[Validate] = new Monoid[Validate]:
17+
def empty: Validate = _ => IO.pure(true)
18+
def combine(leftValidate: Validate, rightValidate: Validate): Validate = (prompt: Prompt) =>
19+
for {
20+
leftResult <- leftValidate(prompt)
21+
rightResult <- rightValidate(prompt)
22+
} yield leftResult && rightResult
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package com.hivemind.llmsdsl.interpolator
2+
3+
import scala.Tuple.*
4+
import scala.util.NotGiven
5+
import com.hivemind.llmsdsl.*
6+
import scala.annotation.implicitNotFound
7+
8+
@implicitNotFound(
9+
"No implicit instance of Template[${Arguments}] available.\n" +
10+
"Make sure to provide an implicit instance of Template for every type in ${Arguments}.\n" +
11+
"It will fail from left to right, starting looking for the first missing instance in the error message.",
12+
)
13+
// Filter out the values that are not templates from the hlist
14+
// (A0, Template[B0], A1, Template[B1], ...) => (Template[B0], Template[B1], ...)
15+
trait ArgumentReducer[Arguments, O]:
16+
def apply(args: Arguments, lits: Seq[Prompt]): (O, Seq[Prompt])
17+
18+
// uses + variants to filter and reduce the hlist
19+
object ArgumentReducer:
20+
given emptyTuple: ArgumentReducer[EmptyTuple, EmptyTuple] = (args, lits) => (EmptyTuple, lits)
21+
22+
given singleTemplate[A]: ArgumentReducer[Template[A], Template[A]] =
23+
(template, lits) => (lits.head + template, lits.tail)
24+
25+
given singleNonTemplate[A](using
26+
notATemplate: NotGiven[A <:< Template[?]],
27+
notATuple: NotGiven[A <:< Tuple],
28+
template: Template[A],
29+
): ArgumentReducer[A, EmptyTuple] =
30+
(args, lits) =>
31+
val rendered: Prompt = template(args)
32+
val newHead: Prompt = lits.head + rendered
33+
val lastLit: Prompt = newHead + lits.tail.mkPrompt
34+
(EmptyTuple, Seq(lastLit))
35+
36+
given recursiveTemplateHeadToSingle[H, T <: Tuple](using
37+
tailReducer: ArgumentReducer[T, EmptyTuple],
38+
): ArgumentReducer[Template[H] *: T, Template[H]] =
39+
(args, lits) =>
40+
val (_, tailLits) = tailReducer(args.tail, lits.tail)
41+
val template: Template[H] = lits.head + args.head
42+
(template, tailLits)
43+
44+
given recursiveTemplateHeadToTuple[H, T <: Tuple, TOut](using
45+
tailReducer: ArgumentReducer[T, Template[TOut]],
46+
notEmptyTuple: NotGiven[Template[TOut] =:= EmptyTuple],
47+
): ArgumentReducer[Template[H] *: T, (Template[H], Template[TOut])] =
48+
(args, lits) =>
49+
val (tailOut, tailLits) = tailReducer(args.tail, lits.tail)
50+
val newHead: Template[H] = lits.head + args.head
51+
((newHead, tailOut), tailLits)
52+
53+
given recursiveTemplateHeadPrependToTuple[H, T <: Tuple, TOut <: Tuple](using
54+
tailReducer: ArgumentReducer[T, TOut],
55+
notSingle: NotGiven[TOut <:< Template[?]],
56+
notEmptyTuple: NotGiven[TOut =:= EmptyTuple],
57+
): ArgumentReducer[Template[H] *: T, Template[H] *: TOut] =
58+
(args, lits) =>
59+
val (tailOut, tailLits) = tailReducer(args.tail, lits.tail)
60+
val template: Template[H] = lits.head + args.head
61+
(template *: tailOut, tailLits)
62+
63+
given recursiveNonTemplateHead[H, T <: Tuple, TOut](using
64+
tailReducer: ArgumentReducer[T, TOut],
65+
notTemplate: NotGiven[H <:< Template[?]],
66+
template: Template[H],
67+
): ArgumentReducer[H *: T, TOut] =
68+
(args, lits) =>
69+
val rendered: Prompt = template(args.head)
70+
val combinedWithNext: Prompt = lits.head + rendered + lits.tail.head
71+
val modifiedLits = combinedWithNext +: lits.tail.tail
72+
tailReducer(args.tail, modifiedLits)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package com.hivemind.llmsdsl.interpolator
2+
3+
case class MarginStripper(marginChar: Option[Char])
4+
object MarginStripper extends LowPriorityMarginStripper:
5+
val none = MarginStripper(Option.empty[Char])
6+
val pipe = MarginStripper(Some('|'))
7+
val hash = MarginStripper(Some('#'))
8+
object None extends LowPriorityMarginStripper
9+
object Pipe:
10+
given MarginStripper = pipe
11+
object Hash:
12+
given MarginStripper = hash
13+
14+
trait LowPriorityMarginStripper:
15+
given MarginStripper = MarginStripper.none

0 commit comments

Comments
 (0)