rapx/analysis/unsafety_isolation/
generate_dot.rs

1use crate::analysis::unsafety_isolation::UnsafetyIsolationCheck;
2use crate::analysis::utils::fn_info::*;
3use petgraph::dot::{Config, Dot};
4use petgraph::graph::{DiGraph, EdgeReference, NodeIndex};
5use petgraph::Graph;
6use rustc_hir::def_id::DefId;
7use rustc_middle::ty::TyCtxt;
8use std::collections::HashSet;
9use std::fmt::{self, Write};
10
11#[derive(Debug, Clone, Eq, PartialEq, Hash)]
12pub enum UigNode {
13    Safe(DefId, String),
14    Unsafe(DefId, String),
15}
16
17#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
18pub enum UigEdge {
19    CallerToCallee,
20    ConsToMethod,
21}
22
23impl fmt::Display for UigNode {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            UigNode::Safe(_, _) => write!(f, "Safe"),
27            UigNode::Unsafe(_, _) => write!(f, "Unsafe"),
28        }
29    }
30}
31
32impl fmt::Display for UigEdge {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            UigEdge::CallerToCallee => write!(f, "CallerToCallee"),
36            UigEdge::ConsToMethod => write!(f, "ConsToMethod"),
37        }
38    }
39}
40
41// def_id, is_unsafe_function(true, false), function type(0-constructor, 1-method, 2-function)
42pub type NodeType = (DefId, bool, usize);
43
44#[derive(Debug, Clone)]
45pub struct UigUnit {
46    pub caller: NodeType,
47    pub caller_cons: Vec<NodeType>,
48    pub callee_cons_pair: HashSet<(NodeType, Vec<NodeType>)>,
49}
50
51impl UigUnit {
52    pub fn new(caller: NodeType, caller_cons: Vec<NodeType>) -> Self {
53        Self {
54            caller,
55            caller_cons,
56            callee_cons_pair: HashSet::default(),
57        }
58    }
59
60    pub fn new_by_pair(
61        caller: NodeType,
62        caller_cons: Vec<NodeType>,
63        callee_cons_pair: HashSet<(NodeType, Vec<NodeType>)>,
64    ) -> Self {
65        Self {
66            caller,
67            caller_cons,
68            callee_cons_pair,
69        }
70    }
71
72    pub fn count_basic_units(&self, data: &mut [u32]) {
73        if self.caller.1 && self.callee_cons_pair.is_empty() {
74            data[0] += 1;
75        }
76        if !self.caller.1 && self.caller.2 != 1 {
77            for (callee, _) in &self.callee_cons_pair {
78                if callee.2 == 1 {
79                    data[2] += 1;
80                } else {
81                    data[1] += 1;
82                }
83            }
84        }
85        if self.caller.1 && self.caller.2 != 1 {
86            for (callee, _) in &self.callee_cons_pair {
87                if callee.2 == 1 {
88                    data[4] += 1;
89                } else {
90                    data[3] += 1;
91                }
92            }
93        }
94        if self.caller.1 && self.caller.2 == 1 {
95            let mut unsafe_cons = 0;
96            let mut safe_cons = 0;
97            for cons in &self.caller_cons {
98                if cons.1 {
99                    unsafe_cons += 1;
100                } else {
101                    safe_cons += 1;
102                }
103            }
104            if unsafe_cons == 0 && safe_cons == 0 {
105                safe_cons = 1;
106            }
107            for (callee, _) in &self.callee_cons_pair {
108                if callee.2 == 1 {
109                    data[7] += safe_cons;
110                    data[8] += unsafe_cons;
111                } else {
112                    data[5] += safe_cons;
113                    data[6] += unsafe_cons;
114                }
115            }
116        }
117        if !self.caller.1 && self.caller.2 == 1 {
118            let mut unsafe_cons = 0;
119            let mut safe_cons = 0;
120            for cons in &self.caller_cons {
121                if cons.1 {
122                    unsafe_cons += 1;
123                } else {
124                    safe_cons += 1;
125                }
126            }
127            if unsafe_cons == 0 && safe_cons == 0 {
128                safe_cons = 1;
129            }
130            for (callee, _) in &self.callee_cons_pair {
131                if callee.2 == 1 {
132                    data[11] += safe_cons;
133                    data[12] += unsafe_cons;
134                } else {
135                    data[9] += safe_cons;
136                    data[10] += unsafe_cons;
137                }
138            }
139        }
140    }
141
142    pub fn get_node_ty(node: NodeType) -> UigNode {
143        match (node.1, node.2) {
144            (true, 0) => UigNode::Unsafe(node.0, "doublecircle".to_string()),
145            (true, 1) => UigNode::Unsafe(node.0, "ellipse".to_string()),
146            (true, 2) => UigNode::Unsafe(node.0, "box".to_string()),
147            (false, 0) => UigNode::Safe(node.0, "doublecircle".to_string()),
148            (false, 1) => UigNode::Safe(node.0, "ellipse".to_string()),
149            (false, 2) => UigNode::Safe(node.0, "box".to_string()),
150            _ => UigNode::Safe(node.0, "ellipse".to_string()),
151        }
152    }
153
154    pub fn generate_dot_str(&self) -> String {
155        let mut graph: Graph<UigNode, UigEdge> = DiGraph::new();
156        let get_edge_attr = |_graph: &Graph<UigNode, UigEdge>,
157                             edge_ref: EdgeReference<'_, UigEdge>| {
158            match edge_ref.weight() {
159                UigEdge::CallerToCallee => "color=black, style=solid",
160                UigEdge::ConsToMethod => "color=black, style=dotted",
161            }
162            .to_owned()
163        };
164        let get_node_attr = |_graph: &Graph<UigNode, UigEdge>, node_ref: (NodeIndex, &UigNode)| {
165            match node_ref.1 {
166                UigNode::Safe(def_id, shape) => {
167                    format!("label=\"{:?}\", color=black, shape={:?}", def_id, shape)
168                }
169                UigNode::Unsafe(def_id, shape) => {
170                    // let sps = self.get_sp(*def_id);
171                    // let mut label = format!("{:?}\n ", def_id);
172                    // for sp_name in sps {
173                    //     label.push_str(&format!(" {}", sp_name));
174                    // }
175                    let label = format!("{:?}\n ", def_id);
176                    let node_attr = format!("label={:?}, shape={:?}, color=red", label, shape);
177                    node_attr
178                }
179            }
180        };
181
182        let caller_node = graph.add_node(Self::get_node_ty(self.caller));
183        for caller_cons in &self.caller_cons {
184            let caller_cons_node = graph.add_node(Self::get_node_ty(*caller_cons));
185            graph.add_edge(caller_cons_node, caller_node, UigEdge::ConsToMethod);
186        }
187        for (callee, cons) in &self.callee_cons_pair {
188            let callee_node = graph.add_node(Self::get_node_ty(*callee));
189            for callee_cons in cons {
190                let callee_cons_node = graph.add_node(Self::get_node_ty(*callee_cons));
191                graph.add_edge(callee_cons_node, callee_node, UigEdge::ConsToMethod);
192            }
193            graph.add_edge(caller_node, callee_node, UigEdge::CallerToCallee);
194        }
195
196        let mut dot_str = String::new();
197        let dot = Dot::with_attr_getters(
198            &graph,
199            // &[Config::NodeNoLabel, Config::EdgeNoLabel],
200            &[Config::NodeNoLabel],
201            &get_edge_attr,
202            &get_node_attr,
203        );
204
205        write!(dot_str, "{}", dot).unwrap();
206        println!("{}", dot_str);
207        dot_str
208    }
209
210    pub fn compare_labels(&self, tcx: TyCtxt<'_>) {
211        let caller_sp = get_sp(tcx, self.caller.0);
212        // for caller_con in &self.caller_cons {
213        //     if caller_con.1 != true {
214        //         // if constructor is safe, it won't have labels.
215        //         continue;
216        //     }
217        //     let caller_con_sp = Self::get_sp(caller_con.0);
218        //     caller_sp.extend(caller_con_sp); // Merge sp of each unsafe constructor
219        // }
220        let caller_label: Vec<_> = caller_sp.clone().into_iter().collect();
221
222        let mut combined_callee_sp = HashSet::new();
223        for (callee, _sp_vec) in &self.callee_cons_pair {
224            let callee_sp = get_sp(tcx, callee.0);
225            combined_callee_sp.extend(callee_sp); // Merge sp of each callee
226        }
227        let combined_labels: Vec<_> = combined_callee_sp.clone().into_iter().collect();
228
229        if caller_sp == combined_callee_sp {
230            // println!("----------same sp------------");
231            // println!(
232            //     "Caller: {:?}.\n--Caller's constructors: {:?}.\n--SP labels: {:?}.",
233            //     Self::get_cleaned_def_path_name(self.caller.0),
234            //     self.caller_cons
235            //         .iter()
236            //         .map(|node_type| Self::get_cleaned_def_path_name(node_type.0))
237            //         .collect::<Vec<_>>(),
238            //     caller_label
239            // );
240            // println!(
241            //     "Callee: {:?}.\n--Combined Callee Labels: {:?}",
242            //     self.callee_cons_pair
243            //         .iter()
244            //         .map(|(node_type, _)| Self::get_cleaned_def_path_name(node_type.0))
245            //         .collect::<Vec<_>>(),
246            //     combined_labels
247            // );
248        } else {
249            println!("----------unmatched sp------------");
250            println!(
251                "Caller: {:?}.\n--Caller's constructors: {:?}.\n--SP labels: {:?}.",
252                get_cleaned_def_path_name(tcx, self.caller.0),
253                self.caller_cons
254                    .iter()
255                    .map(|node_type| get_cleaned_def_path_name(tcx, node_type.0))
256                    .collect::<Vec<_>>(),
257                caller_label
258            );
259            println!(
260                "Callee: {:?}.\n--Combined Callee Labels: {:?}",
261                self.callee_cons_pair
262                    .iter()
263                    .map(|(node_type, _)| get_cleaned_def_path_name(tcx, node_type.0))
264                    .collect::<Vec<_>>(),
265                combined_labels
266            );
267        }
268    }
269
270    pub fn print_self(&self, tcx: TyCtxt<'_>) {
271        let caller_sp = get_sp(tcx, self.caller.0);
272        let caller_label: Vec<_> = caller_sp.clone().into_iter().collect();
273
274        let mut combined_callee_sp = HashSet::new();
275        for (callee, _sp_vec) in &self.callee_cons_pair {
276            let callee_sp = get_sp(tcx, callee.0);
277            combined_callee_sp.extend(callee_sp); // Merge sp of each callee
278        }
279        let combined_labels: Vec<_> = combined_callee_sp.clone().into_iter().collect();
280        println!(
281            "Caller: {:?}.\n--Caller's constructors: {:?}.\n--SP labels: {:?}.",
282            get_cleaned_def_path_name(tcx, self.caller.0),
283            self.caller_cons
284                .iter()
285                .filter(|cons| cons.1)
286                .map(|node_type| get_cleaned_def_path_name(tcx, node_type.0))
287                .collect::<Vec<_>>(),
288            caller_label
289        );
290        println!(
291            "Callee: {:?}.\n--Combined Callee Labels: {:?}",
292            self.callee_cons_pair
293                .iter()
294                .map(|(node_type, _)| get_cleaned_def_path_name(tcx, node_type.0))
295                .collect::<Vec<_>>(),
296            combined_labels
297        );
298    }
299}
300
301#[derive(PartialEq)]
302pub enum UigOp {
303    DrawPic,
304    TypeCount,
305}
306
307impl UnsafetyIsolationCheck<'_> {
308    pub fn get_node_name_by_def_id(&self, def_id: DefId) -> String {
309        if let Some(node) = self.nodes.iter().find(|n| n.node_id == def_id) {
310            return node.node_name.clone();
311        }
312        String::new()
313    }
314
315    pub fn get_node_type_by_def_id(&self, def_id: DefId) -> usize {
316        if let Some(node) = self.nodes.iter().find(|n| n.node_id == def_id) {
317            return node.node_type;
318        }
319        2
320    }
321
322    pub fn get_node_unsafety_by_def_id(&self, def_id: DefId) -> bool {
323        if let Some(node) = self.nodes.iter().find(|n| n.node_id == def_id) {
324            return node.node_unsafety;
325        }
326        false
327    }
328
329    pub fn get_adjacent_nodes_by_def_id(&self, def_id: DefId) -> Vec<DefId> {
330        let mut nodes = Vec::new();
331        if let Some(node) = self.nodes.iter().find(|n| n.node_id == def_id) {
332            nodes.extend(node.callees.clone());
333            nodes.extend(node.methods.clone());
334            nodes.extend(node.callers.clone());
335            nodes.extend(node.constructors.clone());
336        }
337        nodes
338    }
339
340    pub fn get_constructor_nodes_by_def_id(&self, def_id: DefId) -> Vec<DefId> {
341        let mut nodes = Vec::new();
342        if let Some(node) = self.nodes.iter().find(|n| n.node_id == def_id) {
343            nodes.extend(node.constructors.clone());
344        }
345        nodes
346    }
347}