rapx/analysis/opt/checking/encoding_checking/
array_encoding.rs

1use std::collections::HashSet;
2
3use once_cell::sync::OnceCell;
4
5use rustc_middle::mir::Local;
6use rustc_middle::ty::TyCtxt;
7use rustc_span::Span;
8
9use super::{report_encoding_bug, value_is_from_const};
10use crate::analysis::core::dataflow::graph::{EdgeOp, Graph, GraphEdge, GraphNode, NodeOp};
11use crate::analysis::opt::OptCheck;
12use crate::analysis::utils::def_path::DefPath;
13
14static DEFPATHS: OnceCell<DefPaths> = OnceCell::new();
15
16struct DefPaths {
17    str_from_utf8: DefPath,
18}
19
20impl DefPaths {
21    pub fn new(tcx: &TyCtxt<'_>) -> Self {
22        Self {
23            str_from_utf8: DefPath::new("std::str::from_utf8", &tcx),
24        }
25    }
26}
27
28pub struct ArrayEncodingCheck {
29    record: Vec<Span>,
30}
31
32fn extract_ancestor_set_if_is_str_from(
33    graph: &Graph,
34    node_idx: Local,
35    node: &GraphNode,
36) -> Option<HashSet<Local>> {
37    let def_paths = DEFPATHS.get().unwrap();
38    for op in node.ops.iter() {
39        if let NodeOp::Call(def_id) = op {
40            if *def_id == def_paths.str_from_utf8.last_def_id() {
41                return Some(graph.collect_ancestor_locals(node_idx, false));
42            }
43        }
44    }
45    None
46}
47
48fn is_valid_index_edge(graph: &Graph, edge: &GraphEdge) -> bool {
49    if let EdgeOp::Index = edge.op {
50        // must be Index edge
51        let dst_node = &graph.nodes[edge.dst];
52        if dst_node.in_edges.len() > 2 {
53            // must be the left value
54            let rvalue_edge_idx = dst_node.in_edges[2];
55            let rvalue_idx = graph.edges[rvalue_edge_idx].src;
56            if value_is_from_const(graph, rvalue_idx) {
57                return true;
58            }
59        }
60    }
61    false
62}
63
64impl OptCheck for ArrayEncodingCheck {
65    fn new() -> Self {
66        Self { record: Vec::new() }
67    }
68
69    fn check(&mut self, graph: &Graph, tcx: &TyCtxt) {
70        let _ = &DEFPATHS.get_or_init(|| DefPaths::new(tcx));
71        let common_ancestor = graph
72            .edges
73            .iter()
74            .filter_map(|edge| {
75                // 另外这里index必须是左值,且右值必须来自const
76                if is_valid_index_edge(graph, edge) {
77                    Some(graph.collect_ancestor_locals(edge.src, true))
78                } else {
79                    None
80                }
81            })
82            .reduce(|set1, set2| set1.into_iter().filter(|k| set2.contains(k)).collect());
83
84        if let Some(common_ancestor) = common_ancestor {
85            for (node_idx, node) in graph.nodes.iter_enumerated() {
86                if let Some(str_from_ancestor_set) =
87                    extract_ancestor_set_if_is_str_from(graph, node_idx, node)
88                {
89                    if !common_ancestor
90                        .intersection(&str_from_ancestor_set)
91                        .next()
92                        .is_some()
93                    {
94                        self.record.clear();
95                        return;
96                    }
97                    self.record.push(node.span);
98                }
99            }
100        }
101    }
102
103    fn report(&self, graph: &Graph) {
104        for span in self.record.iter() {
105            report_encoding_bug(graph, *span);
106        }
107    }
108
109    fn cnt(&self) -> usize {
110        self.record.len()
111    }
112}