rustc_ast/expand/autodiff_attrs.rs
1//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
3//! is the function to which the autodiff attribute is applied, and the target is the function
4//! getting generated by us (with a name given by the user as the first autodiff arg).
5
6use std::fmt::{self, Display, Formatter};
7use std::str::FromStr;
8
9use crate::expand::{Decodable, Encodable, HashStable_Generic};
10use crate::{Ty, TyKind};
11
12/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
13/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
14/// are a hack to support higher order derivatives. We need to compute first order derivatives
15/// before we compute second order derivatives, otherwise we would differentiate our placeholder
16/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
17/// as it's already done in the C++ and Julia frontend of Enzyme.
18///
19/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
20/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
21#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
22pub enum DiffMode {
23 /// No autodiff is applied (used during error handling).
24 Error,
25 /// The primal function which we will differentiate.
26 Source,
27 /// The target function, to be created using forward mode AD.
28 Forward,
29 /// The target function, to be created using reverse mode AD.
30 Reverse,
31}
32
33/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
34/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
35/// we add to the previous shadow value. To not surprise users, we picked different names.
36/// Dual numbers is also a quite well known name for forward mode AD types.
37#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
38pub enum DiffActivity {
39 /// Implicit or Explicit () return type, so a special case of Const.
40 None,
41 /// Don't compute derivatives with respect to this input/output.
42 Const,
43 /// Reverse Mode, Compute derivatives for this scalar input/output.
44 Active,
45 /// Reverse Mode, Compute derivatives for this scalar output, but don't compute
46 /// the original return value.
47 ActiveOnly,
48 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
49 /// with it.
50 Dual,
51 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
52 /// with it. It expects the shadow argument to be `width` times larger than the original
53 /// input/output.
54 Dualv,
55 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
56 /// with it. Drop the code which updates the original input/output for maximum performance.
57 DualOnly,
58 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
59 /// with it. Drop the code which updates the original input/output for maximum performance.
60 /// It expects the shadow argument to be `width` times larger than the original input/output.
61 DualvOnly,
62 /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
63 Duplicated,
64 /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
65 /// Drop the code which updates the original input for maximum performance.
66 DuplicatedOnly,
67 /// All Integers must be Const, but these are used to mark the integer which represents the
68 /// length of a slice/vec. This is used for safety checks on slices.
69 /// The integer (if given) specifies the size of the slice element in bytes.
70 FakeActivitySize(Option<u32>),
71}
72
73impl DiffActivity {
74 pub fn is_dual_or_const(&self) -> bool {
75 use DiffActivity::*;
76 matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const)
77 }
78}
79/// We generate one of these structs for each `#[autodiff(...)]` attribute.
80#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
81pub struct AutoDiffItem {
82 /// The name of the function getting differentiated
83 pub source: String,
84 /// The name of the function being generated
85 pub target: String,
86 pub attrs: AutoDiffAttrs,
87}
88
89#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
90pub struct AutoDiffAttrs {
91 /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
92 /// e.g. in the [JAX
93 /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
94 pub mode: DiffMode,
95 /// A user-provided, batching width. If not given, we will default to 1 (no batching).
96 /// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
97 /// - Calling the function 50 times with a batch size of 2
98 /// - Calling the function 25 times with a batch size of 4,
99 /// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
100 /// cache locality, better re-usal of primal values, and other optimizations.
101 /// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
102 /// times, so this massively increases code size. As such, values like 1024 are unlikely to
103 /// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
104 /// experiments for now and focus on documenting the implications of a large width.
105 pub width: u32,
106 pub ret_activity: DiffActivity,
107 pub input_activity: Vec<DiffActivity>,
108}
109
110impl AutoDiffAttrs {
111 pub fn has_primal_ret(&self) -> bool {
112 matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
113 }
114}
115
116impl DiffMode {
117 pub fn is_rev(&self) -> bool {
118 matches!(self, DiffMode::Reverse)
119 }
120 pub fn is_fwd(&self) -> bool {
121 matches!(self, DiffMode::Forward)
122 }
123}
124
125impl Display for DiffMode {
126 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
127 match self {
128 DiffMode::Error => write!(f, "Error"),
129 DiffMode::Source => write!(f, "Source"),
130 DiffMode::Forward => write!(f, "Forward"),
131 DiffMode::Reverse => write!(f, "Reverse"),
132 }
133 }
134}
135
136/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
137/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
138/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
139/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
140/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
141pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
142 if activity == DiffActivity::None {
143 // Only valid if primal returns (), but we can't check that here.
144 return true;
145 }
146 match mode {
147 DiffMode::Error => false,
148 DiffMode::Source => false,
149 DiffMode::Forward => activity.is_dual_or_const(),
150 DiffMode::Reverse => {
151 activity == DiffActivity::Const
152 || activity == DiffActivity::Active
153 || activity == DiffActivity::ActiveOnly
154 }
155 }
156}
157
158/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
159/// for the given argument, but we generally can't know the size of such a type.
160/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
161/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
162/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
163/// users here from marking scalars as Duplicated, due to type aliases.
164pub fn valid_ty_for_activity(ty: &Box<Ty>, activity: DiffActivity) -> bool {
165 use DiffActivity::*;
166 // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
167 // Dual variants also support all types.
168 if activity.is_dual_or_const() {
169 return true;
170 }
171 // FIXME(ZuseZ4) We should make this more robust to also
172 // handle type aliases. Once that is done, we can be more restrictive here.
173 if matches!(activity, Active | ActiveOnly) {
174 return true;
175 }
176 matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
177 && matches!(activity, Duplicated | DuplicatedOnly)
178}
179pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
180 use DiffActivity::*;
181 return match mode {
182 DiffMode::Error => false,
183 DiffMode::Source => false,
184 DiffMode::Forward => activity.is_dual_or_const(),
185 DiffMode::Reverse => {
186 matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
187 }
188 };
189}
190
191impl Display for DiffActivity {
192 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
193 match self {
194 DiffActivity::None => write!(f, "None"),
195 DiffActivity::Const => write!(f, "Const"),
196 DiffActivity::Active => write!(f, "Active"),
197 DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
198 DiffActivity::Dual => write!(f, "Dual"),
199 DiffActivity::Dualv => write!(f, "Dualv"),
200 DiffActivity::DualOnly => write!(f, "DualOnly"),
201 DiffActivity::DualvOnly => write!(f, "DualvOnly"),
202 DiffActivity::Duplicated => write!(f, "Duplicated"),
203 DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
204 DiffActivity::FakeActivitySize(s) => write!(f, "FakeActivitySize({:?})", s),
205 }
206 }
207}
208
209impl FromStr for DiffMode {
210 type Err = ();
211
212 fn from_str(s: &str) -> Result<DiffMode, ()> {
213 match s {
214 "Error" => Ok(DiffMode::Error),
215 "Source" => Ok(DiffMode::Source),
216 "Forward" => Ok(DiffMode::Forward),
217 "Reverse" => Ok(DiffMode::Reverse),
218 _ => Err(()),
219 }
220 }
221}
222impl FromStr for DiffActivity {
223 type Err = ();
224
225 fn from_str(s: &str) -> Result<DiffActivity, ()> {
226 match s {
227 "None" => Ok(DiffActivity::None),
228 "Active" => Ok(DiffActivity::Active),
229 "ActiveOnly" => Ok(DiffActivity::ActiveOnly),
230 "Const" => Ok(DiffActivity::Const),
231 "Dual" => Ok(DiffActivity::Dual),
232 "Dualv" => Ok(DiffActivity::Dualv),
233 "DualOnly" => Ok(DiffActivity::DualOnly),
234 "DualvOnly" => Ok(DiffActivity::DualvOnly),
235 "Duplicated" => Ok(DiffActivity::Duplicated),
236 "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
237 _ => Err(()),
238 }
239 }
240}
241
242impl AutoDiffAttrs {
243 pub fn has_ret_activity(&self) -> bool {
244 self.ret_activity != DiffActivity::None
245 }
246 pub fn has_active_only_ret(&self) -> bool {
247 self.ret_activity == DiffActivity::ActiveOnly
248 }
249
250 pub const fn error() -> Self {
251 AutoDiffAttrs {
252 mode: DiffMode::Error,
253 width: 0,
254 ret_activity: DiffActivity::None,
255 input_activity: Vec::new(),
256 }
257 }
258 pub fn source() -> Self {
259 AutoDiffAttrs {
260 mode: DiffMode::Source,
261 width: 0,
262 ret_activity: DiffActivity::None,
263 input_activity: Vec::new(),
264 }
265 }
266
267 pub fn is_active(&self) -> bool {
268 self.mode != DiffMode::Error
269 }
270
271 pub fn is_source(&self) -> bool {
272 self.mode == DiffMode::Source
273 }
274 pub fn apply_autodiff(&self) -> bool {
275 !matches!(self.mode, DiffMode::Error | DiffMode::Source)
276 }
277
278 pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
279 AutoDiffItem { source, target, attrs: self }
280 }
281}
282
283impl fmt::Display for AutoDiffItem {
284 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285 write!(f, "Differentiating {} -> {}", self.source, self.target)?;
286 write!(f, " with attributes: {:?}", self.attrs)
287 }
288}