miri/shims/x86/
avx2.rs

1use rustc_abi::CanonAbi;
2use rustc_middle::mir;
3use rustc_middle::ty::Ty;
4use rustc_span::Symbol;
5use rustc_target::callconv::FnAbi;
6
7use super::{
8    ShiftOp, horizontal_bin_op, int_abs, mask_load, mask_store, mpsadbw, packssdw, packsswb,
9    packusdw, packuswb, pmulhrsw, psign, shift_simd_by_scalar, shift_simd_by_simd,
10};
11use crate::*;
12
13impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
14pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
15    fn emulate_x86_avx2_intrinsic(
16        &mut self,
17        link_name: Symbol,
18        abi: &FnAbi<'tcx, Ty<'tcx>>,
19        args: &[OpTy<'tcx>],
20        dest: &MPlaceTy<'tcx>,
21    ) -> InterpResult<'tcx, EmulateItemResult> {
22        let this = self.eval_context_mut();
23        this.expect_target_feature_for_intrinsic(link_name, "avx2")?;
24        // Prefix should have already been checked.
25        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx2.").unwrap();
26
27        match unprefixed_name {
28            // Used to implement the _mm256_abs_epi{8,16,32} functions.
29            // Calculates the absolute value of packed 8/16/32-bit integers.
30            "pabs.b" | "pabs.w" | "pabs.d" => {
31                let [op] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
32
33                int_abs(this, op, dest)?;
34            }
35            // Used to implement the _mm256_h{add,adds,sub}_epi{16,32} functions.
36            // Horizontally add / add with saturation / subtract adjacent 16/32-bit
37            // integer values in `left` and `right`.
38            "phadd.w" | "phadd.sw" | "phadd.d" | "phsub.w" | "phsub.sw" | "phsub.d" => {
39                let [left, right] =
40                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
41
42                let (which, saturating) = match unprefixed_name {
43                    "phadd.w" | "phadd.d" => (mir::BinOp::Add, false),
44                    "phadd.sw" => (mir::BinOp::Add, true),
45                    "phsub.w" | "phsub.d" => (mir::BinOp::Sub, false),
46                    "phsub.sw" => (mir::BinOp::Sub, true),
47                    _ => unreachable!(),
48                };
49
50                horizontal_bin_op(this, which, saturating, left, right, dest)?;
51            }
52            // Used to implement `_mm{,_mask}_{i32,i64}gather_{epi32,epi64,pd,ps}` functions
53            // Gathers elements from `slice` using `offsets * scale` as indices.
54            // When the highest bit of the corresponding element of `mask` is 0,
55            // the value is copied from `src` instead.
56            "gather.d.d" | "gather.d.d.256" | "gather.d.q" | "gather.d.q.256" | "gather.q.d"
57            | "gather.q.d.256" | "gather.q.q" | "gather.q.q.256" | "gather.d.pd"
58            | "gather.d.pd.256" | "gather.q.pd" | "gather.q.pd.256" | "gather.d.ps"
59            | "gather.d.ps.256" | "gather.q.ps" | "gather.q.ps.256" => {
60                let [src, slice, offsets, mask, scale] =
61                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
62
63                assert_eq!(dest.layout, src.layout);
64
65                let (src, _) = this.project_to_simd(src)?;
66                let (offsets, offsets_len) = this.project_to_simd(offsets)?;
67                let (mask, mask_len) = this.project_to_simd(mask)?;
68                let (dest, dest_len) = this.project_to_simd(dest)?;
69
70                // There are cases like dest: i32x4, offsets: i64x2
71                // If dest has more elements than offset, extra dest elements are filled with zero.
72                // If offsets has more elements than dest, extra offsets are ignored.
73                let actual_len = dest_len.min(offsets_len);
74
75                assert_eq!(dest_len, mask_len);
76
77                let mask_item_size = mask.layout.field(this, 0).size;
78                let high_bit_offset = mask_item_size.bits().strict_sub(1);
79
80                let scale = this.read_scalar(scale)?.to_i8()?;
81                if !matches!(scale, 1 | 2 | 4 | 8) {
82                    panic!("invalid gather scale {scale}");
83                }
84                let scale = i64::from(scale);
85
86                let slice = this.read_pointer(slice)?;
87                for i in 0..actual_len {
88                    let mask = this.project_index(&mask, i)?;
89                    let dest = this.project_index(&dest, i)?;
90
91                    if this.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
92                        let offset = this.project_index(&offsets, i)?;
93                        let offset =
94                            i64::try_from(this.read_scalar(&offset)?.to_int(offset.layout.size)?)
95                                .unwrap();
96                        let ptr = slice.wrapping_signed_offset(offset.strict_mul(scale), &this.tcx);
97                        // Unaligned copy, which is what we want.
98                        this.mem_copy(
99                            ptr,
100                            dest.ptr(),
101                            dest.layout.size,
102                            /*nonoverlapping*/ true,
103                        )?;
104                    } else {
105                        this.copy_op(&this.project_index(&src, i)?, &dest)?;
106                    }
107                }
108                for i in actual_len..dest_len {
109                    let dest = this.project_index(&dest, i)?;
110                    this.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
111                }
112            }
113            // Used to implement the _mm256_madd_epi16 function.
114            // Multiplies packed signed 16-bit integers in `left` and `right`, producing
115            // intermediate signed 32-bit integers. Horizontally add adjacent pairs of
116            // intermediate 32-bit integers, and pack the results in `dest`.
117            "pmadd.wd" => {
118                let [left, right] =
119                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
120
121                let (left, left_len) = this.project_to_simd(left)?;
122                let (right, right_len) = this.project_to_simd(right)?;
123                let (dest, dest_len) = this.project_to_simd(dest)?;
124
125                assert_eq!(left_len, right_len);
126                assert_eq!(dest_len.strict_mul(2), left_len);
127
128                for i in 0..dest_len {
129                    let j1 = i.strict_mul(2);
130                    let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_i16()?;
131                    let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i16()?;
132
133                    let j2 = j1.strict_add(1);
134                    let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_i16()?;
135                    let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i16()?;
136
137                    let dest = this.project_index(&dest, i)?;
138
139                    // Multiplications are i16*i16->i32, which will not overflow.
140                    let mul1 = i32::from(left1).strict_mul(right1.into());
141                    let mul2 = i32::from(left2).strict_mul(right2.into());
142                    // However, this addition can overflow in the most extreme case
143                    // (-0x8000)*(-0x8000)+(-0x8000)*(-0x8000) = 0x80000000
144                    let res = mul1.wrapping_add(mul2);
145
146                    this.write_scalar(Scalar::from_i32(res), &dest)?;
147                }
148            }
149            // Used to implement the _mm256_maddubs_epi16 function.
150            // Multiplies packed 8-bit unsigned integers from `left` and packed
151            // signed 8-bit integers from `right` into 16-bit signed integers. Then,
152            // the saturating sum of the products with indices `2*i` and `2*i+1`
153            // produces the output at index `i`.
154            "pmadd.ub.sw" => {
155                let [left, right] =
156                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
157
158                let (left, left_len) = this.project_to_simd(left)?;
159                let (right, right_len) = this.project_to_simd(right)?;
160                let (dest, dest_len) = this.project_to_simd(dest)?;
161
162                assert_eq!(left_len, right_len);
163                assert_eq!(dest_len.strict_mul(2), left_len);
164
165                for i in 0..dest_len {
166                    let j1 = i.strict_mul(2);
167                    let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_u8()?;
168                    let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i8()?;
169
170                    let j2 = j1.strict_add(1);
171                    let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_u8()?;
172                    let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i8()?;
173
174                    let dest = this.project_index(&dest, i)?;
175
176                    // Multiplication of a u8 and an i8 into an i16 cannot overflow.
177                    let mul1 = i16::from(left1).strict_mul(right1.into());
178                    let mul2 = i16::from(left2).strict_mul(right2.into());
179                    let res = mul1.saturating_add(mul2);
180
181                    this.write_scalar(Scalar::from_i16(res), &dest)?;
182                }
183            }
184            // Used to implement the _mm_maskload_epi32, _mm_maskload_epi64,
185            // _mm256_maskload_epi32 and _mm256_maskload_epi64 functions.
186            // For the element `i`, if the high bit of the `i`-th element of `mask`
187            // is one, it is loaded from `ptr.wrapping_add(i)`, otherwise zero is
188            // loaded.
189            "maskload.d" | "maskload.q" | "maskload.d.256" | "maskload.q.256" => {
190                let [ptr, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
191
192                mask_load(this, ptr, mask, dest)?;
193            }
194            // Used to implement the _mm_maskstore_epi32, _mm_maskstore_epi64,
195            // _mm256_maskstore_epi32 and _mm256_maskstore_epi64 functions.
196            // For the element `i`, if the high bit of the element `i`-th of `mask`
197            // is one, it is stored into `ptr.wapping_add(i)`.
198            // Unlike SSE2's _mm_maskmoveu_si128, these are not non-temporal stores.
199            "maskstore.d" | "maskstore.q" | "maskstore.d.256" | "maskstore.q.256" => {
200                let [ptr, mask, value] =
201                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
202
203                mask_store(this, ptr, mask, value)?;
204            }
205            // Used to implement the _mm256_mpsadbw_epu8 function.
206            // Compute the sum of absolute differences of quadruplets of unsigned
207            // 8-bit integers in `left` and `right`, and store the 16-bit results
208            // in `right`. Quadruplets are selected from `left` and `right` with
209            // offsets specified in `imm`.
210            // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mpsadbw_epu8
211            "mpsadbw" => {
212                let [left, right, imm] =
213                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
214
215                mpsadbw(this, left, right, imm, dest)?;
216            }
217            // Used to implement the _mm256_mulhrs_epi16 function.
218            // Multiplies packed 16-bit signed integer values, truncates the 32-bit
219            // product to the 18 most significant bits by right-shifting, and then
220            // divides the 18-bit value by 2 (rounding to nearest) by first adding
221            // 1 and then taking the bits `1..=16`.
222            // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mulhrs_epi16
223            "pmul.hr.sw" => {
224                let [left, right] =
225                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
226
227                pmulhrsw(this, left, right, dest)?;
228            }
229            // Used to implement the _mm256_packs_epi16 function.
230            // Converts two 16-bit integer vectors to a single 8-bit integer
231            // vector with signed saturation.
232            "packsswb" => {
233                let [left, right] =
234                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
235
236                packsswb(this, left, right, dest)?;
237            }
238            // Used to implement the _mm256_packs_epi32 function.
239            // Converts two 32-bit integer vectors to a single 16-bit integer
240            // vector with signed saturation.
241            "packssdw" => {
242                let [left, right] =
243                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
244
245                packssdw(this, left, right, dest)?;
246            }
247            // Used to implement the _mm256_packus_epi16 function.
248            // Converts two 16-bit signed integer vectors to a single 8-bit
249            // unsigned integer vector with saturation.
250            "packuswb" => {
251                let [left, right] =
252                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
253
254                packuswb(this, left, right, dest)?;
255            }
256            // Used to implement the _mm256_packus_epi32 function.
257            // Concatenates two 32-bit signed integer vectors and converts
258            // the result to a 16-bit unsigned integer vector with saturation.
259            "packusdw" => {
260                let [left, right] =
261                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
262
263                packusdw(this, left, right, dest)?;
264            }
265            // Used to implement the _mm256_permutevar8x32_epi32 and
266            // _mm256_permutevar8x32_ps function.
267            // Shuffles `left` using the three low bits of each element of `right`
268            // as indices.
269            "permd" | "permps" => {
270                let [left, right] =
271                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
272
273                let (left, left_len) = this.project_to_simd(left)?;
274                let (right, right_len) = this.project_to_simd(right)?;
275                let (dest, dest_len) = this.project_to_simd(dest)?;
276
277                assert_eq!(dest_len, left_len);
278                assert_eq!(dest_len, right_len);
279
280                for i in 0..dest_len {
281                    let dest = this.project_index(&dest, i)?;
282                    let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u32()?;
283                    let left = this.project_index(&left, (right & 0b111).into())?;
284
285                    this.copy_op(&left, &dest)?;
286                }
287            }
288            // Used to implement the _mm256_permute2x128_si256 function.
289            // Shuffles 128-bit blocks of `a` and `b` using `imm` as pattern.
290            "vperm2i128" => {
291                let [left, right, imm] =
292                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
293
294                assert_eq!(left.layout.size.bits(), 256);
295                assert_eq!(right.layout.size.bits(), 256);
296                assert_eq!(dest.layout.size.bits(), 256);
297
298                // Transmute to `[i128; 2]`
299
300                let array_layout =
301                    this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.i128, 2))?;
302                let left = left.transmute(array_layout, this)?;
303                let right = right.transmute(array_layout, this)?;
304                let dest = dest.transmute(array_layout, this)?;
305
306                let imm = this.read_scalar(imm)?.to_u8()?;
307
308                for i in 0..2 {
309                    let dest = this.project_index(&dest, i)?;
310                    let src = match (imm >> i.strict_mul(4)) & 0b11 {
311                        0 => this.project_index(&left, 0)?,
312                        1 => this.project_index(&left, 1)?,
313                        2 => this.project_index(&right, 0)?,
314                        3 => this.project_index(&right, 1)?,
315                        _ => unreachable!(),
316                    };
317
318                    this.copy_op(&src, &dest)?;
319                }
320            }
321            // Used to implement the _mm256_sad_epu8 function.
322            // Compute the absolute differences of packed unsigned 8-bit integers
323            // in `left` and `right`, then horizontally sum each consecutive 8
324            // differences to produce four unsigned 16-bit integers, and pack
325            // these unsigned 16-bit integers in the low 16 bits of 64-bit elements
326            // in `dest`.
327            // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_sad_epu8
328            "psad.bw" => {
329                let [left, right] =
330                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
331
332                let (left, left_len) = this.project_to_simd(left)?;
333                let (right, right_len) = this.project_to_simd(right)?;
334                let (dest, dest_len) = this.project_to_simd(dest)?;
335
336                assert_eq!(left_len, right_len);
337                assert_eq!(left_len, dest_len.strict_mul(8));
338
339                for i in 0..dest_len {
340                    let dest = this.project_index(&dest, i)?;
341
342                    let mut acc: u16 = 0;
343                    for j in 0..8 {
344                        let src_index = i.strict_mul(8).strict_add(j);
345
346                        let left = this.project_index(&left, src_index)?;
347                        let left = this.read_scalar(&left)?.to_u8()?;
348
349                        let right = this.project_index(&right, src_index)?;
350                        let right = this.read_scalar(&right)?.to_u8()?;
351
352                        acc = acc.strict_add(left.abs_diff(right).into());
353                    }
354
355                    this.write_scalar(Scalar::from_u64(acc.into()), &dest)?;
356                }
357            }
358            // Used to implement the _mm256_shuffle_epi8 intrinsic.
359            // Shuffles bytes from `left` using `right` as pattern.
360            // Each 128-bit block is shuffled independently.
361            "pshuf.b" => {
362                let [left, right] =
363                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
364
365                let (left, left_len) = this.project_to_simd(left)?;
366                let (right, right_len) = this.project_to_simd(right)?;
367                let (dest, dest_len) = this.project_to_simd(dest)?;
368
369                assert_eq!(dest_len, left_len);
370                assert_eq!(dest_len, right_len);
371
372                for i in 0..dest_len {
373                    let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u8()?;
374                    let dest = this.project_index(&dest, i)?;
375
376                    let res = if right & 0x80 == 0 {
377                        // Shuffle each 128-bit (16-byte) block independently.
378                        let j = u64::from(right % 16).strict_add(i & !15);
379                        this.read_scalar(&this.project_index(&left, j)?)?
380                    } else {
381                        // If the highest bit in `right` is 1, write zero.
382                        Scalar::from_u8(0)
383                    };
384
385                    this.write_scalar(res, &dest)?;
386                }
387            }
388            // Used to implement the _mm256_sign_epi{8,16,32} functions.
389            // Negates elements from `left` when the corresponding element in
390            // `right` is negative. If an element from `right` is zero, zero
391            // is writen to the corresponding output element.
392            // Basically, we multiply `left` with `right.signum()`.
393            "psign.b" | "psign.w" | "psign.d" => {
394                let [left, right] =
395                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
396
397                psign(this, left, right, dest)?;
398            }
399            // Used to implement the _mm256_{sll,srl,sra}_epi{16,32,64} functions
400            // (except _mm256_sra_epi64, which is not available in AVX2).
401            // Shifts N-bit packed integers in left by the amount in right.
402            // `right` is as 128-bit vector. but it is interpreted as a single
403            // 64-bit integer (remaining bits are ignored).
404            // For logic shifts, when right is larger than N - 1, zero is produced.
405            // For arithmetic shifts, when right is larger than N - 1, the sign bit
406            // is copied to remaining bits.
407            "psll.w" | "psrl.w" | "psra.w" | "psll.d" | "psrl.d" | "psra.d" | "psll.q"
408            | "psrl.q" => {
409                let [left, right] =
410                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
411
412                let which = match unprefixed_name {
413                    "psll.w" | "psll.d" | "psll.q" => ShiftOp::Left,
414                    "psrl.w" | "psrl.d" | "psrl.q" => ShiftOp::RightLogic,
415                    "psra.w" | "psra.d" => ShiftOp::RightArith,
416                    _ => unreachable!(),
417                };
418
419                shift_simd_by_scalar(this, left, right, which, dest)?;
420            }
421            // Used to implement the _mm{,256}_{sllv,srlv,srav}_epi{32,64} functions
422            // (except _mm{,256}_srav_epi64, which are not available in AVX2).
423            "psllv.d" | "psllv.d.256" | "psllv.q" | "psllv.q.256" | "psrlv.d" | "psrlv.d.256"
424            | "psrlv.q" | "psrlv.q.256" | "psrav.d" | "psrav.d.256" => {
425                let [left, right] =
426                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
427
428                let which = match unprefixed_name {
429                    "psllv.d" | "psllv.d.256" | "psllv.q" | "psllv.q.256" => ShiftOp::Left,
430                    "psrlv.d" | "psrlv.d.256" | "psrlv.q" | "psrlv.q.256" => ShiftOp::RightLogic,
431                    "psrav.d" | "psrav.d.256" => ShiftOp::RightArith,
432                    _ => unreachable!(),
433                };
434
435                shift_simd_by_simd(this, left, right, which, dest)?;
436            }
437            _ => return interp_ok(EmulateItemResult::NotSupported),
438        }
439        interp_ok(EmulateItemResult::NeedsReturn)
440    }
441}