miri/shims/x86/
mod.rs

1use rustc_abi::{CanonAbi, FieldIdx, Size};
2use rustc_apfloat::Float;
3use rustc_apfloat::ieee::Single;
4use rustc_middle::ty::Ty;
5use rustc_middle::{mir, ty};
6use rustc_span::Symbol;
7use rustc_target::callconv::FnAbi;
8
9use self::helpers::bool_to_simd_element;
10use crate::*;
11
12mod aesni;
13mod avx;
14mod avx2;
15mod bmi;
16mod gfni;
17mod sha;
18mod sse;
19mod sse2;
20mod sse3;
21mod sse41;
22mod sse42;
23mod ssse3;
24
25impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
26pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
27    fn emulate_x86_intrinsic(
28        &mut self,
29        link_name: Symbol,
30        abi: &FnAbi<'tcx, Ty<'tcx>>,
31        args: &[OpTy<'tcx>],
32        dest: &MPlaceTy<'tcx>,
33    ) -> InterpResult<'tcx, EmulateItemResult> {
34        let this = self.eval_context_mut();
35        // Prefix should have already been checked.
36        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.").unwrap();
37        match unprefixed_name {
38            // Used to implement the `_addcarry_u{32, 64}` and the `_subborrow_u{32, 64}` functions.
39            // Computes a + b or a - b with input and output carry/borrow. The input carry/borrow is an 8-bit
40            // value, which is interpreted as 1 if it is non-zero. The output carry/borrow is an 8-bit value that will be 0 or 1.
41            // https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-8/addcarry-u32-addcarry-u64.html
42            // https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-8/subborrow-u32-subborrow-u64.html
43            "addcarry.32" | "addcarry.64" | "subborrow.32" | "subborrow.64" => {
44                if unprefixed_name.ends_with("64") && this.tcx.sess.target.arch != "x86_64" {
45                    return interp_ok(EmulateItemResult::NotSupported);
46                }
47
48                let [cb_in, a, b] =
49                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
50                let op = if unprefixed_name.starts_with("add") {
51                    mir::BinOp::AddWithOverflow
52                } else {
53                    mir::BinOp::SubWithOverflow
54                };
55
56                let (sum, cb_out) = carrying_add(this, cb_in, a, b, op)?;
57                this.write_scalar(cb_out, &this.project_field(dest, FieldIdx::ZERO)?)?;
58                this.write_immediate(*sum, &this.project_field(dest, FieldIdx::ONE)?)?;
59            }
60
61            // Used to implement the `_addcarryx_u{32, 64}` functions. They are semantically identical with the `_addcarry_u{32, 64}` functions,
62            // except for a slightly different type signature and the requirement for the "adx" target feature.
63            // https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-8/addcarryx-u32-addcarryx-u64.html
64            "addcarryx.u32" | "addcarryx.u64" => {
65                this.expect_target_feature_for_intrinsic(link_name, "adx")?;
66
67                let is_u64 = unprefixed_name.ends_with("64");
68                if is_u64 && this.tcx.sess.target.arch != "x86_64" {
69                    return interp_ok(EmulateItemResult::NotSupported);
70                }
71                let [c_in, a, b, out] =
72                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
73                let out = this.deref_pointer_as(
74                    out,
75                    if is_u64 { this.machine.layouts.u64 } else { this.machine.layouts.u32 },
76                )?;
77
78                let (sum, c_out) = carrying_add(this, c_in, a, b, mir::BinOp::AddWithOverflow)?;
79                this.write_scalar(c_out, dest)?;
80                this.write_immediate(*sum, &out)?;
81            }
82
83            // Used to implement the `_mm_pause` function.
84            // The intrinsic is used to hint the processor that the code is in a spin-loop.
85            // It is compiled down to a `pause` instruction. When SSE2 is not available,
86            // the instruction behaves like a no-op, so it is always safe to call the
87            // intrinsic.
88            "sse2.pause" => {
89                let [] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
90                // Only exhibit the spin-loop hint behavior when SSE2 is enabled.
91                if this.tcx.sess.unstable_target_features.contains(&Symbol::intern("sse2")) {
92                    this.yield_active_thread();
93                }
94            }
95
96            "pclmulqdq" | "pclmulqdq.256" | "pclmulqdq.512" => {
97                let mut len = 2; // in units of 64bits
98                this.expect_target_feature_for_intrinsic(link_name, "pclmulqdq")?;
99                if unprefixed_name.ends_with(".256") {
100                    this.expect_target_feature_for_intrinsic(link_name, "vpclmulqdq")?;
101                    len = 4;
102                } else if unprefixed_name.ends_with(".512") {
103                    this.expect_target_feature_for_intrinsic(link_name, "vpclmulqdq")?;
104                    this.expect_target_feature_for_intrinsic(link_name, "avx512f")?;
105                    len = 8;
106                }
107
108                let [left, right, imm] =
109                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
110
111                pclmulqdq(this, left, right, imm, dest, len)?;
112            }
113
114            name if name.starts_with("bmi.") => {
115                return bmi::EvalContextExt::emulate_x86_bmi_intrinsic(
116                    this, link_name, abi, args, dest,
117                );
118            }
119            // The GFNI extension does not get its own namespace.
120            // Check for instruction names instead.
121            name if name.starts_with("vgf2p8affine") || name.starts_with("vgf2p8mulb") => {
122                return gfni::EvalContextExt::emulate_x86_gfni_intrinsic(
123                    this, link_name, abi, args, dest,
124                );
125            }
126            name if name.starts_with("sha") => {
127                return sha::EvalContextExt::emulate_x86_sha_intrinsic(
128                    this, link_name, abi, args, dest,
129                );
130            }
131            name if name.starts_with("sse.") => {
132                return sse::EvalContextExt::emulate_x86_sse_intrinsic(
133                    this, link_name, abi, args, dest,
134                );
135            }
136            name if name.starts_with("sse2.") => {
137                return sse2::EvalContextExt::emulate_x86_sse2_intrinsic(
138                    this, link_name, abi, args, dest,
139                );
140            }
141            name if name.starts_with("sse3.") => {
142                return sse3::EvalContextExt::emulate_x86_sse3_intrinsic(
143                    this, link_name, abi, args, dest,
144                );
145            }
146            name if name.starts_with("ssse3.") => {
147                return ssse3::EvalContextExt::emulate_x86_ssse3_intrinsic(
148                    this, link_name, abi, args, dest,
149                );
150            }
151            name if name.starts_with("sse41.") => {
152                return sse41::EvalContextExt::emulate_x86_sse41_intrinsic(
153                    this, link_name, abi, args, dest,
154                );
155            }
156            name if name.starts_with("sse42.") => {
157                return sse42::EvalContextExt::emulate_x86_sse42_intrinsic(
158                    this, link_name, abi, args, dest,
159                );
160            }
161            name if name.starts_with("aesni.") => {
162                return aesni::EvalContextExt::emulate_x86_aesni_intrinsic(
163                    this, link_name, abi, args, dest,
164                );
165            }
166            name if name.starts_with("avx.") => {
167                return avx::EvalContextExt::emulate_x86_avx_intrinsic(
168                    this, link_name, abi, args, dest,
169                );
170            }
171            name if name.starts_with("avx2.") => {
172                return avx2::EvalContextExt::emulate_x86_avx2_intrinsic(
173                    this, link_name, abi, args, dest,
174                );
175            }
176
177            _ => return interp_ok(EmulateItemResult::NotSupported),
178        }
179        interp_ok(EmulateItemResult::NeedsReturn)
180    }
181}
182
183#[derive(Copy, Clone)]
184enum FloatBinOp {
185    /// Comparison
186    ///
187    /// The semantics of this operator is a case distinction: we compare the two operands,
188    /// and then we return one of the four booleans `gt`, `lt`, `eq`, `unord` depending on
189    /// which class they fall into.
190    ///
191    /// AVX supports all 16 combinations, SSE only a subset
192    ///
193    /// <https://www.felixcloutier.com/x86/cmpss>
194    /// <https://www.felixcloutier.com/x86/cmpps>
195    /// <https://www.felixcloutier.com/x86/cmpsd>
196    /// <https://www.felixcloutier.com/x86/cmppd>
197    Cmp {
198        /// Result when lhs < rhs
199        gt: bool,
200        /// Result when lhs > rhs
201        lt: bool,
202        /// Result when lhs == rhs
203        eq: bool,
204        /// Result when lhs is NaN or rhs is NaN
205        unord: bool,
206    },
207    /// Minimum value (with SSE semantics)
208    ///
209    /// <https://www.felixcloutier.com/x86/minss>
210    /// <https://www.felixcloutier.com/x86/minps>
211    /// <https://www.felixcloutier.com/x86/minsd>
212    /// <https://www.felixcloutier.com/x86/minpd>
213    Min,
214    /// Maximum value (with SSE semantics)
215    ///
216    /// <https://www.felixcloutier.com/x86/maxss>
217    /// <https://www.felixcloutier.com/x86/maxps>
218    /// <https://www.felixcloutier.com/x86/maxsd>
219    /// <https://www.felixcloutier.com/x86/maxpd>
220    Max,
221}
222
223impl FloatBinOp {
224    /// Convert from the `imm` argument used to specify the comparison
225    /// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
226    fn cmp_from_imm<'tcx>(
227        ecx: &crate::MiriInterpCx<'tcx>,
228        imm: i8,
229        intrinsic: Symbol,
230    ) -> InterpResult<'tcx, Self> {
231        // Only bits 0..=4 are used, remaining should be zero.
232        if imm & !0b1_1111 != 0 {
233            panic!("invalid `imm` parameter of {intrinsic}: 0x{imm:x}");
234        }
235        // Bit 4 specifies whether the operation is quiet or signaling, which
236        // we do not care in Miri.
237        // Bits 0..=2 specifies the operation.
238        // `gt` indicates the result to be returned when the LHS is strictly
239        // greater than the RHS, and so on.
240        let (gt, lt, eq, mut unord) = match imm & 0b111 {
241            // Equal
242            0x0 => (false, false, true, false),
243            // Less-than
244            0x1 => (false, true, false, false),
245            // Less-or-equal
246            0x2 => (false, true, true, false),
247            // Unordered (either is NaN)
248            0x3 => (false, false, false, true),
249            // Not equal
250            0x4 => (true, true, false, true),
251            // Not less-than
252            0x5 => (true, false, true, true),
253            // Not less-or-equal
254            0x6 => (true, false, false, true),
255            // Ordered (neither is NaN)
256            0x7 => (true, true, true, false),
257            _ => unreachable!(),
258        };
259        // When bit 3 is 1 (only possible in AVX), unord is toggled.
260        if imm & 0b1000 != 0 {
261            ecx.expect_target_feature_for_intrinsic(intrinsic, "avx")?;
262            unord = !unord;
263        }
264        interp_ok(Self::Cmp { gt, lt, eq, unord })
265    }
266}
267
268/// Performs `which` scalar operation on `left` and `right` and returns
269/// the result.
270fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
271    which: FloatBinOp,
272    left: &ImmTy<'tcx>,
273    right: &ImmTy<'tcx>,
274) -> InterpResult<'tcx, Scalar> {
275    match which {
276        FloatBinOp::Cmp { gt, lt, eq, unord } => {
277            let left = left.to_scalar().to_float::<F>()?;
278            let right = right.to_scalar().to_float::<F>()?;
279
280            let res = match left.partial_cmp(&right) {
281                None => unord,
282                Some(std::cmp::Ordering::Less) => lt,
283                Some(std::cmp::Ordering::Equal) => eq,
284                Some(std::cmp::Ordering::Greater) => gt,
285            };
286            interp_ok(bool_to_simd_element(res, Size::from_bits(F::BITS)))
287        }
288        FloatBinOp::Min => {
289            let left_scalar = left.to_scalar();
290            let left = left_scalar.to_float::<F>()?;
291            let right_scalar = right.to_scalar();
292            let right = right_scalar.to_float::<F>()?;
293            // SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
294            // is true when `x` is either +0 or -0.
295            if (left == F::ZERO && right == F::ZERO)
296                || left.is_nan()
297                || right.is_nan()
298                || left >= right
299            {
300                interp_ok(right_scalar)
301            } else {
302                interp_ok(left_scalar)
303            }
304        }
305        FloatBinOp::Max => {
306            let left_scalar = left.to_scalar();
307            let left = left_scalar.to_float::<F>()?;
308            let right_scalar = right.to_scalar();
309            let right = right_scalar.to_float::<F>()?;
310            // SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
311            // is true when `x` is either +0 or -0.
312            if (left == F::ZERO && right == F::ZERO)
313                || left.is_nan()
314                || right.is_nan()
315                || left <= right
316            {
317                interp_ok(right_scalar)
318            } else {
319                interp_ok(left_scalar)
320            }
321        }
322    }
323}
324
325/// Performs `which` operation on the first component of `left` and `right`
326/// and copies the other components from `left`. The result is stored in `dest`.
327fn bin_op_simd_float_first<'tcx, F: rustc_apfloat::Float>(
328    ecx: &mut crate::MiriInterpCx<'tcx>,
329    which: FloatBinOp,
330    left: &OpTy<'tcx>,
331    right: &OpTy<'tcx>,
332    dest: &MPlaceTy<'tcx>,
333) -> InterpResult<'tcx, ()> {
334    let (left, left_len) = ecx.project_to_simd(left)?;
335    let (right, right_len) = ecx.project_to_simd(right)?;
336    let (dest, dest_len) = ecx.project_to_simd(dest)?;
337
338    assert_eq!(dest_len, left_len);
339    assert_eq!(dest_len, right_len);
340
341    let res0 = bin_op_float::<F>(
342        which,
343        &ecx.read_immediate(&ecx.project_index(&left, 0)?)?,
344        &ecx.read_immediate(&ecx.project_index(&right, 0)?)?,
345    )?;
346    ecx.write_scalar(res0, &ecx.project_index(&dest, 0)?)?;
347
348    for i in 1..dest_len {
349        ecx.copy_op(&ecx.project_index(&left, i)?, &ecx.project_index(&dest, i)?)?;
350    }
351
352    interp_ok(())
353}
354
355/// Performs `which` operation on each component of `left` and
356/// `right`, storing the result is stored in `dest`.
357fn bin_op_simd_float_all<'tcx, F: rustc_apfloat::Float>(
358    ecx: &mut crate::MiriInterpCx<'tcx>,
359    which: FloatBinOp,
360    left: &OpTy<'tcx>,
361    right: &OpTy<'tcx>,
362    dest: &MPlaceTy<'tcx>,
363) -> InterpResult<'tcx, ()> {
364    let (left, left_len) = ecx.project_to_simd(left)?;
365    let (right, right_len) = ecx.project_to_simd(right)?;
366    let (dest, dest_len) = ecx.project_to_simd(dest)?;
367
368    assert_eq!(dest_len, left_len);
369    assert_eq!(dest_len, right_len);
370
371    for i in 0..dest_len {
372        let left = ecx.read_immediate(&ecx.project_index(&left, i)?)?;
373        let right = ecx.read_immediate(&ecx.project_index(&right, i)?)?;
374        let dest = ecx.project_index(&dest, i)?;
375
376        let res = bin_op_float::<F>(which, &left, &right)?;
377        ecx.write_scalar(res, &dest)?;
378    }
379
380    interp_ok(())
381}
382
383#[derive(Copy, Clone)]
384enum FloatUnaryOp {
385    /// Approximation of 1/x
386    ///
387    /// <https://www.felixcloutier.com/x86/rcpss>
388    /// <https://www.felixcloutier.com/x86/rcpps>
389    Rcp,
390    /// Approximation of 1/sqrt(x)
391    ///
392    /// <https://www.felixcloutier.com/x86/rsqrtss>
393    /// <https://www.felixcloutier.com/x86/rsqrtps>
394    Rsqrt,
395}
396
397/// Performs `which` scalar operation on `op` and returns the result.
398fn unary_op_f32<'tcx>(
399    ecx: &mut crate::MiriInterpCx<'tcx>,
400    which: FloatUnaryOp,
401    op: &ImmTy<'tcx>,
402) -> InterpResult<'tcx, Scalar> {
403    match which {
404        FloatUnaryOp::Rcp => {
405            let op = op.to_scalar().to_f32()?;
406            let div = (Single::from_u128(1).value / op).value;
407            // Apply a relative error with a magnitude on the order of 2^-12 to simulate the
408            // inaccuracy of RCP.
409            let res = math::apply_random_float_error(ecx, div, -12);
410            interp_ok(Scalar::from_f32(res))
411        }
412        FloatUnaryOp::Rsqrt => {
413            let op = op.to_scalar().to_f32()?;
414            let rsqrt = (Single::from_u128(1).value / math::sqrt(op)).value;
415            // Apply a relative error with a magnitude on the order of 2^-12 to simulate the
416            // inaccuracy of RSQRT.
417            let res = math::apply_random_float_error(ecx, rsqrt, -12);
418            interp_ok(Scalar::from_f32(res))
419        }
420    }
421}
422
423/// Performs `which` operation on the first component of `op` and copies
424/// the other components. The result is stored in `dest`.
425fn unary_op_ss<'tcx>(
426    ecx: &mut crate::MiriInterpCx<'tcx>,
427    which: FloatUnaryOp,
428    op: &OpTy<'tcx>,
429    dest: &MPlaceTy<'tcx>,
430) -> InterpResult<'tcx, ()> {
431    let (op, op_len) = ecx.project_to_simd(op)?;
432    let (dest, dest_len) = ecx.project_to_simd(dest)?;
433
434    assert_eq!(dest_len, op_len);
435
436    let res0 = unary_op_f32(ecx, which, &ecx.read_immediate(&ecx.project_index(&op, 0)?)?)?;
437    ecx.write_scalar(res0, &ecx.project_index(&dest, 0)?)?;
438
439    for i in 1..dest_len {
440        ecx.copy_op(&ecx.project_index(&op, i)?, &ecx.project_index(&dest, i)?)?;
441    }
442
443    interp_ok(())
444}
445
446/// Performs `which` operation on each component of `op`, storing the
447/// result is stored in `dest`.
448fn unary_op_ps<'tcx>(
449    ecx: &mut crate::MiriInterpCx<'tcx>,
450    which: FloatUnaryOp,
451    op: &OpTy<'tcx>,
452    dest: &MPlaceTy<'tcx>,
453) -> InterpResult<'tcx, ()> {
454    let (op, op_len) = ecx.project_to_simd(op)?;
455    let (dest, dest_len) = ecx.project_to_simd(dest)?;
456
457    assert_eq!(dest_len, op_len);
458
459    for i in 0..dest_len {
460        let op = ecx.read_immediate(&ecx.project_index(&op, i)?)?;
461        let dest = ecx.project_index(&dest, i)?;
462
463        let res = unary_op_f32(ecx, which, &op)?;
464        ecx.write_scalar(res, &dest)?;
465    }
466
467    interp_ok(())
468}
469
470enum ShiftOp {
471    /// Shift left, logically (shift in zeros) -- same as shift left, arithmetically
472    Left,
473    /// Shift right, logically (shift in zeros)
474    RightLogic,
475    /// Shift right, arithmetically (shift in sign)
476    RightArith,
477}
478
479/// Shifts each element of `left` by a scalar amount. The shift amount
480/// is determined by the lowest 64 bits of `right` (which is a 128-bit vector).
481///
482/// For logic shifts, when right is larger than BITS - 1, zero is produced.
483/// For arithmetic right-shifts, when right is larger than BITS - 1, the sign
484/// bit is copied to all bits.
485fn shift_simd_by_scalar<'tcx>(
486    ecx: &mut crate::MiriInterpCx<'tcx>,
487    left: &OpTy<'tcx>,
488    right: &OpTy<'tcx>,
489    which: ShiftOp,
490    dest: &MPlaceTy<'tcx>,
491) -> InterpResult<'tcx, ()> {
492    let (left, left_len) = ecx.project_to_simd(left)?;
493    let (dest, dest_len) = ecx.project_to_simd(dest)?;
494
495    assert_eq!(dest_len, left_len);
496    // `right` may have a different length, and we only care about its
497    // lowest 64bit anyway.
498
499    // Get the 64-bit shift operand and convert it to the type expected
500    // by checked_{shl,shr} (u32).
501    // It is ok to saturate the value to u32::MAX because any value
502    // above BITS - 1 will produce the same result.
503    let shift = u32::try_from(extract_first_u64(ecx, right)?).unwrap_or(u32::MAX);
504
505    for i in 0..dest_len {
506        let left = ecx.read_scalar(&ecx.project_index(&left, i)?)?;
507        let dest = ecx.project_index(&dest, i)?;
508
509        let res = match which {
510            ShiftOp::Left => {
511                let left = left.to_uint(dest.layout.size)?;
512                let res = left.checked_shl(shift).unwrap_or(0);
513                // `truncate` is needed as left-shift can make the absolute value larger.
514                Scalar::from_uint(dest.layout.size.truncate(res), dest.layout.size)
515            }
516            ShiftOp::RightLogic => {
517                let left = left.to_uint(dest.layout.size)?;
518                let res = left.checked_shr(shift).unwrap_or(0);
519                // No `truncate` needed as right-shift can only make the absolute value smaller.
520                Scalar::from_uint(res, dest.layout.size)
521            }
522            ShiftOp::RightArith => {
523                let left = left.to_int(dest.layout.size)?;
524                // On overflow, copy the sign bit to the remaining bits
525                let res = left.checked_shr(shift).unwrap_or(left >> 127);
526                // No `truncate` needed as right-shift can only make the absolute value smaller.
527                Scalar::from_int(res, dest.layout.size)
528            }
529        };
530        ecx.write_scalar(res, &dest)?;
531    }
532
533    interp_ok(())
534}
535
536/// Shifts each element of `left` by the corresponding element of `right`.
537///
538/// For logic shifts, when right is larger than BITS - 1, zero is produced.
539/// For arithmetic right-shifts, when right is larger than BITS - 1, the sign
540/// bit is copied to all bits.
541fn shift_simd_by_simd<'tcx>(
542    ecx: &mut crate::MiriInterpCx<'tcx>,
543    left: &OpTy<'tcx>,
544    right: &OpTy<'tcx>,
545    which: ShiftOp,
546    dest: &MPlaceTy<'tcx>,
547) -> InterpResult<'tcx, ()> {
548    let (left, left_len) = ecx.project_to_simd(left)?;
549    let (right, right_len) = ecx.project_to_simd(right)?;
550    let (dest, dest_len) = ecx.project_to_simd(dest)?;
551
552    assert_eq!(dest_len, left_len);
553    assert_eq!(dest_len, right_len);
554
555    for i in 0..dest_len {
556        let left = ecx.read_scalar(&ecx.project_index(&left, i)?)?;
557        let right = ecx.read_scalar(&ecx.project_index(&right, i)?)?;
558        let dest = ecx.project_index(&dest, i)?;
559
560        // It is ok to saturate the value to u32::MAX because any value
561        // above BITS - 1 will produce the same result.
562        let shift = u32::try_from(right.to_uint(dest.layout.size)?).unwrap_or(u32::MAX);
563
564        let res = match which {
565            ShiftOp::Left => {
566                let left = left.to_uint(dest.layout.size)?;
567                let res = left.checked_shl(shift).unwrap_or(0);
568                // `truncate` is needed as left-shift can make the absolute value larger.
569                Scalar::from_uint(dest.layout.size.truncate(res), dest.layout.size)
570            }
571            ShiftOp::RightLogic => {
572                let left = left.to_uint(dest.layout.size)?;
573                let res = left.checked_shr(shift).unwrap_or(0);
574                // No `truncate` needed as right-shift can only make the absolute value smaller.
575                Scalar::from_uint(res, dest.layout.size)
576            }
577            ShiftOp::RightArith => {
578                let left = left.to_int(dest.layout.size)?;
579                // On overflow, copy the sign bit to the remaining bits
580                let res = left.checked_shr(shift).unwrap_or(left >> 127);
581                // No `truncate` needed as right-shift can only make the absolute value smaller.
582                Scalar::from_int(res, dest.layout.size)
583            }
584        };
585        ecx.write_scalar(res, &dest)?;
586    }
587
588    interp_ok(())
589}
590
591/// Takes a 128-bit vector, transmutes it to `[u64; 2]` and extracts
592/// the first value.
593fn extract_first_u64<'tcx>(
594    ecx: &crate::MiriInterpCx<'tcx>,
595    op: &OpTy<'tcx>,
596) -> InterpResult<'tcx, u64> {
597    // Transmute vector to `[u64; 2]`
598    let array_layout = ecx.layout_of(Ty::new_array(ecx.tcx.tcx, ecx.tcx.types.u64, 2))?;
599    let op = op.transmute(array_layout, ecx)?;
600
601    // Get the first u64 from the array
602    ecx.read_scalar(&ecx.project_index(&op, 0)?)?.to_u64()
603}
604
605// Rounds the first element of `right` according to `rounding`
606// and copies the remaining elements from `left`.
607fn round_first<'tcx, F: rustc_apfloat::Float>(
608    ecx: &mut crate::MiriInterpCx<'tcx>,
609    left: &OpTy<'tcx>,
610    right: &OpTy<'tcx>,
611    rounding: &OpTy<'tcx>,
612    dest: &MPlaceTy<'tcx>,
613) -> InterpResult<'tcx, ()> {
614    let (left, left_len) = ecx.project_to_simd(left)?;
615    let (right, right_len) = ecx.project_to_simd(right)?;
616    let (dest, dest_len) = ecx.project_to_simd(dest)?;
617
618    assert_eq!(dest_len, left_len);
619    assert_eq!(dest_len, right_len);
620
621    let rounding = rounding_from_imm(ecx.read_scalar(rounding)?.to_i32()?)?;
622
623    let op0: F = ecx.read_scalar(&ecx.project_index(&right, 0)?)?.to_float()?;
624    let res = op0.round_to_integral(rounding).value;
625    ecx.write_scalar(
626        Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
627        &ecx.project_index(&dest, 0)?,
628    )?;
629
630    for i in 1..dest_len {
631        ecx.copy_op(&ecx.project_index(&left, i)?, &ecx.project_index(&dest, i)?)?;
632    }
633
634    interp_ok(())
635}
636
637// Rounds all elements of `op` according to `rounding`.
638fn round_all<'tcx, F: rustc_apfloat::Float>(
639    ecx: &mut crate::MiriInterpCx<'tcx>,
640    op: &OpTy<'tcx>,
641    rounding: &OpTy<'tcx>,
642    dest: &MPlaceTy<'tcx>,
643) -> InterpResult<'tcx, ()> {
644    let (op, op_len) = ecx.project_to_simd(op)?;
645    let (dest, dest_len) = ecx.project_to_simd(dest)?;
646
647    assert_eq!(dest_len, op_len);
648
649    let rounding = rounding_from_imm(ecx.read_scalar(rounding)?.to_i32()?)?;
650
651    for i in 0..dest_len {
652        let op: F = ecx.read_scalar(&ecx.project_index(&op, i)?)?.to_float()?;
653        let res = op.round_to_integral(rounding).value;
654        ecx.write_scalar(
655            Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
656            &ecx.project_index(&dest, i)?,
657        )?;
658    }
659
660    interp_ok(())
661}
662
663/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of
664/// `round.{ss,sd,ps,pd}` intrinsics.
665fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> {
666    // The fourth bit of `rounding` only affects the SSE status
667    // register, which cannot be accessed from Miri (or from Rust,
668    // for that matter), so we can ignore it.
669    match rounding & !0b1000 {
670        // When the third bit is 0, the rounding mode is determined by the
671        // first two bits.
672        0b000 => interp_ok(rustc_apfloat::Round::NearestTiesToEven),
673        0b001 => interp_ok(rustc_apfloat::Round::TowardNegative),
674        0b010 => interp_ok(rustc_apfloat::Round::TowardPositive),
675        0b011 => interp_ok(rustc_apfloat::Round::TowardZero),
676        // When the third bit is 1, the rounding mode is determined by the
677        // SSE status register. Since we do not support modifying it from
678        // Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
679        0b100..=0b111 => interp_ok(rustc_apfloat::Round::NearestTiesToEven),
680        rounding => panic!("invalid rounding mode 0x{rounding:02x}"),
681    }
682}
683
684/// Converts each element of `op` from floating point to signed integer.
685///
686/// When the input value is NaN or out of range, fall back to minimum value.
687///
688/// If `op` has more elements than `dest`, extra elements are ignored. If `op`
689/// has less elements than `dest`, the rest is filled with zeros.
690fn convert_float_to_int<'tcx>(
691    ecx: &mut crate::MiriInterpCx<'tcx>,
692    op: &OpTy<'tcx>,
693    rnd: rustc_apfloat::Round,
694    dest: &MPlaceTy<'tcx>,
695) -> InterpResult<'tcx, ()> {
696    let (op, op_len) = ecx.project_to_simd(op)?;
697    let (dest, dest_len) = ecx.project_to_simd(dest)?;
698
699    // Output must be *signed* integers.
700    assert!(matches!(dest.layout.field(ecx, 0).ty.kind(), ty::Int(_)));
701
702    for i in 0..op_len.min(dest_len) {
703        let op = ecx.read_immediate(&ecx.project_index(&op, i)?)?;
704        let dest = ecx.project_index(&dest, i)?;
705
706        let res = ecx.float_to_int_checked(&op, dest.layout, rnd)?.unwrap_or_else(|| {
707            // Fallback to minimum according to SSE/AVX semantics.
708            ImmTy::from_int(dest.layout.size.signed_int_min(), dest.layout)
709        });
710        ecx.write_immediate(*res, &dest)?;
711    }
712    // Fill remainder with zeros
713    for i in op_len..dest_len {
714        let dest = ecx.project_index(&dest, i)?;
715        ecx.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
716    }
717
718    interp_ok(())
719}
720
721/// Calculates absolute value of integers in `op` and stores the result in `dest`.
722///
723/// In case of overflow (when the operand is the minimum value), the operation
724/// will wrap around.
725fn int_abs<'tcx>(
726    ecx: &mut crate::MiriInterpCx<'tcx>,
727    op: &OpTy<'tcx>,
728    dest: &MPlaceTy<'tcx>,
729) -> InterpResult<'tcx, ()> {
730    let (op, op_len) = ecx.project_to_simd(op)?;
731    let (dest, dest_len) = ecx.project_to_simd(dest)?;
732
733    assert_eq!(op_len, dest_len);
734
735    let zero = ImmTy::from_int(0, op.layout.field(ecx, 0));
736
737    for i in 0..dest_len {
738        let op = ecx.read_immediate(&ecx.project_index(&op, i)?)?;
739        let dest = ecx.project_index(&dest, i)?;
740
741        let lt_zero = ecx.binary_op(mir::BinOp::Lt, &op, &zero)?;
742        let res =
743            if lt_zero.to_scalar().to_bool()? { ecx.unary_op(mir::UnOp::Neg, &op)? } else { op };
744
745        ecx.write_immediate(*res, &dest)?;
746    }
747
748    interp_ok(())
749}
750
751/// Splits `op` (which must be a SIMD vector) into 128-bit chunks.
752///
753/// Returns a tuple where:
754/// * The first element is the number of 128-bit chunks (let's call it `N`).
755/// * The second element is the number of elements per chunk (let's call it `M`).
756/// * The third element is the `op` vector split into chunks, i.e, it's
757///   type is `[[T; M]; N]` where `T` is the element type of `op`.
758fn split_simd_to_128bit_chunks<'tcx, P: Projectable<'tcx, Provenance>>(
759    ecx: &mut crate::MiriInterpCx<'tcx>,
760    op: &P,
761) -> InterpResult<'tcx, (u64, u64, P)> {
762    let simd_layout = op.layout();
763    let (simd_len, element_ty) = simd_layout.ty.simd_size_and_type(ecx.tcx.tcx);
764
765    assert_eq!(simd_layout.size.bits() % 128, 0);
766    let num_chunks = simd_layout.size.bits() / 128;
767    let items_per_chunk = simd_len.strict_div(num_chunks);
768
769    // Transmute to `[[T; items_per_chunk]; num_chunks]`
770    let chunked_layout = ecx
771        .layout_of(Ty::new_array(
772            ecx.tcx.tcx,
773            Ty::new_array(ecx.tcx.tcx, element_ty, items_per_chunk),
774            num_chunks,
775        ))
776        .unwrap();
777    let chunked_op = op.transmute(chunked_layout, ecx)?;
778
779    interp_ok((num_chunks, items_per_chunk, chunked_op))
780}
781
782/// Horizontally performs `which` operation on adjacent values of
783/// `left` and `right` SIMD vectors and stores the result in `dest`.
784/// "Horizontal" means that the i-th output element is calculated
785/// from the elements 2*i and 2*i+1 of the concatenation of `left` and
786/// `right`.
787///
788/// Each 128-bit chunk is treated independently (i.e., the value for
789/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
790/// 128-bit chunks of `left` and `right`).
791fn horizontal_bin_op<'tcx>(
792    ecx: &mut crate::MiriInterpCx<'tcx>,
793    which: mir::BinOp,
794    saturating: bool,
795    left: &OpTy<'tcx>,
796    right: &OpTy<'tcx>,
797    dest: &MPlaceTy<'tcx>,
798) -> InterpResult<'tcx, ()> {
799    assert_eq!(left.layout, dest.layout);
800    assert_eq!(right.layout, dest.layout);
801
802    let (num_chunks, items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
803    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
804    let (_, _, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
805
806    let middle = items_per_chunk / 2;
807    for i in 0..num_chunks {
808        let left = ecx.project_index(&left, i)?;
809        let right = ecx.project_index(&right, i)?;
810        let dest = ecx.project_index(&dest, i)?;
811
812        for j in 0..items_per_chunk {
813            // `j` is the index in `dest`
814            // `k` is the index of the 2-item chunk in `src`
815            let (k, src) = if j < middle { (j, &left) } else { (j.strict_sub(middle), &right) };
816            // `base_i` is the index of the first item of the 2-item chunk in `src`
817            let base_i = k.strict_mul(2);
818            let lhs = ecx.read_immediate(&ecx.project_index(src, base_i)?)?;
819            let rhs = ecx.read_immediate(&ecx.project_index(src, base_i.strict_add(1))?)?;
820
821            let res = if saturating {
822                Immediate::from(ecx.saturating_arith(which, &lhs, &rhs)?)
823            } else {
824                *ecx.binary_op(which, &lhs, &rhs)?
825            };
826
827            ecx.write_immediate(res, &ecx.project_index(&dest, j)?)?;
828        }
829    }
830
831    interp_ok(())
832}
833
834/// Conditionally multiplies the packed floating-point elements in
835/// `left` and `right` using the high 4 bits in `imm`, sums the calculated
836/// products (up to 4), and conditionally stores the sum in `dest` using
837/// the low 4 bits of `imm`.
838///
839/// Each 128-bit chunk is treated independently (i.e., the value for
840/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
841/// 128-bit blocks of `left` and `right`).
842fn conditional_dot_product<'tcx>(
843    ecx: &mut crate::MiriInterpCx<'tcx>,
844    left: &OpTy<'tcx>,
845    right: &OpTy<'tcx>,
846    imm: &OpTy<'tcx>,
847    dest: &MPlaceTy<'tcx>,
848) -> InterpResult<'tcx, ()> {
849    assert_eq!(left.layout, dest.layout);
850    assert_eq!(right.layout, dest.layout);
851
852    let (num_chunks, items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
853    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
854    let (_, _, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
855
856    let element_layout = left.layout.field(ecx, 0).field(ecx, 0);
857    assert!(items_per_chunk <= 4);
858
859    // `imm` is a `u8` for SSE4.1 or an `i32` for AVX :/
860    let imm = ecx.read_scalar(imm)?.to_uint(imm.layout.size)?;
861
862    for i in 0..num_chunks {
863        let left = ecx.project_index(&left, i)?;
864        let right = ecx.project_index(&right, i)?;
865        let dest = ecx.project_index(&dest, i)?;
866
867        // Calculate dot product
868        // Elements are floating point numbers, but we can use `from_int`
869        // for the initial value because the representation of 0.0 is all zero bits.
870        let mut sum = ImmTy::from_int(0u8, element_layout);
871        for j in 0..items_per_chunk {
872            if imm & (1 << j.strict_add(4)) != 0 {
873                let left = ecx.read_immediate(&ecx.project_index(&left, j)?)?;
874                let right = ecx.read_immediate(&ecx.project_index(&right, j)?)?;
875
876                let mul = ecx.binary_op(mir::BinOp::Mul, &left, &right)?;
877                sum = ecx.binary_op(mir::BinOp::Add, &sum, &mul)?;
878            }
879        }
880
881        // Write to destination (conditioned to imm)
882        for j in 0..items_per_chunk {
883            let dest = ecx.project_index(&dest, j)?;
884
885            if imm & (1 << j) != 0 {
886                ecx.write_immediate(*sum, &dest)?;
887            } else {
888                ecx.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
889            }
890        }
891    }
892
893    interp_ok(())
894}
895
896/// Calculates two booleans.
897///
898/// The first is true when all the bits of `op & mask` are zero.
899/// The second is true when `(op & mask) == mask`
900fn test_bits_masked<'tcx>(
901    ecx: &crate::MiriInterpCx<'tcx>,
902    op: &OpTy<'tcx>,
903    mask: &OpTy<'tcx>,
904) -> InterpResult<'tcx, (bool, bool)> {
905    assert_eq!(op.layout, mask.layout);
906
907    let (op, op_len) = ecx.project_to_simd(op)?;
908    let (mask, mask_len) = ecx.project_to_simd(mask)?;
909
910    assert_eq!(op_len, mask_len);
911
912    let mut all_zero = true;
913    let mut masked_set = true;
914    for i in 0..op_len {
915        let op = ecx.project_index(&op, i)?;
916        let mask = ecx.project_index(&mask, i)?;
917
918        let op = ecx.read_scalar(&op)?.to_uint(op.layout.size)?;
919        let mask = ecx.read_scalar(&mask)?.to_uint(mask.layout.size)?;
920        all_zero &= (op & mask) == 0;
921        masked_set &= (op & mask) == mask;
922    }
923
924    interp_ok((all_zero, masked_set))
925}
926
927/// Calculates two booleans.
928///
929/// The first is true when the highest bit of each element of `op & mask` is zero.
930/// The second is true when the highest bit of each element of `!op & mask` is zero.
931fn test_high_bits_masked<'tcx>(
932    ecx: &crate::MiriInterpCx<'tcx>,
933    op: &OpTy<'tcx>,
934    mask: &OpTy<'tcx>,
935) -> InterpResult<'tcx, (bool, bool)> {
936    assert_eq!(op.layout, mask.layout);
937
938    let (op, op_len) = ecx.project_to_simd(op)?;
939    let (mask, mask_len) = ecx.project_to_simd(mask)?;
940
941    assert_eq!(op_len, mask_len);
942
943    let high_bit_offset = op.layout.field(ecx, 0).size.bits().strict_sub(1);
944
945    let mut direct = true;
946    let mut negated = true;
947    for i in 0..op_len {
948        let op = ecx.project_index(&op, i)?;
949        let mask = ecx.project_index(&mask, i)?;
950
951        let op = ecx.read_scalar(&op)?.to_uint(op.layout.size)?;
952        let mask = ecx.read_scalar(&mask)?.to_uint(mask.layout.size)?;
953        direct &= (op & mask) >> high_bit_offset == 0;
954        negated &= (!op & mask) >> high_bit_offset == 0;
955    }
956
957    interp_ok((direct, negated))
958}
959
960/// Conditionally loads from `ptr` according the high bit of each
961/// element of `mask`. `ptr` does not need to be aligned.
962fn mask_load<'tcx>(
963    ecx: &mut crate::MiriInterpCx<'tcx>,
964    ptr: &OpTy<'tcx>,
965    mask: &OpTy<'tcx>,
966    dest: &MPlaceTy<'tcx>,
967) -> InterpResult<'tcx, ()> {
968    let (mask, mask_len) = ecx.project_to_simd(mask)?;
969    let (dest, dest_len) = ecx.project_to_simd(dest)?;
970
971    assert_eq!(dest_len, mask_len);
972
973    let mask_item_size = mask.layout.field(ecx, 0).size;
974    let high_bit_offset = mask_item_size.bits().strict_sub(1);
975
976    let ptr = ecx.read_pointer(ptr)?;
977    for i in 0..dest_len {
978        let mask = ecx.project_index(&mask, i)?;
979        let dest = ecx.project_index(&dest, i)?;
980
981        if ecx.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
982            let ptr = ptr.wrapping_offset(dest.layout.size * i, &ecx.tcx);
983            // Unaligned copy, which is what we want.
984            ecx.mem_copy(ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?;
985        } else {
986            ecx.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
987        }
988    }
989
990    interp_ok(())
991}
992
993/// Conditionally stores into `ptr` according the high bit of each
994/// element of `mask`. `ptr` does not need to be aligned.
995fn mask_store<'tcx>(
996    ecx: &mut crate::MiriInterpCx<'tcx>,
997    ptr: &OpTy<'tcx>,
998    mask: &OpTy<'tcx>,
999    value: &OpTy<'tcx>,
1000) -> InterpResult<'tcx, ()> {
1001    let (mask, mask_len) = ecx.project_to_simd(mask)?;
1002    let (value, value_len) = ecx.project_to_simd(value)?;
1003
1004    assert_eq!(value_len, mask_len);
1005
1006    let mask_item_size = mask.layout.field(ecx, 0).size;
1007    let high_bit_offset = mask_item_size.bits().strict_sub(1);
1008
1009    let ptr = ecx.read_pointer(ptr)?;
1010    for i in 0..value_len {
1011        let mask = ecx.project_index(&mask, i)?;
1012        let value = ecx.project_index(&value, i)?;
1013
1014        if ecx.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
1015            // *Non-inbounds* pointer arithmetic to compute the destination.
1016            // (That's why we can't use a place projection.)
1017            let ptr = ptr.wrapping_offset(value.layout.size * i, &ecx.tcx);
1018            // Deref the pointer *unaligned*, and do the copy.
1019            let dest = ecx.ptr_to_mplace_unaligned(ptr, value.layout);
1020            ecx.copy_op(&value, &dest)?;
1021        }
1022    }
1023
1024    interp_ok(())
1025}
1026
1027/// Compute the sum of absolute differences of quadruplets of unsigned
1028/// 8-bit integers in `left` and `right`, and store the 16-bit results
1029/// in `right`. Quadruplets are selected from `left` and `right` with
1030/// offsets specified in `imm`.
1031///
1032/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_maddubs_epi16>
1033/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mpsadbw_epu8>
1034///
1035/// Each 128-bit chunk is treated independently (i.e., the value for
1036/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1037/// 128-bit chunks of `left` and `right`).
1038fn mpsadbw<'tcx>(
1039    ecx: &mut crate::MiriInterpCx<'tcx>,
1040    left: &OpTy<'tcx>,
1041    right: &OpTy<'tcx>,
1042    imm: &OpTy<'tcx>,
1043    dest: &MPlaceTy<'tcx>,
1044) -> InterpResult<'tcx, ()> {
1045    assert_eq!(left.layout, right.layout);
1046    assert_eq!(left.layout.size, dest.layout.size);
1047
1048    let (num_chunks, op_items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
1049    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
1050    let (_, dest_items_per_chunk, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
1051
1052    assert_eq!(op_items_per_chunk, dest_items_per_chunk.strict_mul(2));
1053
1054    let imm = ecx.read_scalar(imm)?.to_uint(imm.layout.size)?;
1055    // Bit 2 of `imm` specifies the offset for indices of `left`.
1056    // The offset is 0 when the bit is 0 or 4 when the bit is 1.
1057    let left_offset = u64::try_from((imm >> 2) & 1).unwrap().strict_mul(4);
1058    // Bits 0..=1 of `imm` specify the offset for indices of
1059    // `right` in blocks of 4 elements.
1060    let right_offset = u64::try_from(imm & 0b11).unwrap().strict_mul(4);
1061
1062    for i in 0..num_chunks {
1063        let left = ecx.project_index(&left, i)?;
1064        let right = ecx.project_index(&right, i)?;
1065        let dest = ecx.project_index(&dest, i)?;
1066
1067        for j in 0..dest_items_per_chunk {
1068            let left_offset = left_offset.strict_add(j);
1069            let mut res: u16 = 0;
1070            for k in 0..4 {
1071                let left = ecx
1072                    .read_scalar(&ecx.project_index(&left, left_offset.strict_add(k))?)?
1073                    .to_u8()?;
1074                let right = ecx
1075                    .read_scalar(&ecx.project_index(&right, right_offset.strict_add(k))?)?
1076                    .to_u8()?;
1077                res = res.strict_add(left.abs_diff(right).into());
1078            }
1079            ecx.write_scalar(Scalar::from_u16(res), &ecx.project_index(&dest, j)?)?;
1080        }
1081    }
1082
1083    interp_ok(())
1084}
1085
1086/// Multiplies packed 16-bit signed integer values, truncates the 32-bit
1087/// product to the 18 most significant bits by right-shifting, and then
1088/// divides the 18-bit value by 2 (rounding to nearest) by first adding
1089/// 1 and then taking the bits `1..=16`.
1090///
1091/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_mulhrs_epi16>
1092/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mulhrs_epi16>
1093fn pmulhrsw<'tcx>(
1094    ecx: &mut crate::MiriInterpCx<'tcx>,
1095    left: &OpTy<'tcx>,
1096    right: &OpTy<'tcx>,
1097    dest: &MPlaceTy<'tcx>,
1098) -> InterpResult<'tcx, ()> {
1099    let (left, left_len) = ecx.project_to_simd(left)?;
1100    let (right, right_len) = ecx.project_to_simd(right)?;
1101    let (dest, dest_len) = ecx.project_to_simd(dest)?;
1102
1103    assert_eq!(dest_len, left_len);
1104    assert_eq!(dest_len, right_len);
1105
1106    for i in 0..dest_len {
1107        let left = ecx.read_scalar(&ecx.project_index(&left, i)?)?.to_i16()?;
1108        let right = ecx.read_scalar(&ecx.project_index(&right, i)?)?.to_i16()?;
1109        let dest = ecx.project_index(&dest, i)?;
1110
1111        let res = (i32::from(left).strict_mul(right.into()) >> 14).strict_add(1) >> 1;
1112
1113        // The result of this operation can overflow a signed 16-bit integer.
1114        // When `left` and `right` are -0x8000, the result is 0x8000.
1115        #[expect(clippy::as_conversions)]
1116        let res = res as i16;
1117
1118        ecx.write_scalar(Scalar::from_i16(res), &dest)?;
1119    }
1120
1121    interp_ok(())
1122}
1123
1124/// Perform a carry-less multiplication of two 64-bit integers, selected from `left` and `right` according to `imm8`,
1125/// and store the results in `dst`.
1126///
1127/// `left` and `right` are both vectors of type `len` x i64. Only bits 0 and 4 of `imm8` matter;
1128/// they select the element of `left` and `right`, respectively.
1129///
1130/// `len` is the SIMD vector length (in counts of `i64` values). It is expected to be one of
1131/// `2`, `4`, or `8`.
1132///
1133/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_clmulepi64_si128>
1134fn pclmulqdq<'tcx>(
1135    ecx: &mut MiriInterpCx<'tcx>,
1136    left: &OpTy<'tcx>,
1137    right: &OpTy<'tcx>,
1138    imm8: &OpTy<'tcx>,
1139    dest: &MPlaceTy<'tcx>,
1140    len: u64,
1141) -> InterpResult<'tcx, ()> {
1142    assert_eq!(left.layout, right.layout);
1143    assert_eq!(left.layout.size, dest.layout.size);
1144    assert!([2u64, 4, 8].contains(&len));
1145
1146    // Transmute the input into arrays of `[u64; len]`.
1147    // Transmute the output into an array of `[u128, len / 2]`.
1148
1149    let src_layout = ecx.layout_of(Ty::new_array(ecx.tcx.tcx, ecx.tcx.types.u64, len))?;
1150    let dest_layout = ecx.layout_of(Ty::new_array(ecx.tcx.tcx, ecx.tcx.types.u128, len / 2))?;
1151
1152    let left = left.transmute(src_layout, ecx)?;
1153    let right = right.transmute(src_layout, ecx)?;
1154    let dest = dest.transmute(dest_layout, ecx)?;
1155
1156    let imm8 = ecx.read_scalar(imm8)?.to_u8()?;
1157
1158    for i in 0..(len / 2) {
1159        let lo = i.strict_mul(2);
1160        let hi = i.strict_mul(2).strict_add(1);
1161
1162        // select the 64-bit integer from left that the user specified (low or high)
1163        let index = if (imm8 & 0x01) == 0 { lo } else { hi };
1164        let left = ecx.read_scalar(&ecx.project_index(&left, index)?)?.to_u64()?;
1165
1166        // select the 64-bit integer from right that the user specified (low or high)
1167        let index = if (imm8 & 0x10) == 0 { lo } else { hi };
1168        let right = ecx.read_scalar(&ecx.project_index(&right, index)?)?.to_u64()?;
1169
1170        // Perform carry-less multiplication.
1171        //
1172        // This operation is like long multiplication, but ignores all carries.
1173        // That idea corresponds to the xor operator, which is used in the implementation.
1174        //
1175        // Wikipedia has an example https://en.wikipedia.org/wiki/Carry-less_product#Example
1176        let mut result: u128 = 0;
1177
1178        for i in 0..64 {
1179            // if the i-th bit in right is set
1180            if (right & (1 << i)) != 0 {
1181                // xor result with `left` shifted to the left by i positions
1182                result ^= u128::from(left) << i;
1183            }
1184        }
1185
1186        let dest = ecx.project_index(&dest, i)?;
1187        ecx.write_scalar(Scalar::from_u128(result), &dest)?;
1188    }
1189
1190    interp_ok(())
1191}
1192
1193/// Packs two N-bit integer vectors to a single N/2-bit integers.
1194///
1195/// The conversion from N-bit to N/2-bit should be provided by `f`.
1196///
1197/// Each 128-bit chunk is treated independently (i.e., the value for
1198/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1199/// 128-bit chunks of `left` and `right`).
1200fn pack_generic<'tcx>(
1201    ecx: &mut crate::MiriInterpCx<'tcx>,
1202    left: &OpTy<'tcx>,
1203    right: &OpTy<'tcx>,
1204    dest: &MPlaceTy<'tcx>,
1205    f: impl Fn(Scalar) -> InterpResult<'tcx, Scalar>,
1206) -> InterpResult<'tcx, ()> {
1207    assert_eq!(left.layout, right.layout);
1208    assert_eq!(left.layout.size, dest.layout.size);
1209
1210    let (num_chunks, op_items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
1211    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
1212    let (_, dest_items_per_chunk, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
1213
1214    assert_eq!(dest_items_per_chunk, op_items_per_chunk.strict_mul(2));
1215
1216    for i in 0..num_chunks {
1217        let left = ecx.project_index(&left, i)?;
1218        let right = ecx.project_index(&right, i)?;
1219        let dest = ecx.project_index(&dest, i)?;
1220
1221        for j in 0..op_items_per_chunk {
1222            let left = ecx.read_scalar(&ecx.project_index(&left, j)?)?;
1223            let right = ecx.read_scalar(&ecx.project_index(&right, j)?)?;
1224            let left_dest = ecx.project_index(&dest, j)?;
1225            let right_dest = ecx.project_index(&dest, j.strict_add(op_items_per_chunk))?;
1226
1227            let left_res = f(left)?;
1228            let right_res = f(right)?;
1229
1230            ecx.write_scalar(left_res, &left_dest)?;
1231            ecx.write_scalar(right_res, &right_dest)?;
1232        }
1233    }
1234
1235    interp_ok(())
1236}
1237
1238/// Converts two 16-bit integer vectors to a single 8-bit integer
1239/// vector with signed saturation.
1240///
1241/// Each 128-bit chunk is treated independently (i.e., the value for
1242/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1243/// 128-bit chunks of `left` and `right`).
1244fn packsswb<'tcx>(
1245    ecx: &mut crate::MiriInterpCx<'tcx>,
1246    left: &OpTy<'tcx>,
1247    right: &OpTy<'tcx>,
1248    dest: &MPlaceTy<'tcx>,
1249) -> InterpResult<'tcx, ()> {
1250    pack_generic(ecx, left, right, dest, |op| {
1251        let op = op.to_i16()?;
1252        let res = i8::try_from(op).unwrap_or(if op < 0 { i8::MIN } else { i8::MAX });
1253        interp_ok(Scalar::from_i8(res))
1254    })
1255}
1256
1257/// Converts two 16-bit signed integer vectors to a single 8-bit
1258/// unsigned integer vector with saturation.
1259///
1260/// Each 128-bit chunk is treated independently (i.e., the value for
1261/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1262/// 128-bit chunks of `left` and `right`).
1263fn packuswb<'tcx>(
1264    ecx: &mut crate::MiriInterpCx<'tcx>,
1265    left: &OpTy<'tcx>,
1266    right: &OpTy<'tcx>,
1267    dest: &MPlaceTy<'tcx>,
1268) -> InterpResult<'tcx, ()> {
1269    pack_generic(ecx, left, right, dest, |op| {
1270        let op = op.to_i16()?;
1271        let res = u8::try_from(op).unwrap_or(if op < 0 { 0 } else { u8::MAX });
1272        interp_ok(Scalar::from_u8(res))
1273    })
1274}
1275
1276/// Converts two 32-bit integer vectors to a single 16-bit integer
1277/// vector with signed saturation.
1278///
1279/// Each 128-bit chunk is treated independently (i.e., the value for
1280/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1281/// 128-bit chunks of `left` and `right`).
1282fn packssdw<'tcx>(
1283    ecx: &mut crate::MiriInterpCx<'tcx>,
1284    left: &OpTy<'tcx>,
1285    right: &OpTy<'tcx>,
1286    dest: &MPlaceTy<'tcx>,
1287) -> InterpResult<'tcx, ()> {
1288    pack_generic(ecx, left, right, dest, |op| {
1289        let op = op.to_i32()?;
1290        let res = i16::try_from(op).unwrap_or(if op < 0 { i16::MIN } else { i16::MAX });
1291        interp_ok(Scalar::from_i16(res))
1292    })
1293}
1294
1295/// Converts two 32-bit integer vectors to a single 16-bit integer
1296/// vector with unsigned saturation.
1297///
1298/// Each 128-bit chunk is treated independently (i.e., the value for
1299/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1300/// 128-bit chunks of `left` and `right`).
1301fn packusdw<'tcx>(
1302    ecx: &mut crate::MiriInterpCx<'tcx>,
1303    left: &OpTy<'tcx>,
1304    right: &OpTy<'tcx>,
1305    dest: &MPlaceTy<'tcx>,
1306) -> InterpResult<'tcx, ()> {
1307    pack_generic(ecx, left, right, dest, |op| {
1308        let op = op.to_i32()?;
1309        let res = u16::try_from(op).unwrap_or(if op < 0 { 0 } else { u16::MAX });
1310        interp_ok(Scalar::from_u16(res))
1311    })
1312}
1313
1314/// Negates elements from `left` when the corresponding element in
1315/// `right` is negative. If an element from `right` is zero, zero
1316/// is written to the corresponding output element.
1317/// In other words, multiplies `left` with `right.signum()`.
1318fn psign<'tcx>(
1319    ecx: &mut crate::MiriInterpCx<'tcx>,
1320    left: &OpTy<'tcx>,
1321    right: &OpTy<'tcx>,
1322    dest: &MPlaceTy<'tcx>,
1323) -> InterpResult<'tcx, ()> {
1324    let (left, left_len) = ecx.project_to_simd(left)?;
1325    let (right, right_len) = ecx.project_to_simd(right)?;
1326    let (dest, dest_len) = ecx.project_to_simd(dest)?;
1327
1328    assert_eq!(dest_len, left_len);
1329    assert_eq!(dest_len, right_len);
1330
1331    for i in 0..dest_len {
1332        let dest = ecx.project_index(&dest, i)?;
1333        let left = ecx.read_immediate(&ecx.project_index(&left, i)?)?;
1334        let right = ecx.read_scalar(&ecx.project_index(&right, i)?)?.to_int(dest.layout.size)?;
1335
1336        let res =
1337            ecx.binary_op(mir::BinOp::Mul, &left, &ImmTy::from_int(right.signum(), dest.layout))?;
1338
1339        ecx.write_immediate(*res, &dest)?;
1340    }
1341
1342    interp_ok(())
1343}
1344
1345/// Calcultates either `a + b + cb_in` or `a - b - cb_in` depending on the value
1346/// of `op` and returns both the sum and the overflow bit. `op` is expected to be
1347/// either one of `mir::BinOp::AddWithOverflow` and `mir::BinOp::SubWithOverflow`.
1348fn carrying_add<'tcx>(
1349    ecx: &mut crate::MiriInterpCx<'tcx>,
1350    cb_in: &OpTy<'tcx>,
1351    a: &OpTy<'tcx>,
1352    b: &OpTy<'tcx>,
1353    op: mir::BinOp,
1354) -> InterpResult<'tcx, (ImmTy<'tcx>, Scalar)> {
1355    assert!(op == mir::BinOp::AddWithOverflow || op == mir::BinOp::SubWithOverflow);
1356
1357    let cb_in = ecx.read_scalar(cb_in)?.to_u8()? != 0;
1358    let a = ecx.read_immediate(a)?;
1359    let b = ecx.read_immediate(b)?;
1360
1361    let (sum, overflow1) = ecx.binary_op(op, &a, &b)?.to_pair(ecx);
1362    let (sum, overflow2) =
1363        ecx.binary_op(op, &sum, &ImmTy::from_uint(cb_in, a.layout))?.to_pair(ecx);
1364    let cb_out = overflow1.to_scalar().to_bool()? | overflow2.to_scalar().to_bool()?;
1365
1366    interp_ok((sum, Scalar::from_u8(cb_out.into())))
1367}