rapx/analysis/senryx/
visitor_check.rs

1use std::collections::HashSet;
2
3use super::{
4    contracts::{abstract_state::AlignState, state_lattice::Lattice},
5    matcher::{get_arg_place, UnsafeApi},
6    visitor::{BodyVisitor, CheckResult, PlaceTy},
7};
8use crate::{
9    analysis::{
10        core::{
11            alias_analysis::AAResult,
12            dataflow::{default::DataFlowAnalyzer, DataFlowAnalysis},
13        },
14        senryx::contracts::property::{CisRange, CisRangeItem, PropertyContract},
15        utils::fn_info::{
16            display_hashmap, generate_contract_from_annotation_without_field_types,
17            get_cleaned_def_path_name, is_strict_ty_convert, reflect_generic,
18        },
19    },
20    rap_debug, rap_error, rap_info, rap_warn,
21};
22use rustc_data_structures::fx::FxHashMap;
23use rustc_hir::def_id::DefId;
24use rustc_middle::mir::BinOp;
25use rustc_middle::mir::Operand;
26use rustc_middle::mir::Place;
27use rustc_middle::ty::Ty;
28use rustc_span::source_map::Spanned;
29use rustc_span::Span;
30
31impl<'tcx> BodyVisitor<'tcx> {
32    pub fn handle_std_unsafe_call(
33        &mut self,
34        _dst_place: &Place<'_>,
35        def_id: &DefId,
36        args: &[Spanned<Operand>],
37        _path_index: usize,
38        _fn_map: &FxHashMap<DefId, AAResult>,
39        fn_span: Span,
40        fn_result: UnsafeApi,
41        generic_mapping: FxHashMap<String, Ty<'tcx>>,
42    ) {
43        let func_name = get_cleaned_def_path_name(self.tcx, *def_id);
44        let args_with_contracts =
45            generate_contract_from_annotation_without_field_types(self.tcx, *def_id);
46        rap_debug!(
47            "Checking contracts {:?} for {:?}",
48            args_with_contracts,
49            def_id
50        );
51        let mut count = 0;
52        for (base, fields, contract) in args_with_contracts {
53            rap_debug!("Find contract for {:?}, {base}: {:?}", def_id, contract);
54            if base == 0 {
55                rap_warn!("Wrong base index for {:?}, with {:?}", def_id, contract);
56                continue;
57            }
58            let arg_tuple = get_arg_place(&args[base - 1].node);
59            // if this arg is a constant
60            if arg_tuple.0 {
61                continue; //TODO: check the constant value
62            } else {
63                let arg_place = self.chains.find_var_id_with_fields_seq(arg_tuple.1, fields);
64                self.check_contract(
65                    arg_place,
66                    args,
67                    contract,
68                    &generic_mapping,
69                    func_name.clone(),
70                    fn_span,
71                    count,
72                );
73            }
74            count += 1;
75        }
76
77        for (idx, sp_set) in fn_result.sps.iter().enumerate() {
78            if args.is_empty() {
79                break;
80            }
81            let arg_tuple = get_arg_place(&args[idx].node);
82            // if this arg is a constant
83            if arg_tuple.0 {
84                continue;
85            }
86            let arg_place = arg_tuple.1;
87            let _self_func_name = get_cleaned_def_path_name(self.tcx, self.def_id);
88            let func_name = get_cleaned_def_path_name(self.tcx, *def_id);
89            for sp in sp_set {
90                match sp.sp_name.as_str() {
91                    "NonNull" => {
92                        if !self.check_non_null(arg_place) {
93                            self.insert_failed_check_result(
94                                func_name.clone(),
95                                fn_span,
96                                idx + 1,
97                                "NonNull",
98                            );
99                        } else {
100                            self.insert_successful_check_result(
101                                func_name.clone(),
102                                fn_span,
103                                idx + 1,
104                                "NonNull",
105                            );
106                        }
107                    }
108                    "AllocatorConsistency" => {
109                        if !self.check_allocator_consistency(func_name.clone(), arg_place) {
110                            self.insert_failed_check_result(
111                                func_name.clone(),
112                                fn_span,
113                                idx + 1,
114                                "AllocatorConsistency",
115                            );
116                        } else {
117                            self.insert_successful_check_result(
118                                func_name.clone(),
119                                fn_span,
120                                idx + 1,
121                                "AllocatorConsistency",
122                            );
123                        }
124                    }
125                    "!ZST" => {
126                        if !self.check_non_zst(arg_place) {
127                            self.insert_failed_check_result(
128                                func_name.clone(),
129                                fn_span,
130                                idx + 1,
131                                "!ZST",
132                            );
133                        } else {
134                            self.insert_successful_check_result(
135                                func_name.clone(),
136                                fn_span,
137                                idx + 1,
138                                "!ZST",
139                            );
140                        }
141                    }
142                    "Typed" => {
143                        if !self.check_typed(arg_place) {
144                            self.insert_failed_check_result(
145                                func_name.clone(),
146                                fn_span,
147                                idx + 1,
148                                "Typed",
149                            );
150                        } else {
151                            self.insert_successful_check_result(
152                                func_name.clone(),
153                                fn_span,
154                                idx + 1,
155                                "Typed",
156                            );
157                        }
158                    }
159                    "Allocated" => {
160                        if !self.check_allocated(arg_place) {
161                            self.insert_failed_check_result(
162                                func_name.clone(),
163                                fn_span,
164                                idx + 1,
165                                "Allocated",
166                            );
167                        } else {
168                            self.insert_successful_check_result(
169                                func_name.clone(),
170                                fn_span,
171                                idx + 1,
172                                "Allocated",
173                            );
174                        }
175                    }
176                    "ValidString" => {
177                        if !self.check_valid_string(arg_place) {
178                            self.insert_failed_check_result(
179                                func_name.clone(),
180                                fn_span,
181                                idx + 1,
182                                "ValidString",
183                            );
184                        } else {
185                            self.insert_successful_check_result(
186                                func_name.clone(),
187                                fn_span,
188                                idx + 1,
189                                "ValidString",
190                            );
191                        }
192                    }
193                    "ValidCStr" => {
194                        if !self.check_valid_cstr(arg_place) {
195                            self.insert_failed_check_result(
196                                func_name.clone(),
197                                fn_span,
198                                idx + 1,
199                                "ValidCStr",
200                            );
201                        } else {
202                            self.insert_successful_check_result(
203                                func_name.clone(),
204                                fn_span,
205                                idx + 1,
206                                "ValidCStr",
207                            );
208                        }
209                    }
210                    "ValidInt" => {
211                        if !self.check_valid_num(arg_place) {
212                            self.insert_failed_check_result(
213                                func_name.clone(),
214                                fn_span,
215                                idx + 1,
216                                "ValidNum",
217                            );
218                        } else {
219                            self.insert_successful_check_result(
220                                func_name.clone(),
221                                fn_span,
222                                idx + 1,
223                                "ValidInt",
224                            );
225                        }
226                    }
227                    "Init" => {
228                        if !self.check_init(arg_place) {
229                            self.insert_failed_check_result(
230                                func_name.clone(),
231                                fn_span,
232                                idx + 1,
233                                "Init",
234                            );
235                        } else {
236                            self.insert_successful_check_result(
237                                func_name.clone(),
238                                fn_span,
239                                idx + 1,
240                                "Init",
241                            );
242                        }
243                    }
244                    "ValidPtr" => {
245                        if !self.check_valid_ptr(arg_place) {
246                            self.insert_failed_check_result(
247                                func_name.clone(),
248                                fn_span,
249                                idx + 1,
250                                "ValidPtr",
251                            );
252                        } else {
253                            self.insert_successful_check_result(
254                                func_name.clone(),
255                                fn_span,
256                                idx + 1,
257                                "ValidPtr",
258                            );
259                        }
260                    }
261                    "Ref2Ptr" => {
262                        if !self.check_ref_to_ptr(arg_place) {
263                            self.insert_failed_check_result(
264                                func_name.clone(),
265                                fn_span,
266                                idx + 1,
267                                "Ref2Ptr",
268                            );
269                        } else {
270                            self.insert_successful_check_result(
271                                func_name.clone(),
272                                fn_span,
273                                idx + 1,
274                                "Ref2Ptr",
275                            );
276                        }
277                    }
278                    _ => {}
279                }
280            }
281        }
282    }
283
284    pub fn insert_failed_check_result(
285        &mut self,
286        func_name: String,
287        fn_span: Span,
288        idx: usize,
289        sp: &str,
290    ) {
291        if let Some(existing) = self
292            .check_results
293            .iter_mut()
294            .find(|result| result.func_name == func_name && result.func_span == fn_span)
295        {
296            if let Some(passed_set) = existing.passed_contracts.get_mut(&idx) {
297                passed_set.remove(sp);
298                if passed_set.is_empty() {
299                    existing.passed_contracts.remove(&idx);
300                }
301            }
302            existing
303                .failed_contracts
304                .entry(idx)
305                .and_modify(|set| {
306                    set.insert(sp.to_string());
307                })
308                .or_insert_with(|| {
309                    let mut new_set = HashSet::new();
310                    new_set.insert(sp.to_string());
311                    new_set
312                });
313        } else {
314            let mut new_result = CheckResult::new(&func_name, fn_span);
315            new_result
316                .failed_contracts
317                .insert(idx, HashSet::from([sp.to_string()]));
318            self.check_results.push(new_result);
319        }
320    }
321
322    pub fn insert_successful_check_result(
323        &mut self,
324        func_name: String,
325        fn_span: Span,
326        idx: usize,
327        sp: &str,
328    ) {
329        if let Some(existing) = self
330            .check_results
331            .iter_mut()
332            .find(|result| result.func_name == func_name && result.func_span == fn_span)
333        {
334            if let Some(failed_set) = existing.failed_contracts.get_mut(&idx) {
335                if failed_set.contains(sp) {
336                    return;
337                }
338            }
339
340            existing
341                .passed_contracts
342                .entry(idx)
343                .and_modify(|set| {
344                    set.insert(sp.to_string());
345                })
346                .or_insert_with(|| HashSet::from([sp.to_string()]));
347        } else {
348            let mut new_result = CheckResult::new(&func_name, fn_span);
349            new_result
350                .passed_contracts
351                .insert(idx, HashSet::from([sp.to_string()]));
352            self.check_results.push(new_result);
353        }
354    }
355
356    pub fn insert_checking_result(
357        &mut self,
358        sp: &str,
359        is_passed: bool,
360        func_name: String,
361        fn_span: Span,
362        idx: usize,
363    ) {
364        if is_passed {
365            self.insert_successful_check_result(func_name.clone(), fn_span, idx + 1, sp);
366        } else {
367            self.insert_failed_check_result(func_name.clone(), fn_span, idx + 1, sp);
368        }
369    }
370
371    pub fn check_contract(
372        &mut self,
373        arg: usize,
374        args: &[Spanned<Operand>],
375        contract: PropertyContract<'tcx>,
376        generic_mapping: &FxHashMap<String, Ty<'tcx>>,
377        func_name: String,
378        fn_span: Span,
379        idx: usize,
380    ) -> bool {
381        match contract {
382            PropertyContract::Align(ty) => {
383                let contract_required_ty = reflect_generic(generic_mapping, ty);
384                rap_debug!(
385                    "peel generic ty for {:?}, actual_ty is {:?}",
386                    func_name.clone(),
387                    contract_required_ty
388                );
389                if !self.check_align(arg, contract_required_ty) {
390                    self.insert_checking_result("Align", false, func_name, fn_span, idx);
391                } else {
392                    rap_debug!("Checking Align passed for {func_name} in {:?}!", fn_span);
393                    self.insert_checking_result("Align", true, func_name, fn_span, idx);
394                }
395            }
396            PropertyContract::InBound(ty, contract_len) => {
397                let contract_ty = reflect_generic(generic_mapping, ty);
398                if let CisRangeItem::Var(base, len_fields) = contract_len {
399                    let base_tuple = get_arg_place(&args[base - 1].node);
400                    let length_arg = self
401                        .chains
402                        .find_var_id_with_fields_seq(base_tuple.1, len_fields);
403                    if !self.check_inbound(arg, length_arg, contract_ty) {
404                        self.insert_checking_result("InBound", false, func_name, fn_span, idx);
405                    } else {
406                        rap_info!("Checking InBound passed for {func_name} in {:?}!", fn_span);
407                        self.insert_checking_result("InBound", true, func_name, fn_span, idx);
408                    }
409                } else {
410                    rap_error!("Wrong arg {:?} in Inbound safety check!", contract_len);
411                }
412            }
413            PropertyContract::NonNull => {
414                self.check_non_null(arg);
415            }
416            _ => {}
417        }
418        true
419    }
420
421    // ----------------------Sp checking functions--------------------------
422
423    // TODO: Currently can not support unaligned offset checking
424    pub fn check_align(&self, arg: usize, contract_required_ty: Ty<'tcx>) -> bool {
425        // rap_warn!("Checking Align {arg}!");
426        // display_hashmap(&self.chains.variables, 1);
427        // 1. Check the var's cis.
428        let var = self.chains.get_var_node(arg).unwrap();
429        let required_ty = self.visit_ty_and_get_layout(contract_required_ty);
430        for cis in &var.cis.contracts {
431            if let PropertyContract::Align(cis_ty) = cis {
432                let ori_ty = self.visit_ty_and_get_layout(*cis_ty);
433                return AlignState::Cast(ori_ty, required_ty).check();
434            }
435        }
436        // 2. If the var does not have cis, then check its type and the value type
437        let mem = self.chains.get_obj_ty_through_chain(arg);
438        let mem_ty = self.visit_ty_and_get_layout(mem.unwrap());
439        let cur_ty = self.visit_ty_and_get_layout(var.ty.unwrap());
440        let point_to_id = self.chains.get_point_to_id(arg);
441        let var_ty = self.chains.get_var_node(point_to_id);
442        // display_hashmap(&self.chains.variables, 1);
443        // rap_warn!("{:?}, {:?}, {:?}, {:?}", arg, cur_ty, point_to_id, mem_ty);
444        return AlignState::Cast(mem_ty, cur_ty).check() && var_ty.unwrap().ots.align;
445    }
446
447    pub fn check_non_zst(&self, arg: usize) -> bool {
448        let obj_ty = self.chains.get_obj_ty_through_chain(arg);
449        if obj_ty.is_none() {
450            self.show_error_info(arg);
451        }
452        let ori_ty = self.visit_ty_and_get_layout(obj_ty.unwrap());
453        match ori_ty {
454            PlaceTy::Ty(_align, size) => size == 0,
455            PlaceTy::GenericTy(_, _, tys) => {
456                if tys.is_empty() {
457                    return false;
458                }
459                for (_, size) in tys {
460                    if size != 0 {
461                        return false;
462                    }
463                }
464                true
465            }
466            _ => false,
467        }
468    }
469
470    // checking the value ptr points to is valid for its type
471    pub fn check_typed(&self, arg: usize) -> bool {
472        let obj_ty = self.chains.get_obj_ty_through_chain(arg).unwrap();
473        let var = self.chains.get_var_node(arg);
474        // display_hashmap(&self.chains.variables, 1);
475        let var_ty = var.unwrap().ty.unwrap();
476        if obj_ty != var_ty && is_strict_ty_convert(self.tcx, obj_ty, var_ty) {
477            return false;
478        }
479        self.check_init(arg)
480    }
481
482    pub fn check_non_null(&self, arg: usize) -> bool {
483        let point_to_id = self.chains.get_point_to_id(arg);
484        let var_ty = self.chains.get_var_node(point_to_id);
485        if var_ty.is_none() {
486            self.show_error_info(arg);
487        }
488        var_ty.unwrap().ots.nonnull
489    }
490
491    // check each field's init state in the tree.
492    // check arg itself when it doesn't have fields.
493    pub fn check_init(&self, arg: usize) -> bool {
494        let point_to_id = self.chains.get_point_to_id(arg);
495        let var = self.chains.get_var_node(point_to_id);
496        // display_hashmap(&self.chains.variables, 1);
497        if var.unwrap().field.is_empty() {
498            let mut init_flag = true;
499            for field in &var.unwrap().field {
500                init_flag &= self.check_init(*field.1);
501            }
502            init_flag
503        } else {
504            var.unwrap().ots.init
505        }
506    }
507
508    pub fn check_allocator_consistency(&self, _func_name: String, _arg: usize) -> bool {
509        true
510    }
511
512    pub fn check_allocated(&self, _arg: usize) -> bool {
513        true
514    }
515
516    pub fn check_inbound(&self, arg: usize, length_arg: usize, contract_ty: Ty<'tcx>) -> bool {
517        // 1. Check the var's cis.
518        let mem_arg = self.chains.get_point_to_id(arg);
519        let mem_var = self.chains.get_var_node(mem_arg).unwrap();
520        for cis in &mem_var.cis.contracts {
521            if let PropertyContract::InBound(cis_ty, len) = cis {
522                // display_hashmap(&self.chains.variables, 1);
523                return self.check_le_op(&contract_ty, length_arg, cis_ty, len);
524            }
525        }
526        false
527    }
528
529    /// return the result of less equal comparison (left_len * left_ty <= right_len * right_ty)
530    fn check_le_op(
531        &self,
532        left_ty: &Ty<'tcx>,
533        left_arg: usize,
534        right_ty: &Ty<'tcx>,
535        right_len: &CisRangeItem,
536    ) -> bool {
537        // If they have same types, then compare the length
538        // rap_warn!("{:?}, {left_arg}, {:?}, {:?}", left_ty, right_ty, right_len);
539        // If they have the same type, compare their patial order
540        if left_ty == right_ty {
541            return self
542                .compare_patial_order_of_two_args(left_arg, right_len.get_var_base().unwrap());
543        }
544        // Otherwise, take size of types into consideration
545        let left_layout = self.visit_ty_and_get_layout(*left_ty);
546        let right_layout = self.visit_ty_and_get_layout(*right_ty);
547        let get_size_range = |layout: &PlaceTy<'tcx>| -> Option<(u128, u128)> {
548            match layout {
549                PlaceTy::Ty(_, size) => Some((*size as u128, *size as u128)),
550                PlaceTy::GenericTy(_, _, layouts) if !layouts.is_empty() => {
551                    let sizes: Vec<u128> = layouts.iter().map(|(_, s)| *s as u128).collect();
552                    let min = *sizes.iter().min().unwrap();
553                    let max = *sizes.iter().max().unwrap();
554                    Some((min, max))
555                }
556                _ => None,
557            }
558        };
559        let (left_min_size, left_max_size) = match get_size_range(&left_layout) {
560            Some(range) => range,
561            None => return false, // Can not detemine size
562        };
563        let (right_min_size, right_max_size) = match get_size_range(&right_layout) {
564            Some(range) => range,
565            None => return false, // Can not detemine size
566        };
567        // TODO:
568
569        false
570    }
571
572    /// compare two args, return true if left <= right
573    fn compare_patial_order_of_two_args(&self, left: usize, right: usize) -> bool {
574        // Find the same value node set
575        let mut dataflow_analyzer = DataFlowAnalyzer::new(self.tcx, false);
576        dataflow_analyzer.build_graph(self.def_id);
577        let left_local = rustc_middle::mir::Local::from(left);
578        let right_local = rustc_middle::mir::Local::from(right);
579        let left_local_set = dataflow_analyzer.collect_equivalent_locals(self.def_id, left_local);
580        let right_local_set = dataflow_analyzer.collect_equivalent_locals(self.def_id, right_local);
581        // If left == right
582        if right_local_set.contains(&rustc_middle::mir::Local::from(left)) {
583            return true;
584        }
585        // rap_warn!(
586        //     "left_local: {:?}, left set: {:?}, right_local:{:?}, right set: {:?}",
587        //     left_local,
588        //     left_local_set,
589        //     right_local,
590        //     right_local_set
591        // );
592        for left_local_item in left_local_set {
593            let left_var = self.chains.get_var_node(left_local_item.as_usize());
594            if left_var.is_none() {
595                continue;
596            }
597            for cis in &left_var.unwrap().cis.contracts {
598                if let PropertyContract::ValidNum(cis_range) = cis {
599                    let cis_len = &cis_range.range;
600                    match cis_range.bin_op {
601                        BinOp::Le | BinOp::Lt | BinOp::Eq => {
602                            return cis_len.get_var_base().is_some()
603                                && right_local_set.contains(&rustc_middle::mir::Local::from(
604                                    cis_len.get_var_base().unwrap(),
605                                ));
606                        }
607                        _ => {}
608                    }
609                }
610            }
611        }
612        false
613    }
614
615    // fn compare_cis_range(&self, cis_range: CisRange, right_len: &CisRangeItem) -> bool {
616    //     false
617    // }
618
619    pub fn check_valid_string(&self, _arg: usize) -> bool {
620        true
621    }
622
623    pub fn check_valid_cstr(&self, _arg: usize) -> bool {
624        true
625    }
626
627    pub fn check_valid_num(&self, _arg: usize) -> bool {
628        true
629    }
630
631    pub fn check_alias(&self, _arg: usize) -> bool {
632        true
633    }
634
635    // Compound SPs
636    pub fn check_valid_ptr(&self, arg: usize) -> bool {
637        !self.check_non_zst(arg) || (self.check_non_zst(arg) && self.check_deref(arg))
638    }
639
640    pub fn check_deref(&self, arg: usize) -> bool {
641        self.check_allocated(arg)
642        // && self.check_inbounded(arg)
643    }
644
645    pub fn check_ref_to_ptr(&self, arg: usize) -> bool {
646        self.check_deref(arg)
647            && self.check_init(arg)
648            // && self.check_align(arg)
649            && self.check_alias(arg)
650    }
651
652    pub fn show_error_info(&self, arg: usize) {
653        rap_warn!(
654            "In func {:?}, visitor checker error! Can't get {arg} in chain!",
655            get_cleaned_def_path_name(self.tcx, self.def_id)
656        );
657        display_hashmap(&self.chains.variables, 1);
658    }
659}