Skip to content

Commit bddd2f7

Browse files
committed
begin using the standard vector
1 parent d601424 commit bddd2f7

File tree

10 files changed

+188
-96
lines changed

10 files changed

+188
-96
lines changed

lake-manifest.json

Lines changed: 57 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,95 @@
11
{"version": "1.1.0",
22
"packagesDir": ".lake/packages",
33
"packages":
4-
[{"url": "https://github.com/leanprover-community/batteries",
4+
[{"url": "https://github.com/leanprover-community/aesop.git",
55
"type": "git",
66
"subDir": null,
7-
"scope": "leanprover-community",
8-
"rev": "e7897807913fafdab31b01b9f627550bcc96cff2",
9-
"name": "batteries",
7+
"scope": "",
8+
"rev": "56a2c80b209c253e0281ac4562a92122b457dcc0",
9+
"name": "aesop",
1010
"manifestFile": "lake-manifest.json",
11-
"inputRev": "main",
12-
"inherited": true,
13-
"configFile": "lakefile.lean"},
14-
{"url": "https://github.com/leanprover-community/quote4",
11+
"inputRev": "v4.17.0",
12+
"inherited": false,
13+
"configFile": "lakefile.toml"},
14+
{"url": "https://github.com/leanprover-community/mathlib4.git",
1515
"type": "git",
1616
"subDir": null,
17-
"scope": "leanprover-community",
18-
"rev": "01ad33937acd996ee99eb74eefb39845e4e4b9f5",
19-
"name": "Qq",
17+
"scope": "",
18+
"rev": "5269898d6a51d047931107c8d72d934d8d5d3753",
19+
"name": "mathlib",
2020
"manifestFile": "lake-manifest.json",
21-
"inputRev": "master",
22-
"inherited": true,
21+
"inputRev": "v4.17.0",
22+
"inherited": false,
2323
"configFile": "lakefile.lean"},
24-
{"url": "https://github.com/leanprover-community/aesop.git",
24+
{"url": "https://github.com/leanprover-community/batteries",
2525
"type": "git",
2626
"subDir": null,
2727
"scope": "",
28-
"rev": "79fb157c6a5061190d169535f8e5cb007914a82e",
29-
"name": "aesop",
28+
"rev": "efcc7d9bd9936ecdc625baf0d033b60866565cd5",
29+
"name": "batteries",
3030
"manifestFile": "lake-manifest.json",
31-
"inputRev": null,
32-
"inherited": false,
31+
"inputRev": "v4.17.0",
32+
"inherited": true,
3333
"configFile": "lakefile.toml"},
34-
{"url": "https://github.com/leanprover-community/ProofWidgets4",
34+
{"url": "https://github.com/leanprover-community/plausible",
3535
"type": "git",
3636
"subDir": null,
3737
"scope": "leanprover-community",
38-
"rev": "c87908619cccadda23f71262e6898b9893bffa36",
39-
"name": "proofwidgets",
38+
"rev": "c708be04267e3e995a14ac0d08b1530579c1525a",
39+
"name": "plausible",
4040
"manifestFile": "lake-manifest.json",
41-
"inputRev": "v0.0.40",
41+
"inputRev": "main",
4242
"inherited": true,
43-
"configFile": "lakefile.lean"},
44-
{"url": "https://github.com/leanprover/lean4-cli",
43+
"configFile": "lakefile.toml"},
44+
{"url": "https://github.com/leanprover-community/LeanSearchClient",
4545
"type": "git",
4646
"subDir": null,
47-
"scope": "",
48-
"rev": "a11566029bd9ec4f68a65394e8c3ff1af74c1a29",
49-
"name": "Cli",
47+
"scope": "leanprover-community",
48+
"rev": "0c169a0d55fef3763cfb3099eafd7b884ec7e41d",
49+
"name": "LeanSearchClient",
5050
"manifestFile": "lake-manifest.json",
5151
"inputRev": "main",
5252
"inherited": true,
53-
"configFile": "lakefile.lean"},
53+
"configFile": "lakefile.toml"},
5454
{"url": "https://github.com/leanprover-community/import-graph",
5555
"type": "git",
5656
"subDir": null,
5757
"scope": "leanprover-community",
58-
"rev": "68b518c9b352fbee16e6d632adcb7a6d0760e2b7",
58+
"rev": "0447b0a7b7f41f0a1749010db3f222e4a96f9d30",
5959
"name": "importGraph",
6060
"manifestFile": "lake-manifest.json",
6161
"inputRev": "main",
6262
"inherited": true,
6363
"configFile": "lakefile.toml"},
64-
{"url": "https://github.com/leanprover-community/mathlib4.git",
64+
{"url": "https://github.com/leanprover-community/ProofWidgets4",
6565
"type": "git",
6666
"subDir": null,
67-
"scope": "",
68-
"rev": "85db8c7fd2bcb8d447952bf124670d70e3815d10",
69-
"name": "mathlib",
67+
"scope": "leanprover-community",
68+
"rev": "799f6986de9f61b784ff7be8f6a8b101045b8ffd",
69+
"name": "proofwidgets",
7070
"manifestFile": "lake-manifest.json",
71-
"inputRev": null,
72-
"inherited": false,
73-
"configFile": "lakefile.lean"}],
74-
"name": "llm.lean",
71+
"inputRev": "v0.0.52",
72+
"inherited": true,
73+
"configFile": "lakefile.lean"},
74+
{"url": "https://github.com/leanprover-community/quote4",
75+
"type": "git",
76+
"subDir": null,
77+
"scope": "leanprover-community",
78+
"rev": "95561f7a5811fae6a309e4a1bbe22a0a4a98bf03",
79+
"name": "Qq",
80+
"manifestFile": "lake-manifest.json",
81+
"inputRev": "master",
82+
"inherited": true,
83+
"configFile": "lakefile.toml"},
84+
{"url": "https://github.com/leanprover/lean4-cli",
85+
"type": "git",
86+
"subDir": null,
87+
"scope": "leanprover",
88+
"rev": "e7fd1a415c80985ade02a021172834ca2139b0ca",
89+
"name": "Cli",
90+
"manifestFile": "lake-manifest.json",
91+
"inputRev": "main",
92+
"inherited": true,
93+
"configFile": "lakefile.toml"}],
94+
"name": "LLM.lean",
7595
"lakeDir": ".lake"}
File renamed without changes.

lakefile.toml

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
name = "LLM.lean"
2+
defaultTargets = ["Main"]
3+
4+
[[require]]
5+
name = "mathlib"
6+
git = "https://github.com/leanprover-community/mathlib4.git"
7+
rev = "v4.17.0"
8+
9+
[[require]]
10+
name = "aesop"
11+
git = "https://github.com/leanprover-community/aesop.git"
12+
rev = "v4.17.0"
13+
14+
# [[lean_lib]]
15+
# name = "LinearAlgebra"
16+
# srcDir = "lean"
17+
18+
[[lean_lib]]
19+
name = "Llm"
20+
srcDir = "lean"
21+
22+
[[lean_exe]]
23+
name = "Main"
24+
srcDir = "lean"
25+
supportInterpeter = true

lean-toolchain

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
leanprover/lean4:v4.10.0-rc2
1+
leanprover/lean4:v4.17.0

lean/LinearAlgebra/Vector.lean

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import Mathlib.Algebra.Group.ZeroOne
21
import Mathlib.Tactic.Ring
32

43
/-- The base array type.-/

lean/Llm/Attention.lean

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import LinearAlgebra.Vector
21
import Llm.Matmul
2+
import Llm.FloatTensor
33
import Llm.Softmax
44

5-
def tril [Zero α] (fillValue: α) : Vector C (Vector R α) :=
5+
def tril [Zero α] (fillValue: α) : Vector (Vector α R) C :=
66
Vector.ofFn (fun c =>
77
Vector.ofFn (fun r =>
88
-- no ≤ to avoid diagonal. this should take an axis argument.
@@ -11,9 +11,9 @@ def tril [Zero α] (fillValue: α) : Vector C (Vector R α) :=
1111
)
1212

1313
def attention_forward
14-
(q k v : Vector T (Vector Dₖ Float))
15-
: Vector T (Vector Dₖ Float) :=
16-
let a := q * k.transpose
14+
(q k v : Vector (Vector Float Dₖ) T)
15+
: Vector (Vector Float Dₖ) T :=
16+
let a := q * (transpose k)
1717
let norm_factor := (Float.ofNat Dₖ).sqrt
1818
let a1 := a.map (λ x => x.map (λ y => y / norm_factor))
1919
let a2 := a1 + tril (-Float.inf)
@@ -22,9 +22,9 @@ def attention_forward
2222
a3 * v
2323

2424
def attention_backwards
25-
(dout q k v: Vector T (Vector Dₖ Float))
25+
(dout q k v: Vector (Vector Float Dₖ) T)
2626
-- dq, dk, dv
27-
: (Vector T (Vector Dₖ Float)) × (Vector T (Vector Dₖ Float) ) × (Vector T (Vector Dₖ Float) ) :=
27+
: (Vector (Vector Float Dₖ) T) × (Vector (Vector Float Dₖ) T) × (Vector (Vector Float Dₖ) T) :=
2828
let a := q * k.transpose
2929
let norm_factor := 1 / (Float.ofNat Dₖ).sqrt
3030
let a1 := a.map (λ x => x.map (λ y => y * norm_factor))

lean/Llm/FiniteDiff.lean

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,33 @@
1-
import LinearAlgebra.Vector
1+
import Llm.Matmul
2+
import Llm.FloatTensor
3+
4+
set_option diagnostics true
25

36
/-- centered finite difference approximation of the derivative of a function -/
4-
def finiteDiff (f : Vector n Float → Vector m Float) (x : Vector n Float) (ε := 1e-6) : Vector m Float :=
7+
def finiteDiff (f : Vector Float n → Vector Float m) (x : Vector Float n) (ε := 1e-6) : Vector Float m :=
58
let dx := ε * x
69

7-
(f (x + dx) - f (x - dx)) / (2*dx.norm)
10+
(f (x + dx) - f (x - dx)) / (2*norm dx)
811

912
/-- Coerce a scalar to a vector of length 1 -/
10-
instance : Coe a (Vector 1 a) where
11-
coe a := !v[a]
13+
instance : Coe a (Vector a 1) where
14+
coe a := #v[a]
1215

13-
#eval ((2.0:Float) : Vector 1 Float)
14-
#eval finiteDiff (f:=fun x => (x.dot x : Vector 1 Float)) (x:= !v[1,2,3])
16+
#eval ((2.0:Float) : Vector Float 1)
17+
#eval finiteDiff (f:=fun x => (Vector.singleton (dot x x) : Vector Float 1)) (x:= #v[1,2,3])
1518

1619
-- Test case for x^2
17-
def square (x: Vector n Float) : Vector n Float := x.hadamard x
18-
#eval square (Vector.replicate 5 2.0)
20+
def square (x: Vector Float n) : Vector Float n := x * x
21+
22+
#eval square (Vector.mkVector 5 2.0)
23+
1924

2025
def test_finiteDiff_square (n : Nat) : Bool := Id.run do
21-
let x := Vector.replicate n 2.0 -- Vector of 2.0s
26+
let x := Vector.mkVector n 2.0 -- Vector of 2.0s
2227
let df := finiteDiff square x
23-
let expected := Vector.replicate n 4.0 -- Derivative of x^2 is 2x, so at x=2, it's 4
24-
dbg_trace df
25-
dbg_trace expected
28+
let expected := Vector.mkVector n 4.0 -- Derivative of x^2 is 2x, so at x=2, it's 4
29+
-- dbg_trace df
30+
-- dbg_trace expected
2631
-- Check if the finite difference approximation is close to the expected value
2732
let tolerance := 1e-4
2833
let isClose := df.zipWith (λ a b => (Float.abs (a - b) < tolerance : Bool)) expected
@@ -36,4 +41,4 @@ def run_test_finiteDiff_square (n : Nat:=1) : IO Unit := do
3641
else
3742
IO.println "Test failed: finite difference of x^2 at x=2 is not within tolerance"
3843

39-
#eval run_test_finiteDiff_square 2
44+
#eval run_test_finiteDiff_square 1

lean/Llm/FloatTensor.lean

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
instance [Zero α] : Zero (Vector α n) := ⟨Vector.mkVector n 0
3+
4+
instance [Add α] : Add (Vector α n) := ⟨Vector.zipWith (· + ·)⟩
5+
instance [Sub α] : Sub (Vector α n) := ⟨Vector.zipWith (· - ·)⟩
6+
instance [Mul α] : Mul (Vector α n) := ⟨Vector.zipWith (· * ·)⟩
7+
instance [Div α] : Div (Vector α n) := ⟨Vector.zipWith (· / ·)⟩
8+
9+
-- scalar addition
10+
instance [Add α] : HAdd (Vector α n) α (Vector α n) := ⟨fun v a => v.map (· + a)⟩
11+
-- scalar subtraction
12+
instance [Sub α] : HSub (Vector α n) α (Vector α n) := ⟨fun v a => v.map (· - a)⟩
13+
-- scalar multiplication
14+
instance [Mul α] : HMul α (Vector α n) (Vector α n) := ⟨fun a v => v.map (· * a)⟩
15+
-- scalar division
16+
instance [Div α] : HDiv (Vector α n) α (Vector α n) := ⟨fun v a => v.map (· / a)⟩
17+
18+
19+
def transpose (a: Vector (Vector α N) M): Vector (Vector α M) N :=
20+
Vector.ofFn (fun i =>
21+
Vector.ofFn (fun j =>
22+
a[j][i]
23+
)
24+
)
25+
26+
def norm (a: Vector Float n) : Float :=
27+
(a * a).sum.sqrt
28+
29+
def normalize (a: Vector Float n) : Vector Float n :=
30+
a / norm a

lean/Llm/Matmul.lean

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,66 @@
1-
import LinearAlgebra.Vector
2-
31

42
-- #eval Vector.matmul !v[!v[1,2,3],!v[4,5,6]] !v[!v[7,8],!v[9,10],!v[11,12]]
53

6-
#check
7-
let a: Vector 2 (Vector 3 Float) := sorry
8-
let b: Vector 3 (Vector 2 Float) := sorry
9-
let c := Vector.matmul a b
10-
c
4+
import Llm.FloatTensor
115

6+
def dot [Add α] [Mul α] [Zero α]
7+
(a: Vector α N)
8+
(b: Vector α N)
9+
: α :=
10+
(a.zipWith (· * ·) b).sum
1211

12+
def matmul [Add α] [Mul α] [Zero α]
13+
(a: Vector (Vector α P) M)
14+
(b: Vector (Vector α N) P)
15+
: Vector (Vector α N) M :=
16+
let b_t := transpose b
17+
Vector.ofFn (fun i =>
18+
Vector.ofFn (fun j => dot a[i] b_t[j])
19+
)
20+
21+
#check
22+
let a: Vector (Vector Float 3) 2 := sorry
23+
let b: Vector (Vector Float 2) 3 := sorry
24+
let c := matmul a b
25+
c
1326

1427

1528
def matmul_batched [Add α] [Mul α] [Zero α]
16-
(a: Vector B (Vector M (Vector P α)))
17-
(b: Vector B (Vector P (Vector N α)))
18-
: Vector B (Vector M (Vector N α )) :=
19-
.zipWith (· * ·) a b
29+
(a: Vector (Vector (Vector α P) M) B)
30+
(b: Vector (Vector (Vector α N) P) B)
31+
: Vector (Vector (Vector α N) M) B :=
32+
.zipWith matmul a b
2033

2134
/--
2235
unbatched backward.
2336
returns dinp, dweight
2437
-/
2538
def matmul_backward
26-
(inp: Vector P (Vector N Float))
27-
(weight: Vector M (Vector P Float))
28-
(dout: Vector M (Vector N Float))
29-
: (Vector P (Vector N Float )) × (Vector M (Vector P Float ))
39+
(inp: Vector (Vector Float N) P)
40+
(weight: Vector (Vector Float P) M)
41+
(dout: Vector (Vector Float N) M)
42+
: (Vector (Vector Float N) P) × (Vector (Vector Float P) M)
3043
:=
31-
32-
let dinp := weight.transpose * dout
33-
let dweight := dout * inp.transpose
44+
let dinp := matmul (transpose weight) dout
45+
let dweight := matmul dout (transpose inp)
3446

3547
(dinp, dweight)
3648

3749
/--
3850
We reduce the weight but not the input.
3951
-/
4052
def matmul_backward_batched
41-
(inp: Vector B (Vector P (Vector N Float)))
42-
(weight: Vector B (Vector M (Vector P Float)))
43-
(dout: Vector B (Vector M (Vector N Float)))
44-
: (Vector B (Vector P (Vector N Float))) × (Vector M (Vector P Float))
53+
(inp: Vector (Vector (Vector Float N) P) B)
54+
(weight: Vector (Vector (Vector Float P) M) B)
55+
(dout: Vector (Vector (Vector Float N) M) B)
56+
: (Vector (Vector (Vector Float N) P) B) × (Vector (Vector Float P) M)
4557
:=
46-
let inp_t := inp.map (·.transpose)
47-
let weight_t := weight.map (·.transpose)
58+
let inp_t := inp.map transpose
59+
let weight_t := weight.map transpose
4860

4961
let dinp_b := matmul_batched weight_t dout
5062
let dweight_b := matmul_batched dout inp_t
5163

52-
let dweight := dweight_b.sum
64+
let dweight := dweight_b.sum
5365

5466
(dinp_b, dweight)

0 commit comments

Comments
 (0)