@@ -18,8 +18,14 @@ pub enum TransactionError {
18
18
File ( String ) ,
19
19
}
20
20
21
+ #[ derive( Clone , Debug , PartialEq ) ]
22
+ enum QueryType {
23
+ Query ,
24
+ Execute ,
25
+ }
26
+
21
27
pub struct TransactionLog {
22
- log : Vec < String > ,
28
+ log : Vec < ( QueryType , String ) > ,
23
29
}
24
30
25
31
impl TransactionLog {
@@ -42,7 +48,7 @@ impl TransactionLog {
42
48
let sql_string: String = self
43
49
. log
44
50
. iter ( )
45
- . filter_map ( |stmt| match stmt. as_str ( ) {
51
+ . filter_map ( |( _ , stmt) | match stmt. as_str ( ) {
46
52
"" => None ,
47
53
x if x. ends_with ( ";" ) => Some ( stmt. clone ( ) ) ,
48
54
x => Some ( format ! ( "{x};" ) ) ,
@@ -83,13 +89,40 @@ impl TransactionLog {
83
89
84
90
return Ok ( report) ;
85
91
}
92
+
93
+ #[ allow( unused) ]
94
+ pub ( crate ) async fn commit (
95
+ self ,
96
+ conn : & trailbase_sqlite:: Connection ,
97
+ ) -> Result < ( ) , trailbase_sqlite:: Error > {
98
+ conn
99
+ . call ( |conn : & mut rusqlite:: Connection | {
100
+ let tx = conn. transaction ( ) ?;
101
+ for ( query_type, stmt) in self . log {
102
+ match query_type {
103
+ QueryType :: Query => {
104
+ tx. query_row ( & stmt, ( ) , |_row| Ok ( ( ) ) ) ?;
105
+ }
106
+ QueryType :: Execute => {
107
+ tx. execute ( & stmt, ( ) ) ?;
108
+ }
109
+ }
110
+ }
111
+ tx. commit ( ) ?;
112
+
113
+ return Ok ( ( ) ) ;
114
+ } )
115
+ . await ?;
116
+
117
+ return Ok ( ( ) ) ;
118
+ }
86
119
}
87
120
88
121
/// A recorder for table migrations, i.e.: create, alter, drop, as opposed to data migrations.
89
122
pub struct TransactionRecorder < ' a > {
90
123
tx : rusqlite:: Transaction < ' a > ,
91
124
92
- log : Vec < String > ,
125
+ log : Vec < ( QueryType , String ) > ,
93
126
}
94
127
95
128
impl < ' a > TransactionRecorder < ' a > {
@@ -105,10 +138,18 @@ impl<'a> TransactionRecorder<'a> {
105
138
// Note that we cannot take any sql params for recording purposes.
106
139
#[ allow( unused) ]
107
140
pub fn query ( & mut self , sql : & str , params : impl rusqlite:: Params ) -> Result < ( ) , rusqlite:: Error > {
108
- let mut stmt = self . tx . prepare ( sql) ?;
109
- let mut rows = stmt. query ( params) ?;
141
+ let mut stmt = self . tx . prepare_cached ( sql) ?;
142
+ params. __bind_in ( & mut stmt) ?;
143
+ let Some ( expanded_sql) = stmt. expanded_sql ( ) else {
144
+ return Err ( rusqlite:: Error :: ToSqlConversionFailure (
145
+ "failed to get expanded query" . into ( ) ,
146
+ ) ) ;
147
+ } ;
148
+
149
+ let mut rows = stmt. raw_query ( ) ;
110
150
rows. next ( ) ?;
111
- self . log . push ( sql. to_string ( ) ) ;
151
+
152
+ self . log . push ( ( QueryType :: Query , expanded_sql) ) ;
112
153
113
154
return Ok ( ( ) ) ;
114
155
}
@@ -118,8 +159,18 @@ impl<'a> TransactionRecorder<'a> {
118
159
sql : & str ,
119
160
params : impl rusqlite:: Params ,
120
161
) -> Result < usize , rusqlite:: Error > {
121
- let rows_affected = self . tx . execute ( sql, params) ?;
122
- self . log . push ( sql. to_string ( ) ) ;
162
+ // let rows_affected = self.tx.execute(sql, params)?;
163
+ let mut stmt = self . tx . prepare_cached ( sql) ?;
164
+ params. __bind_in ( & mut stmt) ?;
165
+ let Some ( expanded_sql) = stmt. expanded_sql ( ) else {
166
+ return Err ( rusqlite:: Error :: ToSqlConversionFailure (
167
+ "failed to get expanded query" . into ( ) ,
168
+ ) ) ;
169
+ } ;
170
+
171
+ let rows_affected = stmt. raw_execute ( ) ?;
172
+
173
+ self . log . push ( ( QueryType :: Execute , expanded_sql) ) ;
123
174
return Ok ( rows_affected) ;
124
175
}
125
176
@@ -136,8 +187,11 @@ impl<'a> TransactionRecorder<'a> {
136
187
}
137
188
}
138
189
139
- #[ cfg( not( test) ) ]
140
190
async fn write_migration_file ( path : PathBuf , sql : & str ) -> std:: io:: Result < ( ) > {
191
+ if cfg ! ( test) {
192
+ return Ok ( ( ) ) ;
193
+ }
194
+
141
195
use tokio:: io:: AsyncWriteExt ;
142
196
143
197
let mut migration_file = tokio:: fs:: File :: create_new ( path) . await ?;
@@ -146,6 +200,67 @@ async fn write_migration_file(path: PathBuf, sql: &str) -> std::io::Result<()> {
146
200
}
147
201
148
202
#[ cfg( test) ]
149
- async fn write_migration_file ( _path : PathBuf , _sql : & str ) -> std:: io:: Result < ( ) > {
150
- return Ok ( ( ) ) ;
203
+ mod tests {
204
+ use super :: * ;
205
+
206
+ #[ tokio:: test]
207
+ async fn test_transaction_log ( ) {
208
+ let mut conn = rusqlite:: Connection :: open_in_memory ( ) . unwrap ( ) ;
209
+ conn
210
+ . execute_batch (
211
+ r#"
212
+ CREATE TABLE 'table' (
213
+ id INTEGER PRIMARY KEY NOT NULL,
214
+ name TEXT NOT NULL,
215
+ age INTEGER
216
+ ) STRICT;
217
+
218
+ INSERT INTO 'table' (id, name, age) VALUES (0, 'Alice', 21), (1, 'Bob', 18);
219
+ "# ,
220
+ )
221
+ . unwrap ( ) ;
222
+
223
+ // Just double checking that rusqlite's query and execute ignore everything but the first
224
+ // statement.
225
+ let name: String = conn
226
+ . query_row (
227
+ r#"
228
+ SELECT name FROM 'table' WHERE id = 0;
229
+ SELECT name FROM 'table' WHERE id = 1;
230
+ DROP TABLE 'table';
231
+ "# ,
232
+ ( ) ,
233
+ |row| row. get ( 0 ) ,
234
+ )
235
+ . unwrap ( ) ;
236
+ assert_eq ! ( name, "Alice" ) ;
237
+
238
+ let mut recorder = TransactionRecorder :: new ( & mut conn) . unwrap ( ) ;
239
+
240
+ recorder
241
+ . execute ( "DELETE FROM 'table' WHERE age < ?1" , rusqlite:: params!( 20 ) )
242
+ . unwrap ( ) ;
243
+ let log = recorder. rollback ( ) . unwrap ( ) . unwrap ( ) ;
244
+
245
+ assert_eq ! ( log. log. len( ) , 1 ) ;
246
+ assert_eq ! ( log. log[ 0 ] . 0 , QueryType :: Execute ) ;
247
+ assert_eq ! ( log. log[ 0 ] . 1 , "DELETE FROM 'table' WHERE age < 20" ) ;
248
+
249
+ let conn = trailbase_sqlite:: Connection :: from_connection_test_only ( conn) ;
250
+ let count: i64 = conn
251
+ . query_row_f ( "SELECT COUNT(*) FROM 'table'" , ( ) , |row| row. get ( 0 ) )
252
+ . await
253
+ . unwrap ( )
254
+ . unwrap ( ) ;
255
+ assert_eq ! ( count, 2 ) ;
256
+
257
+ log. commit ( & conn) . await . unwrap ( ) ;
258
+
259
+ let count: i64 = conn
260
+ . query_row_f ( "SELECT COUNT(*) FROM 'table'" , ( ) , |row| row. get ( 0 ) )
261
+ . await
262
+ . unwrap ( )
263
+ . unwrap ( ) ;
264
+ assert_eq ! ( count, 1 ) ;
265
+ }
151
266
}
0 commit comments