miri/
math.rs

1use std::ops::Neg;
2use std::{f32, f64};
3
4use rand::Rng as _;
5use rustc_apfloat::Float as _;
6use rustc_apfloat::ieee::{DoubleS, IeeeFloat, Semantics, SingleS};
7use rustc_middle::ty::{self, FloatTy, ScalarInt};
8
9use crate::*;
10
11/// Disturbes a floating-point result by a relative error in the range (-2^scale, 2^scale).
12pub(crate) fn apply_random_float_error<F: rustc_apfloat::Float>(
13    ecx: &mut crate::MiriInterpCx<'_>,
14    val: F,
15    err_scale: i32,
16) -> F {
17    if !ecx.machine.float_nondet
18        || matches!(ecx.machine.float_rounding_error, FloatRoundingErrorMode::None)
19        // relative errors don't do anything to zeros... avoid messing up the sign
20        || val.is_zero()
21        // The logic below makes no sense if the input is already non-finite.
22        || !val.is_finite()
23    {
24        return val;
25    }
26    let rng = ecx.machine.rng.get_mut();
27
28    // Generate a random integer in the range [0, 2^PREC).
29    // (When read as binary, the position of the first `1` determines the exponent,
30    // and the remaining bits fill the mantissa. `PREC` is one plus the size of the mantissa,
31    // so this all works out.)
32    let r = F::from_u128(match ecx.machine.float_rounding_error {
33        FloatRoundingErrorMode::Random => rng.random_range(0..(1 << F::PRECISION)),
34        FloatRoundingErrorMode::Max => (1 << F::PRECISION) - 1, // force max error
35        FloatRoundingErrorMode::None => unreachable!(),
36    })
37    .value;
38    // Multiply this with 2^(scale - PREC). The result is between 0 and
39    // 2^PREC * 2^(scale - PREC) = 2^scale.
40    let err = r.scalbn(err_scale.strict_sub(F::PRECISION.try_into().unwrap()));
41    // give it a random sign
42    let err = if rng.random() { -err } else { err };
43    // Compute `val*(1+err)`, distributed out as `val + val*err` to avoid the imprecise addition
44    // error being amplified by multiplication.
45    (val + (val * err).value).value
46}
47
48/// Applies an error of `[-N, +N]` ULP to the given value.
49pub(crate) fn apply_random_float_error_ulp<F: rustc_apfloat::Float>(
50    ecx: &mut crate::MiriInterpCx<'_>,
51    val: F,
52    max_error: u32,
53) -> F {
54    // We could try to be clever and reuse `apply_random_float_error`, but that is hard to get right
55    // (see <https://github.com/rust-lang/miri/pull/4558#discussion_r2316838085> for why) so we
56    // implement the logic directly instead.
57    if !ecx.machine.float_nondet
58        || matches!(ecx.machine.float_rounding_error, FloatRoundingErrorMode::None)
59        // FIXME: also disturb zeros? That requires a lot more cases in `fixed_float_value`
60        // and might make the std test suite quite unhappy.
61        || val.is_zero()
62        // The logic below makes no sense if the input is already non-finite.
63        || !val.is_finite()
64    {
65        return val;
66    }
67    let rng = ecx.machine.rng.get_mut();
68
69    let max_error = i64::from(max_error);
70    let error = match ecx.machine.float_rounding_error {
71        FloatRoundingErrorMode::Random => rng.random_range(-max_error..=max_error),
72        FloatRoundingErrorMode::Max =>
73            if rng.random() {
74                max_error
75            } else {
76                -max_error
77            },
78        FloatRoundingErrorMode::None => unreachable!(),
79    };
80    // If upwards ULP and downwards ULP differ, we take the average.
81    let ulp = (((val.next_up().value - val).value + (val - val.next_down().value).value).value
82        / F::from_u128(2).value)
83        .value;
84    // Shift the value by N times the ULP
85    (val + (ulp * F::from_i128(error.into()).value).value).value
86}
87
88/// Applies an error of `[-N, +N]` ULP to the given value.
89/// Will fail if `val` is not a floating point number.
90pub(crate) fn apply_random_float_error_to_imm<'tcx>(
91    ecx: &mut MiriInterpCx<'tcx>,
92    val: ImmTy<'tcx>,
93    max_error: u32,
94) -> InterpResult<'tcx, ImmTy<'tcx>> {
95    let scalar = val.to_scalar_int()?;
96    let res: ScalarInt = match val.layout.ty.kind() {
97        ty::Float(FloatTy::F16) =>
98            apply_random_float_error_ulp(ecx, scalar.to_f16(), max_error).into(),
99        ty::Float(FloatTy::F32) =>
100            apply_random_float_error_ulp(ecx, scalar.to_f32(), max_error).into(),
101        ty::Float(FloatTy::F64) =>
102            apply_random_float_error_ulp(ecx, scalar.to_f64(), max_error).into(),
103        ty::Float(FloatTy::F128) =>
104            apply_random_float_error_ulp(ecx, scalar.to_f128(), max_error).into(),
105        _ => bug!("intrinsic called with non-float input type"),
106    };
107
108    interp_ok(ImmTy::from_scalar_int(res, val.layout))
109}
110
111/// Given a floating-point operation and a floating-point value, clamps the result to the output
112/// range of the given operation according to the C standard, if any.
113pub(crate) fn clamp_float_value<S: Semantics>(
114    intrinsic_name: &str,
115    val: IeeeFloat<S>,
116) -> IeeeFloat<S>
117where
118    IeeeFloat<S>: IeeeExt,
119{
120    let zero = IeeeFloat::<S>::ZERO;
121    let one = IeeeFloat::<S>::one();
122    let two = IeeeFloat::<S>::two();
123    let pi = IeeeFloat::<S>::pi();
124    let pi_over_2 = (pi / two).value;
125
126    match intrinsic_name {
127        // sin, cos, tanh: [-1, 1]
128        #[rustfmt::skip]
129        | "sinf32"
130        | "sinf64"
131        | "cosf32"
132        | "cosf64"
133        | "tanhf"
134        | "tanh"
135         => val.clamp(one.neg(), one),
136
137        // exp: [0, +INF)
138        "expf32" | "exp2f32" | "expf64" | "exp2f64" => val.maximum(zero),
139
140        // cosh: [1, +INF)
141        "coshf" | "cosh" => val.maximum(one),
142
143        // acos: [0, π]
144        "acosf" | "acos" => val.clamp(zero, pi),
145
146        // asin: [-π, +π]
147        "asinf" | "asin" => val.clamp(pi.neg(), pi),
148
149        // atan: (-π/2, +π/2)
150        "atanf" | "atan" => val.clamp(pi_over_2.neg(), pi_over_2),
151
152        // erfc: (-1, 1)
153        "erff" | "erf" => val.clamp(one.neg(), one),
154
155        // erfc: (0, 2)
156        "erfcf" | "erfc" => val.clamp(zero, two),
157
158        // atan2(y, x): arctan(y/x) in [−π, +π]
159        "atan2f" | "atan2" => val.clamp(pi.neg(), pi),
160
161        _ => val,
162    }
163}
164
165/// For the intrinsics:
166/// - sinf32, sinf64, sinhf, sinh
167/// - cosf32, cosf64, coshf, cosh
168/// - tanhf, tanh, atanf, atan, atan2f, atan2
169/// - expf32, expf64, exp2f32, exp2f64
170/// - logf32, logf64, log2f32, log2f64, log10f32, log10f64
171/// - powf32, powf64
172/// - erff, erf, erfcf, erfc
173/// - hypotf, hypot
174///
175/// # Return
176///
177/// Returns `Some(output)` if the `intrinsic` results in a defined fixed `output` specified in the C standard
178/// (specifically, C23 annex F.10)  when given `args` as arguments. Outputs that are unaffected by a relative error
179/// (such as INF and zero) are not handled here, they are assumed to be handled by the underlying
180/// implementation. Returns `None` if no specific value is guaranteed.
181///
182/// # Note
183///
184/// For `powf*` operations of the form:
185///
186/// - `(SNaN)^(±0)`
187/// - `1^(SNaN)`
188///
189/// The result is implementation-defined:
190/// - musl returns for both `1.0`
191/// - glibc returns for both `NaN`
192///
193/// This discrepancy exists because SNaN handling is not consistently defined across platforms,
194/// and the C standard leaves behavior for SNaNs unspecified.
195///
196/// Miri chooses to adhere to both implementations and returns either one of them non-deterministically.
197pub(crate) fn fixed_float_value<S: Semantics>(
198    ecx: &mut MiriInterpCx<'_>,
199    intrinsic_name: &str,
200    args: &[IeeeFloat<S>],
201) -> Option<IeeeFloat<S>>
202where
203    IeeeFloat<S>: IeeeExt,
204{
205    let this = ecx.eval_context_mut();
206    let one = IeeeFloat::<S>::one();
207    let two = IeeeFloat::<S>::two();
208    let three = IeeeFloat::<S>::three();
209    let pi = IeeeFloat::<S>::pi();
210    let pi_over_2 = (pi / two).value;
211    let pi_over_4 = (pi_over_2 / two).value;
212
213    Some(match (intrinsic_name, args) {
214        // cos(±0) and cosh(±0)= 1
215        ("cosf32" | "cosf64" | "coshf" | "cosh", [input]) if input.is_zero() => one,
216
217        // e^0 = 1
218        ("expf32" | "expf64" | "exp2f32" | "exp2f64", [input]) if input.is_zero() => one,
219
220        // tanh(±INF) = ±1
221        ("tanhf" | "tanh", [input]) if input.is_infinite() => one.copy_sign(*input),
222
223        // atan(±INF) = ±π/2
224        ("atanf" | "atan", [input]) if input.is_infinite() => pi_over_2.copy_sign(*input),
225
226        // erf(±INF) = ±1
227        ("erff" | "erf", [input]) if input.is_infinite() => one.copy_sign(*input),
228
229        // erfc(-INF) = 2
230        ("erfcf" | "erfc", [input]) if input.is_neg_infinity() => (one + one).value,
231
232        // hypot(x, ±0) = abs(x), if x is not a NaN.
233        ("_hypotf" | "hypotf" | "_hypot" | "hypot", [x, y]) if !x.is_nan() && y.is_zero() =>
234            x.abs(),
235
236        // atan2(±0,−0) = ±π.
237        // atan2(±0, y) = ±π for y < 0.
238        // Must check for non NaN because `y.is_negative()` also applies to NaN.
239        ("atan2f" | "atan2", [x, y]) if (x.is_zero() && (y.is_negative() && !y.is_nan())) =>
240            pi.copy_sign(*x),
241
242        // atan2(±x,−∞) = ±π for finite x > 0.
243        ("atan2f" | "atan2", [x, y])
244            if (!x.is_zero() && !x.is_infinite()) && y.is_neg_infinity() =>
245            pi.copy_sign(*x),
246
247        // atan2(x, ±0) = −π/2 for x < 0.
248        // atan2(x, ±0) =  π/2 for x > 0.
249        ("atan2f" | "atan2", [x, y]) if !x.is_zero() && y.is_zero() => pi_over_2.copy_sign(*x),
250
251        //atan2(±∞, −∞) = ±3π/4
252        ("atan2f" | "atan2", [x, y]) if x.is_infinite() && y.is_neg_infinity() =>
253            (pi_over_4 * three).value.copy_sign(*x),
254
255        //atan2(±∞, +∞) = ±π/4
256        ("atan2f" | "atan2", [x, y]) if x.is_infinite() && y.is_pos_infinity() =>
257            pi_over_4.copy_sign(*x),
258
259        // atan2(±∞, y) returns ±π/2 for finite y.
260        ("atan2f" | "atan2", [x, y]) if x.is_infinite() && (!y.is_infinite() && !y.is_nan()) =>
261            pi_over_2.copy_sign(*x),
262
263        // (-1)^(±INF) = 1
264        ("powf32" | "powf64", [base, exp]) if *base == -one && exp.is_infinite() => one,
265
266        // 1^y = 1 for any y, even a NaN
267        ("powf32" | "powf64", [base, exp]) if *base == one => {
268            let rng = this.machine.rng.get_mut();
269            // SNaN exponents get special treatment: they might return 1, or a NaN.
270            let return_nan = exp.is_signaling() && this.machine.float_nondet && rng.random();
271            // Handle both the musl and glibc cases non-deterministically.
272            if return_nan { this.generate_nan(args) } else { one }
273        }
274
275        // x^(±0) = 1 for any x, even a NaN
276        ("powf32" | "powf64", [base, exp]) if exp.is_zero() => {
277            let rng = this.machine.rng.get_mut();
278            // SNaN bases get special treatment: they might return 1, or a NaN.
279            let return_nan = base.is_signaling() && this.machine.float_nondet && rng.random();
280            // Handle both the musl and glibc cases non-deterministically.
281            if return_nan { this.generate_nan(args) } else { one }
282        }
283
284        // There are a lot of cases for fixed outputs according to the C Standard, but these are
285        // mainly INF or zero which are not affected by the applied error.
286        _ => return None,
287    })
288}
289
290/// Returns `Some(output)` if `powi` (called `pown` in C) results in a fixed value specified in the
291/// C standard (specifically, C23 annex F.10.4.6) when doing `base^exp`. Otherwise, returns `None`.
292pub(crate) fn fixed_powi_value<S: Semantics>(
293    ecx: &mut MiriInterpCx<'_>,
294    base: IeeeFloat<S>,
295    exp: i32,
296) -> Option<IeeeFloat<S>>
297where
298    IeeeFloat<S>: IeeeExt,
299{
300    match exp {
301        0 => {
302            let one = IeeeFloat::<S>::one();
303            let rng = ecx.machine.rng.get_mut();
304            let return_nan = ecx.machine.float_nondet && rng.random() && base.is_signaling();
305            // For SNaN treatment, we are consistent with `powf`above.
306            // (We wouldn't have two, unlike powf all implementations seem to agree for powi,
307            // but for now we are maximally conservative.)
308            Some(if return_nan { ecx.generate_nan(&[base]) } else { one })
309        }
310
311        _ => return None,
312    }
313}
314
315pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFloat<S> {
316    match x.category() {
317        // preserve zero sign
318        rustc_apfloat::Category::Zero => x,
319        // propagate NaN
320        rustc_apfloat::Category::NaN => x,
321        // sqrt of negative number is NaN
322        _ if x.is_negative() => IeeeFloat::NAN,
323        // sqrt(∞) = ∞
324        rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY,
325        rustc_apfloat::Category::Normal => {
326            // Floating point precision, excluding the integer bit
327            let prec = i32::try_from(S::PRECISION).unwrap() - 1;
328
329            // x = 2^(exp - prec) * mant
330            // where mant is an integer with prec+1 bits
331            // mant is a u128, which should be large enough for the largest prec (112 for f128)
332            let mut exp = x.ilogb();
333            let mut mant = x.scalbn(prec - exp).to_u128(128).value;
334
335            if exp % 2 != 0 {
336                // Make exponent even, so it can be divided by 2
337                exp -= 1;
338                mant <<= 1;
339            }
340
341            // Bit-by-bit (base-2 digit-by-digit) sqrt of mant.
342            // mant is treated here as a fixed point number with prec fractional bits.
343            // mant will be shifted left by one bit to have an extra fractional bit, which
344            // will be used to determine the rounding direction.
345
346            // res is the truncated sqrt of mant, where one bit is added at each iteration.
347            let mut res = 0u128;
348            // rem is the remainder with the current res
349            // rem_i = 2^i * ((mant<<1) - res_i^2)
350            // starting with res = 0, rem = mant<<1
351            let mut rem = mant << 1;
352            // s_i = 2*res_i
353            let mut s = 0u128;
354            // d is used to iterate over bits, from high to low (d_i = 2^(-i))
355            let mut d = 1u128 << (prec + 1);
356
357            // For iteration j=i+1, we need to find largest b_j = 0 or 1 such that
358            //  (res_i + b_j * 2^(-j))^2 <= mant<<1
359            // Expanding (a + b)^2 = a^2 + b^2 + 2*a*b:
360            //  res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1
361            // And rearranging the terms:
362            //  b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2)
363            //  b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i
364
365            while d != 0 {
366                // Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1:
367                // t = 2*res_i + 2^(-j)
368                let t = s + d;
369                if rem >= t {
370                    // b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem
371                    res += d;
372                    s += d + d;
373                    rem -= t;
374                }
375                // Adjust rem for next iteration
376                rem <<= 1;
377                // Shift iterator
378                d >>= 1;
379            }
380
381            // Remove extra fractional bit from result, rounding to nearest.
382            // If the last bit is 0, then the nearest neighbor is definitely the lower one.
383            // If the last bit is 1, it sounds like this may either be a tie (if there's
384            // infinitely many 0s after this 1), or the nearest neighbor is the upper one.
385            // However, since square roots are either exact or irrational, and an exact root
386            // would lead to the last "extra" bit being 0, we can exclude a tie in this case.
387            // We therefore always round up if the last bit is 1. When the last bit is 0,
388            // adding 1 will not do anything since the shift will discard it.
389            res = (res + 1) >> 1;
390
391            // Build resulting value with res as mantissa and exp/2 as exponent
392            IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec)
393        }
394    }
395}
396
397/// Extend functionality of `rustc_apfloat` softfloats for IEEE float types.
398pub trait IeeeExt: rustc_apfloat::Float {
399    // Some values we use:
400
401    #[inline]
402    fn one() -> Self {
403        Self::from_u128(1).value
404    }
405
406    #[inline]
407    fn two() -> Self {
408        Self::from_u128(2).value
409    }
410
411    #[inline]
412    fn three() -> Self {
413        Self::from_u128(3).value
414    }
415
416    fn pi() -> Self;
417
418    #[inline]
419    fn clamp(self, min: Self, max: Self) -> Self {
420        self.maximum(min).minimum(max)
421    }
422}
423
424macro_rules! impl_ieee_pi {
425    ($float_ty:ident, $semantic:ty) => {
426        impl IeeeExt for IeeeFloat<$semantic> {
427            #[inline]
428            fn pi() -> Self {
429                // We take the value from the standard library as the most reasonable source for an exact π here.
430                Self::from_bits($float_ty::consts::PI.to_bits().into())
431            }
432        }
433    };
434}
435
436impl_ieee_pi!(f32, SingleS);
437impl_ieee_pi!(f64, DoubleS);
438
439#[cfg(test)]
440mod tests {
441    use rustc_apfloat::ieee::{DoubleS, HalfS, IeeeFloat, QuadS, SingleS};
442
443    use super::sqrt;
444
445    #[test]
446    fn test_sqrt() {
447        #[track_caller]
448        fn test<S: rustc_apfloat::ieee::Semantics>(x: &str, expected: &str) {
449            let x: IeeeFloat<S> = x.parse().unwrap();
450            let expected: IeeeFloat<S> = expected.parse().unwrap();
451            let result = sqrt(x);
452            assert_eq!(result, expected);
453        }
454
455        fn exact_tests<S: rustc_apfloat::ieee::Semantics>() {
456            test::<S>("0", "0");
457            test::<S>("1", "1");
458            test::<S>("1.5625", "1.25");
459            test::<S>("2.25", "1.5");
460            test::<S>("4", "2");
461            test::<S>("5.0625", "2.25");
462            test::<S>("9", "3");
463            test::<S>("16", "4");
464            test::<S>("25", "5");
465            test::<S>("36", "6");
466            test::<S>("49", "7");
467            test::<S>("64", "8");
468            test::<S>("81", "9");
469            test::<S>("100", "10");
470
471            test::<S>("0.5625", "0.75");
472            test::<S>("0.25", "0.5");
473            test::<S>("0.0625", "0.25");
474            test::<S>("0.00390625", "0.0625");
475        }
476
477        exact_tests::<HalfS>();
478        exact_tests::<SingleS>();
479        exact_tests::<DoubleS>();
480        exact_tests::<QuadS>();
481
482        test::<SingleS>("2", "1.4142135");
483        test::<DoubleS>("2", "1.4142135623730951");
484
485        test::<SingleS>("1.1", "1.0488088");
486        test::<DoubleS>("1.1", "1.0488088481701516");
487
488        test::<SingleS>("2.2", "1.4832398");
489        test::<DoubleS>("2.2", "1.4832396974191326");
490
491        test::<SingleS>("1.22101e-40", "1.10499205e-20");
492        test::<DoubleS>("1.22101e-310", "1.1049932126488395e-155");
493
494        test::<SingleS>("3.4028235e38", "1.8446743e19");
495        test::<DoubleS>("1.7976931348623157e308", "1.3407807929942596e154");
496    }
497}