1use rustc_abi::CanonAbi;
2use rustc_middle::ty::Ty;
3use rustc_span::Symbol;
4use rustc_target::callconv::FnAbi;
5
6use crate::*;
7
8impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
9pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
10 fn emulate_x86_gfni_intrinsic(
11 &mut self,
12 link_name: Symbol,
13 abi: &FnAbi<'tcx, Ty<'tcx>>,
14 args: &[OpTy<'tcx>],
15 dest: &MPlaceTy<'tcx>,
16 ) -> InterpResult<'tcx, EmulateItemResult> {
17 let this = self.eval_context_mut();
18
19 let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.").unwrap();
21
22 this.expect_target_feature_for_intrinsic(link_name, "gfni")?;
23 if unprefixed_name.ends_with(".256") {
24 this.expect_target_feature_for_intrinsic(link_name, "avx")?;
25 } else if unprefixed_name.ends_with(".512") {
26 this.expect_target_feature_for_intrinsic(link_name, "avx512f")?;
27 }
28
29 match unprefixed_name {
30 "vgf2p8affineqb.128" | "vgf2p8affineqb.256" | "vgf2p8affineqb.512" => {
34 let [left, right, imm8] =
35 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
36 affine_transform(this, left, right, imm8, dest, false)?;
37 }
38 "vgf2p8affineinvqb.128" | "vgf2p8affineinvqb.256" | "vgf2p8affineinvqb.512" => {
42 let [left, right, imm8] =
43 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
44 affine_transform(this, left, right, imm8, dest, true)?;
45 }
46 "vgf2p8mulb.128" | "vgf2p8mulb.256" | "vgf2p8mulb.512" => {
52 let [left, right] =
53 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
54 let (left, left_len) = this.project_to_simd(left)?;
55 let (right, right_len) = this.project_to_simd(right)?;
56 let (dest, dest_len) = this.project_to_simd(dest)?;
57
58 assert_eq!(left_len, right_len);
59 assert_eq!(dest_len, right_len);
60
61 for i in 0..dest_len {
62 let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u8()?;
63 let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u8()?;
64 let dest = this.project_index(&dest, i)?;
65 this.write_scalar(Scalar::from_u8(gf2p8_mul(left, right)), &dest)?;
66 }
67 }
68 _ => return interp_ok(EmulateItemResult::NotSupported),
69 }
70 interp_ok(EmulateItemResult::NeedsReturn)
71 }
72}
73
74fn affine_transform<'tcx>(
79 ecx: &mut MiriInterpCx<'tcx>,
80 left: &OpTy<'tcx>,
81 right: &OpTy<'tcx>,
82 imm8: &OpTy<'tcx>,
83 dest: &MPlaceTy<'tcx>,
84 inverse: bool,
85) -> InterpResult<'tcx, ()> {
86 let (left, left_len) = ecx.project_to_simd(left)?;
87 let (right, right_len) = ecx.project_to_simd(right)?;
88 let (dest, dest_len) = ecx.project_to_simd(dest)?;
89
90 assert_eq!(dest_len, right_len);
91 assert_eq!(dest_len, left_len);
92
93 let imm8 = ecx.read_scalar(imm8)?.to_u8()?;
94
95 for i in (0..dest_len).step_by(8) {
98 let mut matrix = [0u8; 8];
100 for j in 0..8 {
101 matrix[usize::try_from(j).unwrap()] =
102 ecx.read_scalar(&ecx.project_index(&right, i.wrapping_add(j))?)?.to_u8()?;
103 }
104
105 for j in 0..8 {
107 let index = i.wrapping_add(j);
108 let left = ecx.read_scalar(&ecx.project_index(&left, index)?)?.to_u8()?;
109 let left = if inverse { TABLE[usize::from(left)] } else { left };
110
111 let mut res = 0;
112
113 for bit in 0u8..8 {
115 let mut b = matrix[usize::from(bit)] & left;
116
117 b = (b & 0b1111) ^ (b >> 4);
119 b = (b & 0b11) ^ (b >> 2);
120 b = (b & 0b1) ^ (b >> 1);
121
122 res |= b << 7u8.wrapping_sub(bit);
123 }
124
125 res ^= imm8;
127
128 let dest = ecx.project_index(&dest, index)?;
129 ecx.write_scalar(Scalar::from_u8(res), &dest)?;
130 }
131 }
132
133 interp_ok(())
134}
135
136static TABLE: [u8; 256] = {
141 let mut array = [0; 256];
142
143 let mut i = 1;
144 while i < 256 {
145 #[expect(clippy::as_conversions)] let mut x = i as u8;
147 let mut y = gf2p8_mul(x, x);
148 x = y;
149 let mut j = 2;
150 while j < 8 {
151 x = gf2p8_mul(x, x);
152 y = gf2p8_mul(x, y);
153 j += 1;
154 }
155 array[i] = y;
156 i += 1;
157 }
158
159 array
160};
161
162#[expect(clippy::as_conversions)]
168const fn gf2p8_mul(left: u8, right: u8) -> u8 {
169 const POLYNOMIAL: u32 = 0x11b;
174
175 let left = left as u32;
176 let right = right as u32;
177
178 let mut result = 0u32;
179
180 let mut i = 0u32;
181 while i < 8 {
182 if left & (1 << i) != 0 {
183 result ^= right << i;
184 }
185 i = i.wrapping_add(1);
186 }
187
188 let mut i = 14u32;
189 while i >= 8 {
190 if result & (1 << i) != 0 {
191 result ^= POLYNOMIAL << i.wrapping_sub(8);
192 }
193 i = i.wrapping_sub(1);
194 }
195
196 result as u8
197}