miri/shims/x86/
gfni.rs

1use rustc_abi::CanonAbi;
2use rustc_middle::ty::Ty;
3use rustc_span::Symbol;
4use rustc_target::callconv::FnAbi;
5
6use crate::*;
7
8impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
9pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
10    fn emulate_x86_gfni_intrinsic(
11        &mut self,
12        link_name: Symbol,
13        abi: &FnAbi<'tcx, Ty<'tcx>>,
14        args: &[OpTy<'tcx>],
15        dest: &MPlaceTy<'tcx>,
16    ) -> InterpResult<'tcx, EmulateItemResult> {
17        let this = self.eval_context_mut();
18
19        // Prefix should have already been checked.
20        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.").unwrap();
21
22        this.expect_target_feature_for_intrinsic(link_name, "gfni")?;
23        if unprefixed_name.ends_with(".256") {
24            this.expect_target_feature_for_intrinsic(link_name, "avx")?;
25        } else if unprefixed_name.ends_with(".512") {
26            this.expect_target_feature_for_intrinsic(link_name, "avx512f")?;
27        }
28
29        match unprefixed_name {
30            // Used to implement the `_mm{, 256, 512}_gf2p8affine_epi64_epi8` functions.
31            // See `affine_transform` for details.
32            // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=gf2p8affine_
33            "vgf2p8affineqb.128" | "vgf2p8affineqb.256" | "vgf2p8affineqb.512" => {
34                let [left, right, imm8] =
35                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
36                affine_transform(this, left, right, imm8, dest, /* inverse */ false)?;
37            }
38            // Used to implement the `_mm{, 256, 512}_gf2p8affineinv_epi64_epi8` functions.
39            // See `affine_transform` for details.
40            // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=gf2p8affineinv
41            "vgf2p8affineinvqb.128" | "vgf2p8affineinvqb.256" | "vgf2p8affineinvqb.512" => {
42                let [left, right, imm8] =
43                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
44                affine_transform(this, left, right, imm8, dest, /* inverse */ true)?;
45            }
46            // Used to implement the `_mm{, 256, 512}_gf2p8mul_epi8` functions.
47            // Multiplies packed 8-bit integers in `left` and `right` in the finite field GF(2^8)
48            // and store the results in `dst`. The field GF(2^8) is represented in
49            // polynomial representation with the reduction polynomial x^8 + x^4 + x^3 + x + 1.
50            // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=gf2p8mul
51            "vgf2p8mulb.128" | "vgf2p8mulb.256" | "vgf2p8mulb.512" => {
52                let [left, right] =
53                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
54                let (left, left_len) = this.project_to_simd(left)?;
55                let (right, right_len) = this.project_to_simd(right)?;
56                let (dest, dest_len) = this.project_to_simd(dest)?;
57
58                assert_eq!(left_len, right_len);
59                assert_eq!(dest_len, right_len);
60
61                for i in 0..dest_len {
62                    let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u8()?;
63                    let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u8()?;
64                    let dest = this.project_index(&dest, i)?;
65                    this.write_scalar(Scalar::from_u8(gf2p8_mul(left, right)), &dest)?;
66                }
67            }
68            _ => return interp_ok(EmulateItemResult::NotSupported),
69        }
70        interp_ok(EmulateItemResult::NeedsReturn)
71    }
72}
73
74/// Calculates the affine transformation `right * left + imm8` inside the finite field GF(2^8).
75/// `right` is an 8x8 bit matrix, `left` and `imm8` are bit vectors.
76/// If `inverse` is set, then the inverse transformation with respect to the reduction polynomial
77/// x^8 + x^4 + x^3 + x + 1 is performed instead.
78fn affine_transform<'tcx>(
79    ecx: &mut MiriInterpCx<'tcx>,
80    left: &OpTy<'tcx>,
81    right: &OpTy<'tcx>,
82    imm8: &OpTy<'tcx>,
83    dest: &MPlaceTy<'tcx>,
84    inverse: bool,
85) -> InterpResult<'tcx, ()> {
86    let (left, left_len) = ecx.project_to_simd(left)?;
87    let (right, right_len) = ecx.project_to_simd(right)?;
88    let (dest, dest_len) = ecx.project_to_simd(dest)?;
89
90    assert_eq!(dest_len, right_len);
91    assert_eq!(dest_len, left_len);
92
93    let imm8 = ecx.read_scalar(imm8)?.to_u8()?;
94
95    // Each 8x8 bit matrix gets multiplied with eight bit vectors.
96    // Therefore, the iteration is done in chunks of eight.
97    for i in (0..dest_len).step_by(8) {
98        // Get the bit matrix.
99        let mut matrix = [0u8; 8];
100        for j in 0..8 {
101            matrix[usize::try_from(j).unwrap()] =
102                ecx.read_scalar(&ecx.project_index(&right, i.wrapping_add(j))?)?.to_u8()?;
103        }
104
105        // Multiply the matrix with the vector and perform the addition.
106        for j in 0..8 {
107            let index = i.wrapping_add(j);
108            let left = ecx.read_scalar(&ecx.project_index(&left, index)?)?.to_u8()?;
109            let left = if inverse { TABLE[usize::from(left)] } else { left };
110
111            let mut res = 0;
112
113            // Do the matrix multiplication.
114            for bit in 0u8..8 {
115                let mut b = matrix[usize::from(bit)] & left;
116
117                // Calculate the parity bit.
118                b = (b & 0b1111) ^ (b >> 4);
119                b = (b & 0b11) ^ (b >> 2);
120                b = (b & 0b1) ^ (b >> 1);
121
122                res |= b << 7u8.wrapping_sub(bit);
123            }
124
125            // Perform the addition.
126            res ^= imm8;
127
128            let dest = ecx.project_index(&dest, index)?;
129            ecx.write_scalar(Scalar::from_u8(res), &dest)?;
130        }
131    }
132
133    interp_ok(())
134}
135
136/// A lookup table for computing the inverse byte for the inverse affine transformation.
137// This is a evaluated at compile time. Trait based conversion is not available.
138/// See <https://www.corsix.org/content/galois-field-instructions-2021-cpus> for the
139/// definition of `gf_inv` which was used for the creation of this table.
140static TABLE: [u8; 256] = {
141    let mut array = [0; 256];
142
143    let mut i = 1;
144    while i < 256 {
145        #[expect(clippy::as_conversions)] // no `try_from` in const...
146        let mut x = i as u8;
147        let mut y = gf2p8_mul(x, x);
148        x = y;
149        let mut j = 2;
150        while j < 8 {
151            x = gf2p8_mul(x, x);
152            y = gf2p8_mul(x, y);
153            j += 1;
154        }
155        array[i] = y;
156        i += 1;
157    }
158
159    array
160};
161
162/// Multiplies packed 8-bit integers in `left` and `right` in the finite field GF(2^8)
163/// and store the results in `dst`. The field GF(2^8) is represented in
164/// polynomial representation with the reduction polynomial x^8 + x^4 + x^3 + x + 1.
165/// See <https://www.corsix.org/content/galois-field-instructions-2021-cpus> for details.
166// This is a const function. Trait based conversion is not available.
167#[expect(clippy::as_conversions)]
168const fn gf2p8_mul(left: u8, right: u8) -> u8 {
169    // This implementation is based on the `gf2p8mul_byte` definition found inside the Intel intrinsics guide.
170    // See https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=gf2p8mul
171    // for more information.
172
173    const POLYNOMIAL: u32 = 0x11b;
174
175    let left = left as u32;
176    let right = right as u32;
177
178    let mut result = 0u32;
179
180    let mut i = 0u32;
181    while i < 8 {
182        if left & (1 << i) != 0 {
183            result ^= right << i;
184        }
185        i = i.wrapping_add(1);
186    }
187
188    let mut i = 14u32;
189    while i >= 8 {
190        if result & (1 << i) != 0 {
191            result ^= POLYNOMIAL << i.wrapping_sub(8);
192        }
193        i = i.wrapping_sub(1);
194    }
195
196    result as u8
197}