rustc_monomorphize/partitioning/
autodiff.rs

1use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
2use rustc_hir::def_id::LOCAL_CRATE;
3use rustc_middle::bug;
4use rustc_middle::mir::mono::MonoItem;
5use rustc_middle::ty::{self, Instance, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
6use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
7use tracing::{debug, trace};
8
9use crate::partitioning::UsageMap;
10
11fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) {
12    if !matches!(fn_ty.kind(), ty::FnDef(..)) {
13        bug!("expected fn def for autodiff, got {:?}", fn_ty);
14    }
15
16    // We don't actually pass the types back into the type system.
17    // All we do is decide how to handle the arguments.
18    let sig = fn_ty.fn_sig(tcx).skip_binder();
19
20    let mut new_activities = vec![];
21    let mut new_positions = vec![];
22    for (i, ty) in sig.inputs().iter().enumerate() {
23        if let Some(inner_ty) = ty.builtin_deref(true) {
24            if inner_ty.is_slice() {
25                // Now we need to figure out the size of each slice element in memory to allow
26                // safety checks and usability improvements in the backend.
27                let sty = match inner_ty.builtin_index() {
28                    Some(sty) => sty,
29                    None => {
30                        panic!("slice element type unknown");
31                    }
32                };
33                let pci = PseudoCanonicalInput {
34                    typing_env: TypingEnv::fully_monomorphized(),
35                    value: sty,
36                };
37
38                let layout = tcx.layout_of(pci);
39                let elem_size = match layout {
40                    Ok(layout) => layout.size,
41                    Err(_) => {
42                        bug!("autodiff failed to compute slice element size");
43                    }
44                };
45                let elem_size: u32 = elem_size.bytes() as u32;
46
47                // We know that the length will be passed as extra arg.
48                if !da.is_empty() {
49                    // We are looking at a slice. The length of that slice will become an
50                    // extra integer on llvm level. Integers are always const.
51                    // However, if the slice get's duplicated, we want to know to later check the
52                    // size. So we mark the new size argument as FakeActivitySize.
53                    // There is one FakeActivitySize per slice, so for convenience we store the
54                    // slice element size in bytes in it. We will use the size in the backend.
55                    let activity = match da[i] {
56                        DiffActivity::DualOnly
57                        | DiffActivity::Dual
58                        | DiffActivity::Dualv
59                        | DiffActivity::DuplicatedOnly
60                        | DiffActivity::Duplicated => {
61                            DiffActivity::FakeActivitySize(Some(elem_size))
62                        }
63                        DiffActivity::Const => DiffActivity::Const,
64                        _ => bug!("unexpected activity for ptr/ref"),
65                    };
66                    new_activities.push(activity);
67                    new_positions.push(i + 1);
68                }
69
70                continue;
71            }
72        }
73    }
74    // now add the extra activities coming from slices
75    // Reverse order to not invalidate the indices
76    for _ in 0..new_activities.len() {
77        let pos = new_positions.pop().unwrap();
78        let activity = new_activities.pop().unwrap();
79        da.insert(pos, activity);
80    }
81}
82
83pub(crate) fn find_autodiff_source_functions<'tcx>(
84    tcx: TyCtxt<'tcx>,
85    usage_map: &UsageMap<'tcx>,
86    autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>,
87) -> Vec<AutoDiffItem> {
88    let mut autodiff_items: Vec<AutoDiffItem> = vec![];
89    for (item, instance) in autodiff_mono_items {
90        let target_id = instance.def_id();
91        let cg_fn_attr = &tcx.codegen_fn_attrs(target_id).autodiff_item;
92        let Some(target_attrs) = cg_fn_attr else {
93            continue;
94        };
95        let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
96        if target_attrs.is_source() {
97            trace!("source found: {:?}", target_id);
98        }
99        if !target_attrs.apply_autodiff() {
100            continue;
101        }
102
103        let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
104
105        let source =
106            usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
107                MonoItem::Fn(ref instance_s) => {
108                    let source_id = instance_s.def_id();
109                    if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item
110                        && ad.is_active()
111                    {
112                        return Some(instance_s);
113                    }
114                    None
115                }
116                _ => None,
117            });
118        let inst = match source {
119            Some(source) => source,
120            None => continue,
121        };
122
123        debug!("source_id: {:?}", inst.def_id());
124        let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
125        assert!(fn_ty.is_fn());
126        adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
127        let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);
128
129        let mut new_target_attrs = target_attrs.clone();
130        new_target_attrs.input_activity = input_activities;
131        let itm = new_target_attrs.into_item(symb, target_symbol);
132        autodiff_items.push(itm);
133    }
134
135    if !autodiff_items.is_empty() {
136        trace!("AUTODIFF ITEMS EXIST");
137        for item in &mut *autodiff_items {
138            trace!("{}", &item);
139        }
140    }
141
142    autodiff_items
143}