rapx/analysis/opt/checking/encoding_checking/
array_encoding.rs

1use std::collections::HashSet;
2
3use once_cell::sync::OnceCell;
4
5use rustc_middle::{mir::Local, ty::TyCtxt};
6use rustc_span::Span;
7
8use super::{report_encoding_bug, value_is_from_const};
9use crate::analysis::{
10    core::dataflow::{graph::*, *},
11    opt::OptCheck,
12    utils::def_path::DefPath,
13};
14
15static DEFPATHS: OnceCell<DefPaths> = OnceCell::new();
16
17struct DefPaths {
18    str_from_utf8: DefPath,
19}
20
21impl DefPaths {
22    pub fn new(tcx: &TyCtxt<'_>) -> Self {
23        Self {
24            str_from_utf8: DefPath::new("std::str::from_utf8", &tcx),
25        }
26    }
27}
28
29pub struct ArrayEncodingCheck {
30    record: Vec<Span>,
31}
32
33fn extract_ancestor_set_if_is_str_from(
34    graph: &Graph,
35    node_idx: Local,
36    node: &GraphNode,
37) -> Option<HashSet<Local>> {
38    let def_paths = DEFPATHS.get().unwrap();
39    for op in node.ops.iter() {
40        if let NodeOp::Call(def_id) = op {
41            if *def_id == def_paths.str_from_utf8.last_def_id() {
42                return Some(graph.collect_ancestor_locals(node_idx, false));
43            }
44        }
45    }
46    None
47}
48
49fn is_valid_index_edge(graph: &Graph, edge: &GraphEdge) -> bool {
50    if let EdgeOp::Index = edge.op {
51        // must be Index edge
52        let dst_node = &graph.nodes[edge.dst];
53        if dst_node.in_edges.len() > 2 {
54            // must be the left value
55            let rvalue_edge_idx = dst_node.in_edges[2];
56            let rvalue_idx = graph.edges[rvalue_edge_idx].src;
57            if value_is_from_const(graph, rvalue_idx) {
58                return true;
59            }
60        }
61    }
62    false
63}
64
65impl OptCheck for ArrayEncodingCheck {
66    fn new() -> Self {
67        Self { record: Vec::new() }
68    }
69
70    fn check(&mut self, graph: &Graph, tcx: &TyCtxt) {
71        let _ = &DEFPATHS.get_or_init(|| DefPaths::new(tcx));
72        let common_ancestor = graph
73            .edges
74            .iter()
75            .filter_map(|edge| {
76                // 另外这里index必须是左值,且右值必须来自const
77                if is_valid_index_edge(graph, edge) {
78                    Some(graph.collect_ancestor_locals(edge.src, true))
79                } else {
80                    None
81                }
82            })
83            .reduce(|set1, set2| set1.into_iter().filter(|k| set2.contains(k)).collect());
84
85        if let Some(common_ancestor) = common_ancestor {
86            for (node_idx, node) in graph.nodes.iter_enumerated() {
87                if let Some(str_from_ancestor_set) =
88                    extract_ancestor_set_if_is_str_from(graph, node_idx, node)
89                {
90                    if !common_ancestor
91                        .intersection(&str_from_ancestor_set)
92                        .next()
93                        .is_some()
94                    {
95                        self.record.clear();
96                        return;
97                    }
98                    self.record.push(node.span);
99                }
100            }
101        }
102    }
103
104    fn report(&self, graph: &Graph) {
105        for span in self.record.iter() {
106            report_encoding_bug(graph, *span);
107        }
108    }
109
110    fn cnt(&self) -> usize {
111        self.record.len()
112    }
113}