miri/shims/x86/avx.rs
1use rustc_abi::CanonAbi;
2use rustc_apfloat::ieee::{Double, Single};
3use rustc_middle::mir;
4use rustc_middle::ty::Ty;
5use rustc_span::Symbol;
6use rustc_target::callconv::FnAbi;
7
8use super::{
9 FloatBinOp, FloatUnaryOp, bin_op_simd_float_all, conditional_dot_product, convert_float_to_int,
10 horizontal_bin_op, mask_load, mask_store, round_all, test_bits_masked, test_high_bits_masked,
11 unary_op_ps,
12};
13use crate::*;
14
15impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
16pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
17 fn emulate_x86_avx_intrinsic(
18 &mut self,
19 link_name: Symbol,
20 abi: &FnAbi<'tcx, Ty<'tcx>>,
21 args: &[OpTy<'tcx>],
22 dest: &MPlaceTy<'tcx>,
23 ) -> InterpResult<'tcx, EmulateItemResult> {
24 let this = self.eval_context_mut();
25 this.expect_target_feature_for_intrinsic(link_name, "avx")?;
26 // Prefix should have already been checked.
27 let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx.").unwrap();
28
29 match unprefixed_name {
30 // Used to implement _mm256_min_ps and _mm256_max_ps functions.
31 // Note that the semantics are a bit different from Rust simd_min
32 // and simd_max intrinsics regarding handling of NaN and -0.0: Rust
33 // matches the IEEE min/max operations, while x86 has different
34 // semantics.
35 "min.ps.256" | "max.ps.256" => {
36 let [left, right] =
37 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
38
39 let which = match unprefixed_name {
40 "min.ps.256" => FloatBinOp::Min,
41 "max.ps.256" => FloatBinOp::Max,
42 _ => unreachable!(),
43 };
44
45 bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
46 }
47 // Used to implement _mm256_min_pd and _mm256_max_pd functions.
48 "min.pd.256" | "max.pd.256" => {
49 let [left, right] =
50 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
51
52 let which = match unprefixed_name {
53 "min.pd.256" => FloatBinOp::Min,
54 "max.pd.256" => FloatBinOp::Max,
55 _ => unreachable!(),
56 };
57
58 bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
59 }
60 // Used to implement the _mm256_round_ps function.
61 // Rounds the elements of `op` according to `rounding`.
62 "round.ps.256" => {
63 let [op, rounding] =
64 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
65
66 round_all::<rustc_apfloat::ieee::Single>(this, op, rounding, dest)?;
67 }
68 // Used to implement the _mm256_round_pd function.
69 // Rounds the elements of `op` according to `rounding`.
70 "round.pd.256" => {
71 let [op, rounding] =
72 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
73
74 round_all::<rustc_apfloat::ieee::Double>(this, op, rounding, dest)?;
75 }
76 // Used to implement _mm256_{rcp,rsqrt}_ps functions.
77 // Performs the operations on all components of `op`.
78 "rcp.ps.256" | "rsqrt.ps.256" => {
79 let [op] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
80
81 let which = match unprefixed_name {
82 "rcp.ps.256" => FloatUnaryOp::Rcp,
83 "rsqrt.ps.256" => FloatUnaryOp::Rsqrt,
84 _ => unreachable!(),
85 };
86
87 unary_op_ps(this, which, op, dest)?;
88 }
89 // Used to implement the _mm256_dp_ps function.
90 "dp.ps.256" => {
91 let [left, right, imm] =
92 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
93
94 conditional_dot_product(this, left, right, imm, dest)?;
95 }
96 // Used to implement the _mm256_h{add,sub}_p{s,d} functions.
97 // Horizontally add/subtract adjacent floating point values
98 // in `left` and `right`.
99 "hadd.ps.256" | "hadd.pd.256" | "hsub.ps.256" | "hsub.pd.256" => {
100 let [left, right] =
101 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
102
103 let which = match unprefixed_name {
104 "hadd.ps.256" | "hadd.pd.256" => mir::BinOp::Add,
105 "hsub.ps.256" | "hsub.pd.256" => mir::BinOp::Sub,
106 _ => unreachable!(),
107 };
108
109 horizontal_bin_op(this, which, /*saturating*/ false, left, right, dest)?;
110 }
111 // Used to implement the _mm256_cmp_ps function.
112 // Performs a comparison operation on each component of `left`
113 // and `right`. For each component, returns 0 if false or u32::MAX
114 // if true.
115 "cmp.ps.256" => {
116 let [left, right, imm] =
117 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
118
119 let which =
120 FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
121
122 bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
123 }
124 // Used to implement the _mm256_cmp_pd function.
125 // Performs a comparison operation on each component of `left`
126 // and `right`. For each component, returns 0 if false or u64::MAX
127 // if true.
128 "cmp.pd.256" => {
129 let [left, right, imm] =
130 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
131
132 let which =
133 FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
134
135 bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
136 }
137 // Used to implement the _mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_cvtpd_epi32
138 // and _mm256_cvttpd_epi32 functions.
139 // Converts packed f32/f64 to packed i32.
140 "cvt.ps2dq.256" | "cvtt.ps2dq.256" | "cvt.pd2dq.256" | "cvtt.pd2dq.256" => {
141 let [op] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
142
143 let rnd = match unprefixed_name {
144 // "current SSE rounding mode", assume nearest
145 "cvt.ps2dq.256" | "cvt.pd2dq.256" => rustc_apfloat::Round::NearestTiesToEven,
146 // always truncate
147 "cvtt.ps2dq.256" | "cvtt.pd2dq.256" => rustc_apfloat::Round::TowardZero,
148 _ => unreachable!(),
149 };
150
151 convert_float_to_int(this, op, rnd, dest)?;
152 }
153 // Used to implement the _mm_permutevar_ps and _mm256_permutevar_ps functions.
154 // Shuffles 32-bit floats from `data` using `control` as control. Each 128-bit
155 // chunk is shuffled independently: this means that we view the vector as a
156 // sequence of 4-element arrays, and we shuffle each of these arrays, where
157 // `control` determines which element of the current `data` array is written.
158 "vpermilvar.ps" | "vpermilvar.ps.256" => {
159 let [data, control] =
160 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
161
162 let (data, data_len) = this.project_to_simd(data)?;
163 let (control, control_len) = this.project_to_simd(control)?;
164 let (dest, dest_len) = this.project_to_simd(dest)?;
165
166 assert_eq!(dest_len, data_len);
167 assert_eq!(dest_len, control_len);
168
169 for i in 0..dest_len {
170 let control = this.project_index(&control, i)?;
171
172 // Each 128-bit chunk is shuffled independently. Since each chunk contains
173 // four 32-bit elements, only two bits from `control` are used. To read the
174 // value from the current chunk, add the destination index truncated to a multiple
175 // of 4.
176 let chunk_base = i & !0b11;
177 let src_i = u64::from(this.read_scalar(&control)?.to_u32()? & 0b11)
178 .strict_add(chunk_base);
179
180 this.copy_op(
181 &this.project_index(&data, src_i)?,
182 &this.project_index(&dest, i)?,
183 )?;
184 }
185 }
186 // Used to implement the _mm_permutevar_pd and _mm256_permutevar_pd functions.
187 // Shuffles 64-bit floats from `left` using `right` as control. Each 128-bit
188 // chunk is shuffled independently: this means that we view the vector as
189 // a sequence of 2-element arrays, and we shuffle each of these arrays,
190 // where `right` determines which element of the current `left` array is
191 // written.
192 "vpermilvar.pd" | "vpermilvar.pd.256" => {
193 let [data, control] =
194 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
195
196 let (data, data_len) = this.project_to_simd(data)?;
197 let (control, control_len) = this.project_to_simd(control)?;
198 let (dest, dest_len) = this.project_to_simd(dest)?;
199
200 assert_eq!(dest_len, data_len);
201 assert_eq!(dest_len, control_len);
202
203 for i in 0..dest_len {
204 let control = this.project_index(&control, i)?;
205
206 // Each 128-bit chunk is shuffled independently. Since each chunk contains
207 // two 64-bit elements, only the second bit from `control` is used (yes, the
208 // second instead of the first, ask Intel). To read the value from the current
209 // chunk, add the destination index truncated to a multiple of 2.
210 let chunk_base = i & !1;
211 let src_i =
212 ((this.read_scalar(&control)?.to_u64()? >> 1) & 1).strict_add(chunk_base);
213
214 this.copy_op(
215 &this.project_index(&data, src_i)?,
216 &this.project_index(&dest, i)?,
217 )?;
218 }
219 }
220 // Used to implement the _mm256_permute2f128_ps, _mm256_permute2f128_pd and
221 // _mm256_permute2f128_si256 functions. Regardless of the suffix in the name
222 // thay all can be considered to operate on vectors of 128-bit elements.
223 // For each 128-bit element of `dest`, copies one from `left`, `right` or
224 // zero, according to `imm`.
225 "vperm2f128.ps.256" | "vperm2f128.pd.256" | "vperm2f128.si.256" => {
226 let [left, right, imm] =
227 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
228
229 assert_eq!(dest.layout, left.layout);
230 assert_eq!(dest.layout, right.layout);
231 assert_eq!(dest.layout.size.bits(), 256);
232
233 // Transmute to `[u128; 2]` to process each 128-bit chunk independently.
234 let u128x2_layout =
235 this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.u128, 2))?;
236 let left = left.transmute(u128x2_layout, this)?;
237 let right = right.transmute(u128x2_layout, this)?;
238 let dest = dest.transmute(u128x2_layout, this)?;
239
240 let imm = this.read_scalar(imm)?.to_u8()?;
241
242 for i in 0..2 {
243 let dest = this.project_index(&dest, i)?;
244
245 let imm = match i {
246 0 => imm & 0xF,
247 1 => imm >> 4,
248 _ => unreachable!(),
249 };
250 if imm & 0b100 != 0 {
251 this.write_scalar(Scalar::from_u128(0), &dest)?;
252 } else {
253 let src = match imm {
254 0b00 => this.project_index(&left, 0)?,
255 0b01 => this.project_index(&left, 1)?,
256 0b10 => this.project_index(&right, 0)?,
257 0b11 => this.project_index(&right, 1)?,
258 _ => unreachable!(),
259 };
260 this.copy_op(&src, &dest)?;
261 }
262 }
263 }
264 // Used to implement the _mm_maskload_ps, _mm_maskload_pd, _mm256_maskload_ps
265 // and _mm256_maskload_pd functions.
266 // For the element `i`, if the high bit of the `i`-th element of `mask`
267 // is one, it is loaded from `ptr.wrapping_add(i)`, otherwise zero is
268 // loaded.
269 "maskload.ps" | "maskload.pd" | "maskload.ps.256" | "maskload.pd.256" => {
270 let [ptr, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
271
272 mask_load(this, ptr, mask, dest)?;
273 }
274 // Used to implement the _mm_maskstore_ps, _mm_maskstore_pd, _mm256_maskstore_ps
275 // and _mm256_maskstore_pd functions.
276 // For the element `i`, if the high bit of the element `i`-th of `mask`
277 // is one, it is stored into `ptr.wapping_add(i)`.
278 // Unlike SSE2's _mm_maskmoveu_si128, these are not non-temporal stores.
279 "maskstore.ps" | "maskstore.pd" | "maskstore.ps.256" | "maskstore.pd.256" => {
280 let [ptr, mask, value] =
281 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
282
283 mask_store(this, ptr, mask, value)?;
284 }
285 // Used to implement the _mm256_lddqu_si256 function.
286 // Reads a 256-bit vector from an unaligned pointer. This intrinsic
287 // is expected to perform better than a regular unaligned read when
288 // the data crosses a cache line, but for Miri this is just a regular
289 // unaligned read.
290 "ldu.dq.256" => {
291 let [src_ptr] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
292 let src_ptr = this.read_pointer(src_ptr)?;
293 let dest = dest.force_mplace(this)?;
294
295 // Unaligned copy, which is what we want.
296 this.mem_copy(src_ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?;
297 }
298 // Used to implement the _mm256_testz_si256, _mm256_testc_si256 and
299 // _mm256_testnzc_si256 functions.
300 // Tests `op & mask == 0`, `op & mask == mask` or
301 // `op & mask != 0 && op & mask != mask`
302 "ptestz.256" | "ptestc.256" | "ptestnzc.256" => {
303 let [op, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
304
305 let (all_zero, masked_set) = test_bits_masked(this, op, mask)?;
306 let res = match unprefixed_name {
307 "ptestz.256" => all_zero,
308 "ptestc.256" => masked_set,
309 "ptestnzc.256" => !all_zero && !masked_set,
310 _ => unreachable!(),
311 };
312
313 this.write_scalar(Scalar::from_i32(res.into()), dest)?;
314 }
315 // Used to implement the _mm256_testz_pd, _mm256_testc_pd, _mm256_testnzc_pd
316 // _mm_testz_pd, _mm_testc_pd, _mm_testnzc_pd, _mm256_testz_ps,
317 // _mm256_testc_ps, _mm256_testnzc_ps, _mm_testz_ps, _mm_testc_ps and
318 // _mm_testnzc_ps functions.
319 // Calculates two booleans:
320 // `direct`, which is true when the highest bit of each element of `op & mask` is zero.
321 // `negated`, which is true when the highest bit of each element of `!op & mask` is zero.
322 // Return `direct` (testz), `negated` (testc) or `!direct & !negated` (testnzc)
323 "vtestz.pd.256" | "vtestc.pd.256" | "vtestnzc.pd.256" | "vtestz.pd" | "vtestc.pd"
324 | "vtestnzc.pd" | "vtestz.ps.256" | "vtestc.ps.256" | "vtestnzc.ps.256"
325 | "vtestz.ps" | "vtestc.ps" | "vtestnzc.ps" => {
326 let [op, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
327
328 let (direct, negated) = test_high_bits_masked(this, op, mask)?;
329 let res = match unprefixed_name {
330 "vtestz.pd.256" | "vtestz.pd" | "vtestz.ps.256" | "vtestz.ps" => direct,
331 "vtestc.pd.256" | "vtestc.pd" | "vtestc.ps.256" | "vtestc.ps" => negated,
332 "vtestnzc.pd.256" | "vtestnzc.pd" | "vtestnzc.ps.256" | "vtestnzc.ps" =>
333 !direct && !negated,
334 _ => unreachable!(),
335 };
336
337 this.write_scalar(Scalar::from_i32(res.into()), dest)?;
338 }
339 // Used to implement the `_mm256_zeroupper` and `_mm256_zeroall` functions.
340 // These function clear out the upper 128 bits of all avx registers or
341 // zero out all avx registers respectively.
342 "vzeroupper" | "vzeroall" => {
343 // These functions are purely a performance hint for the CPU.
344 // Any registers currently in use will be saved beforehand by the
345 // compiler, making these functions no-ops.
346
347 // The only thing that needs to be ensured is the correct calling convention.
348 let [] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
349 }
350 _ => return interp_ok(EmulateItemResult::NotSupported),
351 }
352 interp_ok(EmulateItemResult::NeedsReturn)
353 }
354}