rustc_builtin_macros/
autodiff.rs

1//! This module contains the implementation of the `#[autodiff]` attribute.
2//! Currently our linter isn't smart enough to see that each import is used in one of the two
3//! configs (autodiff enabled or disabled), so we have to add cfg's to each import.
4//! FIXME(ZuseZ4): Remove this once we have a smarter linter.
5
6mod llvm_enzyme {
7    use std::str::FromStr;
8    use std::string::String;
9
10    use rustc_ast::expand::autodiff_attrs::{
11        AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
12        valid_ty_for_activity,
13    };
14    use rustc_ast::ptr::P;
15    use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
16    use rustc_ast::tokenstream::*;
17    use rustc_ast::visit::AssocCtxt::*;
18    use rustc_ast::{
19        self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
20        MetaItemInner, PatKind, QSelf, TyKind, Visibility,
21    };
22    use rustc_expand::base::{Annotatable, ExtCtxt};
23    use rustc_span::{Ident, Span, Symbol, kw, sym};
24    use thin_vec::{ThinVec, thin_vec};
25    use tracing::{debug, trace};
26
27    use crate::errors;
28
29    pub(crate) fn outer_normal_attr(
30        kind: &P<rustc_ast::NormalAttr>,
31        id: rustc_ast::AttrId,
32        span: Span,
33    ) -> rustc_ast::Attribute {
34        let style = rustc_ast::AttrStyle::Outer;
35        let kind = rustc_ast::AttrKind::Normal(kind.clone());
36        rustc_ast::Attribute { kind, id, style, span }
37    }
38
39    // If we have a default `()` return type or explicitley `()` return type,
40    // then we often can skip doing some work.
41    fn has_ret(ty: &FnRetTy) -> bool {
42        match ty {
43            FnRetTy::Ty(ty) => !ty.kind.is_unit(),
44            FnRetTy::Default(_) => false,
45        }
46    }
47    fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
48        if let Some(l) = x.lit() {
49            match l.kind {
50                ast::LitKind::Int(val, _) => {
51                    // get an Ident from a lit
52                    return rustc_span::Ident::from_str(val.get().to_string().as_str());
53                }
54                _ => {}
55            }
56        }
57
58        let segments = &x.meta_item().unwrap().path.segments;
59        assert!(segments.len() == 1);
60        segments[0].ident
61    }
62
63    fn name(x: &MetaItemInner) -> String {
64        first_ident(x).name.to_string()
65    }
66
67    fn width(x: &MetaItemInner) -> Option<u128> {
68        let lit = x.lit()?;
69        match lit.kind {
70            ast::LitKind::Int(x, _) => Some(x.get()),
71            _ => return None,
72        }
73    }
74
75    // Get information about the function the macro is applied to
76    fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
77        match &iitem.kind {
78            ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
79                Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
80            }
81            _ => None,
82        }
83    }
84
85    pub(crate) fn from_ast(
86        ecx: &mut ExtCtxt<'_>,
87        meta_item: &ThinVec<MetaItemInner>,
88        has_ret: bool,
89        mode: DiffMode,
90    ) -> AutoDiffAttrs {
91        let dcx = ecx.sess.dcx();
92
93        // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
94        // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
95        let mut first_activity = 1;
96
97        let width = if let [_, x, ..] = &meta_item[..]
98            && let Some(x) = width(x)
99        {
100            first_activity = 2;
101            match x.try_into() {
102                Ok(x) => x,
103                Err(_) => {
104                    dcx.emit_err(errors::AutoDiffInvalidWidth {
105                        span: meta_item[1].span(),
106                        width: x,
107                    });
108                    return AutoDiffAttrs::error();
109                }
110            }
111        } else {
112            1
113        };
114
115        let mut activities: Vec<DiffActivity> = vec![];
116        let mut errors = false;
117        for x in &meta_item[first_activity..] {
118            let activity_str = name(&x);
119            let res = DiffActivity::from_str(&activity_str);
120            match res {
121                Ok(x) => activities.push(x),
122                Err(_) => {
123                    dcx.emit_err(errors::AutoDiffUnknownActivity {
124                        span: x.span(),
125                        act: activity_str,
126                    });
127                    errors = true;
128                }
129            };
130        }
131        if errors {
132            return AutoDiffAttrs::error();
133        }
134
135        // If a return type exist, we need to split the last activity,
136        // otherwise we return None as placeholder.
137        let (ret_activity, input_activity) = if has_ret {
138            let Some((last, rest)) = activities.split_last() else {
139                unreachable!(
140                    "should not be reachable because we counted the number of activities previously"
141                );
142            };
143            (last, rest)
144        } else {
145            (&DiffActivity::None, activities.as_slice())
146        };
147
148        AutoDiffAttrs {
149            mode,
150            width,
151            ret_activity: *ret_activity,
152            input_activity: input_activity.to_vec(),
153        }
154    }
155
156    fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
157        let comma: Token = Token::new(TokenKind::Comma, Span::default());
158        let val = first_ident(t);
159        let t = Token::from_ast_ident(val);
160        ts.push(TokenTree::Token(t, Spacing::Joint));
161        ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
162    }
163
164    pub(crate) fn expand_forward(
165        ecx: &mut ExtCtxt<'_>,
166        expand_span: Span,
167        meta_item: &ast::MetaItem,
168        item: Annotatable,
169    ) -> Vec<Annotatable> {
170        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
171    }
172
173    pub(crate) fn expand_reverse(
174        ecx: &mut ExtCtxt<'_>,
175        expand_span: Span,
176        meta_item: &ast::MetaItem,
177        item: Annotatable,
178    ) -> Vec<Annotatable> {
179        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
180    }
181
182    /// We expand the autodiff macro to generate a new placeholder function which passes
183    /// type-checking and can be called by users. The function body of the placeholder function will
184    /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
185    /// should just prevent early inlining and optimizations which alter the function signature.
186    /// The exact signature of the generated function depends on the configuration provided by the
187    /// user, but here is an example:
188    ///
189    /// ```
190    /// #[autodiff(cos_box, Reverse, Duplicated, Active)]
191    /// fn sin(x: &Box<f32>) -> f32 {
192    ///     f32::sin(**x)
193    /// }
194    /// ```
195    /// which becomes expanded to:
196    /// ```
197    /// #[rustc_autodiff]
198    /// #[inline(never)]
199    /// fn sin(x: &Box<f32>) -> f32 {
200    ///     f32::sin(**x)
201    /// }
202    /// #[rustc_autodiff(Reverse, Duplicated, Active)]
203    /// #[inline(never)]
204    /// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
205    ///     unsafe {
206    ///         asm!("NOP");
207    ///     };
208    ///     ::core::hint::black_box(sin(x));
209    ///     ::core::hint::black_box((dx, dret));
210    ///     ::core::hint::black_box(sin(x))
211    /// }
212    /// ```
213    /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
214    /// in CI.
215    pub(crate) fn expand_with_mode(
216        ecx: &mut ExtCtxt<'_>,
217        expand_span: Span,
218        meta_item: &ast::MetaItem,
219        mut item: Annotatable,
220        mode: DiffMode,
221    ) -> Vec<Annotatable> {
222        if cfg!(not(llvm_enzyme)) {
223            ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
224            return vec![item];
225        }
226        let dcx = ecx.sess.dcx();
227
228        // first get information about the annotable item: visibility, signature, name and generic
229        // parameters.
230        // these will be used to generate the differentiated version of the function
231        let Some((vis, sig, primal, generics)) = (match &item {
232            Annotatable::Item(iitem) => extract_item_info(iitem),
233            Annotatable::Stmt(stmt) => match &stmt.kind {
234                ast::StmtKind::Item(iitem) => extract_item_info(iitem),
235                _ => None,
236            },
237            Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
238                ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
239                    Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
240                }
241                _ => None,
242            },
243            _ => None,
244        }) else {
245            dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
246            return vec![item];
247        };
248
249        let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
250            ast::MetaItemKind::List(ref vec) => vec.clone(),
251            _ => {
252                dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
253                return vec![item];
254            }
255        };
256
257        let has_ret = has_ret(&sig.decl.output);
258        let sig_span = ecx.with_call_site_ctxt(sig.span);
259
260        // create TokenStream from vec elemtents:
261        // meta_item doesn't have a .tokens field
262        let mut ts: Vec<TokenTree> = vec![];
263        if meta_item_vec.len() < 1 {
264            // At the bare minimum, we need a fnc name.
265            dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
266            return vec![item];
267        }
268
269        let mode_symbol = match mode {
270            DiffMode::Forward => sym::Forward,
271            DiffMode::Reverse => sym::Reverse,
272            _ => unreachable!("Unsupported mode: {:?}", mode),
273        };
274
275        // Insert mode token
276        let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
277        ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
278        ts.insert(
279            1,
280            TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
281        );
282
283        // Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
284        // If it is not given, we default to 1 (scalar mode).
285        let start_position;
286        let kind: LitKind = LitKind::Integer;
287        let symbol;
288        if meta_item_vec.len() >= 2
289            && let Some(width) = width(&meta_item_vec[1])
290        {
291            start_position = 2;
292            symbol = Symbol::intern(&width.to_string());
293        } else {
294            start_position = 1;
295            symbol = sym::integer(1);
296        }
297
298        let l: Lit = Lit { kind, symbol, suffix: None };
299        let t = Token::new(TokenKind::Literal(l), Span::default());
300        let comma = Token::new(TokenKind::Comma, Span::default());
301        ts.push(TokenTree::Token(t, Spacing::Joint));
302        ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
303
304        for t in meta_item_vec.clone()[start_position..].iter() {
305            meta_item_inner_to_ts(t, &mut ts);
306        }
307
308        if !has_ret {
309            // We don't want users to provide a return activity if the function doesn't return anything.
310            // For simplicity, we just add a dummy token to the end of the list.
311            let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
312            ts.push(TokenTree::Token(t, Spacing::Joint));
313            ts.push(TokenTree::Token(comma, Spacing::Alone));
314        }
315        // We remove the last, trailing comma.
316        ts.pop();
317        let ts: TokenStream = TokenStream::from_iter(ts);
318
319        let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
320        if !x.is_active() {
321            // We encountered an error, so we return the original item.
322            // This allows us to potentially parse other attributes.
323            return vec![item];
324        }
325        let span = ecx.with_def_site_ctxt(expand_span);
326
327        let n_active: u32 = x
328            .input_activity
329            .iter()
330            .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
331            .count() as u32;
332        let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
333        let d_body = gen_enzyme_body(
334            ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
335            &generics,
336        );
337
338        // The first element of it is the name of the function to be generated
339        let asdf = Box::new(ast::Fn {
340            defaultness: ast::Defaultness::Final,
341            sig: d_sig,
342            ident: first_ident(&meta_item_vec[0]),
343            generics,
344            contract: None,
345            body: Some(d_body),
346            define_opaque: None,
347        });
348        let mut rustc_ad_attr =
349            P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
350
351        let ts2: Vec<TokenTree> = vec![TokenTree::Token(
352            Token::new(TokenKind::Ident(sym::never, false.into()), span),
353            Spacing::Joint,
354        )];
355        let never_arg = ast::DelimArgs {
356            dspan: DelimSpan::from_single(span),
357            delim: ast::token::Delimiter::Parenthesis,
358            tokens: TokenStream::from_iter(ts2),
359        };
360        let inline_item = ast::AttrItem {
361            unsafety: ast::Safety::Default,
362            path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
363            args: ast::AttrArgs::Delimited(never_arg),
364            tokens: None,
365        };
366        let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
367        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
368        let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
369        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
370        let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
371
372        // We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
373        fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
374            match (attr, item) {
375                (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
376                    let a = &a.item.path;
377                    let b = &b.item.path;
378                    a.segments.len() == b.segments.len()
379                        && a.segments.iter().zip(b.segments.iter()).all(|(a, b)| a.ident == b.ident)
380                }
381                _ => false,
382            }
383        }
384
385        // Don't add it multiple times:
386        let orig_annotatable: Annotatable = match item {
387            Annotatable::Item(ref mut iitem) => {
388                if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
389                    iitem.attrs.push(attr);
390                }
391                if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
392                    iitem.attrs.push(inline_never.clone());
393                }
394                Annotatable::Item(iitem.clone())
395            }
396            Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
397                if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
398                    assoc_item.attrs.push(attr);
399                }
400                if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
401                    assoc_item.attrs.push(inline_never.clone());
402                }
403                Annotatable::AssocItem(assoc_item.clone(), i)
404            }
405            Annotatable::Stmt(ref mut stmt) => {
406                match stmt.kind {
407                    ast::StmtKind::Item(ref mut iitem) => {
408                        if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
409                            iitem.attrs.push(attr);
410                        }
411                        if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
412                        {
413                            iitem.attrs.push(inline_never.clone());
414                        }
415                    }
416                    _ => unreachable!("stmt kind checked previously"),
417                };
418
419                Annotatable::Stmt(stmt.clone())
420            }
421            _ => {
422                unreachable!("annotatable kind checked previously")
423            }
424        };
425        // Now update for d_fn
426        rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
427            dspan: DelimSpan::dummy(),
428            delim: rustc_ast::token::Delimiter::Parenthesis,
429            tokens: ts,
430        });
431
432        let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
433        let d_annotatable = match &item {
434            Annotatable::AssocItem(_, _) => {
435                let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
436                let d_fn = P(ast::AssocItem {
437                    attrs: thin_vec![d_attr, inline_never],
438                    id: ast::DUMMY_NODE_ID,
439                    span,
440                    vis,
441                    kind: assoc_item,
442                    tokens: None,
443                });
444                Annotatable::AssocItem(d_fn, Impl { of_trait: false })
445            }
446            Annotatable::Item(_) => {
447                let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
448                d_fn.vis = vis;
449
450                Annotatable::Item(d_fn)
451            }
452            Annotatable::Stmt(_) => {
453                let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
454                d_fn.vis = vis;
455
456                Annotatable::Stmt(P(ast::Stmt {
457                    id: ast::DUMMY_NODE_ID,
458                    kind: ast::StmtKind::Item(d_fn),
459                    span,
460                }))
461            }
462            _ => {
463                unreachable!("item kind checked previously")
464            }
465        };
466
467        return vec![orig_annotatable, d_annotatable];
468    }
469
470    // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
471    // mutable references or ptrs, because Enzyme will write into them.
472    fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
473        let mut ty = ty.clone();
474        match ty.kind {
475            TyKind::Ptr(ref mut mut_ty) => {
476                mut_ty.mutbl = ast::Mutability::Mut;
477            }
478            TyKind::Ref(_, ref mut mut_ty) => {
479                mut_ty.mutbl = ast::Mutability::Mut;
480            }
481            _ => {
482                panic!("unsupported type: {:?}", ty);
483            }
484        }
485        ty
486    }
487
488    // Will generate a body of the type:
489    // ```
490    // {
491    //   unsafe {
492    //   asm!("NOP");
493    //   }
494    //   ::core::hint::black_box(primal(args));
495    //   ::core::hint::black_box((args, ret));
496    //   <This part remains to be done by following function>
497    // }
498    // ```
499    fn init_body_helper(
500        ecx: &ExtCtxt<'_>,
501        span: Span,
502        primal: Ident,
503        new_names: &[String],
504        sig_span: Span,
505        new_decl_span: Span,
506        idents: &[Ident],
507        errored: bool,
508        generics: &Generics,
509    ) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
510        let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
511        let noop = ast::InlineAsm {
512            asm_macro: ast::AsmMacro::Asm,
513            template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
514            template_strs: Box::new([]),
515            operands: vec![],
516            clobber_abis: vec![],
517            options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
518            line_spans: vec![],
519        };
520        let noop_expr = ecx.expr_asm(span, P(noop));
521        let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
522        let unsf_block = ast::Block {
523            stmts: thin_vec![ecx.stmt_semi(noop_expr)],
524            id: ast::DUMMY_NODE_ID,
525            tokens: None,
526            rules: unsf,
527            span,
528        };
529        let unsf_expr = ecx.expr_block(P(unsf_block));
530        let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
531        let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
532        let black_box_primal_call = ecx.expr_call(
533            new_decl_span,
534            blackbox_call_expr.clone(),
535            thin_vec![primal_call.clone()],
536        );
537        let tup_args = new_names
538            .iter()
539            .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
540            .collect();
541
542        let black_box_remaining_args = ecx.expr_call(
543            sig_span,
544            blackbox_call_expr.clone(),
545            thin_vec![ecx.expr_tuple(sig_span, tup_args)],
546        );
547
548        let mut body = ecx.block(span, ThinVec::new());
549        body.stmts.push(ecx.stmt_semi(unsf_expr));
550
551        // This uses primal args which won't be available if we errored before
552        if !errored {
553            body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
554        }
555        body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
556
557        (body, primal_call, black_box_primal_call, blackbox_call_expr)
558    }
559
560    /// We only want this function to type-check, since we will replace the body
561    /// later on llvm level. Using `loop {}` does not cover all return types anymore,
562    /// so instead we manually build something that should pass the type checker.
563    /// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
564    /// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
565    /// bug would ever try to accidentially differentiate this placeholder function body.
566    /// Finally, we also add back_box usages of all input arguments, to prevent rustc
567    /// from optimizing any arguments away.
568    fn gen_enzyme_body(
569        ecx: &ExtCtxt<'_>,
570        x: &AutoDiffAttrs,
571        n_active: u32,
572        sig: &ast::FnSig,
573        d_sig: &ast::FnSig,
574        primal: Ident,
575        new_names: &[String],
576        span: Span,
577        sig_span: Span,
578        idents: Vec<Ident>,
579        errored: bool,
580        generics: &Generics,
581    ) -> P<ast::Block> {
582        let new_decl_span = d_sig.span;
583
584        // Just adding some default inline-asm and black_box usages to prevent early inlining
585        // and optimizations which alter the function signature.
586        //
587        // The bb_primal_call is the black_box call of the primal function. We keep it around,
588        // since it has the convenient property of returning the type of the primal function,
589        // Remember, we only care to match types here.
590        // No matter which return we pick, we always wrap it into a std::hint::black_box call,
591        // to prevent rustc from propagating it into the caller.
592        let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
593            ecx,
594            span,
595            primal,
596            new_names,
597            sig_span,
598            new_decl_span,
599            &idents,
600            errored,
601            generics,
602        );
603
604        if !has_ret(&d_sig.decl.output) {
605            // there is no return type that we have to match, () works fine.
606            return body;
607        }
608
609        // Everything from here onwards just tries to fullfil the return type. Fun!
610
611        // having an active-only return means we'll drop the original return type.
612        // So that can be treated identical to not having one in the first place.
613        let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
614
615        if primal_ret && n_active == 0 && x.mode.is_rev() {
616            // We only have the primal ret.
617            body.stmts.push(ecx.stmt_expr(bb_primal_call));
618            return body;
619        }
620
621        if !primal_ret && n_active == 1 {
622            // Again no tuple return, so return default float val.
623            let ty = match d_sig.decl.output {
624                FnRetTy::Ty(ref ty) => ty.clone(),
625                FnRetTy::Default(span) => {
626                    panic!("Did not expect Default ret ty: {:?}", span);
627                }
628            };
629            let arg = ty.kind.is_simple_path().unwrap();
630            let tmp = ecx.def_site_path(&[arg, kw::Default]);
631            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
632            let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
633            body.stmts.push(ecx.stmt_expr(default_call_expr));
634            return body;
635        }
636
637        let mut exprs: P<ast::Expr> = primal_call;
638        let d_ret_ty = match d_sig.decl.output {
639            FnRetTy::Ty(ref ty) => ty.clone(),
640            FnRetTy::Default(span) => {
641                panic!("Did not expect Default ret ty: {:?}", span);
642            }
643        };
644        if x.mode.is_fwd() {
645            // Fwd mode is easy. If the return activity is Const, we support arbitrary types.
646            // Otherwise, we only support a scalar, a pair of scalars, or an array of scalars.
647            // We checked that (on a best-effort base) in the preceding gen_enzyme_decl function.
648            // In all three cases, we can return `std::hint::black_box(<T>::default())`.
649            if x.ret_activity == DiffActivity::Const {
650                // Here we call the primal function, since our dummy function has the same return
651                // type due to the Const return activity.
652                exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
653            } else {
654                let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 };
655                let y =
656                    ExprKind::Path(Some(P(q)), ecx.path_ident(span, Ident::from_str("default")));
657                let default_call_expr = ecx.expr(span, y);
658                let default_call_expr =
659                    ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
660                exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
661            }
662        } else if x.mode.is_rev() {
663            if x.width == 1 {
664                // We either have `-> ArbitraryType` or `-> (ArbitraryType, repeated_float_scalars)`.
665                match d_ret_ty.kind {
666                    TyKind::Tup(ref args) => {
667                        // We have a tuple return type. We need to create a tuple of the same size
668                        // and fill it with default values.
669                        let mut exprs2 = thin_vec![exprs];
670                        for arg in args.iter().skip(1) {
671                            let arg = arg.kind.is_simple_path().unwrap();
672                            let tmp = ecx.def_site_path(&[arg, kw::Default]);
673                            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
674                            let default_call_expr =
675                                ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
676                            exprs2.push(default_call_expr);
677                        }
678                        exprs = ecx.expr_tuple(new_decl_span, exprs2);
679                    }
680                    _ => {
681                        // Interestingly, even the `-> ArbitraryType` case
682                        // ends up getting matched and handled correctly above,
683                        // so we don't have to handle any other case for now.
684                        panic!("Unsupported return type: {:?}", d_ret_ty);
685                    }
686                }
687            }
688            exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
689        } else {
690            unreachable!("Unsupported mode: {:?}", x.mode);
691        }
692
693        body.stmts.push(ecx.stmt_expr(exprs));
694
695        body
696    }
697
698    fn gen_primal_call(
699        ecx: &ExtCtxt<'_>,
700        span: Span,
701        primal: Ident,
702        idents: &[Ident],
703        generics: &Generics,
704    ) -> P<ast::Expr> {
705        let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
706
707        if has_self {
708            let args: ThinVec<_> =
709                idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
710            let self_expr = ecx.expr_self(span);
711            ecx.expr_method_call(span, self_expr, primal, args)
712        } else {
713            let args: ThinVec<_> =
714                idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
715            let mut primal_path = ecx.path_ident(span, primal);
716
717            let is_generic = !generics.params.is_empty();
718
719            match (is_generic, primal_path.segments.last_mut()) {
720                (true, Some(function_path)) => {
721                    let primal_generic_types = generics
722                        .params
723                        .iter()
724                        .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
725
726                    let generated_generic_types = primal_generic_types
727                        .map(|type_param| {
728                            let generic_param = TyKind::Path(
729                                None,
730                                ast::Path {
731                                    span,
732                                    segments: thin_vec![ast::PathSegment {
733                                        ident: type_param.ident,
734                                        args: None,
735                                        id: ast::DUMMY_NODE_ID,
736                                    }],
737                                    tokens: None,
738                                },
739                            );
740
741                            ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
742                                id: type_param.id,
743                                span,
744                                kind: generic_param,
745                                tokens: None,
746                            })))
747                        })
748                        .collect();
749
750                    function_path.args =
751                        Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
752                            span,
753                            args: generated_generic_types,
754                        })));
755                }
756                _ => {}
757            }
758
759            let primal_call_expr = ecx.expr_path(primal_path);
760            ecx.expr_call(span, primal_call_expr, args)
761        }
762    }
763
764    // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
765    // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
766    // Active arguments must be scalars. Their shadow argument is added to the return type (and will be
767    // zero-initialized by Enzyme).
768    // Each argument of the primal function (and the return type if existing) must be annotated with an
769    // activity.
770    //
771    // Error handling: If the user provides an invalid configuration (incorrect numbers, types, or
772    // both), we emit an error and return the original signature. This allows us to continue parsing.
773    // FIXME(Sa4dUs): make individual activities' span available so errors
774    // can point to only the activity instead of the entire attribute
775    fn gen_enzyme_decl(
776        ecx: &ExtCtxt<'_>,
777        sig: &ast::FnSig,
778        x: &AutoDiffAttrs,
779        span: Span,
780    ) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
781        let dcx = ecx.sess.dcx();
782        let has_ret = has_ret(&sig.decl.output);
783        let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
784        let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
785        if sig_args != num_activities {
786            dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
787                span,
788                expected: sig_args,
789                found: num_activities,
790            });
791            // This is not the right signature, but we can continue parsing.
792            return (sig.clone(), vec![], vec![], true);
793        }
794        assert!(sig.decl.inputs.len() == x.input_activity.len());
795        assert!(has_ret == x.has_ret_activity());
796        let mut d_decl = sig.decl.clone();
797        let mut d_inputs = Vec::new();
798        let mut new_inputs = Vec::new();
799        let mut idents = Vec::new();
800        let mut act_ret = ThinVec::new();
801
802        // We have two loops, a first one just to check the activities and types and possibly report
803        // multiple errors in one compilation session.
804        let mut errors = false;
805        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
806            if !valid_input_activity(x.mode, *activity) {
807                dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
808                    span,
809                    mode: x.mode.to_string(),
810                    act: activity.to_string(),
811                });
812                errors = true;
813            }
814            if !valid_ty_for_activity(&arg.ty, *activity) {
815                dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
816                    span: arg.ty.span,
817                    act: activity.to_string(),
818                });
819                errors = true;
820            }
821        }
822
823        if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {
824            dcx.emit_err(errors::AutoDiffInvalidRetAct {
825                span,
826                mode: x.mode.to_string(),
827                act: x.ret_activity.to_string(),
828            });
829            // We don't set `errors = true` to avoid annoying type errors relative
830            // to the expanded macro type signature
831        }
832
833        if errors {
834            // This is not the right signature, but we can continue parsing.
835            return (sig.clone(), new_inputs, idents, true);
836        }
837
838        let unsafe_activities = x
839            .input_activity
840            .iter()
841            .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
842        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
843            d_inputs.push(arg.clone());
844            match activity {
845                DiffActivity::Active => {
846                    act_ret.push(arg.ty.clone());
847                    // if width =/= 1, then push [arg.ty; width] to act_ret
848                }
849                DiffActivity::ActiveOnly => {
850                    // We will add the active scalar to the return type.
851                    // This is handled later.
852                }
853                DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
854                    for i in 0..x.width {
855                        let mut shadow_arg = arg.clone();
856                        // We += into the shadow in reverse mode.
857                        shadow_arg.ty = P(assure_mut_ref(&arg.ty));
858                        let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
859                            ident.name
860                        } else {
861                            debug!("{:#?}", &shadow_arg.pat);
862                            panic!("not an ident?");
863                        };
864                        let name: String = format!("d{}_{}", old_name, i);
865                        new_inputs.push(name.clone());
866                        let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
867                        shadow_arg.pat = P(ast::Pat {
868                            id: ast::DUMMY_NODE_ID,
869                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
870                            span: shadow_arg.pat.span,
871                            tokens: shadow_arg.pat.tokens.clone(),
872                        });
873                        d_inputs.push(shadow_arg.clone());
874                    }
875                }
876                DiffActivity::Dual
877                | DiffActivity::DualOnly
878                | DiffActivity::Dualv
879                | DiffActivity::DualvOnly => {
880                    // the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
881                    // Enzyme to not expect N arguments, but one argument (which is instead larger).
882                    let iterations =
883                        if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
884                            1
885                        } else {
886                            x.width
887                        };
888                    for i in 0..iterations {
889                        let mut shadow_arg = arg.clone();
890                        let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
891                            ident.name
892                        } else {
893                            debug!("{:#?}", &shadow_arg.pat);
894                            panic!("not an ident?");
895                        };
896                        let name: String = format!("b{}_{}", old_name, i);
897                        new_inputs.push(name.clone());
898                        let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
899                        shadow_arg.pat = P(ast::Pat {
900                            id: ast::DUMMY_NODE_ID,
901                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
902                            span: shadow_arg.pat.span,
903                            tokens: shadow_arg.pat.tokens.clone(),
904                        });
905                        d_inputs.push(shadow_arg.clone());
906                    }
907                }
908                DiffActivity::Const => {
909                    // Nothing to do here.
910                }
911                DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
912                    panic!("Should not happen");
913                }
914            }
915            if let PatKind::Ident(_, ident, _) = arg.pat.kind {
916                idents.push(ident.clone());
917            } else {
918                panic!("not an ident?");
919            }
920        }
921
922        let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
923        if active_only_ret {
924            assert!(x.mode.is_rev());
925        }
926
927        // If we return a scalar in the primal and the scalar is active,
928        // then add it as last arg to the inputs.
929        if x.mode.is_rev() {
930            match x.ret_activity {
931                DiffActivity::Active | DiffActivity::ActiveOnly => {
932                    let ty = match d_decl.output {
933                        FnRetTy::Ty(ref ty) => ty.clone(),
934                        FnRetTy::Default(span) => {
935                            panic!("Did not expect Default ret ty: {:?}", span);
936                        }
937                    };
938                    let name = "dret".to_string();
939                    let ident = Ident::from_str_and_span(&name, ty.span);
940                    let shadow_arg = ast::Param {
941                        attrs: ThinVec::new(),
942                        ty: ty.clone(),
943                        pat: P(ast::Pat {
944                            id: ast::DUMMY_NODE_ID,
945                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
946                            span: ty.span,
947                            tokens: None,
948                        }),
949                        id: ast::DUMMY_NODE_ID,
950                        span: ty.span,
951                        is_placeholder: false,
952                    };
953                    d_inputs.push(shadow_arg);
954                    new_inputs.push(name);
955                }
956                _ => {}
957            }
958        }
959        d_decl.inputs = d_inputs.into();
960
961        if x.mode.is_fwd() {
962            let ty = match d_decl.output {
963                FnRetTy::Ty(ref ty) => ty.clone(),
964                FnRetTy::Default(span) => {
965                    // We want to return std::hint::black_box(()).
966                    let kind = TyKind::Tup(ThinVec::new());
967                    let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
968                    d_decl.output = FnRetTy::Ty(ty.clone());
969                    assert!(matches!(x.ret_activity, DiffActivity::None));
970                    // this won't be used below, so any type would be fine.
971                    ty
972                }
973            };
974
975            if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
976                let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
977                    // Dual can only be used for f32/f64 ret.
978                    // In that case we return now a tuple with two floats.
979                    TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
980                } else {
981                    // We have to return [T; width+1], +1 for the primal return.
982                    let anon_const = rustc_ast::AnonConst {
983                        id: ast::DUMMY_NODE_ID,
984                        value: ecx.expr_usize(span, 1 + x.width as usize),
985                    };
986                    TyKind::Array(ty.clone(), anon_const)
987                };
988                let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
989                d_decl.output = FnRetTy::Ty(ty);
990            }
991            if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
992                // No need to change the return type,
993                // we will just return the shadow in place of the primal return.
994                // However, if we have a width > 1, then we don't return -> T, but -> [T; width]
995                if x.width > 1 {
996                    let anon_const = rustc_ast::AnonConst {
997                        id: ast::DUMMY_NODE_ID,
998                        value: ecx.expr_usize(span, x.width as usize),
999                    };
1000                    let kind = TyKind::Array(ty.clone(), anon_const);
1001                    let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
1002                    d_decl.output = FnRetTy::Ty(ty);
1003                }
1004            }
1005        }
1006
1007        // If we use ActiveOnly, drop the original return value.
1008        d_decl.output =
1009            if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
1010
1011        trace!("act_ret: {:?}", act_ret);
1012
1013        // If we have an active input scalar, add it's gradient to the
1014        // return type. This might require changing the return type to a
1015        // tuple.
1016        if act_ret.len() > 0 {
1017            let ret_ty = match d_decl.output {
1018                FnRetTy::Ty(ref ty) => {
1019                    if !active_only_ret {
1020                        act_ret.insert(0, ty.clone());
1021                    }
1022                    let kind = TyKind::Tup(act_ret);
1023                    P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
1024                }
1025                FnRetTy::Default(span) => {
1026                    if act_ret.len() == 1 {
1027                        act_ret[0].clone()
1028                    } else {
1029                        let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
1030                        P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
1031                    }
1032                }
1033            };
1034            d_decl.output = FnRetTy::Ty(ret_ty);
1035        }
1036
1037        let mut d_header = sig.header.clone();
1038        if unsafe_activities {
1039            d_header.safety = rustc_ast::Safety::Unsafe(span);
1040        }
1041        let d_sig = FnSig { header: d_header, decl: d_decl, span };
1042        trace!("Generated signature: {:?}", d_sig);
1043        (d_sig, new_inputs, idents, false)
1044    }
1045}
1046
1047pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};