@@ -88,25 +88,20 @@ mod llvm_enzyme {
8888 has_ret : bool ,
8989 ) -> AutoDiffAttrs {
9090 let dcx = ecx. sess . dcx ( ) ;
91- let mode = name ( & meta_item[ 1 ] ) ;
92- let Ok ( mode) = DiffMode :: from_str ( & mode) else {
93- dcx. emit_err ( errors:: AutoDiffInvalidMode { span : meta_item[ 1 ] . span ( ) , mode } ) ;
94- return AutoDiffAttrs :: error ( ) ;
95- } ;
9691
9792 // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
9893 // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
99- let mut first_activity = 2 ;
94+ let mut first_activity = 1 ;
10095
101- let width = if let [ _, _ , x, ..] = & meta_item[ ..]
96+ let width = if let [ _, x, ..] = & meta_item[ ..]
10297 && let Some ( x) = width ( x)
10398 {
104- first_activity = 3 ;
99+ first_activity = 2 ;
105100 match x. try_into ( ) {
106101 Ok ( x) => x,
107102 Err ( _) => {
108103 dcx. emit_err ( errors:: AutoDiffInvalidWidth {
109- span : meta_item[ 2 ] . span ( ) ,
104+ span : meta_item[ 1 ] . span ( ) ,
110105 width : x,
111106 } ) ;
112107 return AutoDiffAttrs :: error ( ) ;
@@ -150,7 +145,7 @@ mod llvm_enzyme {
150145 } ;
151146
152147 AutoDiffAttrs {
153- mode,
148+ mode : DiffMode :: Error ,
154149 width,
155150 ret_activity : * ret_activity,
156151 input_activity : input_activity. to_vec ( ) ,
@@ -165,6 +160,24 @@ mod llvm_enzyme {
165160 ts. push ( TokenTree :: Token ( comma. clone ( ) , Spacing :: Alone ) ) ;
166161 }
167162
163+ pub ( crate ) fn expand_forward (
164+ ecx : & mut ExtCtxt < ' _ > ,
165+ expand_span : Span ,
166+ meta_item : & ast:: MetaItem ,
167+ item : Annotatable ,
168+ ) -> Vec < Annotatable > {
169+ expand_with_mode ( ecx, expand_span, meta_item, item, DiffMode :: Forward )
170+ }
171+
172+ pub ( crate ) fn expand_reverse (
173+ ecx : & mut ExtCtxt < ' _ > ,
174+ expand_span : Span ,
175+ meta_item : & ast:: MetaItem ,
176+ item : Annotatable ,
177+ ) -> Vec < Annotatable > {
178+ expand_with_mode ( ecx, expand_span, meta_item, item, DiffMode :: Reverse )
179+ }
180+
168181 /// We expand the autodiff macro to generate a new placeholder function which passes
169182 /// type-checking and can be called by users. The function body of the placeholder function will
170183 /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
@@ -198,11 +211,12 @@ mod llvm_enzyme {
198211 /// ```
199212 /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
200213 /// in CI.
201- pub ( crate ) fn expand (
214+ pub ( crate ) fn expand_with_mode (
202215 ecx : & mut ExtCtxt < ' _ > ,
203216 expand_span : Span ,
204217 meta_item : & ast:: MetaItem ,
205218 mut item : Annotatable ,
219+ mode : DiffMode ,
206220 ) -> Vec < Annotatable > {
207221 if cfg ! ( not( llvm_enzyme) ) {
208222 ecx. sess . dcx ( ) . emit_err ( errors:: AutoDiffSupportNotBuild { span : meta_item. span } ) ;
@@ -289,7 +303,8 @@ mod llvm_enzyme {
289303 ts. pop ( ) ;
290304 let ts: TokenStream = TokenStream :: from_iter ( ts) ;
291305
292- let x: AutoDiffAttrs = from_ast ( ecx, & meta_item_vec, has_ret) ;
306+ let mut x: AutoDiffAttrs = from_ast ( ecx, & meta_item_vec, has_ret) ;
307+ x. mode = mode;
293308 if !x. is_active ( ) {
294309 // We encountered an error, so we return the original item.
295310 // This allows us to potentially parse other attributes.
@@ -1017,4 +1032,4 @@ mod llvm_enzyme {
10171032 }
10181033}
10191034
1020- pub ( crate ) use llvm_enzyme:: expand ;
1035+ pub ( crate ) use llvm_enzyme:: { expand_forward , expand_reverse } ;
0 commit comments