rapx/analysis/unsafety_isolation/
generate_dot.rs

1use crate::analysis::utils::fn_info::*;
2use petgraph::Graph;
3use petgraph::dot::{Config, Dot};
4use petgraph::graph::{DiGraph, EdgeReference, NodeIndex};
5use rustc_hir::def_id::DefId;
6use rustc_middle::ty::TyCtxt;
7use std::collections::HashSet;
8use std::fmt::{self, Write};
9
10#[derive(Debug, Clone, Eq, PartialEq, Hash)]
11pub enum UigNode {
12    Safe(DefId, String),
13    Unsafe(DefId, String),
14    MergedCallerCons(String),
15    MutMethods(String),
16}
17
18#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
19pub enum UigEdge {
20    CallerToCallee,
21    ConsToMethod,
22    MutToCaller,
23}
24
25impl fmt::Display for UigNode {
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        match self {
28            UigNode::Safe(_, _) => write!(f, "Safe"),
29            UigNode::Unsafe(_, _) => write!(f, "Unsafe"),
30            UigNode::MergedCallerCons(_) => write!(f, "MergedCallerCons"),
31            UigNode::MutMethods(_) => write!(f, "MutMethods"),
32        }
33    }
34}
35
36impl fmt::Display for UigEdge {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            UigEdge::CallerToCallee => write!(f, "CallerToCallee"),
40            UigEdge::ConsToMethod => write!(f, "ConsToMethod"),
41            UigEdge::MutToCaller => write!(f, "MutToCaller"),
42        }
43    }
44}
45
46// def_id, is_unsafe_function(true, false), function type(0-constructor, 1-method, 2-function)
47pub type NodeType = (DefId, bool, usize);
48
49#[derive(Debug, Clone)]
50pub struct UigUnit {
51    pub caller: NodeType,
52    pub caller_cons: Vec<NodeType>,
53    pub callee_cons_pair: HashSet<(NodeType, Vec<NodeType>)>,
54    pub mut_methods: Vec<DefId>,
55}
56
57impl UigUnit {
58    pub fn new(caller: NodeType, caller_cons: Vec<NodeType>) -> Self {
59        Self {
60            caller,
61            caller_cons,
62            callee_cons_pair: HashSet::default(),
63            mut_methods: Vec::new(),
64        }
65    }
66
67    pub fn new_by_pair(
68        caller: NodeType,
69        caller_cons: Vec<NodeType>,
70        callee_cons_pair: HashSet<(NodeType, Vec<NodeType>)>,
71        mut_methods: Vec<DefId>,
72    ) -> Self {
73        Self {
74            caller,
75            caller_cons,
76            callee_cons_pair,
77            mut_methods,
78        }
79    }
80
81    pub fn count_basic_units(&self, data: &mut [u32]) {
82        if self.caller.1 && self.callee_cons_pair.is_empty() {
83            data[0] += 1;
84        }
85        if !self.caller.1 && self.caller.2 != 1 {
86            for (callee, _) in &self.callee_cons_pair {
87                if callee.2 == 1 {
88                    data[2] += 1;
89                } else {
90                    data[1] += 1;
91                }
92            }
93        }
94        if self.caller.1 && self.caller.2 != 1 {
95            for (callee, _) in &self.callee_cons_pair {
96                if callee.2 == 1 {
97                    data[4] += 1;
98                } else {
99                    data[3] += 1;
100                }
101            }
102        }
103        if self.caller.1 && self.caller.2 == 1 {
104            let mut unsafe_cons = 0;
105            let mut safe_cons = 0;
106            for cons in &self.caller_cons {
107                if cons.1 {
108                    unsafe_cons += 1;
109                } else {
110                    safe_cons += 1;
111                }
112            }
113            if unsafe_cons == 0 && safe_cons == 0 {
114                safe_cons = 1;
115            }
116            for (callee, _) in &self.callee_cons_pair {
117                if callee.2 == 1 {
118                    data[7] += safe_cons;
119                    data[8] += unsafe_cons;
120                } else {
121                    data[5] += safe_cons;
122                    data[6] += unsafe_cons;
123                }
124            }
125        }
126        if !self.caller.1 && self.caller.2 == 1 {
127            let mut unsafe_cons = 0;
128            let mut safe_cons = 0;
129            for cons in &self.caller_cons {
130                if cons.1 {
131                    unsafe_cons += 1;
132                } else {
133                    safe_cons += 1;
134                }
135            }
136            if unsafe_cons == 0 && safe_cons == 0 {
137                safe_cons = 1;
138            }
139            for (callee, _) in &self.callee_cons_pair {
140                if callee.2 == 1 {
141                    data[11] += safe_cons;
142                    data[12] += unsafe_cons;
143                } else {
144                    data[9] += safe_cons;
145                    data[10] += unsafe_cons;
146                }
147            }
148        }
149    }
150
151    /// (node.0, node.1, node.2) : (def_id, is_unsafe, type_of_func--0:cons,1:method,2:function)
152    pub fn get_node_ty(node: NodeType) -> UigNode {
153        match (node.1, node.2) {
154            (true, 0) => UigNode::Unsafe(node.0, "doublecircle".to_string()),
155            (true, 1) => UigNode::Unsafe(node.0, "ellipse".to_string()),
156            (true, 2) => UigNode::Unsafe(node.0, "box".to_string()),
157            (false, 0) => UigNode::Safe(node.0, "doublecircle".to_string()),
158            (false, 1) => UigNode::Safe(node.0, "ellipse".to_string()),
159            (false, 2) => UigNode::Safe(node.0, "box".to_string()),
160            _ => UigNode::Safe(node.0, "ellipse".to_string()),
161        }
162    }
163
164    pub fn generate_dot_str(&self) -> String {
165        let mut graph: Graph<UigNode, UigEdge> = DiGraph::new();
166        let get_edge_attr = |_graph: &Graph<UigNode, UigEdge>,
167                             edge_ref: EdgeReference<'_, UigEdge>| {
168            match edge_ref.weight() {
169                UigEdge::CallerToCallee => "color=black, style=solid",
170                UigEdge::ConsToMethod => "color=black, style=dotted",
171                UigEdge::MutToCaller => "color=blue, style=dashed",
172            }
173            .to_owned()
174        };
175        let get_node_attr = |_graph: &Graph<UigNode, UigEdge>, node_ref: (NodeIndex, &UigNode)| {
176            match node_ref.1 {
177                UigNode::Safe(def_id, shape) => {
178                    format!("label=\"{:?}\", color=black, shape={:?}", def_id, shape)
179                }
180                UigNode::Unsafe(def_id, shape) => {
181                    // let sps = self.get_sp(*def_id);
182                    // let mut label = format!("{:?}\n ", def_id);
183                    // for sp_name in sps {
184                    //     label.push_str(&format!(" {}", sp_name));
185                    // }
186                    let label = format!("{:?}\n ", def_id);
187                    let node_attr = format!("label={:?}, shape={:?}, color=red", label, shape);
188                    node_attr
189                }
190                UigNode::MergedCallerCons(label) => {
191                    format!(
192                        "label=\"{}\", shape=box, style=filled, fillcolor=lightgrey",
193                        label
194                    )
195                }
196                UigNode::MutMethods(label) => {
197                    format!(
198                        "label=\"{}\", shape=octagon, style=filled, fillcolor=lightyellow",
199                        label
200                    )
201                }
202            }
203        };
204
205        let caller_node = graph.add_node(Self::get_node_ty(self.caller));
206        if !self.caller_cons.is_empty() {
207            let cons_labels: Vec<String> = self
208                .caller_cons
209                .iter()
210                .map(|(def_id, _, _)| format!("{:?}", def_id))
211                .collect();
212            let merged_label = format!("Caller Constructors\n{}", cons_labels.join("\n"));
213            let merged_cons_node = graph.add_node(UigNode::MergedCallerCons(merged_label));
214            graph.add_edge(merged_cons_node, caller_node, UigEdge::ConsToMethod);
215        }
216
217        if !self.mut_methods.is_empty() {
218            let mut_method_labels: Vec<String> = self
219                .mut_methods
220                .iter()
221                .map(|def_id| format!("{:?}", def_id))
222                .collect();
223            let merged_label = format!("Mutable Methods\n{}", mut_method_labels.join("\n"));
224
225            let mut_methods_node = graph.add_node(UigNode::MutMethods(merged_label));
226            graph.add_edge(mut_methods_node, caller_node, UigEdge::MutToCaller);
227        }
228
229        for (callee, cons) in &self.callee_cons_pair {
230            let callee_node = graph.add_node(Self::get_node_ty(*callee));
231            for callee_cons in cons {
232                let callee_cons_node = graph.add_node(Self::get_node_ty(*callee_cons));
233                graph.add_edge(callee_cons_node, callee_node, UigEdge::ConsToMethod);
234            }
235            graph.add_edge(caller_node, callee_node, UigEdge::CallerToCallee);
236        }
237
238        let mut dot_str = String::new();
239        let dot = Dot::with_attr_getters(
240            &graph,
241            // &[Config::NodeNoLabel, Config::EdgeNoLabel],
242            &[Config::NodeNoLabel],
243            &get_edge_attr,
244            &get_node_attr,
245        );
246
247        write!(dot_str, "{}", dot).unwrap();
248        // println!("{}", dot_str);
249        dot_str
250    }
251
252    pub fn print_self(&self, tcx: TyCtxt<'_>) {
253        let caller_sp = get_sp(tcx, self.caller.0);
254        let caller_label: Vec<_> = caller_sp.clone().into_iter().collect();
255
256        let mut combined_callee_sp = HashSet::new();
257        for (callee, _sp_vec) in &self.callee_cons_pair {
258            let callee_sp = get_sp(tcx, callee.0);
259            combined_callee_sp.extend(callee_sp); // Merge sp of each callee
260        }
261        let combined_labels: Vec<_> = combined_callee_sp.clone().into_iter().collect();
262        println!(
263            "Caller: {:?}.\n--Caller's constructors: {:?}.\n--SP labels: {:?}.",
264            get_cleaned_def_path_name(tcx, self.caller.0),
265            self.caller_cons
266                .iter()
267                .filter(|cons| cons.1)
268                .map(|node_type| get_cleaned_def_path_name(tcx, node_type.0))
269                .collect::<Vec<_>>(),
270            caller_label
271        );
272        println!(
273            "Callee: {:?}.\n--Combined Callee Labels: {:?}",
274            self.callee_cons_pair
275                .iter()
276                .map(|(node_type, _)| get_cleaned_def_path_name(tcx, node_type.0))
277                .collect::<Vec<_>>(),
278            combined_labels
279        );
280    }
281}