miri/shims/x86/
sse2.rs

1use rustc_abi::CanonAbi;
2use rustc_apfloat::ieee::Double;
3use rustc_middle::ty::Ty;
4use rustc_span::Symbol;
5use rustc_target::callconv::FnAbi;
6
7use super::{
8    FloatBinOp, ShiftOp, bin_op_simd_float_all, bin_op_simd_float_first, convert_float_to_int,
9    packssdw, packsswb, packuswb, shift_simd_by_scalar,
10};
11use crate::*;
12
13impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
14pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
15    fn emulate_x86_sse2_intrinsic(
16        &mut self,
17        link_name: Symbol,
18        abi: &FnAbi<'tcx, Ty<'tcx>>,
19        args: &[OpTy<'tcx>],
20        dest: &MPlaceTy<'tcx>,
21    ) -> InterpResult<'tcx, EmulateItemResult> {
22        let this = self.eval_context_mut();
23        this.expect_target_feature_for_intrinsic(link_name, "sse2")?;
24        // Prefix should have already been checked.
25        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.sse2.").unwrap();
26
27        // These intrinsics operate on 128-bit (f32x4, f64x2, i8x16, i16x8, i32x4, i64x2) SIMD
28        // vectors unless stated otherwise.
29        // Many intrinsic names are sufixed with "ps" (packed single), "ss" (scalar signle),
30        // "pd" (packed double) or "sd" (scalar double), where single means single precision
31        // floating point (f32) and double means double precision floating point (f64). "ps"
32        // and "pd" means thet the operation is performed on each element of the vector, while
33        // "ss" and "sd" means that the operation is performed only on the first element, copying
34        // the remaining elements from the input vector (for binary operations, from the left-hand
35        // side).
36        // Intrinsincs sufixed with "epiX" or "epuX" operate with X-bit signed or unsigned
37        // vectors.
38        match unprefixed_name {
39            // Used to implement the _mm_madd_epi16 function.
40            // Multiplies packed signed 16-bit integers in `left` and `right`, producing
41            // intermediate signed 32-bit integers. Horizontally add adjacent pairs of
42            // intermediate 32-bit integers, and pack the results in `dest`.
43            "pmadd.wd" => {
44                let [left, right] =
45                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
46
47                let (left, left_len) = this.project_to_simd(left)?;
48                let (right, right_len) = this.project_to_simd(right)?;
49                let (dest, dest_len) = this.project_to_simd(dest)?;
50
51                assert_eq!(left_len, right_len);
52                assert_eq!(dest_len.strict_mul(2), left_len);
53
54                for i in 0..dest_len {
55                    let j1 = i.strict_mul(2);
56                    let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_i16()?;
57                    let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i16()?;
58
59                    let j2 = j1.strict_add(1);
60                    let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_i16()?;
61                    let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i16()?;
62
63                    let dest = this.project_index(&dest, i)?;
64
65                    // Multiplications are i16*i16->i32, which will not overflow.
66                    let mul1 = i32::from(left1).strict_mul(right1.into());
67                    let mul2 = i32::from(left2).strict_mul(right2.into());
68                    // However, this addition can overflow in the most extreme case
69                    // (-0x8000)*(-0x8000)+(-0x8000)*(-0x8000) = 0x80000000
70                    let res = mul1.wrapping_add(mul2);
71
72                    this.write_scalar(Scalar::from_i32(res), &dest)?;
73                }
74            }
75            // Used to implement the _mm_sad_epu8 function.
76            // Computes the absolute differences of packed unsigned 8-bit integers in `a`
77            // and `b`, then horizontally sum each consecutive 8 differences to produce
78            // two unsigned 16-bit integers, and pack these unsigned 16-bit integers in
79            // the low 16 bits of 64-bit elements returned.
80            //
81            // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_sad_epu8
82            "psad.bw" => {
83                let [left, right] =
84                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
85
86                let (left, left_len) = this.project_to_simd(left)?;
87                let (right, right_len) = this.project_to_simd(right)?;
88                let (dest, dest_len) = this.project_to_simd(dest)?;
89
90                // left and right are u8x16, dest is u64x2
91                assert_eq!(left_len, right_len);
92                assert_eq!(left_len, 16);
93                assert_eq!(dest_len, 2);
94
95                for i in 0..dest_len {
96                    let dest = this.project_index(&dest, i)?;
97
98                    let mut res: u16 = 0;
99                    let n = left_len.strict_div(dest_len);
100                    for j in 0..n {
101                        let op_i = j.strict_add(i.strict_mul(n));
102                        let left = this.read_scalar(&this.project_index(&left, op_i)?)?.to_u8()?;
103                        let right =
104                            this.read_scalar(&this.project_index(&right, op_i)?)?.to_u8()?;
105
106                        res = res.strict_add(left.abs_diff(right).into());
107                    }
108
109                    this.write_scalar(Scalar::from_u64(res.into()), &dest)?;
110                }
111            }
112            // Used to implement the _mm_{sll,srl,sra}_epi{16,32,64} functions
113            // (except _mm_sra_epi64, which is not available in SSE2).
114            // Shifts N-bit packed integers in left by the amount in right.
115            // Both operands are 128-bit vectors. However, right is interpreted as
116            // a single 64-bit integer (remaining bits are ignored).
117            // For logic shifts, when right is larger than N - 1, zero is produced.
118            // For arithmetic shifts, when right is larger than N - 1, the sign bit
119            // is copied to remaining bits.
120            "psll.w" | "psrl.w" | "psra.w" | "psll.d" | "psrl.d" | "psra.d" | "psll.q"
121            | "psrl.q" => {
122                let [left, right] =
123                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
124
125                let which = match unprefixed_name {
126                    "psll.w" | "psll.d" | "psll.q" => ShiftOp::Left,
127                    "psrl.w" | "psrl.d" | "psrl.q" => ShiftOp::RightLogic,
128                    "psra.w" | "psra.d" => ShiftOp::RightArith,
129                    _ => unreachable!(),
130                };
131
132                shift_simd_by_scalar(this, left, right, which, dest)?;
133            }
134            // Used to implement the _mm_cvtps_epi32, _mm_cvttps_epi32, _mm_cvtpd_epi32
135            // and _mm_cvttpd_epi32 functions.
136            // Converts packed f32/f64 to packed i32.
137            "cvtps2dq" | "cvttps2dq" | "cvtpd2dq" | "cvttpd2dq" => {
138                let [op] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
139
140                let (op_len, _) = op.layout.ty.simd_size_and_type(*this.tcx);
141                let (dest_len, _) = dest.layout.ty.simd_size_and_type(*this.tcx);
142                match unprefixed_name {
143                    "cvtps2dq" | "cvttps2dq" => {
144                        // f32x4 to i32x4 conversion
145                        assert_eq!(op_len, 4);
146                        assert_eq!(dest_len, op_len);
147                    }
148                    "cvtpd2dq" | "cvttpd2dq" => {
149                        // f64x2 to i32x4 conversion
150                        // the last two values are filled with zeros
151                        assert_eq!(op_len, 2);
152                        assert_eq!(dest_len, 4);
153                    }
154                    _ => unreachable!(),
155                }
156
157                let rnd = match unprefixed_name {
158                    // "current SSE rounding mode", assume nearest
159                    // https://www.felixcloutier.com/x86/cvtps2dq
160                    // https://www.felixcloutier.com/x86/cvtpd2dq
161                    "cvtps2dq" | "cvtpd2dq" => rustc_apfloat::Round::NearestTiesToEven,
162                    // always truncate
163                    // https://www.felixcloutier.com/x86/cvttps2dq
164                    // https://www.felixcloutier.com/x86/cvttpd2dq
165                    "cvttps2dq" | "cvttpd2dq" => rustc_apfloat::Round::TowardZero,
166                    _ => unreachable!(),
167                };
168
169                convert_float_to_int(this, op, rnd, dest)?;
170            }
171            // Used to implement the _mm_packs_epi16 function.
172            // Converts two 16-bit integer vectors to a single 8-bit integer
173            // vector with signed saturation.
174            "packsswb.128" => {
175                let [left, right] =
176                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
177
178                packsswb(this, left, right, dest)?;
179            }
180            // Used to implement the _mm_packus_epi16 function.
181            // Converts two 16-bit signed integer vectors to a single 8-bit
182            // unsigned integer vector with saturation.
183            "packuswb.128" => {
184                let [left, right] =
185                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
186
187                packuswb(this, left, right, dest)?;
188            }
189            // Used to implement the _mm_packs_epi32 function.
190            // Converts two 32-bit integer vectors to a single 16-bit integer
191            // vector with signed saturation.
192            "packssdw.128" => {
193                let [left, right] =
194                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
195
196                packssdw(this, left, right, dest)?;
197            }
198            // Used to implement _mm_min_sd and _mm_max_sd functions.
199            // Note that the semantics are a bit different from Rust simd_min
200            // and simd_max intrinsics regarding handling of NaN and -0.0: Rust
201            // matches the IEEE min/max operations, while x86 has different
202            // semantics.
203            "min.sd" | "max.sd" => {
204                let [left, right] =
205                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
206
207                let which = match unprefixed_name {
208                    "min.sd" => FloatBinOp::Min,
209                    "max.sd" => FloatBinOp::Max,
210                    _ => unreachable!(),
211                };
212
213                bin_op_simd_float_first::<Double>(this, which, left, right, dest)?;
214            }
215            // Used to implement _mm_min_pd and _mm_max_pd functions.
216            // Note that the semantics are a bit different from Rust simd_min
217            // and simd_max intrinsics regarding handling of NaN and -0.0: Rust
218            // matches the IEEE min/max operations, while x86 has different
219            // semantics.
220            "min.pd" | "max.pd" => {
221                let [left, right] =
222                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
223
224                let which = match unprefixed_name {
225                    "min.pd" => FloatBinOp::Min,
226                    "max.pd" => FloatBinOp::Max,
227                    _ => unreachable!(),
228                };
229
230                bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
231            }
232            // Used to implement the _mm_cmp*_sd functions.
233            // Performs a comparison operation on the first component of `left`
234            // and `right`, returning 0 if false or `u64::MAX` if true. The remaining
235            // components are copied from `left`.
236            // _mm_cmp_sd is actually an AVX function where the operation is specified
237            // by a const parameter.
238            // _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_sd are SSE2 functions
239            // with hard-coded operations.
240            "cmp.sd" => {
241                let [left, right, imm] =
242                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
243
244                let which =
245                    FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
246
247                bin_op_simd_float_first::<Double>(this, which, left, right, dest)?;
248            }
249            // Used to implement the _mm_cmp*_pd functions.
250            // Performs a comparison operation on each component of `left`
251            // and `right`. For each component, returns 0 if false or `u64::MAX`
252            // if true.
253            // _mm_cmp_pd is actually an AVX function where the operation is specified
254            // by a const parameter.
255            // _mm_cmp{eq,lt,le,gt,ge,neq,nlt,nle,ngt,nge,ord,unord}_pd are SSE2 functions
256            // with hard-coded operations.
257            "cmp.pd" => {
258                let [left, right, imm] =
259                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
260
261                let which =
262                    FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
263
264                bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
265            }
266            // Used to implement _mm_{,u}comi{eq,lt,le,gt,ge,neq}_sd functions.
267            // Compares the first component of `left` and `right` and returns
268            // a scalar value (0 or 1).
269            "comieq.sd" | "comilt.sd" | "comile.sd" | "comigt.sd" | "comige.sd" | "comineq.sd"
270            | "ucomieq.sd" | "ucomilt.sd" | "ucomile.sd" | "ucomigt.sd" | "ucomige.sd"
271            | "ucomineq.sd" => {
272                let [left, right] =
273                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
274
275                let (left, left_len) = this.project_to_simd(left)?;
276                let (right, right_len) = this.project_to_simd(right)?;
277
278                assert_eq!(left_len, right_len);
279
280                let left = this.read_scalar(&this.project_index(&left, 0)?)?.to_f64()?;
281                let right = this.read_scalar(&this.project_index(&right, 0)?)?.to_f64()?;
282                // The difference between the com* and ucom* variants is signaling
283                // of exceptions when either argument is a quiet NaN. We do not
284                // support accessing the SSE status register from miri (or from Rust,
285                // for that matter), so we treat both variants equally.
286                let res = match unprefixed_name {
287                    "comieq.sd" | "ucomieq.sd" => left == right,
288                    "comilt.sd" | "ucomilt.sd" => left < right,
289                    "comile.sd" | "ucomile.sd" => left <= right,
290                    "comigt.sd" | "ucomigt.sd" => left > right,
291                    "comige.sd" | "ucomige.sd" => left >= right,
292                    "comineq.sd" | "ucomineq.sd" => left != right,
293                    _ => unreachable!(),
294                };
295                this.write_scalar(Scalar::from_i32(i32::from(res)), dest)?;
296            }
297            // Use to implement the _mm_cvtsd_si32, _mm_cvttsd_si32,
298            // _mm_cvtsd_si64 and _mm_cvttsd_si64 functions.
299            // Converts the first component of `op` from f64 to i32/i64.
300            "cvtsd2si" | "cvttsd2si" | "cvtsd2si64" | "cvttsd2si64" => {
301                let [op] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
302                let (op, _) = this.project_to_simd(op)?;
303
304                let op = this.read_immediate(&this.project_index(&op, 0)?)?;
305
306                let rnd = match unprefixed_name {
307                    // "current SSE rounding mode", assume nearest
308                    // https://www.felixcloutier.com/x86/cvtsd2si
309                    "cvtsd2si" | "cvtsd2si64" => rustc_apfloat::Round::NearestTiesToEven,
310                    // always truncate
311                    // https://www.felixcloutier.com/x86/cvttsd2si
312                    "cvttsd2si" | "cvttsd2si64" => rustc_apfloat::Round::TowardZero,
313                    _ => unreachable!(),
314                };
315
316                let res = this.float_to_int_checked(&op, dest.layout, rnd)?.unwrap_or_else(|| {
317                    // Fallback to minimum according to SSE semantics.
318                    ImmTy::from_int(dest.layout.size.signed_int_min(), dest.layout)
319                });
320
321                this.write_immediate(*res, dest)?;
322            }
323            // Used to implement the _mm_cvtsd_ss and _mm_cvtss_sd functions.
324            // Converts the first f64/f32 from `right` to f32/f64 and copies
325            // the remaining elements from `left`
326            "cvtsd2ss" | "cvtss2sd" => {
327                let [left, right] =
328                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
329
330                let (left, left_len) = this.project_to_simd(left)?;
331                let (right, _) = this.project_to_simd(right)?;
332                let (dest, dest_len) = this.project_to_simd(dest)?;
333
334                assert_eq!(dest_len, left_len);
335
336                // Convert first element of `right`
337                let right0 = this.read_immediate(&this.project_index(&right, 0)?)?;
338                let dest0 = this.project_index(&dest, 0)?;
339                // `float_to_float_or_int` here will convert from f64 to f32 (cvtsd2ss) or
340                // from f32 to f64 (cvtss2sd).
341                let res0 = this.float_to_float_or_int(&right0, dest0.layout)?;
342                this.write_immediate(*res0, &dest0)?;
343
344                // Copy remaining from `left`
345                for i in 1..dest_len {
346                    this.copy_op(&this.project_index(&left, i)?, &this.project_index(&dest, i)?)?;
347                }
348            }
349            _ => return interp_ok(EmulateItemResult::NotSupported),
350        }
351        interp_ok(EmulateItemResult::NeedsReturn)
352    }
353}