rustc_monomorphize/partitioning/
autodiff.rs1use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
2use rustc_hir::def_id::LOCAL_CRATE;
3use rustc_middle::bug;
4use rustc_middle::mir::mono::MonoItem;
5use rustc_middle::ty::{self, Instance, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
6use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
7use tracing::{debug, trace};
8
9use crate::partitioning::UsageMap;
10
11fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) {
12 if !matches!(fn_ty.kind(), ty::FnDef(..)) {
13 bug!("expected fn def for autodiff, got {:?}", fn_ty);
14 }
15
16 let sig = fn_ty.fn_sig(tcx).skip_binder();
19
20 let mut new_activities = vec![];
21 let mut new_positions = vec![];
22 for (i, ty) in sig.inputs().iter().enumerate() {
23 if let Some(inner_ty) = ty.builtin_deref(true) {
24 if inner_ty.is_slice() {
25 let sty = match inner_ty.builtin_index() {
28 Some(sty) => sty,
29 None => {
30 panic!("slice element type unknown");
31 }
32 };
33 let pci = PseudoCanonicalInput {
34 typing_env: TypingEnv::fully_monomorphized(),
35 value: sty,
36 };
37
38 let layout = tcx.layout_of(pci);
39 let elem_size = match layout {
40 Ok(layout) => layout.size,
41 Err(_) => {
42 bug!("autodiff failed to compute slice element size");
43 }
44 };
45 let elem_size: u32 = elem_size.bytes() as u32;
46
47 if !da.is_empty() {
49 let activity = match da[i] {
56 DiffActivity::DualOnly
57 | DiffActivity::Dual
58 | DiffActivity::Dualv
59 | DiffActivity::DuplicatedOnly
60 | DiffActivity::Duplicated => {
61 DiffActivity::FakeActivitySize(Some(elem_size))
62 }
63 DiffActivity::Const => DiffActivity::Const,
64 _ => bug!("unexpected activity for ptr/ref"),
65 };
66 new_activities.push(activity);
67 new_positions.push(i + 1);
68 }
69
70 continue;
71 }
72 }
73 }
74 for _ in 0..new_activities.len() {
77 let pos = new_positions.pop().unwrap();
78 let activity = new_activities.pop().unwrap();
79 da.insert(pos, activity);
80 }
81}
82
83pub(crate) fn find_autodiff_source_functions<'tcx>(
84 tcx: TyCtxt<'tcx>,
85 usage_map: &UsageMap<'tcx>,
86 autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>,
87) -> Vec<AutoDiffItem> {
88 let mut autodiff_items: Vec<AutoDiffItem> = vec![];
89 for (item, instance) in autodiff_mono_items {
90 let target_id = instance.def_id();
91 let cg_fn_attr = &tcx.codegen_fn_attrs(target_id).autodiff_item;
92 let Some(target_attrs) = cg_fn_attr else {
93 continue;
94 };
95 let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
96 if target_attrs.is_source() {
97 trace!("source found: {:?}", target_id);
98 }
99 if !target_attrs.apply_autodiff() {
100 continue;
101 }
102
103 let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
104
105 let source =
106 usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
107 MonoItem::Fn(ref instance_s) => {
108 let source_id = instance_s.def_id();
109 if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item
110 && ad.is_active()
111 {
112 return Some(instance_s);
113 }
114 None
115 }
116 _ => None,
117 });
118 let inst = match source {
119 Some(source) => source,
120 None => continue,
121 };
122
123 debug!("source_id: {:?}", inst.def_id());
124 let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
125 assert!(fn_ty.is_fn());
126 adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
127 let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);
128
129 let mut new_target_attrs = target_attrs.clone();
130 new_target_attrs.input_activity = input_activities;
131 let itm = new_target_attrs.into_item(symb, target_symbol);
132 autodiff_items.push(itm);
133 }
134
135 if !autodiff_items.is_empty() {
136 trace!("AUTODIFF ITEMS EXIST");
137 for item in &mut *autodiff_items {
138 trace!("{}", &item);
139 }
140 }
141
142 autodiff_items
143}