miri/intrinsics/
simd.rs

1use either::Either;
2use rand::Rng;
3use rustc_abi::{Endian, HasDataLayout};
4use rustc_apfloat::{Float, Round};
5use rustc_middle::ty::FloatTy;
6use rustc_middle::{mir, ty};
7use rustc_span::{Symbol, sym};
8
9use super::check_intrinsic_arg_count;
10use crate::helpers::{ToHost, ToSoft, bool_to_simd_element, simd_element_to_bool};
11use crate::*;
12
13#[derive(Copy, Clone)]
14pub(crate) enum MinMax {
15    Min,
16    Max,
17}
18
19impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
20pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
21    /// Calls the simd intrinsic `intrinsic`; the `simd_` prefix has already been removed.
22    /// Returns `Ok(true)` if the intrinsic was handled.
23    fn emulate_simd_intrinsic(
24        &mut self,
25        intrinsic_name: &str,
26        generic_args: ty::GenericArgsRef<'tcx>,
27        args: &[OpTy<'tcx>],
28        dest: &MPlaceTy<'tcx>,
29    ) -> InterpResult<'tcx, EmulateItemResult> {
30        let this = self.eval_context_mut();
31        match intrinsic_name {
32            #[rustfmt::skip]
33            | "neg"
34            | "fabs"
35            | "ceil"
36            | "floor"
37            | "round"
38            | "round_ties_even"
39            | "trunc"
40            | "fsqrt"
41            | "fsin"
42            | "fcos"
43            | "fexp"
44            | "fexp2"
45            | "flog"
46            | "flog2"
47            | "flog10"
48            | "ctlz"
49            | "ctpop"
50            | "cttz"
51            | "bswap"
52            | "bitreverse"
53            => {
54                let [op] = check_intrinsic_arg_count(args)?;
55                let (op, op_len) = this.project_to_simd(op)?;
56                let (dest, dest_len) = this.project_to_simd(dest)?;
57
58                assert_eq!(dest_len, op_len);
59
60                #[derive(Copy, Clone)]
61                enum Op<'a> {
62                    MirOp(mir::UnOp),
63                    Abs,
64                    Round(rustc_apfloat::Round),
65                    Numeric(Symbol),
66                    HostOp(&'a str),
67                }
68                let which = match intrinsic_name {
69                    "neg" => Op::MirOp(mir::UnOp::Neg),
70                    "fabs" => Op::Abs,
71                    "ceil" => Op::Round(rustc_apfloat::Round::TowardPositive),
72                    "floor" => Op::Round(rustc_apfloat::Round::TowardNegative),
73                    "round" => Op::Round(rustc_apfloat::Round::NearestTiesToAway),
74                    "round_ties_even" => Op::Round(rustc_apfloat::Round::NearestTiesToEven),
75                    "trunc" => Op::Round(rustc_apfloat::Round::TowardZero),
76                    "ctlz" => Op::Numeric(sym::ctlz),
77                    "ctpop" => Op::Numeric(sym::ctpop),
78                    "cttz" => Op::Numeric(sym::cttz),
79                    "bswap" => Op::Numeric(sym::bswap),
80                    "bitreverse" => Op::Numeric(sym::bitreverse),
81                    _ => Op::HostOp(intrinsic_name),
82                };
83
84                for i in 0..dest_len {
85                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
86                    let dest = this.project_index(&dest, i)?;
87                    let val = match which {
88                        Op::MirOp(mir_op) => {
89                            // This already does NaN adjustments
90                            this.unary_op(mir_op, &op)?.to_scalar()
91                        }
92                        Op::Abs => {
93                            // Works for f32 and f64.
94                            let ty::Float(float_ty) = op.layout.ty.kind() else {
95                                span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
96                            };
97                            let op = op.to_scalar();
98                            // "Bitwise" operation, no NaN adjustments
99                            match float_ty {
100                                FloatTy::F16 => unimplemented!("f16_f128"),
101                                FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()),
102                                FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
103                                FloatTy::F128 => unimplemented!("f16_f128"),
104                            }
105                        }
106                        Op::HostOp(host_op) => {
107                            let ty::Float(float_ty) = op.layout.ty.kind() else {
108                                span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
109                            };
110                            // Using host floats except for sqrt (but it's fine, these operations do not
111                            // have guaranteed precision).
112                            match float_ty {
113                                FloatTy::F16 => unimplemented!("f16_f128"),
114                                FloatTy::F32 => {
115                                    let f = op.to_scalar().to_f32()?;
116                                    let res = match host_op {
117                                        "fsqrt" => math::sqrt(f),
118                                        "fsin" => f.to_host().sin().to_soft(),
119                                        "fcos" => f.to_host().cos().to_soft(),
120                                        "fexp" => f.to_host().exp().to_soft(),
121                                        "fexp2" => f.to_host().exp2().to_soft(),
122                                        "flog" => f.to_host().ln().to_soft(),
123                                        "flog2" => f.to_host().log2().to_soft(),
124                                        "flog10" => f.to_host().log10().to_soft(),
125                                        _ => bug!(),
126                                    };
127                                    let res = this.adjust_nan(res, &[f]);
128                                    Scalar::from(res)
129                                }
130                                FloatTy::F64 => {
131                                    let f = op.to_scalar().to_f64()?;
132                                    let res = match host_op {
133                                        "fsqrt" => math::sqrt(f),
134                                        "fsin" => f.to_host().sin().to_soft(),
135                                        "fcos" => f.to_host().cos().to_soft(),
136                                        "fexp" => f.to_host().exp().to_soft(),
137                                        "fexp2" => f.to_host().exp2().to_soft(),
138                                        "flog" => f.to_host().ln().to_soft(),
139                                        "flog2" => f.to_host().log2().to_soft(),
140                                        "flog10" => f.to_host().log10().to_soft(),
141                                        _ => bug!(),
142                                    };
143                                    let res = this.adjust_nan(res, &[f]);
144                                    Scalar::from(res)
145                                }
146                                FloatTy::F128 => unimplemented!("f16_f128"),
147                            }
148                        }
149                        Op::Round(rounding) => {
150                            let ty::Float(float_ty) = op.layout.ty.kind() else {
151                                span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
152                            };
153                            match float_ty {
154                                FloatTy::F16 => unimplemented!("f16_f128"),
155                                FloatTy::F32 => {
156                                    let f = op.to_scalar().to_f32()?;
157                                    let res = f.round_to_integral(rounding).value;
158                                    let res = this.adjust_nan(res, &[f]);
159                                    Scalar::from_f32(res)
160                                }
161                                FloatTy::F64 => {
162                                    let f = op.to_scalar().to_f64()?;
163                                    let res = f.round_to_integral(rounding).value;
164                                    let res = this.adjust_nan(res, &[f]);
165                                    Scalar::from_f64(res)
166                                }
167                                FloatTy::F128 => unimplemented!("f16_f128"),
168                            }
169                        }
170                        Op::Numeric(name) => {
171                            this.numeric_intrinsic(name, op.to_scalar(), op.layout, op.layout)?
172                        }
173                    };
174                    this.write_scalar(val, &dest)?;
175                }
176            }
177            #[rustfmt::skip]
178            | "add"
179            | "sub"
180            | "mul"
181            | "div"
182            | "rem"
183            | "shl"
184            | "shr"
185            | "and"
186            | "or"
187            | "xor"
188            | "eq"
189            | "ne"
190            | "lt"
191            | "le"
192            | "gt"
193            | "ge"
194            | "fmax"
195            | "fmin"
196            | "saturating_add"
197            | "saturating_sub"
198            | "arith_offset"
199            => {
200                use mir::BinOp;
201
202                let [left, right] = check_intrinsic_arg_count(args)?;
203                let (left, left_len) = this.project_to_simd(left)?;
204                let (right, right_len) = this.project_to_simd(right)?;
205                let (dest, dest_len) = this.project_to_simd(dest)?;
206
207                assert_eq!(dest_len, left_len);
208                assert_eq!(dest_len, right_len);
209
210                enum Op {
211                    MirOp(BinOp),
212                    SaturatingOp(BinOp),
213                    FMinMax(MinMax),
214                    WrappingOffset,
215                }
216                let which = match intrinsic_name {
217                    "add" => Op::MirOp(BinOp::Add),
218                    "sub" => Op::MirOp(BinOp::Sub),
219                    "mul" => Op::MirOp(BinOp::Mul),
220                    "div" => Op::MirOp(BinOp::Div),
221                    "rem" => Op::MirOp(BinOp::Rem),
222                    "shl" => Op::MirOp(BinOp::ShlUnchecked),
223                    "shr" => Op::MirOp(BinOp::ShrUnchecked),
224                    "and" => Op::MirOp(BinOp::BitAnd),
225                    "or" => Op::MirOp(BinOp::BitOr),
226                    "xor" => Op::MirOp(BinOp::BitXor),
227                    "eq" => Op::MirOp(BinOp::Eq),
228                    "ne" => Op::MirOp(BinOp::Ne),
229                    "lt" => Op::MirOp(BinOp::Lt),
230                    "le" => Op::MirOp(BinOp::Le),
231                    "gt" => Op::MirOp(BinOp::Gt),
232                    "ge" => Op::MirOp(BinOp::Ge),
233                    "fmax" => Op::FMinMax(MinMax::Max),
234                    "fmin" => Op::FMinMax(MinMax::Min),
235                    "saturating_add" => Op::SaturatingOp(BinOp::Add),
236                    "saturating_sub" => Op::SaturatingOp(BinOp::Sub),
237                    "arith_offset" => Op::WrappingOffset,
238                    _ => unreachable!(),
239                };
240
241                for i in 0..dest_len {
242                    let left = this.read_immediate(&this.project_index(&left, i)?)?;
243                    let right = this.read_immediate(&this.project_index(&right, i)?)?;
244                    let dest = this.project_index(&dest, i)?;
245                    let val = match which {
246                        Op::MirOp(mir_op) => {
247                            // This does NaN adjustments.
248                            let val = this.binary_op(mir_op, &left, &right).map_err_kind(|kind| {
249                                match kind {
250                                    InterpErrorKind::UndefinedBehavior(UndefinedBehaviorInfo::ShiftOverflow { shift_amount, .. }) => {
251                                        // This resets the interpreter backtrace, but it's not worth avoiding that.
252                                        let shift_amount = match shift_amount {
253                                            Either::Left(v) => v.to_string(),
254                                            Either::Right(v) => v.to_string(),
255                                        };
256                                        err_ub_format!("overflowing shift by {shift_amount} in `simd_{intrinsic_name}` in lane {i}")
257                                    }
258                                    kind => kind
259                                }
260                            })?;
261                            if matches!(mir_op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) {
262                                // Special handling for boolean-returning operations
263                                assert_eq!(val.layout.ty, this.tcx.types.bool);
264                                let val = val.to_scalar().to_bool().unwrap();
265                                bool_to_simd_element(val, dest.layout.size)
266                            } else {
267                                assert_ne!(val.layout.ty, this.tcx.types.bool);
268                                assert_eq!(val.layout.ty, dest.layout.ty);
269                                val.to_scalar()
270                            }
271                        }
272                        Op::SaturatingOp(mir_op) => {
273                            this.saturating_arith(mir_op, &left, &right)?
274                        }
275                        Op::WrappingOffset => {
276                            let ptr = left.to_scalar().to_pointer(this)?;
277                            let offset_count = right.to_scalar().to_target_isize(this)?;
278                            let pointee_ty = left.layout.ty.builtin_deref(true).unwrap();
279
280                            let pointee_size = i64::try_from(this.layout_of(pointee_ty)?.size.bytes()).unwrap();
281                            let offset_bytes = offset_count.wrapping_mul(pointee_size);
282                            let offset_ptr = ptr.wrapping_signed_offset(offset_bytes, this);
283                            Scalar::from_maybe_pointer(offset_ptr, this)
284                        }
285                        Op::FMinMax(op) => {
286                            this.fminmax_op(op, &left, &right)?
287                        }
288                    };
289                    this.write_scalar(val, &dest)?;
290                }
291            }
292            "fma" | "relaxed_fma" => {
293                let [a, b, c] = check_intrinsic_arg_count(args)?;
294                let (a, a_len) = this.project_to_simd(a)?;
295                let (b, b_len) = this.project_to_simd(b)?;
296                let (c, c_len) = this.project_to_simd(c)?;
297                let (dest, dest_len) = this.project_to_simd(dest)?;
298
299                assert_eq!(dest_len, a_len);
300                assert_eq!(dest_len, b_len);
301                assert_eq!(dest_len, c_len);
302
303                for i in 0..dest_len {
304                    let a = this.read_scalar(&this.project_index(&a, i)?)?;
305                    let b = this.read_scalar(&this.project_index(&b, i)?)?;
306                    let c = this.read_scalar(&this.project_index(&c, i)?)?;
307                    let dest = this.project_index(&dest, i)?;
308
309                    let fuse: bool = intrinsic_name == "fma"
310                        || (this.machine.float_nondet && this.machine.rng.get_mut().random());
311
312                    // Works for f32 and f64.
313                    // FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
314                    let ty::Float(float_ty) = dest.layout.ty.kind() else {
315                        span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
316                    };
317                    let val = match float_ty {
318                        FloatTy::F16 => unimplemented!("f16_f128"),
319                        FloatTy::F32 => {
320                            let a = a.to_f32()?;
321                            let b = b.to_f32()?;
322                            let c = c.to_f32()?;
323                            let res = if fuse {
324                                a.mul_add(b, c).value
325                            } else {
326                                ((a * b).value + c).value
327                            };
328                            let res = this.adjust_nan(res, &[a, b, c]);
329                            Scalar::from(res)
330                        }
331                        FloatTy::F64 => {
332                            let a = a.to_f64()?;
333                            let b = b.to_f64()?;
334                            let c = c.to_f64()?;
335                            let res = if fuse {
336                                a.mul_add(b, c).value
337                            } else {
338                                ((a * b).value + c).value
339                            };
340                            let res = this.adjust_nan(res, &[a, b, c]);
341                            Scalar::from(res)
342                        }
343                        FloatTy::F128 => unimplemented!("f16_f128"),
344                    };
345                    this.write_scalar(val, &dest)?;
346                }
347            }
348            #[rustfmt::skip]
349            | "reduce_and"
350            | "reduce_or"
351            | "reduce_xor"
352            | "reduce_any"
353            | "reduce_all"
354            | "reduce_max"
355            | "reduce_min" => {
356                use mir::BinOp;
357
358                let [op] = check_intrinsic_arg_count(args)?;
359                let (op, op_len) = this.project_to_simd(op)?;
360
361                let imm_from_bool =
362                    |b| ImmTy::from_scalar(Scalar::from_bool(b), this.machine.layouts.bool);
363
364                enum Op {
365                    MirOp(BinOp),
366                    MirOpBool(BinOp),
367                    MinMax(MinMax),
368                }
369                let which = match intrinsic_name {
370                    "reduce_and" => Op::MirOp(BinOp::BitAnd),
371                    "reduce_or" => Op::MirOp(BinOp::BitOr),
372                    "reduce_xor" => Op::MirOp(BinOp::BitXor),
373                    "reduce_any" => Op::MirOpBool(BinOp::BitOr),
374                    "reduce_all" => Op::MirOpBool(BinOp::BitAnd),
375                    "reduce_max" => Op::MinMax(MinMax::Max),
376                    "reduce_min" => Op::MinMax(MinMax::Min),
377                    _ => unreachable!(),
378                };
379
380                // Initialize with first lane, then proceed with the rest.
381                let mut res = this.read_immediate(&this.project_index(&op, 0)?)?;
382                if matches!(which, Op::MirOpBool(_)) {
383                    // Convert to `bool` scalar.
384                    res = imm_from_bool(simd_element_to_bool(res)?);
385                }
386                for i in 1..op_len {
387                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
388                    res = match which {
389                        Op::MirOp(mir_op) => {
390                            this.binary_op(mir_op, &res, &op)?
391                        }
392                        Op::MirOpBool(mir_op) => {
393                            let op = imm_from_bool(simd_element_to_bool(op)?);
394                            this.binary_op(mir_op, &res, &op)?
395                        }
396                        Op::MinMax(mmop) => {
397                            if matches!(res.layout.ty.kind(), ty::Float(_)) {
398                                ImmTy::from_scalar(this.fminmax_op(mmop, &res, &op)?, res.layout)
399                            } else {
400                                // Just boring integers, so NaNs to worry about
401                                let mirop = match mmop {
402                                    MinMax::Min => BinOp::Le,
403                                    MinMax::Max => BinOp::Ge,
404                                };
405                                if this.binary_op(mirop, &res, &op)?.to_scalar().to_bool()? {
406                                    res
407                                } else {
408                                    op
409                                }
410                            }
411                        }
412                    };
413                }
414                this.write_immediate(*res, dest)?;
415            }
416            #[rustfmt::skip]
417            | "reduce_add_ordered"
418            | "reduce_mul_ordered" => {
419                use mir::BinOp;
420
421                let [op, init] = check_intrinsic_arg_count(args)?;
422                let (op, op_len) = this.project_to_simd(op)?;
423                let init = this.read_immediate(init)?;
424
425                let mir_op = match intrinsic_name {
426                    "reduce_add_ordered" => BinOp::Add,
427                    "reduce_mul_ordered" => BinOp::Mul,
428                    _ => unreachable!(),
429                };
430
431                let mut res = init;
432                for i in 0..op_len {
433                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
434                    res = this.binary_op(mir_op, &res, &op)?;
435                }
436                this.write_immediate(*res, dest)?;
437            }
438            "select" => {
439                let [mask, yes, no] = check_intrinsic_arg_count(args)?;
440                let (mask, mask_len) = this.project_to_simd(mask)?;
441                let (yes, yes_len) = this.project_to_simd(yes)?;
442                let (no, no_len) = this.project_to_simd(no)?;
443                let (dest, dest_len) = this.project_to_simd(dest)?;
444
445                assert_eq!(dest_len, mask_len);
446                assert_eq!(dest_len, yes_len);
447                assert_eq!(dest_len, no_len);
448
449                for i in 0..dest_len {
450                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
451                    let yes = this.read_immediate(&this.project_index(&yes, i)?)?;
452                    let no = this.read_immediate(&this.project_index(&no, i)?)?;
453                    let dest = this.project_index(&dest, i)?;
454
455                    let val = if simd_element_to_bool(mask)? { yes } else { no };
456                    this.write_immediate(*val, &dest)?;
457                }
458            }
459            // Variant of `select` that takes a bitmask rather than a "vector of bool".
460            "select_bitmask" => {
461                let [mask, yes, no] = check_intrinsic_arg_count(args)?;
462                let (yes, yes_len) = this.project_to_simd(yes)?;
463                let (no, no_len) = this.project_to_simd(no)?;
464                let (dest, dest_len) = this.project_to_simd(dest)?;
465                let bitmask_len = dest_len.next_multiple_of(8);
466                if bitmask_len > 64 {
467                    throw_unsup_format!(
468                        "simd_select_bitmask: vectors larger than 64 elements are currently not supported"
469                    );
470                }
471
472                assert_eq!(dest_len, yes_len);
473                assert_eq!(dest_len, no_len);
474
475                // Read the mask, either as an integer or as an array.
476                let mask: u64 = match mask.layout.ty.kind() {
477                    ty::Uint(_) => {
478                        // Any larger integer type is fine.
479                        assert!(mask.layout.size.bits() >= bitmask_len);
480                        this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap()
481                    }
482                    ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
483                        // The array must have exactly the right size.
484                        assert_eq!(mask.layout.size.bits(), bitmask_len);
485                        // Read the raw bytes.
486                        let mask = mask.assert_mem_place(); // arrays cannot be immediate
487                        let mask_bytes =
488                            this.read_bytes_ptr_strip_provenance(mask.ptr(), mask.layout.size)?;
489                        // Turn them into a `u64` in the right way.
490                        let mask_size = mask.layout.size.bytes_usize();
491                        let mut mask_arr = [0u8; 8];
492                        match this.data_layout().endian {
493                            Endian::Little => {
494                                // Fill the first N bytes.
495                                mask_arr[..mask_size].copy_from_slice(mask_bytes);
496                                u64::from_le_bytes(mask_arr)
497                            }
498                            Endian::Big => {
499                                // Fill the last N bytes.
500                                let i = mask_arr.len().strict_sub(mask_size);
501                                mask_arr[i..].copy_from_slice(mask_bytes);
502                                u64::from_be_bytes(mask_arr)
503                            }
504                        }
505                    }
506                    _ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty),
507                };
508
509                let dest_len = u32::try_from(dest_len).unwrap();
510                for i in 0..dest_len {
511                    let bit_i = simd_bitmask_index(i, dest_len, this.data_layout().endian);
512                    let mask = mask & 1u64.strict_shl(bit_i);
513                    let yes = this.read_immediate(&this.project_index(&yes, i.into())?)?;
514                    let no = this.read_immediate(&this.project_index(&no, i.into())?)?;
515                    let dest = this.project_index(&dest, i.into())?;
516
517                    let val = if mask != 0 { yes } else { no };
518                    this.write_immediate(*val, &dest)?;
519                }
520                // The remaining bits of the mask are ignored.
521            }
522            // Converts a "vector of bool" into a bitmask.
523            "bitmask" => {
524                let [op] = check_intrinsic_arg_count(args)?;
525                let (op, op_len) = this.project_to_simd(op)?;
526                let bitmask_len = op_len.next_multiple_of(8);
527                if bitmask_len > 64 {
528                    throw_unsup_format!(
529                        "simd_bitmask: vectors larger than 64 elements are currently not supported"
530                    );
531                }
532
533                let op_len = u32::try_from(op_len).unwrap();
534                let mut res = 0u64;
535                for i in 0..op_len {
536                    let op = this.read_immediate(&this.project_index(&op, i.into())?)?;
537                    if simd_element_to_bool(op)? {
538                        let bit_i = simd_bitmask_index(i, op_len, this.data_layout().endian);
539                        res |= 1u64.strict_shl(bit_i);
540                    }
541                }
542                // Write the result, depending on the `dest` type.
543                // Returns either an unsigned integer or array of `u8`.
544                match dest.layout.ty.kind() {
545                    ty::Uint(_) => {
546                        // Any larger integer type is fine, it will be zero-extended.
547                        assert!(dest.layout.size.bits() >= bitmask_len);
548                        this.write_int(res, dest)?;
549                    }
550                    ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
551                        // The array must have exactly the right size.
552                        assert_eq!(dest.layout.size.bits(), bitmask_len);
553                        // We have to write the result byte-for-byte.
554                        let res_size = dest.layout.size.bytes_usize();
555                        let res_bytes;
556                        let res_bytes_slice = match this.data_layout().endian {
557                            Endian::Little => {
558                                res_bytes = res.to_le_bytes();
559                                &res_bytes[..res_size] // take the first N bytes
560                            }
561                            Endian::Big => {
562                                res_bytes = res.to_be_bytes();
563                                &res_bytes[res_bytes.len().strict_sub(res_size)..] // take the last N bytes
564                            }
565                        };
566                        this.write_bytes_ptr(dest.ptr(), res_bytes_slice.iter().cloned())?;
567                    }
568                    _ => bug!("simd_bitmask: invalid return type {}", dest.layout.ty),
569                }
570            }
571            "cast" | "as" | "cast_ptr" | "expose_provenance" | "with_exposed_provenance" => {
572                let [op] = check_intrinsic_arg_count(args)?;
573                let (op, op_len) = this.project_to_simd(op)?;
574                let (dest, dest_len) = this.project_to_simd(dest)?;
575
576                assert_eq!(dest_len, op_len);
577
578                let unsafe_cast = intrinsic_name == "cast";
579                let safe_cast = intrinsic_name == "as";
580                let ptr_cast = intrinsic_name == "cast_ptr";
581                let expose_cast = intrinsic_name == "expose_provenance";
582                let from_exposed_cast = intrinsic_name == "with_exposed_provenance";
583
584                for i in 0..dest_len {
585                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
586                    let dest = this.project_index(&dest, i)?;
587
588                    let val = match (op.layout.ty.kind(), dest.layout.ty.kind()) {
589                        // Int-to-(int|float): always safe
590                        (ty::Int(_) | ty::Uint(_), ty::Int(_) | ty::Uint(_) | ty::Float(_))
591                            if safe_cast || unsafe_cast =>
592                            this.int_to_int_or_float(&op, dest.layout)?,
593                        // Float-to-float: always safe
594                        (ty::Float(_), ty::Float(_)) if safe_cast || unsafe_cast =>
595                            this.float_to_float_or_int(&op, dest.layout)?,
596                        // Float-to-int in safe mode
597                        (ty::Float(_), ty::Int(_) | ty::Uint(_)) if safe_cast =>
598                            this.float_to_float_or_int(&op, dest.layout)?,
599                        // Float-to-int in unchecked mode
600                        (ty::Float(_), ty::Int(_) | ty::Uint(_)) if unsafe_cast => {
601                            this.float_to_int_checked(&op, dest.layout, Round::TowardZero)?
602                                .ok_or_else(|| {
603                                    err_ub_format!(
604                                        "`simd_cast` intrinsic called on {op} which cannot be represented in target type `{:?}`",
605                                        dest.layout.ty
606                                    )
607                                })?
608                        }
609                        // Ptr-to-ptr cast
610                        (ty::RawPtr(..), ty::RawPtr(..)) if ptr_cast =>
611                            this.ptr_to_ptr(&op, dest.layout)?,
612                        // Ptr/Int casts
613                        (ty::RawPtr(..), ty::Int(_) | ty::Uint(_)) if expose_cast =>
614                            this.pointer_expose_provenance_cast(&op, dest.layout)?,
615                        (ty::Int(_) | ty::Uint(_), ty::RawPtr(..)) if from_exposed_cast =>
616                            this.pointer_with_exposed_provenance_cast(&op, dest.layout)?,
617                        // Error otherwise
618                        _ =>
619                            throw_unsup_format!(
620                                "Unsupported SIMD cast from element type {from_ty} to {to_ty}",
621                                from_ty = op.layout.ty,
622                                to_ty = dest.layout.ty,
623                            ),
624                    };
625                    this.write_immediate(*val, &dest)?;
626                }
627            }
628            "shuffle_const_generic" => {
629                let [left, right] = check_intrinsic_arg_count(args)?;
630                let (left, left_len) = this.project_to_simd(left)?;
631                let (right, right_len) = this.project_to_simd(right)?;
632                let (dest, dest_len) = this.project_to_simd(dest)?;
633
634                let index = generic_args[2].expect_const().to_value().valtree.unwrap_branch();
635                let index_len = index.len();
636
637                assert_eq!(left_len, right_len);
638                assert_eq!(u64::try_from(index_len).unwrap(), dest_len);
639
640                for i in 0..dest_len {
641                    let src_index: u64 =
642                        index[usize::try_from(i).unwrap()].unwrap_leaf().to_u32().into();
643                    let dest = this.project_index(&dest, i)?;
644
645                    let val = if src_index < left_len {
646                        this.read_immediate(&this.project_index(&left, src_index)?)?
647                    } else if src_index < left_len.strict_add(right_len) {
648                        let right_idx = src_index.strict_sub(left_len);
649                        this.read_immediate(&this.project_index(&right, right_idx)?)?
650                    } else {
651                        throw_ub_format!(
652                            "`simd_shuffle_const_generic` index {src_index} is out-of-bounds for 2 vectors with length {dest_len}"
653                        );
654                    };
655                    this.write_immediate(*val, &dest)?;
656                }
657            }
658            "shuffle" => {
659                let [left, right, index] = check_intrinsic_arg_count(args)?;
660                let (left, left_len) = this.project_to_simd(left)?;
661                let (right, right_len) = this.project_to_simd(right)?;
662                let (index, index_len) = this.project_to_simd(index)?;
663                let (dest, dest_len) = this.project_to_simd(dest)?;
664
665                assert_eq!(left_len, right_len);
666                assert_eq!(index_len, dest_len);
667
668                for i in 0..dest_len {
669                    let src_index: u64 = this
670                        .read_immediate(&this.project_index(&index, i)?)?
671                        .to_scalar()
672                        .to_u32()?
673                        .into();
674                    let dest = this.project_index(&dest, i)?;
675
676                    let val = if src_index < left_len {
677                        this.read_immediate(&this.project_index(&left, src_index)?)?
678                    } else if src_index < left_len.strict_add(right_len) {
679                        let right_idx = src_index.strict_sub(left_len);
680                        this.read_immediate(&this.project_index(&right, right_idx)?)?
681                    } else {
682                        throw_ub_format!(
683                            "`simd_shuffle` index {src_index} is out-of-bounds for 2 vectors with length {dest_len}"
684                        );
685                    };
686                    this.write_immediate(*val, &dest)?;
687                }
688            }
689            "gather" => {
690                let [passthru, ptrs, mask] = check_intrinsic_arg_count(args)?;
691                let (passthru, passthru_len) = this.project_to_simd(passthru)?;
692                let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
693                let (mask, mask_len) = this.project_to_simd(mask)?;
694                let (dest, dest_len) = this.project_to_simd(dest)?;
695
696                assert_eq!(dest_len, passthru_len);
697                assert_eq!(dest_len, ptrs_len);
698                assert_eq!(dest_len, mask_len);
699
700                for i in 0..dest_len {
701                    let passthru = this.read_immediate(&this.project_index(&passthru, i)?)?;
702                    let ptr = this.read_immediate(&this.project_index(&ptrs, i)?)?;
703                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
704                    let dest = this.project_index(&dest, i)?;
705
706                    let val = if simd_element_to_bool(mask)? {
707                        let place = this.deref_pointer(&ptr)?;
708                        this.read_immediate(&place)?
709                    } else {
710                        passthru
711                    };
712                    this.write_immediate(*val, &dest)?;
713                }
714            }
715            "scatter" => {
716                let [value, ptrs, mask] = check_intrinsic_arg_count(args)?;
717                let (value, value_len) = this.project_to_simd(value)?;
718                let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
719                let (mask, mask_len) = this.project_to_simd(mask)?;
720
721                assert_eq!(ptrs_len, value_len);
722                assert_eq!(ptrs_len, mask_len);
723
724                for i in 0..ptrs_len {
725                    let value = this.read_immediate(&this.project_index(&value, i)?)?;
726                    let ptr = this.read_immediate(&this.project_index(&ptrs, i)?)?;
727                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
728
729                    if simd_element_to_bool(mask)? {
730                        let place = this.deref_pointer(&ptr)?;
731                        this.write_immediate(*value, &place)?;
732                    }
733                }
734            }
735            "masked_load" => {
736                let [mask, ptr, default] = check_intrinsic_arg_count(args)?;
737                let (mask, mask_len) = this.project_to_simd(mask)?;
738                let ptr = this.read_pointer(ptr)?;
739                let (default, default_len) = this.project_to_simd(default)?;
740                let (dest, dest_len) = this.project_to_simd(dest)?;
741
742                assert_eq!(dest_len, mask_len);
743                assert_eq!(dest_len, default_len);
744
745                for i in 0..dest_len {
746                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
747                    let default = this.read_immediate(&this.project_index(&default, i)?)?;
748                    let dest = this.project_index(&dest, i)?;
749
750                    let val = if simd_element_to_bool(mask)? {
751                        // Size * u64 is implemented as always checked
752                        let ptr = ptr.wrapping_offset(dest.layout.size * i, this);
753                        let place = this.ptr_to_mplace(ptr, dest.layout);
754                        this.read_immediate(&place)?
755                    } else {
756                        default
757                    };
758                    this.write_immediate(*val, &dest)?;
759                }
760            }
761            "masked_store" => {
762                let [mask, ptr, vals] = check_intrinsic_arg_count(args)?;
763                let (mask, mask_len) = this.project_to_simd(mask)?;
764                let ptr = this.read_pointer(ptr)?;
765                let (vals, vals_len) = this.project_to_simd(vals)?;
766
767                assert_eq!(mask_len, vals_len);
768
769                for i in 0..vals_len {
770                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
771                    let val = this.read_immediate(&this.project_index(&vals, i)?)?;
772
773                    if simd_element_to_bool(mask)? {
774                        // Size * u64 is implemented as always checked
775                        let ptr = ptr.wrapping_offset(val.layout.size * i, this);
776                        let place = this.ptr_to_mplace(ptr, val.layout);
777                        this.write_immediate(*val, &place)?
778                    };
779                }
780            }
781
782            _ => return interp_ok(EmulateItemResult::NotSupported),
783        }
784        interp_ok(EmulateItemResult::NeedsReturn)
785    }
786
787    fn fminmax_op(
788        &self,
789        op: MinMax,
790        left: &ImmTy<'tcx>,
791        right: &ImmTy<'tcx>,
792    ) -> InterpResult<'tcx, Scalar> {
793        let this = self.eval_context_ref();
794        assert_eq!(left.layout.ty, right.layout.ty);
795        let ty::Float(float_ty) = left.layout.ty.kind() else {
796            bug!("fmax operand is not a float")
797        };
798        let left = left.to_scalar();
799        let right = right.to_scalar();
800        interp_ok(match float_ty {
801            FloatTy::F16 => unimplemented!("f16_f128"),
802            FloatTy::F32 => {
803                let left = left.to_f32()?;
804                let right = right.to_f32()?;
805                let res = match op {
806                    MinMax::Min => left.min(right),
807                    MinMax::Max => left.max(right),
808                };
809                let res = this.adjust_nan(res, &[left, right]);
810                Scalar::from_f32(res)
811            }
812            FloatTy::F64 => {
813                let left = left.to_f64()?;
814                let right = right.to_f64()?;
815                let res = match op {
816                    MinMax::Min => left.min(right),
817                    MinMax::Max => left.max(right),
818                };
819                let res = this.adjust_nan(res, &[left, right]);
820                Scalar::from_f64(res)
821            }
822            FloatTy::F128 => unimplemented!("f16_f128"),
823        })
824    }
825}
826
827fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
828    assert!(idx < vec_len);
829    match endianness {
830        Endian::Little => idx,
831        #[expect(clippy::arithmetic_side_effects)] // idx < vec_len
832        Endian::Big => vec_len - 1 - idx, // reverse order of bits
833    }
834}