@@ -381,35 +381,29 @@ where
381381
382382 // Calculate the sum of elements (not including the empty element if there is one)
383383 for ( i, element) in self . elements . iter ( ) . enumerate ( ) {
384- match empty_amount_element {
385- Some ( empty_i) => {
386- if i != empty_i {
387- //TODO: perform commodity type conversion here if required
388- sum = match sum. add ( & element. amount . as_ref ( ) . unwrap ( ) ) {
389- Ok ( value) => value,
390- Err ( error) => return Err ( AccountingError :: Commodity ( error) ) ,
391- }
384+ if let Some ( empty_i) = empty_amount_element {
385+ if i != empty_i {
386+ //TODO: perform commodity type conversion here if required
387+ sum = match sum. add ( & element. amount . as_ref ( ) . unwrap ( ) ) {
388+ Ok ( value) => value,
389+ Err ( error) => return Err ( AccountingError :: Commodity ( error) ) ,
392390 }
393391 }
394- None => { }
395392 }
396393 }
397394
398395 // Calculate the value to use for the empty element (negate the sum of the other elements)
399- match empty_amount_element {
400- Some ( empty_i) => {
401- let modified_emtpy_element: & mut TransactionElement =
402- modified_elements. get_mut ( empty_i) . unwrap ( ) ;
403- let negated_sum = sum. neg ( ) ;
404- modified_emtpy_element. amount = Some ( negated_sum. clone ( ) ) ;
405-
406- sum = match sum. add ( & negated_sum) {
407- Ok ( value) => value,
408- Err ( error) => return Err ( AccountingError :: Commodity ( error) ) ,
409- }
396+ if let Some ( empty_i) = empty_amount_element {
397+ let modified_emtpy_element: & mut TransactionElement =
398+ modified_elements. get_mut ( empty_i) . unwrap ( ) ;
399+ let negated_sum = sum. neg ( ) ;
400+ modified_emtpy_element. amount = Some ( negated_sum) ;
401+
402+ sum = match sum. add ( & negated_sum) {
403+ Ok ( value) => value,
404+ Err ( error) => return Err ( AccountingError :: Commodity ( error) ) ,
410405 }
411- None => { }
412- } ;
406+ }
413407
414408 if sum. value != Decimal :: zero ( ) {
415409 return Err ( AccountingError :: InvalidTransaction (
@@ -421,12 +415,11 @@ where
421415 for transaction in & modified_elements {
422416 let mut account_state = program_state
423417 . get_account_state_mut ( & transaction. account_id )
424- . expect (
425- format ! (
418+ . unwrap_or_else ( ||
419+ panic ! (
426420 "unable to find state for account with id: {} please ensure this account was added to the program state before execution." ,
427421 transaction. account_id
428422 )
429- . as_ref ( ) ,
430423 ) ;
431424
432425 match account_state. status {
@@ -459,7 +452,7 @@ where
459452 }
460453 }
461454
462- return Ok ( ( ) ) ;
455+ Ok ( ( ) )
463456 }
464457}
465458
@@ -543,7 +536,7 @@ where
543536 . get_account_state_mut ( & self . account_id )
544537 . unwrap ( ) ;
545538 account_state. status = self . newstatus ;
546- return Ok ( ( ) ) ;
539+ Ok ( ( ) )
547540 }
548541}
549542
@@ -627,21 +620,27 @@ where
627620 }
628621
629622 fn perform ( & self , program_state : & mut ProgramState < AT , ATV > ) -> Result < ( ) , AccountingError > {
630- match program_state. get_account_state ( & self . account_id ) {
623+ let failed_assertion = match program_state. get_account_state ( & self . account_id ) {
631624 Some ( state) => {
632- if state
625+ if ! state
633626 . amount
634627 . eq_approx ( self . expected_balance , Commodity :: default_epsilon ( ) )
635628 {
629+ Some ( FailedBalanceAssertion :: new ( self . clone ( ) , state. amount ) )
636630 } else {
631+ None
637632 }
638633 }
639634 None => {
640635 return Err ( AccountingError :: MissingAccountState ( self . account_id ) ) ;
641636 }
637+ } ;
638+
639+ if let Some ( failed_assertion) = failed_assertion {
640+ program_state. record_failed_balance_assertion ( failed_assertion)
642641 }
643642
644- return Ok ( ( ) ) ;
643+ Ok ( ( ) )
645644 }
646645}
647646
@@ -654,7 +653,14 @@ impl ActionTypeFor<ActionType> for BalanceAssertion {
654653#[ cfg( test) ]
655654mod tests {
656655 use super :: ActionType ;
657- use std:: collections:: HashSet ;
656+ use crate :: {
657+ Account , AccountStatus , AccountingError , ActionTypeValue , ActionTypeValueEnum ,
658+ BalanceAssertion , Program , ProgramState , Transaction ,
659+ } ;
660+ use chrono:: NaiveDate ;
661+ use commodity:: { Commodity , CommodityType } ;
662+ use rust_decimal:: Decimal ;
663+ use std:: { collections:: HashSet , rc:: Rc } ;
658664
659665 #[ test]
660666 fn action_type_order ( ) {
@@ -690,6 +696,66 @@ mod tests {
690696
691697 assert_eq ! ( action_types_ordered, action_types_unordered) ;
692698 }
699+
700+ #[ test]
701+ fn balance_assertion ( ) {
702+ let aud = Rc :: from ( CommodityType :: from_currency_alpha3 ( "AUD" ) . unwrap ( ) ) ;
703+ let account1 = Rc :: from ( Account :: new_with_id ( Some ( "Account 1" ) , aud. id , None ) ) ;
704+ let account2 = Rc :: from ( Account :: new_with_id ( Some ( "Account 2" ) , aud. id , None ) ) ;
705+
706+ let date_1 = NaiveDate :: from_ymd ( 2020 , 01 , 01 ) ;
707+ let date_2 = NaiveDate :: from_ymd ( 2020 , 01 , 02 ) ;
708+ let actions: Vec < Rc < ActionTypeValue > > = vec ! [
709+ Rc :: new(
710+ Transaction :: new_simple:: <String >(
711+ None ,
712+ date_1. clone( ) ,
713+ account1. id,
714+ account2. id,
715+ Commodity :: new( Decimal :: new( 100 , 2 ) , & * aud) ,
716+ None ,
717+ )
718+ . into( ) ,
719+ ) ,
720+ // This assertion is expected to fail because it occurs at the start
721+ // of the day (before the transaction).
722+ Rc :: new(
723+ BalanceAssertion :: new(
724+ account2. id,
725+ date_1. clone( ) ,
726+ Commodity :: new( Decimal :: new( 100 , 2 ) , & * aud) ,
727+ )
728+ . into( ) ,
729+ ) ,
730+ // This assertion is expected to pass because it occurs at the end
731+ // of the day (after the transaction).
732+ Rc :: new(
733+ BalanceAssertion :: new(
734+ account2. id,
735+ date_2. clone( ) ,
736+ Commodity :: new( Decimal :: new( 100 , 2 ) , & * aud) ,
737+ )
738+ . into( ) ,
739+ ) ,
740+ ] ;
741+
742+ let program = Program :: new ( actions) ;
743+
744+ let accounts = vec ! [ account1, account2] ;
745+ let mut program_state = ProgramState :: new ( & accounts, AccountStatus :: Open ) ;
746+ match program_state. execute_program ( & program) {
747+ Err ( AccountingError :: BalanceAssertionFailed ( failure) ) => {
748+ assert_eq ! (
749+ Commodity :: new( Decimal :: new( 0 , 2 ) , & * aud) ,
750+ failure. actual_balance
751+ ) ;
752+ assert_eq ! ( date_1, failure. assertion. date) ;
753+ }
754+ _ => panic ! ( "Expected an AccountingError:BalanceAssertionFailed" ) ,
755+ }
756+
757+ assert_eq ! ( 1 , program_state. failed_balance_assertions. len( ) ) ;
758+ }
693759}
694760
695761#[ cfg( feature = "serde-support" ) ]
0 commit comments