rapx/analysis/senryx/
visitor_check.rs

1use std::collections::HashSet;
2
3use super::{
4    contracts::{abstract_state::AlignState, state_lattice::Lattice},
5    matcher::{get_arg_place, UnsafeApi},
6    visitor::{BodyVisitor, CheckResult, PlaceTy},
7};
8use crate::{
9    analysis::{
10        core::{
11            alias_analysis::AAResult,
12            dataflow::{default::DataFlowAnalyzer, DataFlowAnalysis},
13        },
14        senryx::contracts::property::{CisRange, CisRangeItem, PropertyContract},
15        utils::fn_info::{
16            display_hashmap, generate_contract_from_annotation_without_field_types,
17            get_cleaned_def_path_name, is_strict_ty_convert, reflect_generic,
18        },
19    },
20    rap_debug, rap_error, rap_info, rap_warn,
21};
22use rustc_data_structures::fx::FxHashMap;
23use rustc_hir::def_id::DefId;
24use rustc_middle::mir::BinOp;
25use rustc_middle::mir::Operand;
26use rustc_middle::mir::Place;
27use rustc_middle::ty::Ty;
28use rustc_span::source_map::Spanned;
29use rustc_span::Span;
30
31impl<'tcx> BodyVisitor<'tcx> {
32    pub fn handle_std_unsafe_call(
33        &mut self,
34        _dst_place: &Place<'_>,
35        def_id: &DefId,
36        args: &[Spanned<Operand>],
37        _path_index: usize,
38        _fn_map: &FxHashMap<DefId, AAResult>,
39        fn_span: Span,
40        fn_result: UnsafeApi,
41        generic_mapping: FxHashMap<String, Ty<'tcx>>,
42    ) {
43        let func_name = get_cleaned_def_path_name(self.tcx, *def_id);
44        let args_with_contracts =
45            generate_contract_from_annotation_without_field_types(self.tcx, *def_id);
46        rap_debug!(
47            "Checking contracts {:?} for {:?}",
48            args_with_contracts,
49            def_id
50        );
51        let mut count = 0;
52        for (base, fields, contract) in args_with_contracts {
53            rap_debug!("Find contract for {:?}, {base}: {:?}", def_id, contract);
54            if base == 0 {
55                rap_warn!("Wrong base index for {:?}, with {:?}", def_id, contract);
56                continue;
57            }
58            let arg_tuple = get_arg_place(&args[base - 1].node);
59            // if this arg is a constant
60            if arg_tuple.0 {
61                continue; //TODO: check the constant value
62            } else {
63                let arg_place = self.chains.find_var_id_with_fields_seq(arg_tuple.1, fields);
64                self.check_contract(
65                    arg_place,
66                    args,
67                    contract,
68                    &generic_mapping,
69                    func_name.clone(),
70                    fn_span,
71                    count,
72                );
73            }
74            count += 1;
75        }
76
77        for (idx, sp_set) in fn_result.sps.iter().enumerate() {
78            if args.is_empty() {
79                break;
80            }
81            let arg_tuple = get_arg_place(&args[idx].node);
82            // if this arg is a constant
83            if arg_tuple.0 {
84                continue;
85            }
86            let arg_place = arg_tuple.1;
87            let _self_func_name = get_cleaned_def_path_name(self.tcx, self.def_id);
88            let func_name = get_cleaned_def_path_name(self.tcx, *def_id);
89            for sp in sp_set {
90                match sp.sp_name.as_str() {
91                    "NonNull" => {
92                        if !self.check_non_null(arg_place) {
93                            self.insert_failed_check_result(
94                                func_name.clone(),
95                                fn_span,
96                                idx + 1,
97                                "NonNull",
98                            );
99                        } else {
100                            self.insert_successful_check_result(
101                                func_name.clone(),
102                                fn_span,
103                                idx + 1,
104                                "NonNull",
105                            );
106                        }
107                    }
108                    "AllocatorConsistency" => {
109                        if !self.check_allocator_consistency(func_name.clone(), arg_place) {
110                            self.insert_failed_check_result(
111                                func_name.clone(),
112                                fn_span,
113                                idx + 1,
114                                "AllocatorConsistency",
115                            );
116                        } else {
117                            self.insert_successful_check_result(
118                                func_name.clone(),
119                                fn_span,
120                                idx + 1,
121                                "AllocatorConsistency",
122                            );
123                        }
124                    }
125                    "!ZST" => {
126                        if !self.check_non_zst(arg_place) {
127                            self.insert_failed_check_result(
128                                func_name.clone(),
129                                fn_span,
130                                idx + 1,
131                                "!ZST",
132                            );
133                        } else {
134                            self.insert_successful_check_result(
135                                func_name.clone(),
136                                fn_span,
137                                idx + 1,
138                                "!ZST",
139                            );
140                        }
141                    }
142                    "Typed" => {
143                        if !self.check_typed(arg_place) {
144                            self.insert_failed_check_result(
145                                func_name.clone(),
146                                fn_span,
147                                idx + 1,
148                                "Typed",
149                            );
150                        } else {
151                            self.insert_successful_check_result(
152                                func_name.clone(),
153                                fn_span,
154                                idx + 1,
155                                "Typed",
156                            );
157                        }
158                    }
159                    "Allocated" => {
160                        if !self.check_allocated(arg_place) {
161                            self.insert_failed_check_result(
162                                func_name.clone(),
163                                fn_span,
164                                idx + 1,
165                                "Allocated",
166                            );
167                        } else {
168                            self.insert_successful_check_result(
169                                func_name.clone(),
170                                fn_span,
171                                idx + 1,
172                                "Allocated",
173                            );
174                        }
175                    }
176                    "ValidString" => {
177                        if !self.check_valid_string(arg_place) {
178                            self.insert_failed_check_result(
179                                func_name.clone(),
180                                fn_span,
181                                idx + 1,
182                                "ValidString",
183                            );
184                        } else {
185                            self.insert_successful_check_result(
186                                func_name.clone(),
187                                fn_span,
188                                idx + 1,
189                                "ValidString",
190                            );
191                        }
192                    }
193                    "ValidCStr" => {
194                        if !self.check_valid_cstr(arg_place) {
195                            self.insert_failed_check_result(
196                                func_name.clone(),
197                                fn_span,
198                                idx + 1,
199                                "ValidCStr",
200                            );
201                        } else {
202                            self.insert_successful_check_result(
203                                func_name.clone(),
204                                fn_span,
205                                idx + 1,
206                                "ValidCStr",
207                            );
208                        }
209                    }
210                    "ValidInt" => {
211                        if !self.check_valid_num(arg_place) {
212                            self.insert_failed_check_result(
213                                func_name.clone(),
214                                fn_span,
215                                idx + 1,
216                                "ValidNum",
217                            );
218                        } else {
219                            self.insert_successful_check_result(
220                                func_name.clone(),
221                                fn_span,
222                                idx + 1,
223                                "ValidInt",
224                            );
225                        }
226                    }
227                    "Init" => {
228                        if !self.check_init(arg_place) {
229                            self.insert_failed_check_result(
230                                func_name.clone(),
231                                fn_span,
232                                idx + 1,
233                                "Init",
234                            );
235                        } else {
236                            self.insert_successful_check_result(
237                                func_name.clone(),
238                                fn_span,
239                                idx + 1,
240                                "Init",
241                            );
242                        }
243                    }
244                    "ValidPtr" => {
245                        if !self.check_valid_ptr(arg_place) {
246                            self.insert_failed_check_result(
247                                func_name.clone(),
248                                fn_span,
249                                idx + 1,
250                                "ValidPtr",
251                            );
252                        } else {
253                            self.insert_successful_check_result(
254                                func_name.clone(),
255                                fn_span,
256                                idx + 1,
257                                "ValidPtr",
258                            );
259                        }
260                    }
261                    "Ref2Ptr" => {
262                        if !self.check_ref_to_ptr(arg_place) {
263                            self.insert_failed_check_result(
264                                func_name.clone(),
265                                fn_span,
266                                idx + 1,
267                                "Ref2Ptr",
268                            );
269                        } else {
270                            self.insert_successful_check_result(
271                                func_name.clone(),
272                                fn_span,
273                                idx + 1,
274                                "Ref2Ptr",
275                            );
276                        }
277                    }
278                    _ => {}
279                }
280            }
281        }
282    }
283
284    pub fn insert_failed_check_result(
285        &mut self,
286        func_name: String,
287        fn_span: Span,
288        idx: usize,
289        sp: &str,
290    ) {
291        if let Some(existing) = self
292            .check_results
293            .iter_mut()
294            .find(|result| result.func_name == func_name && result.func_span == fn_span)
295        {
296            if let Some(passed_set) = existing.passed_contracts.get_mut(&idx) {
297                passed_set.remove(sp);
298                if passed_set.is_empty() {
299                    existing.passed_contracts.remove(&idx);
300                }
301            }
302            existing
303                .failed_contracts
304                .entry(idx)
305                .and_modify(|set| {
306                    set.insert(sp.to_string());
307                })
308                .or_insert_with(|| {
309                    let mut new_set = HashSet::new();
310                    new_set.insert(sp.to_string());
311                    new_set
312                });
313        } else {
314            let mut new_result = CheckResult::new(&func_name, fn_span);
315            new_result
316                .failed_contracts
317                .insert(idx, HashSet::from([sp.to_string()]));
318            self.check_results.push(new_result);
319        }
320    }
321
322    pub fn insert_successful_check_result(
323        &mut self,
324        func_name: String,
325        fn_span: Span,
326        idx: usize,
327        sp: &str,
328    ) {
329        if let Some(existing) = self
330            .check_results
331            .iter_mut()
332            .find(|result| result.func_name == func_name && result.func_span == fn_span)
333        {
334            if let Some(failed_set) = existing.failed_contracts.get_mut(&idx) {
335                if failed_set.contains(sp) {
336                    return;
337                }
338            }
339
340            existing
341                .passed_contracts
342                .entry(idx)
343                .and_modify(|set| {
344                    set.insert(sp.to_string());
345                })
346                .or_insert_with(|| HashSet::from([sp.to_string()]));
347        } else {
348            let mut new_result = CheckResult::new(&func_name, fn_span);
349            new_result
350                .passed_contracts
351                .insert(idx, HashSet::from([sp.to_string()]));
352            self.check_results.push(new_result);
353        }
354    }
355
356    pub fn insert_checking_result(
357        &mut self,
358        sp: &str,
359        is_passed: bool,
360        func_name: String,
361        fn_span: Span,
362        idx: usize,
363    ) {
364        if is_passed {
365            self.insert_successful_check_result(func_name.clone(), fn_span, idx + 1, sp);
366        } else {
367            self.insert_failed_check_result(func_name.clone(), fn_span, idx + 1, sp);
368        }
369    }
370
371    pub fn check_contract(
372        &mut self,
373        arg: usize,
374        args: &[Spanned<Operand>],
375        contract: PropertyContract<'tcx>,
376        generic_mapping: &FxHashMap<String, Ty<'tcx>>,
377        func_name: String,
378        fn_span: Span,
379        idx: usize,
380    ) -> bool {
381        match contract {
382            PropertyContract::Align(ty) => {
383                let contract_required_ty = reflect_generic(generic_mapping, ty);
384                rap_debug!(
385                    "peel generic ty for {:?}, actual_ty is {:?}",
386                    func_name.clone(),
387                    contract_required_ty
388                );
389                if !self.check_align(arg, contract_required_ty) {
390                    self.insert_checking_result("Align", false, func_name, fn_span, idx);
391                } else {
392                    rap_debug!("Checking Align passed for {func_name} in {:?}!", fn_span);
393                    self.insert_checking_result("Align", true, func_name, fn_span, idx);
394                }
395            }
396            PropertyContract::InBound(ty, contract_len) => {
397                let contract_ty = reflect_generic(generic_mapping, ty);
398                if let CisRangeItem::Var(base, len_fields) = contract_len {
399                    let base_tuple = get_arg_place(&args[base - 1].node);
400                    let length_arg = self
401                        .chains
402                        .find_var_id_with_fields_seq(base_tuple.1, len_fields);
403                    if !self.check_inbound(arg, length_arg, contract_ty) {
404                        self.insert_checking_result("InBound", false, func_name, fn_span, idx);
405                    } else {
406                        rap_info!("Checking InBound passed for {func_name} in {:?}!", fn_span);
407                        self.insert_checking_result("InBound", true, func_name, fn_span, idx);
408                    }
409                } else {
410                    rap_error!("Wrong arg {:?} in Inbound safety check!", contract_len);
411                }
412            }
413            PropertyContract::NonNull => {
414                self.check_non_null(arg);
415            }
416            PropertyContract::Typed(ty) => {
417                self.check_typed(arg);
418            }
419            PropertyContract::ValidPtr(ty, contract_len) => {
420                self.check_valid_ptr(arg);
421            }
422            _ => {}
423        }
424        true
425    }
426
427    // ----------------------Sp checking functions--------------------------
428
429    // TODO: Currently can not support unaligned offset checking
430    pub fn check_align(&self, arg: usize, contract_required_ty: Ty<'tcx>) -> bool {
431        // rap_warn!("Checking Align {arg}!");
432        // display_hashmap(&self.chains.variables, 1);
433        // 1. Check the var's cis.
434        let var = self.chains.get_var_node(arg).unwrap();
435        let required_ty = self.visit_ty_and_get_layout(contract_required_ty);
436        for cis in &var.cis.contracts {
437            if let PropertyContract::Align(cis_ty) = cis {
438                let ori_ty = self.visit_ty_and_get_layout(*cis_ty);
439                return AlignState::Cast(ori_ty, required_ty).check();
440            }
441        }
442        // 2. If the var does not have cis, then check its type and the value type
443        let mem = self.chains.get_obj_ty_through_chain(arg);
444        let mem_ty = self.visit_ty_and_get_layout(mem.unwrap());
445        let cur_ty = self.visit_ty_and_get_layout(var.ty.unwrap());
446        let point_to_id = self.chains.get_point_to_id(arg);
447        let var_ty = self.chains.get_var_node(point_to_id);
448        // display_hashmap(&self.chains.variables, 1);
449        // rap_warn!("{:?}, {:?}, {:?}, {:?}", arg, cur_ty, point_to_id, mem_ty);
450        return AlignState::Cast(mem_ty, cur_ty).check() && var_ty.unwrap().ots.align;
451    }
452
453    pub fn check_non_zst(&self, arg: usize) -> bool {
454        let obj_ty = self.chains.get_obj_ty_through_chain(arg);
455        if obj_ty.is_none() {
456            self.show_error_info(arg);
457        }
458        let ori_ty = self.visit_ty_and_get_layout(obj_ty.unwrap());
459        match ori_ty {
460            PlaceTy::Ty(_align, size) => size == 0,
461            PlaceTy::GenericTy(_, _, tys) => {
462                if tys.is_empty() {
463                    return false;
464                }
465                for (_, size) in tys {
466                    if size != 0 {
467                        return false;
468                    }
469                }
470                true
471            }
472            _ => false,
473        }
474    }
475
476    // checking the value ptr points to is valid for its type
477    pub fn check_typed(&self, arg: usize) -> bool {
478        let obj_ty = self.chains.get_obj_ty_through_chain(arg).unwrap();
479        let var = self.chains.get_var_node(arg);
480        // display_hashmap(&self.chains.variables, 1);
481        let var_ty = var.unwrap().ty.unwrap();
482        if obj_ty != var_ty && is_strict_ty_convert(self.tcx, obj_ty, var_ty) {
483            return false;
484        }
485        self.check_init(arg)
486    }
487
488    pub fn check_non_null(&self, arg: usize) -> bool {
489        let point_to_id = self.chains.get_point_to_id(arg);
490        let var_ty = self.chains.get_var_node(point_to_id);
491        if var_ty.is_none() {
492            self.show_error_info(arg);
493        }
494        var_ty.unwrap().ots.nonnull
495    }
496
497    // check each field's init state in the tree.
498    // check arg itself when it doesn't have fields.
499    pub fn check_init(&self, arg: usize) -> bool {
500        let point_to_id = self.chains.get_point_to_id(arg);
501        let var = self.chains.get_var_node(point_to_id);
502        // display_hashmap(&self.chains.variables, 1);
503        if var.unwrap().field.is_empty() {
504            let mut init_flag = true;
505            for field in &var.unwrap().field {
506                init_flag &= self.check_init(*field.1);
507            }
508            init_flag
509        } else {
510            var.unwrap().ots.init
511        }
512    }
513
514    pub fn check_allocator_consistency(&self, _func_name: String, _arg: usize) -> bool {
515        true
516    }
517
518    pub fn check_allocated(&self, _arg: usize) -> bool {
519        true
520    }
521
522    pub fn check_inbound(&self, arg: usize, length_arg: usize, contract_ty: Ty<'tcx>) -> bool {
523        // 1. Check the var's cis.
524        let mem_arg = self.chains.get_point_to_id(arg);
525        let mem_var = self.chains.get_var_node(mem_arg).unwrap();
526        for cis in &mem_var.cis.contracts {
527            if let PropertyContract::InBound(cis_ty, len) = cis {
528                // display_hashmap(&self.chains.variables, 1);
529                return self.check_le_op(&contract_ty, length_arg, cis_ty, len);
530            }
531        }
532        false
533    }
534
535    /// return the result of less equal comparison (left_len * left_ty <= right_len * right_ty)
536    fn check_le_op(
537        &self,
538        left_ty: &Ty<'tcx>,
539        left_arg: usize,
540        right_ty: &Ty<'tcx>,
541        right_len: &CisRangeItem,
542    ) -> bool {
543        // If they have same types, then compare the length
544        // rap_warn!("{:?}, {left_arg}, {:?}, {:?}", left_ty, right_ty, right_len);
545        // If they have the same type, compare their patial order
546        if left_ty == right_ty {
547            return self
548                .compare_patial_order_of_two_args(left_arg, right_len.get_var_base().unwrap());
549        }
550        // Otherwise, take size of types into consideration
551        let left_layout = self.visit_ty_and_get_layout(*left_ty);
552        let right_layout = self.visit_ty_and_get_layout(*right_ty);
553        let get_size_range = |layout: &PlaceTy<'tcx>| -> Option<(u128, u128)> {
554            match layout {
555                PlaceTy::Ty(_, size) => Some((*size as u128, *size as u128)),
556                PlaceTy::GenericTy(_, _, layouts) if !layouts.is_empty() => {
557                    let sizes: Vec<u128> = layouts.iter().map(|(_, s)| *s as u128).collect();
558                    let min = *sizes.iter().min().unwrap();
559                    let max = *sizes.iter().max().unwrap();
560                    Some((min, max))
561                }
562                _ => None,
563            }
564        };
565        let (left_min_size, left_max_size) = match get_size_range(&left_layout) {
566            Some(range) => range,
567            None => return false, // Can not detemine size
568        };
569        let (right_min_size, right_max_size) = match get_size_range(&right_layout) {
570            Some(range) => range,
571            None => return false, // Can not detemine size
572        };
573        // TODO:
574
575        false
576    }
577
578    /// compare two args, return true if left <= right
579    fn compare_patial_order_of_two_args(&self, left: usize, right: usize) -> bool {
580        // Find the same value node set
581        let mut dataflow_analyzer = DataFlowAnalyzer::new(self.tcx, false);
582        dataflow_analyzer.build_graph(self.def_id);
583        let left_local = rustc_middle::mir::Local::from(left);
584        let right_local = rustc_middle::mir::Local::from(right);
585        let left_local_set = dataflow_analyzer.collect_equivalent_locals(self.def_id, left_local);
586        let right_local_set = dataflow_analyzer.collect_equivalent_locals(self.def_id, right_local);
587        // If left == right
588        if right_local_set.contains(&rustc_middle::mir::Local::from(left)) {
589            return true;
590        }
591        // rap_warn!(
592        //     "left_local: {:?}, left set: {:?}, right_local:{:?}, right set: {:?}",
593        //     left_local,
594        //     left_local_set,
595        //     right_local,
596        //     right_local_set
597        // );
598        for left_local_item in left_local_set {
599            let left_var = self.chains.get_var_node(left_local_item.as_usize());
600            if left_var.is_none() {
601                continue;
602            }
603            for cis in &left_var.unwrap().cis.contracts {
604                if let PropertyContract::ValidNum(cis_range) = cis {
605                    let cis_len = &cis_range.range;
606                    match cis_range.bin_op {
607                        BinOp::Le | BinOp::Lt | BinOp::Eq => {
608                            return cis_len.get_var_base().is_some()
609                                && right_local_set.contains(&rustc_middle::mir::Local::from(
610                                    cis_len.get_var_base().unwrap(),
611                                ));
612                        }
613                        _ => {}
614                    }
615                }
616            }
617        }
618        false
619    }
620
621    // fn compare_cis_range(&self, cis_range: CisRange, right_len: &CisRangeItem) -> bool {
622    //     false
623    // }
624
625    pub fn check_valid_string(&self, _arg: usize) -> bool {
626        true
627    }
628
629    pub fn check_valid_cstr(&self, _arg: usize) -> bool {
630        true
631    }
632
633    pub fn check_valid_num(&self, _arg: usize) -> bool {
634        true
635    }
636
637    pub fn check_alias(&self, _arg: usize) -> bool {
638        true
639    }
640
641    // Compound SPs
642    pub fn check_valid_ptr(&self, arg: usize) -> bool {
643        !self.check_non_zst(arg) || (self.check_non_zst(arg) && self.check_deref(arg))
644    }
645
646    pub fn check_deref(&self, arg: usize) -> bool {
647        self.check_allocated(arg)
648        // && self.check_inbounded(arg)
649    }
650
651    pub fn check_ref_to_ptr(&self, arg: usize) -> bool {
652        self.check_deref(arg)
653            && self.check_init(arg)
654            // && self.check_align(arg)
655            && self.check_alias(arg)
656    }
657
658    pub fn show_error_info(&self, arg: usize) {
659        rap_warn!(
660            "In func {:?}, visitor checker error! Can't get {arg} in chain!",
661            get_cleaned_def_path_name(self.tcx, self.def_id)
662        );
663        display_hashmap(&self.chains.variables, 1);
664    }
665}