@@ -10,6 +10,7 @@ import (
10
10
"io"
11
11
"net"
12
12
"runtime"
13
+ "strings"
13
14
"sync"
14
15
"time"
15
16
@@ -37,13 +38,12 @@ func init() {
37
38
38
39
type ClientInstance struct {
39
40
sync.RWMutex
40
- nfsEKey * mlkem.EncapsulationKey768
41
- nfsEKeySha256 [32 ]byte
42
- xor uint32
43
- minutes time.Duration
44
- expire time.Time
45
- baseKey []byte
46
- ticket []byte
41
+ nfsEKey * mlkem.EncapsulationKey768
42
+ xorKey []byte
43
+ minutes time.Duration
44
+ expire time.Time
45
+ baseKey []byte
46
+ ticket []byte
47
47
}
48
48
49
49
type ClientConn struct {
@@ -60,10 +60,17 @@ type ClientConn struct {
60
60
}
61
61
62
62
func (i * ClientInstance ) Init (nfsEKeyBytes []byte , xor uint32 , minutes time.Duration ) (err error ) {
63
+ if i .nfsEKey != nil {
64
+ err = errors .New ("already initialized" )
65
+ return
66
+ }
63
67
i .nfsEKey , err = mlkem .NewEncapsulationKey768 (nfsEKeyBytes )
68
+ if err != nil {
69
+ return
70
+ }
64
71
if xor > 0 {
65
- i . nfsEKeySha256 = sha256 .Sum256 (nfsEKeyBytes )
66
- i .xor = xor
72
+ xorKey : = sha256 .Sum256 (nfsEKeyBytes )
73
+ i .xorKey = xorKey [:]
67
74
}
68
75
i .minutes = minutes
69
76
return
@@ -73,8 +80,8 @@ func (i *ClientInstance) Handshake(conn net.Conn) (net.Conn, error) {
73
80
if i .nfsEKey == nil {
74
81
return nil , errors .New ("uninitialized" )
75
82
}
76
- if i .xor > 0 {
77
- conn = NewXorConn (conn , i .nfsEKeySha256 [:] )
83
+ if i .xorKey != nil {
84
+ conn = NewXorConn (conn , i .xorKey )
78
85
}
79
86
c := & ClientConn {Conn : conn }
80
87
@@ -110,14 +117,14 @@ func (i *ClientInstance) Handshake(conn net.Conn) (net.Conn, error) {
110
117
}
111
118
// client can send more padding / NFS AEAD messages if needed
112
119
113
- _ , t , l , err := ReadAndDecodeHeader (c .Conn )
120
+ _ , t , l , err := ReadAndDiscardPaddings (c .Conn )
114
121
if err != nil {
115
122
return nil , err
116
123
}
124
+
117
125
if t != 1 {
118
126
return nil , fmt .Errorf ("unexpected type %v, expect random hello" , t )
119
127
}
120
-
121
128
peerRandomHello := make ([]byte , 1088 + 21 )
122
129
if l != len (peerRandomHello ) {
123
130
return nil , fmt .Errorf ("unexpected length %v for random hello" , l )
@@ -194,34 +201,17 @@ func (c *ClientConn) Read(b []byte) (int, error) {
194
201
return 0 , nil
195
202
}
196
203
if c .peerAead == nil {
197
- var t byte
198
- var l int
199
- var err error
200
- if c .instance == nil { // from 1-RTT
201
- for {
202
- if _ , t , l , err = ReadAndDecodeHeader (c .Conn ); err != nil {
203
- return 0 , err
204
- }
205
- if t != 23 {
206
- break
207
- }
208
- if _ , err := io .ReadFull (c .Conn , make ([]byte , l )); err != nil {
209
- return 0 , err
210
- }
211
- }
212
- } else {
213
- h := make ([]byte , 5 )
214
- if _ , err := io .ReadFull (c .Conn , h ); err != nil {
215
- return 0 , err
216
- }
217
- if t , l , err = DecodeHeader (h ); err != nil {
204
+ _ , t , l , err := ReadAndDiscardPaddings (c .Conn )
205
+ if err != nil {
206
+ if c .instance != nil && strings .HasPrefix (err .Error (), "invalid header: " ) { // from 0-RTT
218
207
c .instance .Lock ()
219
208
if bytes .Equal (c .ticket , c .instance .ticket ) {
220
209
c .instance .expire = time .Now () // expired
221
210
}
222
211
c .instance .Unlock ()
223
212
return 0 , errors .New ("new handshake needed" )
224
213
}
214
+ return 0 , err
225
215
}
226
216
if t != 0 {
227
217
return 0 , fmt .Errorf ("unexpected type %v, expect server random" , t )
0 commit comments