miri/shims/x86/
avx.rs

1use rustc_abi::CanonAbi;
2use rustc_apfloat::ieee::{Double, Single};
3use rustc_middle::mir;
4use rustc_middle::ty::Ty;
5use rustc_span::Symbol;
6use rustc_target::callconv::FnAbi;
7
8use super::{
9    FloatBinOp, FloatUnaryOp, bin_op_simd_float_all, conditional_dot_product, convert_float_to_int,
10    horizontal_bin_op, mask_load, mask_store, round_all, test_bits_masked, test_high_bits_masked,
11    unary_op_ps,
12};
13use crate::*;
14
15impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
16pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
17    fn emulate_x86_avx_intrinsic(
18        &mut self,
19        link_name: Symbol,
20        abi: &FnAbi<'tcx, Ty<'tcx>>,
21        args: &[OpTy<'tcx>],
22        dest: &MPlaceTy<'tcx>,
23    ) -> InterpResult<'tcx, EmulateItemResult> {
24        let this = self.eval_context_mut();
25        this.expect_target_feature_for_intrinsic(link_name, "avx")?;
26        // Prefix should have already been checked.
27        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx.").unwrap();
28
29        match unprefixed_name {
30            // Used to implement _mm256_min_ps and _mm256_max_ps functions.
31            // Note that the semantics are a bit different from Rust simd_min
32            // and simd_max intrinsics regarding handling of NaN and -0.0: Rust
33            // matches the IEEE min/max operations, while x86 has different
34            // semantics.
35            "min.ps.256" | "max.ps.256" => {
36                let [left, right] =
37                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
38
39                let which = match unprefixed_name {
40                    "min.ps.256" => FloatBinOp::Min,
41                    "max.ps.256" => FloatBinOp::Max,
42                    _ => unreachable!(),
43                };
44
45                bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
46            }
47            // Used to implement _mm256_min_pd and _mm256_max_pd functions.
48            "min.pd.256" | "max.pd.256" => {
49                let [left, right] =
50                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
51
52                let which = match unprefixed_name {
53                    "min.pd.256" => FloatBinOp::Min,
54                    "max.pd.256" => FloatBinOp::Max,
55                    _ => unreachable!(),
56                };
57
58                bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
59            }
60            // Used to implement the _mm256_round_ps function.
61            // Rounds the elements of `op` according to `rounding`.
62            "round.ps.256" => {
63                let [op, rounding] =
64                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
65
66                round_all::<rustc_apfloat::ieee::Single>(this, op, rounding, dest)?;
67            }
68            // Used to implement the _mm256_round_pd function.
69            // Rounds the elements of `op` according to `rounding`.
70            "round.pd.256" => {
71                let [op, rounding] =
72                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
73
74                round_all::<rustc_apfloat::ieee::Double>(this, op, rounding, dest)?;
75            }
76            // Used to implement _mm256_{rcp,rsqrt}_ps functions.
77            // Performs the operations on all components of `op`.
78            "rcp.ps.256" | "rsqrt.ps.256" => {
79                let [op] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
80
81                let which = match unprefixed_name {
82                    "rcp.ps.256" => FloatUnaryOp::Rcp,
83                    "rsqrt.ps.256" => FloatUnaryOp::Rsqrt,
84                    _ => unreachable!(),
85                };
86
87                unary_op_ps(this, which, op, dest)?;
88            }
89            // Used to implement the _mm256_dp_ps function.
90            "dp.ps.256" => {
91                let [left, right, imm] =
92                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
93
94                conditional_dot_product(this, left, right, imm, dest)?;
95            }
96            // Used to implement the _mm256_h{add,sub}_p{s,d} functions.
97            // Horizontally add/subtract adjacent floating point values
98            // in `left` and `right`.
99            "hadd.ps.256" | "hadd.pd.256" | "hsub.ps.256" | "hsub.pd.256" => {
100                let [left, right] =
101                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
102
103                let which = match unprefixed_name {
104                    "hadd.ps.256" | "hadd.pd.256" => mir::BinOp::Add,
105                    "hsub.ps.256" | "hsub.pd.256" => mir::BinOp::Sub,
106                    _ => unreachable!(),
107                };
108
109                horizontal_bin_op(this, which, /*saturating*/ false, left, right, dest)?;
110            }
111            // Used to implement the _mm256_cmp_ps function.
112            // Performs a comparison operation on each component of `left`
113            // and `right`. For each component, returns 0 if false or u32::MAX
114            // if true.
115            "cmp.ps.256" => {
116                let [left, right, imm] =
117                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
118
119                let which =
120                    FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
121
122                bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
123            }
124            // Used to implement the _mm256_cmp_pd function.
125            // Performs a comparison operation on each component of `left`
126            // and `right`. For each component, returns 0 if false or u64::MAX
127            // if true.
128            "cmp.pd.256" => {
129                let [left, right, imm] =
130                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
131
132                let which =
133                    FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
134
135                bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
136            }
137            // Used to implement the _mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_cvtpd_epi32
138            // and _mm256_cvttpd_epi32 functions.
139            // Converts packed f32/f64 to packed i32.
140            "cvt.ps2dq.256" | "cvtt.ps2dq.256" | "cvt.pd2dq.256" | "cvtt.pd2dq.256" => {
141                let [op] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
142
143                let rnd = match unprefixed_name {
144                    // "current SSE rounding mode", assume nearest
145                    "cvt.ps2dq.256" | "cvt.pd2dq.256" => rustc_apfloat::Round::NearestTiesToEven,
146                    // always truncate
147                    "cvtt.ps2dq.256" | "cvtt.pd2dq.256" => rustc_apfloat::Round::TowardZero,
148                    _ => unreachable!(),
149                };
150
151                convert_float_to_int(this, op, rnd, dest)?;
152            }
153            // Used to implement the _mm_permutevar_ps and _mm256_permutevar_ps functions.
154            // Shuffles 32-bit floats from `data` using `control` as control. Each 128-bit
155            // chunk is shuffled independently: this means that we view the vector as a
156            // sequence of 4-element arrays, and we shuffle each of these arrays, where
157            // `control` determines which element of the current `data` array is written.
158            "vpermilvar.ps" | "vpermilvar.ps.256" => {
159                let [data, control] =
160                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
161
162                let (data, data_len) = this.project_to_simd(data)?;
163                let (control, control_len) = this.project_to_simd(control)?;
164                let (dest, dest_len) = this.project_to_simd(dest)?;
165
166                assert_eq!(dest_len, data_len);
167                assert_eq!(dest_len, control_len);
168
169                for i in 0..dest_len {
170                    let control = this.project_index(&control, i)?;
171
172                    // Each 128-bit chunk is shuffled independently. Since each chunk contains
173                    // four 32-bit elements, only two bits from `control` are used. To read the
174                    // value from the current chunk, add the destination index truncated to a multiple
175                    // of 4.
176                    let chunk_base = i & !0b11;
177                    let src_i = u64::from(this.read_scalar(&control)?.to_u32()? & 0b11)
178                        .strict_add(chunk_base);
179
180                    this.copy_op(
181                        &this.project_index(&data, src_i)?,
182                        &this.project_index(&dest, i)?,
183                    )?;
184                }
185            }
186            // Used to implement the _mm_permutevar_pd and _mm256_permutevar_pd functions.
187            // Shuffles 64-bit floats from `left` using `right` as control. Each 128-bit
188            // chunk is shuffled independently: this means that we view the vector as
189            // a sequence of 2-element arrays, and we shuffle each of these arrays,
190            // where `right` determines which element of the current `left` array is
191            // written.
192            "vpermilvar.pd" | "vpermilvar.pd.256" => {
193                let [data, control] =
194                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
195
196                let (data, data_len) = this.project_to_simd(data)?;
197                let (control, control_len) = this.project_to_simd(control)?;
198                let (dest, dest_len) = this.project_to_simd(dest)?;
199
200                assert_eq!(dest_len, data_len);
201                assert_eq!(dest_len, control_len);
202
203                for i in 0..dest_len {
204                    let control = this.project_index(&control, i)?;
205
206                    // Each 128-bit chunk is shuffled independently. Since each chunk contains
207                    // two 64-bit elements, only the second bit from `control` is used (yes, the
208                    // second instead of the first, ask Intel). To read the value from the current
209                    // chunk, add the destination index truncated to a multiple of 2.
210                    let chunk_base = i & !1;
211                    let src_i =
212                        ((this.read_scalar(&control)?.to_u64()? >> 1) & 1).strict_add(chunk_base);
213
214                    this.copy_op(
215                        &this.project_index(&data, src_i)?,
216                        &this.project_index(&dest, i)?,
217                    )?;
218                }
219            }
220            // Used to implement the _mm256_permute2f128_ps, _mm256_permute2f128_pd and
221            // _mm256_permute2f128_si256 functions. Regardless of the suffix in the name
222            // thay all can be considered to operate on vectors of 128-bit elements.
223            // For each 128-bit element of `dest`, copies one from `left`, `right` or
224            // zero, according to `imm`.
225            "vperm2f128.ps.256" | "vperm2f128.pd.256" | "vperm2f128.si.256" => {
226                let [left, right, imm] =
227                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
228
229                assert_eq!(dest.layout, left.layout);
230                assert_eq!(dest.layout, right.layout);
231                assert_eq!(dest.layout.size.bits(), 256);
232
233                // Transmute to `[u128; 2]` to process each 128-bit chunk independently.
234                let u128x2_layout =
235                    this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.u128, 2))?;
236                let left = left.transmute(u128x2_layout, this)?;
237                let right = right.transmute(u128x2_layout, this)?;
238                let dest = dest.transmute(u128x2_layout, this)?;
239
240                let imm = this.read_scalar(imm)?.to_u8()?;
241
242                for i in 0..2 {
243                    let dest = this.project_index(&dest, i)?;
244
245                    let imm = match i {
246                        0 => imm & 0xF,
247                        1 => imm >> 4,
248                        _ => unreachable!(),
249                    };
250                    if imm & 0b100 != 0 {
251                        this.write_scalar(Scalar::from_u128(0), &dest)?;
252                    } else {
253                        let src = match imm {
254                            0b00 => this.project_index(&left, 0)?,
255                            0b01 => this.project_index(&left, 1)?,
256                            0b10 => this.project_index(&right, 0)?,
257                            0b11 => this.project_index(&right, 1)?,
258                            _ => unreachable!(),
259                        };
260                        this.copy_op(&src, &dest)?;
261                    }
262                }
263            }
264            // Used to implement the _mm_maskload_ps, _mm_maskload_pd, _mm256_maskload_ps
265            // and _mm256_maskload_pd functions.
266            // For the element `i`, if the high bit of the `i`-th element of `mask`
267            // is one, it is loaded from `ptr.wrapping_add(i)`, otherwise zero is
268            // loaded.
269            "maskload.ps" | "maskload.pd" | "maskload.ps.256" | "maskload.pd.256" => {
270                let [ptr, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
271
272                mask_load(this, ptr, mask, dest)?;
273            }
274            // Used to implement the _mm_maskstore_ps, _mm_maskstore_pd, _mm256_maskstore_ps
275            // and _mm256_maskstore_pd functions.
276            // For the element `i`, if the high bit of the element `i`-th of `mask`
277            // is one, it is stored into `ptr.wapping_add(i)`.
278            // Unlike SSE2's _mm_maskmoveu_si128, these are not non-temporal stores.
279            "maskstore.ps" | "maskstore.pd" | "maskstore.ps.256" | "maskstore.pd.256" => {
280                let [ptr, mask, value] =
281                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
282
283                mask_store(this, ptr, mask, value)?;
284            }
285            // Used to implement the _mm256_lddqu_si256 function.
286            // Reads a 256-bit vector from an unaligned pointer. This intrinsic
287            // is expected to perform better than a regular unaligned read when
288            // the data crosses a cache line, but for Miri this is just a regular
289            // unaligned read.
290            "ldu.dq.256" => {
291                let [src_ptr] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
292                let src_ptr = this.read_pointer(src_ptr)?;
293                let dest = dest.force_mplace(this)?;
294
295                // Unaligned copy, which is what we want.
296                this.mem_copy(src_ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?;
297            }
298            // Used to implement the _mm256_testz_si256, _mm256_testc_si256 and
299            // _mm256_testnzc_si256 functions.
300            // Tests `op & mask == 0`, `op & mask == mask` or
301            // `op & mask != 0 && op & mask != mask`
302            "ptestz.256" | "ptestc.256" | "ptestnzc.256" => {
303                let [op, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
304
305                let (all_zero, masked_set) = test_bits_masked(this, op, mask)?;
306                let res = match unprefixed_name {
307                    "ptestz.256" => all_zero,
308                    "ptestc.256" => masked_set,
309                    "ptestnzc.256" => !all_zero && !masked_set,
310                    _ => unreachable!(),
311                };
312
313                this.write_scalar(Scalar::from_i32(res.into()), dest)?;
314            }
315            // Used to implement the _mm256_testz_pd, _mm256_testc_pd, _mm256_testnzc_pd
316            // _mm_testz_pd, _mm_testc_pd, _mm_testnzc_pd, _mm256_testz_ps,
317            // _mm256_testc_ps, _mm256_testnzc_ps, _mm_testz_ps, _mm_testc_ps and
318            // _mm_testnzc_ps functions.
319            // Calculates two booleans:
320            // `direct`, which is true when the highest bit of each element of `op & mask` is zero.
321            // `negated`, which is true when the highest bit of each element of `!op & mask` is zero.
322            // Return `direct` (testz), `negated` (testc) or `!direct & !negated` (testnzc)
323            "vtestz.pd.256" | "vtestc.pd.256" | "vtestnzc.pd.256" | "vtestz.pd" | "vtestc.pd"
324            | "vtestnzc.pd" | "vtestz.ps.256" | "vtestc.ps.256" | "vtestnzc.ps.256"
325            | "vtestz.ps" | "vtestc.ps" | "vtestnzc.ps" => {
326                let [op, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
327
328                let (direct, negated) = test_high_bits_masked(this, op, mask)?;
329                let res = match unprefixed_name {
330                    "vtestz.pd.256" | "vtestz.pd" | "vtestz.ps.256" | "vtestz.ps" => direct,
331                    "vtestc.pd.256" | "vtestc.pd" | "vtestc.ps.256" | "vtestc.ps" => negated,
332                    "vtestnzc.pd.256" | "vtestnzc.pd" | "vtestnzc.ps.256" | "vtestnzc.ps" =>
333                        !direct && !negated,
334                    _ => unreachable!(),
335                };
336
337                this.write_scalar(Scalar::from_i32(res.into()), dest)?;
338            }
339            // Used to implement the `_mm256_zeroupper` and `_mm256_zeroall` functions.
340            // These function clear out the upper 128 bits of all avx registers or
341            // zero out all avx registers respectively.
342            "vzeroupper" | "vzeroall" => {
343                // These functions are purely a performance hint for the CPU.
344                // Any registers currently in use will be saved beforehand by the
345                // compiler, making these functions no-ops.
346
347                // The only thing that needs to be ensured is the correct calling convention.
348                let [] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
349            }
350            _ => return interp_ok(EmulateItemResult::NotSupported),
351        }
352        interp_ok(EmulateItemResult::NeedsReturn)
353    }
354}