rapx/analysis/unsafety_isolation/
std_unsafety_isolation.rs

1use super::{
2    UnsafetyIsolationCheck,
3    generate_dot::{NodeType, UigUnit},
4};
5use crate::analysis::utils::fn_info::*;
6use crate::analysis::utils::show_mir::display_mir;
7use rustc_hir::def::DefKind;
8use rustc_hir::def_id::DefId;
9use rustc_middle::mir::Local;
10use rustc_middle::ty::Visibility;
11use rustc_middle::{ty, ty::TyCtxt};
12use rustc_span::Symbol;
13use std::collections::HashMap;
14use std::collections::HashSet;
15
16impl<'tcx> UnsafetyIsolationCheck<'tcx> {
17    pub fn audit_std_unsafe(&mut self) {
18        let all_std_fn_def = get_all_std_fns_by_rustc_public(self.tcx);
19        // Specific task for vec;
20        let symbol = Symbol::intern("Vec");
21        let vec_def_id = self.tcx.get_diagnostic_item(symbol).unwrap();
22        for &def_id in &all_std_fn_def {
23            let adt_def = get_adt_def_id_by_adt_method(self.tcx, def_id);
24            if adt_def.is_some() && adt_def.unwrap() == vec_def_id {
25                self.insert_uig(
26                    def_id,
27                    get_callees(self.tcx, def_id),
28                    get_cons(self.tcx, def_id),
29                );
30            }
31        }
32        self.render_dot();
33    }
34
35    pub fn get_chains(&mut self) {
36        let v_fn_def = self.tcx.mir_keys(());
37
38        for local_def_id in v_fn_def {
39            let def_id = local_def_id.to_def_id();
40            if !check_visibility(self.tcx, def_id) {
41                continue;
42            }
43            if get_cleaned_def_path_name(self.tcx, def_id) == "std::boxed::Box::<T>::from_raw" {
44                let body = self.tcx.mir_built(local_def_id).steal();
45                display_mir(def_id, &body);
46            }
47            let chains = get_all_std_unsafe_chains(self.tcx, def_id);
48            let valid_chains: Vec<Vec<String>> = chains
49                .into_iter()
50                .filter(|chain| {
51                    if chain.len() > 1 {
52                        return true;
53                    }
54                    if chain.len() == 1 {
55                        let is_unsafe = check_safety(self.tcx, def_id);
56                        return is_unsafe;
57                    }
58                    false
59                })
60                .collect();
61
62            print_unsafe_chains(&valid_chains);
63        }
64    }
65
66    pub fn get_all_std_unsafe_def_id_by_treat_std_as_local_crate(
67        &mut self,
68        tcx: TyCtxt<'tcx>,
69    ) -> HashSet<DefId> {
70        let mut unsafe_fn = HashSet::new();
71        let mut total_cnt = 0;
72        let mut api_cnt = 0;
73        let mut sp_cnt = 0;
74        let mut sp_count_map: HashMap<String, usize> = HashMap::new();
75        let all_std_fn_def = get_all_std_fns_by_rustc_public(self.tcx);
76
77        for def_id in &all_std_fn_def {
78            if check_safety(tcx, *def_id) {
79                let sp_set = get_sp(tcx, *def_id);
80                if !sp_set.is_empty() {
81                    unsafe_fn.insert(*def_id);
82                    let mut flag = false;
83                    for sp in &sp_set {
84                        if sp.is_empty()
85                            || sp == "Function_sp"
86                            || sp == "System_sp"
87                            || sp == "ValidSlice"
88                        {
89                            flag = true;
90                        }
91                    }
92                    if !flag {
93                        api_cnt += 1;
94                        sp_cnt += sp_set.len();
95                    }
96                    total_cnt += 1;
97                    // println!("unsafe fn : {:?}", get_cleaned_def_path_name(self.tcx, def_id));
98                }
99                for sp in sp_set {
100                    *sp_count_map.entry(sp).or_insert(0) += 1;
101                }
102                // self.check_params(def_id);
103            }
104            self.insert_uig(*def_id, get_callees(tcx, *def_id), get_cons(tcx, *def_id));
105        }
106        // self.analyze_struct();
107        // self.analyze_uig();
108        // self.get_units_data(self.tcx);
109        // for (sp, count) in &sp_count_map {
110        //     println!("SP: {}, Count: {}", sp, count);
111        // }
112
113        rap_info!(
114            "fn_def : {}, count : {:?} and {:?}, sp cnt : {}",
115            all_std_fn_def.len(),
116            total_cnt,
117            api_cnt,
118            sp_cnt
119        );
120        // println!("unsafe fn len {}", unsafe_fn.len());
121        unsafe_fn
122    }
123
124    pub fn check_params(&self, def_id: DefId) {
125        let body = self.tcx.optimized_mir(def_id);
126        let locals = body.local_decls.clone();
127        let fn_sig = self.tcx.fn_sig(def_id).skip_binder();
128        let param_len = fn_sig.inputs().skip_binder().len();
129        let return_ty = fn_sig.output().skip_binder();
130        for idx in 1..param_len + 1 {
131            let local_ty = locals[Local::from(idx)].ty;
132            if is_ptr(local_ty) && !return_ty.is_unit() {
133                println!("{:?}", get_cleaned_def_path_name(self.tcx, def_id));
134            }
135        }
136    }
137
138    pub fn analyze_uig(&self) {
139        let mut func_nc = Vec::new();
140        let mut func_pro1 = Vec::new();
141        let mut func_enc1 = Vec::new();
142        let mut m_nc = Vec::new();
143        let mut m_pro1 = Vec::new();
144        let mut m_enc1 = Vec::new();
145        for uig in &self.uigs {
146            if uig.caller.2 == 1 {
147                // method
148                if uig.caller.1 {
149                    m_pro1.push(uig.clone());
150                } else if !uig.caller.1 {
151                    m_enc1.push(uig.clone());
152                }
153            } else {
154                //function
155                if uig.caller.1 {
156                    func_pro1.push(uig.clone());
157                } else if !uig.caller.1 {
158                    func_enc1.push(uig.clone());
159                }
160            }
161        }
162        for uig in &self.single {
163            if uig.caller.2 == 1 {
164                // method
165                m_nc.push(uig.clone());
166            } else {
167                func_nc.push(uig.clone());
168            }
169        }
170        println!(
171            "func: {},{},{}, method: {},{},{}",
172            func_nc.len(),
173            func_pro1.len(),
174            func_enc1.len(),
175            m_nc.len(),
176            m_pro1.len(),
177            m_enc1.len()
178        );
179        println!("units: {}", self.uigs.len() + self.single.len());
180        // let mut no_unsafe_con = Vec::new();
181        // for uig in pro1 {
182        //     let mut flag = 0;
183        //     for con in &uig.caller_cons {
184        //         if con.1 == true {
185        //             flag = 1;
186        //         }
187        //     }
188        //     if flag == 0 {
189        //         no_unsafe_con.push(uig.clone());
190        //     }
191        // }
192    }
193
194    pub fn analyze_struct(&self) {
195        let mut cache = HashSet::new();
196        let mut s = 0;
197        let mut u = 0;
198        let mut e = 0;
199        let mut uc = 0;
200        let mut vi = 0;
201        for uig in &self.uigs {
202            self.get_struct(
203                uig.caller.0,
204                &mut cache,
205                &mut s,
206                &mut u,
207                &mut e,
208                &mut uc,
209                &mut vi,
210            );
211        }
212        for uig in &self.single {
213            self.get_struct(
214                uig.caller.0,
215                &mut cache,
216                &mut s,
217                &mut u,
218                &mut e,
219                &mut uc,
220                &mut vi,
221            );
222        }
223        println!("{},{},{},{}", s, u, e, vi);
224    }
225
226    pub fn get_struct(
227        &self,
228        def_id: DefId,
229        cache: &mut HashSet<DefId>,
230        s: &mut usize,
231        u: &mut usize,
232        e: &mut usize,
233        uc: &mut usize,
234        vi: &mut usize,
235    ) {
236        let tcx = self.tcx;
237        let mut safe_constructors = Vec::new();
238        let mut unsafe_constructors = Vec::new();
239        let mut unsafe_methods = Vec::new();
240        let mut safe_methods = Vec::new();
241        let mut mut_methods = Vec::new();
242        let mut struct_name = "".to_string();
243        let mut ty_flag = 0;
244        let mut vi_flag = false;
245        if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
246            if let Some(impl_id) = assoc_item.impl_container(tcx) {
247                // get struct ty
248                let ty = tcx.type_of(impl_id).skip_binder();
249                if let Some(adt_def) = ty.ty_adt_def() {
250                    if adt_def.is_union() {
251                        ty_flag = 1;
252                    } else if adt_def.is_enum() {
253                        ty_flag = 2;
254                    }
255                    let adt_def_id = adt_def.did();
256                    struct_name = get_cleaned_def_path_name(tcx, adt_def_id);
257                    if !cache.insert(adt_def_id) {
258                        return;
259                    }
260
261                    vi_flag = false;
262                    let impl_vec = get_impls_for_struct(self.tcx, adt_def_id);
263                    for impl_id in impl_vec {
264                        let associated_items = tcx.associated_items(impl_id);
265                        for item in associated_items.in_definition_order() {
266                            if let ty::AssocKind::Fn {
267                                name: _,
268                                has_self: _,
269                            } = item.kind
270                            {
271                                let item_def_id = item.def_id;
272                                if !get_sp(self.tcx, item_def_id).is_empty() {
273                                    vi_flag = true;
274                                }
275                                if get_type(self.tcx, item_def_id) == 0
276                                    && check_safety(self.tcx, item_def_id)
277                                // && get_sp(self.tcx, item_def_id).len() > 0
278                                {
279                                    unsafe_constructors.push(item_def_id);
280                                }
281                                if get_type(self.tcx, item_def_id) == 0
282                                    && !check_safety(self.tcx, item_def_id)
283                                {
284                                    safe_constructors.push(item_def_id);
285                                }
286                                if get_type(self.tcx, item_def_id) == 1
287                                    && check_safety(self.tcx, item_def_id)
288                                // && get_sp(self.tcx, item_def_id).len() > 0
289                                {
290                                    unsafe_methods.push(item_def_id);
291                                }
292                                if get_type(self.tcx, item_def_id) == 1
293                                    && !check_safety(self.tcx, item_def_id)
294                                {
295                                    if !get_callees(tcx, item_def_id).is_empty() {
296                                        safe_methods.push(item_def_id);
297                                    }
298                                }
299                                if get_type(self.tcx, item_def_id) == 1
300                                    && has_mut_self_param(self.tcx, item_def_id)
301                                {
302                                    mut_methods.push(item_def_id);
303                                }
304                            }
305                        }
306                    }
307                }
308            }
309        }
310        if struct_name == *""
311            || (unsafe_constructors.is_empty()
312                && unsafe_methods.is_empty()
313                && safe_methods.is_empty())
314        {
315            return;
316        }
317        if vi_flag {
318            *vi += 1;
319        }
320        if !unsafe_constructors.is_empty() {
321            *uc += 1;
322        }
323        if ty_flag == 0 {
324            *s += 1;
325            // println!("Struct:{:?}", struct_name);
326        }
327        if ty_flag == 1 {
328            *u += 1;
329            // println!("Union:{:?}", struct_name);
330        }
331        if ty_flag == 2 {
332            *e += 1;
333            // println!("Enum:{:?}", struct_name);
334        }
335
336        println!("Safe Cons: {}", safe_constructors.len());
337        for safe_cons in safe_constructors {
338            println!(" {:?}", get_cleaned_def_path_name(tcx, safe_cons));
339        }
340        println!("Unsafe Cons: {}", unsafe_constructors.len());
341        for unsafe_cons in unsafe_constructors {
342            println!(" {:?}", get_cleaned_def_path_name(tcx, unsafe_cons));
343        }
344        println!("Unsafe Methods: {}", unsafe_methods.len());
345        for method in unsafe_methods {
346            println!(" {:?}", get_cleaned_def_path_name(tcx, method));
347        }
348        println!("Safe Methods with unsafe callee: {}", safe_methods.len());
349        for method in safe_methods {
350            println!(" {:?}", get_cleaned_def_path_name(tcx, method));
351        }
352        println!("Mut self Methods: {}", mut_methods.len());
353        for method in mut_methods {
354            println!(" {:?}", get_cleaned_def_path_name(tcx, method));
355        }
356    }
357
358    pub fn get_units_data(&mut self, tcx: TyCtxt<'tcx>) {
359        // [uf/um, sf-uf, sf-um, uf-uf, uf-um, um(sf)-uf, um(uf)-uf, um(sf)-um, um(uf)-um, sm(sf)-uf, sm(uf)-uf, sm(sf)-um, sm(uf)-um]
360        let mut basic_units_data = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
361        let def_id_sets = tcx.mir_keys(());
362        for local_def_id in def_id_sets {
363            let def_id = local_def_id.to_def_id();
364            if Self::filter_mir(def_id) {
365                continue;
366            }
367            if tcx.def_kind(def_id) == DefKind::Fn || tcx.def_kind(def_id) == DefKind::AssocFn {
368                self.insert_uig(def_id, get_callees(tcx, def_id), get_cons(tcx, def_id));
369            }
370        }
371        for uig in &self.uigs {
372            uig.count_basic_units(&mut basic_units_data);
373        }
374        for single in &self.single {
375            single.count_basic_units(&mut basic_units_data);
376        }
377    }
378
379    pub fn process_def_id(
380        &mut self,
381        tcx: TyCtxt<'tcx>,
382        def_id: DefId,
383        visited: &mut HashSet<DefId>,
384        unsafe_fn: &mut HashSet<DefId>,
385    ) {
386        if !visited.insert(def_id) || Self::filter_mir(def_id) {
387            return;
388        }
389        match tcx.def_kind(def_id) {
390            DefKind::Fn | DefKind::AssocFn => {
391                if check_safety(tcx, def_id) && self.tcx.visibility(def_id) == Visibility::Public {
392                    unsafe_fn.insert(def_id);
393                    self.insert_uig(def_id, get_callees(tcx, def_id), get_cons(tcx, def_id));
394                }
395            }
396            DefKind::Mod => {
397                for child in tcx.module_children(def_id) {
398                    if let Some(child_def_id) = child.res.opt_def_id() {
399                        self.process_def_id(tcx, child_def_id, visited, unsafe_fn);
400                    }
401                }
402            }
403            DefKind::Impl { of_trait: _ } => {
404                for item in tcx.associated_item_def_ids(def_id) {
405                    self.process_def_id(tcx, *item, visited, unsafe_fn);
406                }
407            }
408            DefKind::Struct => {
409                let impls = tcx.inherent_impls(def_id);
410                for impl_def_id in impls {
411                    self.process_def_id(tcx, *impl_def_id, visited, unsafe_fn);
412                }
413            }
414            DefKind::Ctor(_of, _kind) => {
415                if tcx.is_mir_available(def_id) {
416                    let _mir = tcx.optimized_mir(def_id);
417                }
418            }
419            _ => {
420                // println!("{:?}",tcx.def_kind(def_id));
421            }
422        }
423    }
424
425    pub fn filter_mir(_def_id: DefId) -> bool {
426        // let def_id_fmt = format!("{:?}", def_id);
427        false
428        // def_id_fmt.contains("core_arch")
429        //     || def_id_fmt.contains("::__")
430        //     || def_id_fmt.contains("backtrace_rs")
431        //     || def_id_fmt.contains("stack_overflow")
432        //     || def_id_fmt.contains("thread_local")
433        //     || def_id_fmt.contains("raw_vec")
434        //     || def_id_fmt.contains("sys_common")
435        //     || def_id_fmt.contains("adapters")
436        //     || def_id_fmt.contains("sys::sync")
437        //     || def_id_fmt.contains("personality")
438        //     || def_id_fmt.contains("collections::btree::borrow")
439        //     || def_id_fmt.contains("num::int_sqrt")
440        //     || def_id_fmt.contains("collections::btree::node")
441        //     || def_id_fmt.contains("collections::btree::navigate")
442        //     || def_id_fmt.contains("core_simd")
443        //     || def_id_fmt.contains("unique")
444    }
445
446    pub fn insert_uig(
447        &mut self,
448        caller: DefId,
449        callee_set: HashSet<DefId>,
450        caller_cons: Vec<NodeType>,
451    ) {
452        let mut pairs = HashSet::new();
453        for callee in &callee_set {
454            let callee_cons = Vec::new();
455            pairs.insert((generate_node_ty(self.tcx, *callee), callee_cons));
456        }
457        if !check_safety(self.tcx, caller) && callee_set.is_empty() {
458            return;
459        }
460        let mut_methods_set = get_all_mutable_methods(self.tcx, caller);
461        let mut_methods = mut_methods_set.keys().copied().collect();
462        let uig = UigUnit::new_by_pair(
463            generate_node_ty(self.tcx, caller),
464            caller_cons,
465            pairs,
466            mut_methods,
467        );
468        if !callee_set.is_empty() {
469            self.uigs.push(uig);
470        } else {
471            self.single.push(uig);
472        }
473    }
474}