1mod llvm_enzyme {
7 use std::str::FromStr;
8 use std::string::String;
9
10 use rustc_ast::expand::autodiff_attrs::{
11 AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
12 valid_ty_for_activity,
13 };
14 use rustc_ast::ptr::P;
15 use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
16 use rustc_ast::tokenstream::*;
17 use rustc_ast::visit::AssocCtxt::*;
18 use rustc_ast::{
19 self as ast, AssocItemKind, BindingMode, ExprKind, FnRetTy, FnSig, Generics, ItemKind,
20 MetaItemInner, PatKind, QSelf, TyKind, Visibility,
21 };
22 use rustc_expand::base::{Annotatable, ExtCtxt};
23 use rustc_span::{Ident, Span, Symbol, kw, sym};
24 use thin_vec::{ThinVec, thin_vec};
25 use tracing::{debug, trace};
26
27 use crate::errors;
28
29 pub(crate) fn outer_normal_attr(
30 kind: &P<rustc_ast::NormalAttr>,
31 id: rustc_ast::AttrId,
32 span: Span,
33 ) -> rustc_ast::Attribute {
34 let style = rustc_ast::AttrStyle::Outer;
35 let kind = rustc_ast::AttrKind::Normal(kind.clone());
36 rustc_ast::Attribute { kind, id, style, span }
37 }
38
39 fn has_ret(ty: &FnRetTy) -> bool {
42 match ty {
43 FnRetTy::Ty(ty) => !ty.kind.is_unit(),
44 FnRetTy::Default(_) => false,
45 }
46 }
47 fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
48 if let Some(l) = x.lit() {
49 match l.kind {
50 ast::LitKind::Int(val, _) => {
51 return rustc_span::Ident::from_str(val.get().to_string().as_str());
53 }
54 _ => {}
55 }
56 }
57
58 let segments = &x.meta_item().unwrap().path.segments;
59 assert!(segments.len() == 1);
60 segments[0].ident
61 }
62
63 fn name(x: &MetaItemInner) -> String {
64 first_ident(x).name.to_string()
65 }
66
67 fn width(x: &MetaItemInner) -> Option<u128> {
68 let lit = x.lit()?;
69 match lit.kind {
70 ast::LitKind::Int(x, _) => Some(x.get()),
71 _ => return None,
72 }
73 }
74
75 fn extract_item_info(iitem: &P<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
77 match &iitem.kind {
78 ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
79 Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
80 }
81 _ => None,
82 }
83 }
84
85 pub(crate) fn from_ast(
86 ecx: &mut ExtCtxt<'_>,
87 meta_item: &ThinVec<MetaItemInner>,
88 has_ret: bool,
89 mode: DiffMode,
90 ) -> AutoDiffAttrs {
91 let dcx = ecx.sess.dcx();
92
93 let mut first_activity = 1;
96
97 let width = if let [_, x, ..] = &meta_item[..]
98 && let Some(x) = width(x)
99 {
100 first_activity = 2;
101 match x.try_into() {
102 Ok(x) => x,
103 Err(_) => {
104 dcx.emit_err(errors::AutoDiffInvalidWidth {
105 span: meta_item[1].span(),
106 width: x,
107 });
108 return AutoDiffAttrs::error();
109 }
110 }
111 } else {
112 1
113 };
114
115 let mut activities: Vec<DiffActivity> = vec![];
116 let mut errors = false;
117 for x in &meta_item[first_activity..] {
118 let activity_str = name(&x);
119 let res = DiffActivity::from_str(&activity_str);
120 match res {
121 Ok(x) => activities.push(x),
122 Err(_) => {
123 dcx.emit_err(errors::AutoDiffUnknownActivity {
124 span: x.span(),
125 act: activity_str,
126 });
127 errors = true;
128 }
129 };
130 }
131 if errors {
132 return AutoDiffAttrs::error();
133 }
134
135 let (ret_activity, input_activity) = if has_ret {
138 let Some((last, rest)) = activities.split_last() else {
139 unreachable!(
140 "should not be reachable because we counted the number of activities previously"
141 );
142 };
143 (last, rest)
144 } else {
145 (&DiffActivity::None, activities.as_slice())
146 };
147
148 AutoDiffAttrs {
149 mode,
150 width,
151 ret_activity: *ret_activity,
152 input_activity: input_activity.to_vec(),
153 }
154 }
155
156 fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
157 let comma: Token = Token::new(TokenKind::Comma, Span::default());
158 let val = first_ident(t);
159 let t = Token::from_ast_ident(val);
160 ts.push(TokenTree::Token(t, Spacing::Joint));
161 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
162 }
163
164 pub(crate) fn expand_forward(
165 ecx: &mut ExtCtxt<'_>,
166 expand_span: Span,
167 meta_item: &ast::MetaItem,
168 item: Annotatable,
169 ) -> Vec<Annotatable> {
170 expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
171 }
172
173 pub(crate) fn expand_reverse(
174 ecx: &mut ExtCtxt<'_>,
175 expand_span: Span,
176 meta_item: &ast::MetaItem,
177 item: Annotatable,
178 ) -> Vec<Annotatable> {
179 expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
180 }
181
182 pub(crate) fn expand_with_mode(
216 ecx: &mut ExtCtxt<'_>,
217 expand_span: Span,
218 meta_item: &ast::MetaItem,
219 mut item: Annotatable,
220 mode: DiffMode,
221 ) -> Vec<Annotatable> {
222 if cfg!(not(llvm_enzyme)) {
223 ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
224 return vec![item];
225 }
226 let dcx = ecx.sess.dcx();
227
228 let Some((vis, sig, primal, generics)) = (match &item {
232 Annotatable::Item(iitem) => extract_item_info(iitem),
233 Annotatable::Stmt(stmt) => match &stmt.kind {
234 ast::StmtKind::Item(iitem) => extract_item_info(iitem),
235 _ => None,
236 },
237 Annotatable::AssocItem(assoc_item, Impl { .. }) => match &assoc_item.kind {
238 ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
239 Some((assoc_item.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
240 }
241 _ => None,
242 },
243 _ => None,
244 }) else {
245 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
246 return vec![item];
247 };
248
249 let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
250 ast::MetaItemKind::List(ref vec) => vec.clone(),
251 _ => {
252 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
253 return vec![item];
254 }
255 };
256
257 let has_ret = has_ret(&sig.decl.output);
258 let sig_span = ecx.with_call_site_ctxt(sig.span);
259
260 let mut ts: Vec<TokenTree> = vec![];
263 if meta_item_vec.len() < 1 {
264 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
266 return vec![item];
267 }
268
269 let mode_symbol = match mode {
270 DiffMode::Forward => sym::Forward,
271 DiffMode::Reverse => sym::Reverse,
272 _ => unreachable!("Unsupported mode: {:?}", mode),
273 };
274
275 let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
277 ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
278 ts.insert(
279 1,
280 TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
281 );
282
283 let start_position;
286 let kind: LitKind = LitKind::Integer;
287 let symbol;
288 if meta_item_vec.len() >= 2
289 && let Some(width) = width(&meta_item_vec[1])
290 {
291 start_position = 2;
292 symbol = Symbol::intern(&width.to_string());
293 } else {
294 start_position = 1;
295 symbol = sym::integer(1);
296 }
297
298 let l: Lit = Lit { kind, symbol, suffix: None };
299 let t = Token::new(TokenKind::Literal(l), Span::default());
300 let comma = Token::new(TokenKind::Comma, Span::default());
301 ts.push(TokenTree::Token(t, Spacing::Joint));
302 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
303
304 for t in meta_item_vec.clone()[start_position..].iter() {
305 meta_item_inner_to_ts(t, &mut ts);
306 }
307
308 if !has_ret {
309 let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
312 ts.push(TokenTree::Token(t, Spacing::Joint));
313 ts.push(TokenTree::Token(comma, Spacing::Alone));
314 }
315 ts.pop();
317 let ts: TokenStream = TokenStream::from_iter(ts);
318
319 let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
320 if !x.is_active() {
321 return vec![item];
324 }
325 let span = ecx.with_def_site_ctxt(expand_span);
326
327 let n_active: u32 = x
328 .input_activity
329 .iter()
330 .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
331 .count() as u32;
332 let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
333 let d_body = gen_enzyme_body(
334 ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
335 &generics,
336 );
337
338 let asdf = Box::new(ast::Fn {
340 defaultness: ast::Defaultness::Final,
341 sig: d_sig,
342 ident: first_ident(&meta_item_vec[0]),
343 generics,
344 contract: None,
345 body: Some(d_body),
346 define_opaque: None,
347 });
348 let mut rustc_ad_attr =
349 P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
350
351 let ts2: Vec<TokenTree> = vec![TokenTree::Token(
352 Token::new(TokenKind::Ident(sym::never, false.into()), span),
353 Spacing::Joint,
354 )];
355 let never_arg = ast::DelimArgs {
356 dspan: DelimSpan::from_single(span),
357 delim: ast::token::Delimiter::Parenthesis,
358 tokens: TokenStream::from_iter(ts2),
359 };
360 let inline_item = ast::AttrItem {
361 unsafety: ast::Safety::Default,
362 path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
363 args: ast::AttrArgs::Delimited(never_arg),
364 tokens: None,
365 };
366 let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
367 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
368 let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
369 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
370 let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
371
372 fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
374 match (attr, item) {
375 (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
376 let a = &a.item.path;
377 let b = &b.item.path;
378 a.segments.len() == b.segments.len()
379 && a.segments.iter().zip(b.segments.iter()).all(|(a, b)| a.ident == b.ident)
380 }
381 _ => false,
382 }
383 }
384
385 let orig_annotatable: Annotatable = match item {
387 Annotatable::Item(ref mut iitem) => {
388 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
389 iitem.attrs.push(attr);
390 }
391 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
392 iitem.attrs.push(inline_never.clone());
393 }
394 Annotatable::Item(iitem.clone())
395 }
396 Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
397 if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
398 assoc_item.attrs.push(attr);
399 }
400 if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
401 assoc_item.attrs.push(inline_never.clone());
402 }
403 Annotatable::AssocItem(assoc_item.clone(), i)
404 }
405 Annotatable::Stmt(ref mut stmt) => {
406 match stmt.kind {
407 ast::StmtKind::Item(ref mut iitem) => {
408 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
409 iitem.attrs.push(attr);
410 }
411 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
412 {
413 iitem.attrs.push(inline_never.clone());
414 }
415 }
416 _ => unreachable!("stmt kind checked previously"),
417 };
418
419 Annotatable::Stmt(stmt.clone())
420 }
421 _ => {
422 unreachable!("annotatable kind checked previously")
423 }
424 };
425 rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
427 dspan: DelimSpan::dummy(),
428 delim: rustc_ast::token::Delimiter::Parenthesis,
429 tokens: ts,
430 });
431
432 let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
433 let d_annotatable = match &item {
434 Annotatable::AssocItem(_, _) => {
435 let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
436 let d_fn = P(ast::AssocItem {
437 attrs: thin_vec![d_attr, inline_never],
438 id: ast::DUMMY_NODE_ID,
439 span,
440 vis,
441 kind: assoc_item,
442 tokens: None,
443 });
444 Annotatable::AssocItem(d_fn, Impl { of_trait: false })
445 }
446 Annotatable::Item(_) => {
447 let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
448 d_fn.vis = vis;
449
450 Annotatable::Item(d_fn)
451 }
452 Annotatable::Stmt(_) => {
453 let mut d_fn = ecx.item(span, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
454 d_fn.vis = vis;
455
456 Annotatable::Stmt(P(ast::Stmt {
457 id: ast::DUMMY_NODE_ID,
458 kind: ast::StmtKind::Item(d_fn),
459 span,
460 }))
461 }
462 _ => {
463 unreachable!("item kind checked previously")
464 }
465 };
466
467 return vec![orig_annotatable, d_annotatable];
468 }
469
470 fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
473 let mut ty = ty.clone();
474 match ty.kind {
475 TyKind::Ptr(ref mut mut_ty) => {
476 mut_ty.mutbl = ast::Mutability::Mut;
477 }
478 TyKind::Ref(_, ref mut mut_ty) => {
479 mut_ty.mutbl = ast::Mutability::Mut;
480 }
481 _ => {
482 panic!("unsupported type: {:?}", ty);
483 }
484 }
485 ty
486 }
487
488 fn init_body_helper(
500 ecx: &ExtCtxt<'_>,
501 span: Span,
502 primal: Ident,
503 new_names: &[String],
504 sig_span: Span,
505 new_decl_span: Span,
506 idents: &[Ident],
507 errored: bool,
508 generics: &Generics,
509 ) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
510 let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
511 let noop = ast::InlineAsm {
512 asm_macro: ast::AsmMacro::Asm,
513 template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
514 template_strs: Box::new([]),
515 operands: vec![],
516 clobber_abis: vec![],
517 options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
518 line_spans: vec![],
519 };
520 let noop_expr = ecx.expr_asm(span, P(noop));
521 let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
522 let unsf_block = ast::Block {
523 stmts: thin_vec![ecx.stmt_semi(noop_expr)],
524 id: ast::DUMMY_NODE_ID,
525 tokens: None,
526 rules: unsf,
527 span,
528 };
529 let unsf_expr = ecx.expr_block(P(unsf_block));
530 let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
531 let primal_call = gen_primal_call(ecx, span, primal, idents, generics);
532 let black_box_primal_call = ecx.expr_call(
533 new_decl_span,
534 blackbox_call_expr.clone(),
535 thin_vec![primal_call.clone()],
536 );
537 let tup_args = new_names
538 .iter()
539 .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
540 .collect();
541
542 let black_box_remaining_args = ecx.expr_call(
543 sig_span,
544 blackbox_call_expr.clone(),
545 thin_vec![ecx.expr_tuple(sig_span, tup_args)],
546 );
547
548 let mut body = ecx.block(span, ThinVec::new());
549 body.stmts.push(ecx.stmt_semi(unsf_expr));
550
551 if !errored {
553 body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
554 }
555 body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
556
557 (body, primal_call, black_box_primal_call, blackbox_call_expr)
558 }
559
560 fn gen_enzyme_body(
569 ecx: &ExtCtxt<'_>,
570 x: &AutoDiffAttrs,
571 n_active: u32,
572 sig: &ast::FnSig,
573 d_sig: &ast::FnSig,
574 primal: Ident,
575 new_names: &[String],
576 span: Span,
577 sig_span: Span,
578 idents: Vec<Ident>,
579 errored: bool,
580 generics: &Generics,
581 ) -> P<ast::Block> {
582 let new_decl_span = d_sig.span;
583
584 let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
593 ecx,
594 span,
595 primal,
596 new_names,
597 sig_span,
598 new_decl_span,
599 &idents,
600 errored,
601 generics,
602 );
603
604 if !has_ret(&d_sig.decl.output) {
605 return body;
607 }
608
609 let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
614
615 if primal_ret && n_active == 0 && x.mode.is_rev() {
616 body.stmts.push(ecx.stmt_expr(bb_primal_call));
618 return body;
619 }
620
621 if !primal_ret && n_active == 1 {
622 let ty = match d_sig.decl.output {
624 FnRetTy::Ty(ref ty) => ty.clone(),
625 FnRetTy::Default(span) => {
626 panic!("Did not expect Default ret ty: {:?}", span);
627 }
628 };
629 let arg = ty.kind.is_simple_path().unwrap();
630 let tmp = ecx.def_site_path(&[arg, kw::Default]);
631 let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
632 let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
633 body.stmts.push(ecx.stmt_expr(default_call_expr));
634 return body;
635 }
636
637 let mut exprs: P<ast::Expr> = primal_call;
638 let d_ret_ty = match d_sig.decl.output {
639 FnRetTy::Ty(ref ty) => ty.clone(),
640 FnRetTy::Default(span) => {
641 panic!("Did not expect Default ret ty: {:?}", span);
642 }
643 };
644 if x.mode.is_fwd() {
645 if x.ret_activity == DiffActivity::Const {
650 exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
653 } else {
654 let q = QSelf { ty: d_ret_ty, path_span: span, position: 0 };
655 let y =
656 ExprKind::Path(Some(P(q)), ecx.path_ident(span, Ident::from_str("default")));
657 let default_call_expr = ecx.expr(span, y);
658 let default_call_expr =
659 ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
660 exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![default_call_expr]);
661 }
662 } else if x.mode.is_rev() {
663 if x.width == 1 {
664 match d_ret_ty.kind {
666 TyKind::Tup(ref args) => {
667 let mut exprs2 = thin_vec![exprs];
670 for arg in args.iter().skip(1) {
671 let arg = arg.kind.is_simple_path().unwrap();
672 let tmp = ecx.def_site_path(&[arg, kw::Default]);
673 let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
674 let default_call_expr =
675 ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
676 exprs2.push(default_call_expr);
677 }
678 exprs = ecx.expr_tuple(new_decl_span, exprs2);
679 }
680 _ => {
681 panic!("Unsupported return type: {:?}", d_ret_ty);
685 }
686 }
687 }
688 exprs = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![exprs]);
689 } else {
690 unreachable!("Unsupported mode: {:?}", x.mode);
691 }
692
693 body.stmts.push(ecx.stmt_expr(exprs));
694
695 body
696 }
697
698 fn gen_primal_call(
699 ecx: &ExtCtxt<'_>,
700 span: Span,
701 primal: Ident,
702 idents: &[Ident],
703 generics: &Generics,
704 ) -> P<ast::Expr> {
705 let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
706
707 if has_self {
708 let args: ThinVec<_> =
709 idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
710 let self_expr = ecx.expr_self(span);
711 ecx.expr_method_call(span, self_expr, primal, args)
712 } else {
713 let args: ThinVec<_> =
714 idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
715 let mut primal_path = ecx.path_ident(span, primal);
716
717 let is_generic = !generics.params.is_empty();
718
719 match (is_generic, primal_path.segments.last_mut()) {
720 (true, Some(function_path)) => {
721 let primal_generic_types = generics
722 .params
723 .iter()
724 .filter(|param| matches!(param.kind, ast::GenericParamKind::Type { .. }));
725
726 let generated_generic_types = primal_generic_types
727 .map(|type_param| {
728 let generic_param = TyKind::Path(
729 None,
730 ast::Path {
731 span,
732 segments: thin_vec![ast::PathSegment {
733 ident: type_param.ident,
734 args: None,
735 id: ast::DUMMY_NODE_ID,
736 }],
737 tokens: None,
738 },
739 );
740
741 ast::AngleBracketedArg::Arg(ast::GenericArg::Type(P(ast::Ty {
742 id: type_param.id,
743 span,
744 kind: generic_param,
745 tokens: None,
746 })))
747 })
748 .collect();
749
750 function_path.args =
751 Some(P(ast::GenericArgs::AngleBracketed(ast::AngleBracketedArgs {
752 span,
753 args: generated_generic_types,
754 })));
755 }
756 _ => {}
757 }
758
759 let primal_call_expr = ecx.expr_path(primal_path);
760 ecx.expr_call(span, primal_call_expr, args)
761 }
762 }
763
764 fn gen_enzyme_decl(
776 ecx: &ExtCtxt<'_>,
777 sig: &ast::FnSig,
778 x: &AutoDiffAttrs,
779 span: Span,
780 ) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
781 let dcx = ecx.sess.dcx();
782 let has_ret = has_ret(&sig.decl.output);
783 let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
784 let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
785 if sig_args != num_activities {
786 dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
787 span,
788 expected: sig_args,
789 found: num_activities,
790 });
791 return (sig.clone(), vec![], vec![], true);
793 }
794 assert!(sig.decl.inputs.len() == x.input_activity.len());
795 assert!(has_ret == x.has_ret_activity());
796 let mut d_decl = sig.decl.clone();
797 let mut d_inputs = Vec::new();
798 let mut new_inputs = Vec::new();
799 let mut idents = Vec::new();
800 let mut act_ret = ThinVec::new();
801
802 let mut errors = false;
805 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
806 if !valid_input_activity(x.mode, *activity) {
807 dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
808 span,
809 mode: x.mode.to_string(),
810 act: activity.to_string(),
811 });
812 errors = true;
813 }
814 if !valid_ty_for_activity(&arg.ty, *activity) {
815 dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
816 span: arg.ty.span,
817 act: activity.to_string(),
818 });
819 errors = true;
820 }
821 }
822
823 if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {
824 dcx.emit_err(errors::AutoDiffInvalidRetAct {
825 span,
826 mode: x.mode.to_string(),
827 act: x.ret_activity.to_string(),
828 });
829 }
832
833 if errors {
834 return (sig.clone(), new_inputs, idents, true);
836 }
837
838 let unsafe_activities = x
839 .input_activity
840 .iter()
841 .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
842 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
843 d_inputs.push(arg.clone());
844 match activity {
845 DiffActivity::Active => {
846 act_ret.push(arg.ty.clone());
847 }
849 DiffActivity::ActiveOnly => {
850 }
853 DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
854 for i in 0..x.width {
855 let mut shadow_arg = arg.clone();
856 shadow_arg.ty = P(assure_mut_ref(&arg.ty));
858 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
859 ident.name
860 } else {
861 debug!("{:#?}", &shadow_arg.pat);
862 panic!("not an ident?");
863 };
864 let name: String = format!("d{}_{}", old_name, i);
865 new_inputs.push(name.clone());
866 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
867 shadow_arg.pat = P(ast::Pat {
868 id: ast::DUMMY_NODE_ID,
869 kind: PatKind::Ident(BindingMode::NONE, ident, None),
870 span: shadow_arg.pat.span,
871 tokens: shadow_arg.pat.tokens.clone(),
872 });
873 d_inputs.push(shadow_arg.clone());
874 }
875 }
876 DiffActivity::Dual
877 | DiffActivity::DualOnly
878 | DiffActivity::Dualv
879 | DiffActivity::DualvOnly => {
880 let iterations =
883 if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
884 1
885 } else {
886 x.width
887 };
888 for i in 0..iterations {
889 let mut shadow_arg = arg.clone();
890 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
891 ident.name
892 } else {
893 debug!("{:#?}", &shadow_arg.pat);
894 panic!("not an ident?");
895 };
896 let name: String = format!("b{}_{}", old_name, i);
897 new_inputs.push(name.clone());
898 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
899 shadow_arg.pat = P(ast::Pat {
900 id: ast::DUMMY_NODE_ID,
901 kind: PatKind::Ident(BindingMode::NONE, ident, None),
902 span: shadow_arg.pat.span,
903 tokens: shadow_arg.pat.tokens.clone(),
904 });
905 d_inputs.push(shadow_arg.clone());
906 }
907 }
908 DiffActivity::Const => {
909 }
911 DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
912 panic!("Should not happen");
913 }
914 }
915 if let PatKind::Ident(_, ident, _) = arg.pat.kind {
916 idents.push(ident.clone());
917 } else {
918 panic!("not an ident?");
919 }
920 }
921
922 let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
923 if active_only_ret {
924 assert!(x.mode.is_rev());
925 }
926
927 if x.mode.is_rev() {
930 match x.ret_activity {
931 DiffActivity::Active | DiffActivity::ActiveOnly => {
932 let ty = match d_decl.output {
933 FnRetTy::Ty(ref ty) => ty.clone(),
934 FnRetTy::Default(span) => {
935 panic!("Did not expect Default ret ty: {:?}", span);
936 }
937 };
938 let name = "dret".to_string();
939 let ident = Ident::from_str_and_span(&name, ty.span);
940 let shadow_arg = ast::Param {
941 attrs: ThinVec::new(),
942 ty: ty.clone(),
943 pat: P(ast::Pat {
944 id: ast::DUMMY_NODE_ID,
945 kind: PatKind::Ident(BindingMode::NONE, ident, None),
946 span: ty.span,
947 tokens: None,
948 }),
949 id: ast::DUMMY_NODE_ID,
950 span: ty.span,
951 is_placeholder: false,
952 };
953 d_inputs.push(shadow_arg);
954 new_inputs.push(name);
955 }
956 _ => {}
957 }
958 }
959 d_decl.inputs = d_inputs.into();
960
961 if x.mode.is_fwd() {
962 let ty = match d_decl.output {
963 FnRetTy::Ty(ref ty) => ty.clone(),
964 FnRetTy::Default(span) => {
965 let kind = TyKind::Tup(ThinVec::new());
967 let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
968 d_decl.output = FnRetTy::Ty(ty.clone());
969 assert!(matches!(x.ret_activity, DiffActivity::None));
970 ty
972 }
973 };
974
975 if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
976 let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
977 TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
980 } else {
981 let anon_const = rustc_ast::AnonConst {
983 id: ast::DUMMY_NODE_ID,
984 value: ecx.expr_usize(span, 1 + x.width as usize),
985 };
986 TyKind::Array(ty.clone(), anon_const)
987 };
988 let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
989 d_decl.output = FnRetTy::Ty(ty);
990 }
991 if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
992 if x.width > 1 {
996 let anon_const = rustc_ast::AnonConst {
997 id: ast::DUMMY_NODE_ID,
998 value: ecx.expr_usize(span, x.width as usize),
999 };
1000 let kind = TyKind::Array(ty.clone(), anon_const);
1001 let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
1002 d_decl.output = FnRetTy::Ty(ty);
1003 }
1004 }
1005 }
1006
1007 d_decl.output =
1009 if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
1010
1011 trace!("act_ret: {:?}", act_ret);
1012
1013 if act_ret.len() > 0 {
1017 let ret_ty = match d_decl.output {
1018 FnRetTy::Ty(ref ty) => {
1019 if !active_only_ret {
1020 act_ret.insert(0, ty.clone());
1021 }
1022 let kind = TyKind::Tup(act_ret);
1023 P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
1024 }
1025 FnRetTy::Default(span) => {
1026 if act_ret.len() == 1 {
1027 act_ret[0].clone()
1028 } else {
1029 let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
1030 P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
1031 }
1032 }
1033 };
1034 d_decl.output = FnRetTy::Ty(ret_ty);
1035 }
1036
1037 let mut d_header = sig.header.clone();
1038 if unsafe_activities {
1039 d_header.safety = rustc_ast::Safety::Unsafe(span);
1040 }
1041 let d_sig = FnSig { header: d_header, decl: d_decl, span };
1042 trace!("Generated signature: {:?}", d_sig);
1043 (d_sig, new_inputs, idents, false)
1044 }
1045}
1046
1047pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};