1use std::ops::Neg;
2use std::{f32, f64};
3
4use rand::Rng as _;
5use rustc_apfloat::Float as _;
6use rustc_apfloat::ieee::{DoubleS, IeeeFloat, Semantics, SingleS};
7use rustc_middle::ty::{self, FloatTy, ScalarInt};
8
9use crate::*;
10
11pub(crate) fn apply_random_float_error<F: rustc_apfloat::Float>(
13 ecx: &mut crate::MiriInterpCx<'_>,
14 val: F,
15 err_scale: i32,
16) -> F {
17 if !ecx.machine.float_nondet
18 || matches!(ecx.machine.float_rounding_error, FloatRoundingErrorMode::None)
19 || val.is_zero()
21 || !val.is_finite()
23 {
24 return val;
25 }
26 let rng = ecx.machine.rng.get_mut();
27
28 let r = F::from_u128(match ecx.machine.float_rounding_error {
33 FloatRoundingErrorMode::Random => rng.random_range(0..(1 << F::PRECISION)),
34 FloatRoundingErrorMode::Max => (1 << F::PRECISION) - 1, FloatRoundingErrorMode::None => unreachable!(),
36 })
37 .value;
38 let err = r.scalbn(err_scale.strict_sub(F::PRECISION.try_into().unwrap()));
41 let err = if rng.random() { -err } else { err };
43 (val + (val * err).value).value
46}
47
48pub(crate) fn apply_random_float_error_ulp<F: rustc_apfloat::Float>(
50 ecx: &mut crate::MiriInterpCx<'_>,
51 val: F,
52 max_error: u32,
53) -> F {
54 if !ecx.machine.float_nondet
58 || matches!(ecx.machine.float_rounding_error, FloatRoundingErrorMode::None)
59 || val.is_zero()
62 || !val.is_finite()
64 {
65 return val;
66 }
67 let rng = ecx.machine.rng.get_mut();
68
69 let max_error = i64::from(max_error);
70 let error = match ecx.machine.float_rounding_error {
71 FloatRoundingErrorMode::Random => rng.random_range(-max_error..=max_error),
72 FloatRoundingErrorMode::Max =>
73 if rng.random() {
74 max_error
75 } else {
76 -max_error
77 },
78 FloatRoundingErrorMode::None => unreachable!(),
79 };
80 let ulp = (((val.next_up().value - val).value + (val - val.next_down().value).value).value
82 / F::from_u128(2).value)
83 .value;
84 (val + (ulp * F::from_i128(error.into()).value).value).value
86}
87
88pub(crate) fn apply_random_float_error_to_imm<'tcx>(
91 ecx: &mut MiriInterpCx<'tcx>,
92 val: ImmTy<'tcx>,
93 max_error: u32,
94) -> InterpResult<'tcx, ImmTy<'tcx>> {
95 let scalar = val.to_scalar_int()?;
96 let res: ScalarInt = match val.layout.ty.kind() {
97 ty::Float(FloatTy::F16) =>
98 apply_random_float_error_ulp(ecx, scalar.to_f16(), max_error).into(),
99 ty::Float(FloatTy::F32) =>
100 apply_random_float_error_ulp(ecx, scalar.to_f32(), max_error).into(),
101 ty::Float(FloatTy::F64) =>
102 apply_random_float_error_ulp(ecx, scalar.to_f64(), max_error).into(),
103 ty::Float(FloatTy::F128) =>
104 apply_random_float_error_ulp(ecx, scalar.to_f128(), max_error).into(),
105 _ => bug!("intrinsic called with non-float input type"),
106 };
107
108 interp_ok(ImmTy::from_scalar_int(res, val.layout))
109}
110
111pub(crate) fn clamp_float_value<S: Semantics>(
114 intrinsic_name: &str,
115 val: IeeeFloat<S>,
116) -> IeeeFloat<S>
117where
118 IeeeFloat<S>: IeeeExt,
119{
120 let zero = IeeeFloat::<S>::ZERO;
121 let one = IeeeFloat::<S>::one();
122 let two = IeeeFloat::<S>::two();
123 let pi = IeeeFloat::<S>::pi();
124 let pi_over_2 = (pi / two).value;
125
126 match intrinsic_name {
127 #[rustfmt::skip]
129 | "sinf32"
130 | "sinf64"
131 | "cosf32"
132 | "cosf64"
133 | "tanhf"
134 | "tanh"
135 => val.clamp(one.neg(), one),
136
137 "expf32" | "exp2f32" | "expf64" | "exp2f64" => val.maximum(zero),
139
140 "coshf" | "cosh" => val.maximum(one),
142
143 "acosf" | "acos" => val.clamp(zero, pi),
145
146 "asinf" | "asin" => val.clamp(pi.neg(), pi),
148
149 "atanf" | "atan" => val.clamp(pi_over_2.neg(), pi_over_2),
151
152 "erff" | "erf" => val.clamp(one.neg(), one),
154
155 "erfcf" | "erfc" => val.clamp(zero, two),
157
158 "atan2f" | "atan2" => val.clamp(pi.neg(), pi),
160
161 _ => val,
162 }
163}
164
165pub(crate) fn fixed_float_value<S: Semantics>(
198 ecx: &mut MiriInterpCx<'_>,
199 intrinsic_name: &str,
200 args: &[IeeeFloat<S>],
201) -> Option<IeeeFloat<S>>
202where
203 IeeeFloat<S>: IeeeExt,
204{
205 let this = ecx.eval_context_mut();
206 let one = IeeeFloat::<S>::one();
207 let two = IeeeFloat::<S>::two();
208 let three = IeeeFloat::<S>::three();
209 let pi = IeeeFloat::<S>::pi();
210 let pi_over_2 = (pi / two).value;
211 let pi_over_4 = (pi_over_2 / two).value;
212
213 Some(match (intrinsic_name, args) {
214 ("cosf32" | "cosf64" | "coshf" | "cosh", [input]) if input.is_zero() => one,
216
217 ("expf32" | "expf64" | "exp2f32" | "exp2f64", [input]) if input.is_zero() => one,
219
220 ("tanhf" | "tanh", [input]) if input.is_infinite() => one.copy_sign(*input),
222
223 ("atanf" | "atan", [input]) if input.is_infinite() => pi_over_2.copy_sign(*input),
225
226 ("erff" | "erf", [input]) if input.is_infinite() => one.copy_sign(*input),
228
229 ("erfcf" | "erfc", [input]) if input.is_neg_infinity() => (one + one).value,
231
232 ("_hypotf" | "hypotf" | "_hypot" | "hypot", [x, y]) if !x.is_nan() && y.is_zero() =>
234 x.abs(),
235
236 ("atan2f" | "atan2", [x, y]) if (x.is_zero() && (y.is_negative() && !y.is_nan())) =>
240 pi.copy_sign(*x),
241
242 ("atan2f" | "atan2", [x, y])
244 if (!x.is_zero() && !x.is_infinite()) && y.is_neg_infinity() =>
245 pi.copy_sign(*x),
246
247 ("atan2f" | "atan2", [x, y]) if !x.is_zero() && y.is_zero() => pi_over_2.copy_sign(*x),
250
251 ("atan2f" | "atan2", [x, y]) if x.is_infinite() && y.is_neg_infinity() =>
253 (pi_over_4 * three).value.copy_sign(*x),
254
255 ("atan2f" | "atan2", [x, y]) if x.is_infinite() && y.is_pos_infinity() =>
257 pi_over_4.copy_sign(*x),
258
259 ("atan2f" | "atan2", [x, y]) if x.is_infinite() && (!y.is_infinite() && !y.is_nan()) =>
261 pi_over_2.copy_sign(*x),
262
263 ("powf32" | "powf64", [base, exp]) if *base == -one && exp.is_infinite() => one,
265
266 ("powf32" | "powf64", [base, exp]) if *base == one => {
268 let rng = this.machine.rng.get_mut();
269 let return_nan = exp.is_signaling() && this.machine.float_nondet && rng.random();
271 if return_nan { this.generate_nan(args) } else { one }
273 }
274
275 ("powf32" | "powf64", [base, exp]) if exp.is_zero() => {
277 let rng = this.machine.rng.get_mut();
278 let return_nan = base.is_signaling() && this.machine.float_nondet && rng.random();
280 if return_nan { this.generate_nan(args) } else { one }
282 }
283
284 _ => return None,
287 })
288}
289
290pub(crate) fn fixed_powi_value<S: Semantics>(
293 ecx: &mut MiriInterpCx<'_>,
294 base: IeeeFloat<S>,
295 exp: i32,
296) -> Option<IeeeFloat<S>>
297where
298 IeeeFloat<S>: IeeeExt,
299{
300 match exp {
301 0 => {
302 let one = IeeeFloat::<S>::one();
303 let rng = ecx.machine.rng.get_mut();
304 let return_nan = ecx.machine.float_nondet && rng.random() && base.is_signaling();
305 Some(if return_nan { ecx.generate_nan(&[base]) } else { one })
309 }
310
311 _ => return None,
312 }
313}
314
315pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFloat<S> {
316 match x.category() {
317 rustc_apfloat::Category::Zero => x,
319 rustc_apfloat::Category::NaN => x,
321 _ if x.is_negative() => IeeeFloat::NAN,
323 rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY,
325 rustc_apfloat::Category::Normal => {
326 let prec = i32::try_from(S::PRECISION).unwrap() - 1;
328
329 let mut exp = x.ilogb();
333 let mut mant = x.scalbn(prec - exp).to_u128(128).value;
334
335 if exp % 2 != 0 {
336 exp -= 1;
338 mant <<= 1;
339 }
340
341 let mut res = 0u128;
348 let mut rem = mant << 1;
352 let mut s = 0u128;
354 let mut d = 1u128 << (prec + 1);
356
357 while d != 0 {
366 let t = s + d;
369 if rem >= t {
370 res += d;
372 s += d + d;
373 rem -= t;
374 }
375 rem <<= 1;
377 d >>= 1;
379 }
380
381 res = (res + 1) >> 1;
390
391 IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec)
393 }
394 }
395}
396
397pub trait IeeeExt: rustc_apfloat::Float {
399 #[inline]
402 fn one() -> Self {
403 Self::from_u128(1).value
404 }
405
406 #[inline]
407 fn two() -> Self {
408 Self::from_u128(2).value
409 }
410
411 #[inline]
412 fn three() -> Self {
413 Self::from_u128(3).value
414 }
415
416 fn pi() -> Self;
417
418 #[inline]
419 fn clamp(self, min: Self, max: Self) -> Self {
420 self.maximum(min).minimum(max)
421 }
422}
423
424macro_rules! impl_ieee_pi {
425 ($float_ty:ident, $semantic:ty) => {
426 impl IeeeExt for IeeeFloat<$semantic> {
427 #[inline]
428 fn pi() -> Self {
429 Self::from_bits($float_ty::consts::PI.to_bits().into())
431 }
432 }
433 };
434}
435
436impl_ieee_pi!(f32, SingleS);
437impl_ieee_pi!(f64, DoubleS);
438
439#[cfg(test)]
440mod tests {
441 use rustc_apfloat::ieee::{DoubleS, HalfS, IeeeFloat, QuadS, SingleS};
442
443 use super::sqrt;
444
445 #[test]
446 fn test_sqrt() {
447 #[track_caller]
448 fn test<S: rustc_apfloat::ieee::Semantics>(x: &str, expected: &str) {
449 let x: IeeeFloat<S> = x.parse().unwrap();
450 let expected: IeeeFloat<S> = expected.parse().unwrap();
451 let result = sqrt(x);
452 assert_eq!(result, expected);
453 }
454
455 fn exact_tests<S: rustc_apfloat::ieee::Semantics>() {
456 test::<S>("0", "0");
457 test::<S>("1", "1");
458 test::<S>("1.5625", "1.25");
459 test::<S>("2.25", "1.5");
460 test::<S>("4", "2");
461 test::<S>("5.0625", "2.25");
462 test::<S>("9", "3");
463 test::<S>("16", "4");
464 test::<S>("25", "5");
465 test::<S>("36", "6");
466 test::<S>("49", "7");
467 test::<S>("64", "8");
468 test::<S>("81", "9");
469 test::<S>("100", "10");
470
471 test::<S>("0.5625", "0.75");
472 test::<S>("0.25", "0.5");
473 test::<S>("0.0625", "0.25");
474 test::<S>("0.00390625", "0.0625");
475 }
476
477 exact_tests::<HalfS>();
478 exact_tests::<SingleS>();
479 exact_tests::<DoubleS>();
480 exact_tests::<QuadS>();
481
482 test::<SingleS>("2", "1.4142135");
483 test::<DoubleS>("2", "1.4142135623730951");
484
485 test::<SingleS>("1.1", "1.0488088");
486 test::<DoubleS>("1.1", "1.0488088481701516");
487
488 test::<SingleS>("2.2", "1.4832398");
489 test::<DoubleS>("2.2", "1.4832396974191326");
490
491 test::<SingleS>("1.22101e-40", "1.10499205e-20");
492 test::<DoubleS>("1.22101e-310", "1.1049932126488395e-155");
493
494 test::<SingleS>("3.4028235e38", "1.8446743e19");
495 test::<DoubleS>("1.7976931348623157e308", "1.3407807929942596e154");
496 }
497}