rustc_ast/expand/
autodiff_attrs.rs

1//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
3//! is the function to which the autodiff attribute is applied, and the target is the function
4//! getting generated by us (with a name given by the user as the first autodiff arg).
5
6use std::fmt::{self, Display, Formatter};
7use std::str::FromStr;
8
9use crate::expand::{Decodable, Encodable, HashStable_Generic};
10use crate::{Ty, TyKind};
11
12/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
13/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
14/// are a hack to support higher order derivatives. We need to compute first order derivatives
15/// before we compute second order derivatives, otherwise we would differentiate our placeholder
16/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
17/// as it's already done in the C++ and Julia frontend of Enzyme.
18///
19/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
20/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
21#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
22pub enum DiffMode {
23    /// No autodiff is applied (used during error handling).
24    Error,
25    /// The primal function which we will differentiate.
26    Source,
27    /// The target function, to be created using forward mode AD.
28    Forward,
29    /// The target function, to be created using reverse mode AD.
30    Reverse,
31}
32
33/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
34/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
35/// we add to the previous shadow value. To not surprise users, we picked different names.
36/// Dual numbers is also a quite well known name for forward mode AD types.
37#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
38pub enum DiffActivity {
39    /// Implicit or Explicit () return type, so a special case of Const.
40    None,
41    /// Don't compute derivatives with respect to this input/output.
42    Const,
43    /// Reverse Mode, Compute derivatives for this scalar input/output.
44    Active,
45    /// Reverse Mode, Compute derivatives for this scalar output, but don't compute
46    /// the original return value.
47    ActiveOnly,
48    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
49    /// with it.
50    Dual,
51    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
52    /// with it. It expects the shadow argument to be `width` times larger than the original
53    /// input/output.
54    Dualv,
55    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
56    /// with it. Drop the code which updates the original input/output for maximum performance.
57    DualOnly,
58    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
59    /// with it. Drop the code which updates the original input/output for maximum performance.
60    /// It expects the shadow argument to be `width` times larger than the original input/output.
61    DualvOnly,
62    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
63    Duplicated,
64    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
65    /// Drop the code which updates the original input for maximum performance.
66    DuplicatedOnly,
67    /// All Integers must be Const, but these are used to mark the integer which represents the
68    /// length of a slice/vec. This is used for safety checks on slices.
69    /// The integer (if given) specifies the size of the slice element in bytes.
70    FakeActivitySize(Option<u32>),
71}
72
73impl DiffActivity {
74    pub fn is_dual_or_const(&self) -> bool {
75        use DiffActivity::*;
76        matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const)
77    }
78}
79/// We generate one of these structs for each `#[autodiff(...)]` attribute.
80#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
81pub struct AutoDiffItem {
82    /// The name of the function getting differentiated
83    pub source: String,
84    /// The name of the function being generated
85    pub target: String,
86    pub attrs: AutoDiffAttrs,
87}
88
89#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
90pub struct AutoDiffAttrs {
91    /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
92    /// e.g. in the [JAX
93    /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
94    pub mode: DiffMode,
95    /// A user-provided, batching width. If not given, we will default to 1 (no batching).
96    /// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
97    /// - Calling the function 50 times with a batch size of 2
98    /// - Calling the function 25 times with a batch size of 4,
99    /// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
100    /// cache locality, better re-usal of primal values, and other optimizations.
101    /// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
102    /// times, so this massively increases code size. As such, values like 1024 are unlikely to
103    /// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
104    /// experiments for now and focus on documenting the implications of a large width.
105    pub width: u32,
106    pub ret_activity: DiffActivity,
107    pub input_activity: Vec<DiffActivity>,
108}
109
110impl AutoDiffAttrs {
111    pub fn has_primal_ret(&self) -> bool {
112        matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
113    }
114}
115
116impl DiffMode {
117    pub fn is_rev(&self) -> bool {
118        matches!(self, DiffMode::Reverse)
119    }
120    pub fn is_fwd(&self) -> bool {
121        matches!(self, DiffMode::Forward)
122    }
123}
124
125impl Display for DiffMode {
126    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
127        match self {
128            DiffMode::Error => write!(f, "Error"),
129            DiffMode::Source => write!(f, "Source"),
130            DiffMode::Forward => write!(f, "Forward"),
131            DiffMode::Reverse => write!(f, "Reverse"),
132        }
133    }
134}
135
136/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
137/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
138/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
139/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
140/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
141pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
142    if activity == DiffActivity::None {
143        // Only valid if primal returns (), but we can't check that here.
144        return true;
145    }
146    match mode {
147        DiffMode::Error => false,
148        DiffMode::Source => false,
149        DiffMode::Forward => activity.is_dual_or_const(),
150        DiffMode::Reverse => {
151            activity == DiffActivity::Const
152                || activity == DiffActivity::Active
153                || activity == DiffActivity::ActiveOnly
154        }
155    }
156}
157
158/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
159/// for the given argument, but we generally can't know the size of such a type.
160/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
161/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
162/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
163/// users here from marking scalars as Duplicated, due to type aliases.
164pub fn valid_ty_for_activity(ty: &Box<Ty>, activity: DiffActivity) -> bool {
165    use DiffActivity::*;
166    // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
167    // Dual variants also support all types.
168    if activity.is_dual_or_const() {
169        return true;
170    }
171    // FIXME(ZuseZ4) We should make this more robust to also
172    // handle type aliases. Once that is done, we can be more restrictive here.
173    if matches!(activity, Active | ActiveOnly) {
174        return true;
175    }
176    matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
177        && matches!(activity, Duplicated | DuplicatedOnly)
178}
179pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
180    use DiffActivity::*;
181    return match mode {
182        DiffMode::Error => false,
183        DiffMode::Source => false,
184        DiffMode::Forward => activity.is_dual_or_const(),
185        DiffMode::Reverse => {
186            matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
187        }
188    };
189}
190
191impl Display for DiffActivity {
192    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
193        match self {
194            DiffActivity::None => write!(f, "None"),
195            DiffActivity::Const => write!(f, "Const"),
196            DiffActivity::Active => write!(f, "Active"),
197            DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
198            DiffActivity::Dual => write!(f, "Dual"),
199            DiffActivity::Dualv => write!(f, "Dualv"),
200            DiffActivity::DualOnly => write!(f, "DualOnly"),
201            DiffActivity::DualvOnly => write!(f, "DualvOnly"),
202            DiffActivity::Duplicated => write!(f, "Duplicated"),
203            DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
204            DiffActivity::FakeActivitySize(s) => write!(f, "FakeActivitySize({:?})", s),
205        }
206    }
207}
208
209impl FromStr for DiffMode {
210    type Err = ();
211
212    fn from_str(s: &str) -> Result<DiffMode, ()> {
213        match s {
214            "Error" => Ok(DiffMode::Error),
215            "Source" => Ok(DiffMode::Source),
216            "Forward" => Ok(DiffMode::Forward),
217            "Reverse" => Ok(DiffMode::Reverse),
218            _ => Err(()),
219        }
220    }
221}
222impl FromStr for DiffActivity {
223    type Err = ();
224
225    fn from_str(s: &str) -> Result<DiffActivity, ()> {
226        match s {
227            "None" => Ok(DiffActivity::None),
228            "Active" => Ok(DiffActivity::Active),
229            "ActiveOnly" => Ok(DiffActivity::ActiveOnly),
230            "Const" => Ok(DiffActivity::Const),
231            "Dual" => Ok(DiffActivity::Dual),
232            "Dualv" => Ok(DiffActivity::Dualv),
233            "DualOnly" => Ok(DiffActivity::DualOnly),
234            "DualvOnly" => Ok(DiffActivity::DualvOnly),
235            "Duplicated" => Ok(DiffActivity::Duplicated),
236            "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
237            _ => Err(()),
238        }
239    }
240}
241
242impl AutoDiffAttrs {
243    pub fn has_ret_activity(&self) -> bool {
244        self.ret_activity != DiffActivity::None
245    }
246    pub fn has_active_only_ret(&self) -> bool {
247        self.ret_activity == DiffActivity::ActiveOnly
248    }
249
250    pub const fn error() -> Self {
251        AutoDiffAttrs {
252            mode: DiffMode::Error,
253            width: 0,
254            ret_activity: DiffActivity::None,
255            input_activity: Vec::new(),
256        }
257    }
258    pub fn source() -> Self {
259        AutoDiffAttrs {
260            mode: DiffMode::Source,
261            width: 0,
262            ret_activity: DiffActivity::None,
263            input_activity: Vec::new(),
264        }
265    }
266
267    pub fn is_active(&self) -> bool {
268        self.mode != DiffMode::Error
269    }
270
271    pub fn is_source(&self) -> bool {
272        self.mode == DiffMode::Source
273    }
274    pub fn apply_autodiff(&self) -> bool {
275        !matches!(self.mode, DiffMode::Error | DiffMode::Source)
276    }
277
278    pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
279        AutoDiffItem { source, target, attrs: self }
280    }
281}
282
283impl fmt::Display for AutoDiffItem {
284    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
285        write!(f, "Differentiating {} -> {}", self.source, self.target)?;
286        write!(f, " with attributes: {:?}", self.attrs)
287    }
288}