Skip to content

Commit 122dd1e

Browse files
committed
Allow applying recorded transaction as migration or plain transaction. Add tests.
1 parent 4ce7297 commit 122dd1e

File tree

2 files changed

+144
-11
lines changed

2 files changed

+144
-11
lines changed

trailbase-core/src/transaction.rs

Lines changed: 126 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,14 @@ pub enum TransactionError {
1818
File(String),
1919
}
2020

21+
#[derive(Clone, Debug, PartialEq)]
22+
enum QueryType {
23+
Query,
24+
Execute,
25+
}
26+
2127
pub struct TransactionLog {
22-
log: Vec<String>,
28+
log: Vec<(QueryType, String)>,
2329
}
2430

2531
impl TransactionLog {
@@ -42,7 +48,7 @@ impl TransactionLog {
4248
let sql_string: String = self
4349
.log
4450
.iter()
45-
.filter_map(|stmt| match stmt.as_str() {
51+
.filter_map(|(_, stmt)| match stmt.as_str() {
4652
"" => None,
4753
x if x.ends_with(";") => Some(stmt.clone()),
4854
x => Some(format!("{x};")),
@@ -83,13 +89,40 @@ impl TransactionLog {
8389

8490
return Ok(report);
8591
}
92+
93+
#[allow(unused)]
94+
pub(crate) async fn commit(
95+
self,
96+
conn: &trailbase_sqlite::Connection,
97+
) -> Result<(), trailbase_sqlite::Error> {
98+
conn
99+
.call(|conn: &mut rusqlite::Connection| {
100+
let tx = conn.transaction()?;
101+
for (query_type, stmt) in self.log {
102+
match query_type {
103+
QueryType::Query => {
104+
tx.query_row(&stmt, (), |_row| Ok(()))?;
105+
}
106+
QueryType::Execute => {
107+
tx.execute(&stmt, ())?;
108+
}
109+
}
110+
}
111+
tx.commit()?;
112+
113+
return Ok(());
114+
})
115+
.await?;
116+
117+
return Ok(());
118+
}
86119
}
87120

88121
/// A recorder for table migrations, i.e.: create, alter, drop, as opposed to data migrations.
89122
pub struct TransactionRecorder<'a> {
90123
tx: rusqlite::Transaction<'a>,
91124

92-
log: Vec<String>,
125+
log: Vec<(QueryType, String)>,
93126
}
94127

95128
impl<'a> TransactionRecorder<'a> {
@@ -105,10 +138,18 @@ impl<'a> TransactionRecorder<'a> {
105138
// Note that we cannot take any sql params for recording purposes.
106139
#[allow(unused)]
107140
pub fn query(&mut self, sql: &str, params: impl rusqlite::Params) -> Result<(), rusqlite::Error> {
108-
let mut stmt = self.tx.prepare(sql)?;
109-
let mut rows = stmt.query(params)?;
141+
let mut stmt = self.tx.prepare_cached(sql)?;
142+
params.__bind_in(&mut stmt)?;
143+
let Some(expanded_sql) = stmt.expanded_sql() else {
144+
return Err(rusqlite::Error::ToSqlConversionFailure(
145+
"failed to get expanded query".into(),
146+
));
147+
};
148+
149+
let mut rows = stmt.raw_query();
110150
rows.next()?;
111-
self.log.push(sql.to_string());
151+
152+
self.log.push((QueryType::Query, expanded_sql));
112153

113154
return Ok(());
114155
}
@@ -118,8 +159,18 @@ impl<'a> TransactionRecorder<'a> {
118159
sql: &str,
119160
params: impl rusqlite::Params,
120161
) -> Result<usize, rusqlite::Error> {
121-
let rows_affected = self.tx.execute(sql, params)?;
122-
self.log.push(sql.to_string());
162+
// let rows_affected = self.tx.execute(sql, params)?;
163+
let mut stmt = self.tx.prepare_cached(sql)?;
164+
params.__bind_in(&mut stmt)?;
165+
let Some(expanded_sql) = stmt.expanded_sql() else {
166+
return Err(rusqlite::Error::ToSqlConversionFailure(
167+
"failed to get expanded query".into(),
168+
));
169+
};
170+
171+
let rows_affected = stmt.raw_execute()?;
172+
173+
self.log.push((QueryType::Execute, expanded_sql));
123174
return Ok(rows_affected);
124175
}
125176

@@ -136,8 +187,11 @@ impl<'a> TransactionRecorder<'a> {
136187
}
137188
}
138189

139-
#[cfg(not(test))]
140190
async fn write_migration_file(path: PathBuf, sql: &str) -> std::io::Result<()> {
191+
if cfg!(test) {
192+
return Ok(());
193+
}
194+
141195
use tokio::io::AsyncWriteExt;
142196

143197
let mut migration_file = tokio::fs::File::create_new(path).await?;
@@ -146,6 +200,67 @@ async fn write_migration_file(path: PathBuf, sql: &str) -> std::io::Result<()> {
146200
}
147201

148202
#[cfg(test)]
149-
async fn write_migration_file(_path: PathBuf, _sql: &str) -> std::io::Result<()> {
150-
return Ok(());
203+
mod tests {
204+
use super::*;
205+
206+
#[tokio::test]
207+
async fn test_transaction_log() {
208+
let mut conn = rusqlite::Connection::open_in_memory().unwrap();
209+
conn
210+
.execute_batch(
211+
r#"
212+
CREATE TABLE 'table' (
213+
id INTEGER PRIMARY KEY NOT NULL,
214+
name TEXT NOT NULL,
215+
age INTEGER
216+
) STRICT;
217+
218+
INSERT INTO 'table' (id, name, age) VALUES (0, 'Alice', 21), (1, 'Bob', 18);
219+
"#,
220+
)
221+
.unwrap();
222+
223+
// Just double checking that rusqlite's query and execute ignore everything but the first
224+
// statement.
225+
let name: String = conn
226+
.query_row(
227+
r#"
228+
SELECT name FROM 'table' WHERE id = 0;
229+
SELECT name FROM 'table' WHERE id = 1;
230+
DROP TABLE 'table';
231+
"#,
232+
(),
233+
|row| row.get(0),
234+
)
235+
.unwrap();
236+
assert_eq!(name, "Alice");
237+
238+
let mut recorder = TransactionRecorder::new(&mut conn).unwrap();
239+
240+
recorder
241+
.execute("DELETE FROM 'table' WHERE age < ?1", rusqlite::params!(20))
242+
.unwrap();
243+
let log = recorder.rollback().unwrap().unwrap();
244+
245+
assert_eq!(log.log.len(), 1);
246+
assert_eq!(log.log[0].0, QueryType::Execute);
247+
assert_eq!(log.log[0].1, "DELETE FROM 'table' WHERE age < 20");
248+
249+
let conn = trailbase_sqlite::Connection::from_connection_test_only(conn);
250+
let count: i64 = conn
251+
.query_row_f("SELECT COUNT(*) FROM 'table'", (), |row| row.get(0))
252+
.await
253+
.unwrap()
254+
.unwrap();
255+
assert_eq!(count, 2);
256+
257+
log.commit(&conn).await.unwrap();
258+
259+
let count: i64 = conn
260+
.query_row_f("SELECT COUNT(*) FROM 'table'", (), |row| row.get(0))
261+
.await
262+
.unwrap()
263+
.unwrap();
264+
assert_eq!(count, 1);
265+
}
151266
}

trailbase-sqlite/src/connection.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,24 @@ impl Connection {
158158
});
159159
}
160160

161+
pub fn from_connection_test_only(conn: rusqlite::Connection) -> Self {
162+
use parking_lot::lock_api::RwLock;
163+
164+
let (shared_write_sender, shared_write_receiver) = kanal::unbounded::<Message>();
165+
std::thread::spawn(move || {
166+
event_loop(
167+
0,
168+
Arc::new(LockedConnections(RwLock::new(vec![conn]))),
169+
shared_write_receiver,
170+
)
171+
});
172+
173+
return Self {
174+
reader: shared_write_sender.clone(),
175+
writer: shared_write_sender,
176+
};
177+
}
178+
161179
/// Open a new connection to an in-memory SQLite database.
162180
///
163181
/// # Failure

0 commit comments

Comments
 (0)