1use std::collections::HashSet;
2
3use super::{
4 contracts::{abstract_state::AlignState, state_lattice::Lattice},
5 matcher::{get_arg_place, UnsafeApi},
6 visitor::{BodyVisitor, CheckResult, PlaceTy},
7};
8use crate::{
9 analysis::{
10 core::{
11 alias_analysis::AAResult,
12 dataflow::{default::DataFlowAnalyzer, DataFlowAnalysis},
13 },
14 senryx::contracts::property::{CisRange, CisRangeItem, PropertyContract},
15 utils::fn_info::{
16 display_hashmap, generate_contract_from_annotation_without_field_types,
17 get_cleaned_def_path_name, is_strict_ty_convert, reflect_generic,
18 },
19 },
20 rap_debug, rap_error, rap_info, rap_warn,
21};
22use rustc_data_structures::fx::FxHashMap;
23use rustc_hir::def_id::DefId;
24use rustc_middle::mir::BinOp;
25use rustc_middle::mir::Operand;
26use rustc_middle::mir::Place;
27use rustc_middle::ty::Ty;
28use rustc_span::source_map::Spanned;
29use rustc_span::Span;
30
31impl<'tcx> BodyVisitor<'tcx> {
32 pub fn handle_std_unsafe_call(
33 &mut self,
34 _dst_place: &Place<'_>,
35 def_id: &DefId,
36 args: &[Spanned<Operand>],
37 _path_index: usize,
38 _fn_map: &FxHashMap<DefId, AAResult>,
39 fn_span: Span,
40 fn_result: UnsafeApi,
41 generic_mapping: FxHashMap<String, Ty<'tcx>>,
42 ) {
43 let func_name = get_cleaned_def_path_name(self.tcx, *def_id);
44 let args_with_contracts =
45 generate_contract_from_annotation_without_field_types(self.tcx, *def_id);
46 rap_debug!(
47 "Checking contracts {:?} for {:?}",
48 args_with_contracts,
49 def_id
50 );
51 let mut count = 0;
52 for (base, fields, contract) in args_with_contracts {
53 rap_debug!("Find contract for {:?}, {base}: {:?}", def_id, contract);
54 if base == 0 {
55 rap_warn!("Wrong base index for {:?}, with {:?}", def_id, contract);
56 continue;
57 }
58 let arg_tuple = get_arg_place(&args[base - 1].node);
59 if arg_tuple.0 {
61 continue; } else {
63 let arg_place = self.chains.find_var_id_with_fields_seq(arg_tuple.1, fields);
64 self.check_contract(
65 arg_place,
66 args,
67 contract,
68 &generic_mapping,
69 func_name.clone(),
70 fn_span,
71 count,
72 );
73 }
74 count += 1;
75 }
76
77 for (idx, sp_set) in fn_result.sps.iter().enumerate() {
78 if args.is_empty() {
79 break;
80 }
81 let arg_tuple = get_arg_place(&args[idx].node);
82 if arg_tuple.0 {
84 continue;
85 }
86 let arg_place = arg_tuple.1;
87 let _self_func_name = get_cleaned_def_path_name(self.tcx, self.def_id);
88 let func_name = get_cleaned_def_path_name(self.tcx, *def_id);
89 for sp in sp_set {
90 match sp.sp_name.as_str() {
91 "NonNull" => {
92 if !self.check_non_null(arg_place) {
93 self.insert_failed_check_result(
94 func_name.clone(),
95 fn_span,
96 idx + 1,
97 "NonNull",
98 );
99 } else {
100 self.insert_successful_check_result(
101 func_name.clone(),
102 fn_span,
103 idx + 1,
104 "NonNull",
105 );
106 }
107 }
108 "AllocatorConsistency" => {
109 if !self.check_allocator_consistency(func_name.clone(), arg_place) {
110 self.insert_failed_check_result(
111 func_name.clone(),
112 fn_span,
113 idx + 1,
114 "AllocatorConsistency",
115 );
116 } else {
117 self.insert_successful_check_result(
118 func_name.clone(),
119 fn_span,
120 idx + 1,
121 "AllocatorConsistency",
122 );
123 }
124 }
125 "!ZST" => {
126 if !self.check_non_zst(arg_place) {
127 self.insert_failed_check_result(
128 func_name.clone(),
129 fn_span,
130 idx + 1,
131 "!ZST",
132 );
133 } else {
134 self.insert_successful_check_result(
135 func_name.clone(),
136 fn_span,
137 idx + 1,
138 "!ZST",
139 );
140 }
141 }
142 "Typed" => {
143 if !self.check_typed(arg_place) {
144 self.insert_failed_check_result(
145 func_name.clone(),
146 fn_span,
147 idx + 1,
148 "Typed",
149 );
150 } else {
151 self.insert_successful_check_result(
152 func_name.clone(),
153 fn_span,
154 idx + 1,
155 "Typed",
156 );
157 }
158 }
159 "Allocated" => {
160 if !self.check_allocated(arg_place) {
161 self.insert_failed_check_result(
162 func_name.clone(),
163 fn_span,
164 idx + 1,
165 "Allocated",
166 );
167 } else {
168 self.insert_successful_check_result(
169 func_name.clone(),
170 fn_span,
171 idx + 1,
172 "Allocated",
173 );
174 }
175 }
176 "ValidString" => {
177 if !self.check_valid_string(arg_place) {
178 self.insert_failed_check_result(
179 func_name.clone(),
180 fn_span,
181 idx + 1,
182 "ValidString",
183 );
184 } else {
185 self.insert_successful_check_result(
186 func_name.clone(),
187 fn_span,
188 idx + 1,
189 "ValidString",
190 );
191 }
192 }
193 "ValidCStr" => {
194 if !self.check_valid_cstr(arg_place) {
195 self.insert_failed_check_result(
196 func_name.clone(),
197 fn_span,
198 idx + 1,
199 "ValidCStr",
200 );
201 } else {
202 self.insert_successful_check_result(
203 func_name.clone(),
204 fn_span,
205 idx + 1,
206 "ValidCStr",
207 );
208 }
209 }
210 "ValidInt" => {
211 if !self.check_valid_num(arg_place) {
212 self.insert_failed_check_result(
213 func_name.clone(),
214 fn_span,
215 idx + 1,
216 "ValidNum",
217 );
218 } else {
219 self.insert_successful_check_result(
220 func_name.clone(),
221 fn_span,
222 idx + 1,
223 "ValidInt",
224 );
225 }
226 }
227 "Init" => {
228 if !self.check_init(arg_place) {
229 self.insert_failed_check_result(
230 func_name.clone(),
231 fn_span,
232 idx + 1,
233 "Init",
234 );
235 } else {
236 self.insert_successful_check_result(
237 func_name.clone(),
238 fn_span,
239 idx + 1,
240 "Init",
241 );
242 }
243 }
244 "ValidPtr" => {
245 if !self.check_valid_ptr(arg_place) {
246 self.insert_failed_check_result(
247 func_name.clone(),
248 fn_span,
249 idx + 1,
250 "ValidPtr",
251 );
252 } else {
253 self.insert_successful_check_result(
254 func_name.clone(),
255 fn_span,
256 idx + 1,
257 "ValidPtr",
258 );
259 }
260 }
261 "Ref2Ptr" => {
262 if !self.check_ref_to_ptr(arg_place) {
263 self.insert_failed_check_result(
264 func_name.clone(),
265 fn_span,
266 idx + 1,
267 "Ref2Ptr",
268 );
269 } else {
270 self.insert_successful_check_result(
271 func_name.clone(),
272 fn_span,
273 idx + 1,
274 "Ref2Ptr",
275 );
276 }
277 }
278 _ => {}
279 }
280 }
281 }
282 }
283
284 pub fn insert_failed_check_result(
285 &mut self,
286 func_name: String,
287 fn_span: Span,
288 idx: usize,
289 sp: &str,
290 ) {
291 if let Some(existing) = self
292 .check_results
293 .iter_mut()
294 .find(|result| result.func_name == func_name && result.func_span == fn_span)
295 {
296 if let Some(passed_set) = existing.passed_contracts.get_mut(&idx) {
297 passed_set.remove(sp);
298 if passed_set.is_empty() {
299 existing.passed_contracts.remove(&idx);
300 }
301 }
302 existing
303 .failed_contracts
304 .entry(idx)
305 .and_modify(|set| {
306 set.insert(sp.to_string());
307 })
308 .or_insert_with(|| {
309 let mut new_set = HashSet::new();
310 new_set.insert(sp.to_string());
311 new_set
312 });
313 } else {
314 let mut new_result = CheckResult::new(&func_name, fn_span);
315 new_result
316 .failed_contracts
317 .insert(idx, HashSet::from([sp.to_string()]));
318 self.check_results.push(new_result);
319 }
320 }
321
322 pub fn insert_successful_check_result(
323 &mut self,
324 func_name: String,
325 fn_span: Span,
326 idx: usize,
327 sp: &str,
328 ) {
329 if let Some(existing) = self
330 .check_results
331 .iter_mut()
332 .find(|result| result.func_name == func_name && result.func_span == fn_span)
333 {
334 if let Some(failed_set) = existing.failed_contracts.get_mut(&idx) {
335 if failed_set.contains(sp) {
336 return;
337 }
338 }
339
340 existing
341 .passed_contracts
342 .entry(idx)
343 .and_modify(|set| {
344 set.insert(sp.to_string());
345 })
346 .or_insert_with(|| HashSet::from([sp.to_string()]));
347 } else {
348 let mut new_result = CheckResult::new(&func_name, fn_span);
349 new_result
350 .passed_contracts
351 .insert(idx, HashSet::from([sp.to_string()]));
352 self.check_results.push(new_result);
353 }
354 }
355
356 pub fn insert_checking_result(
357 &mut self,
358 sp: &str,
359 is_passed: bool,
360 func_name: String,
361 fn_span: Span,
362 idx: usize,
363 ) {
364 if is_passed {
365 self.insert_successful_check_result(func_name.clone(), fn_span, idx + 1, sp);
366 } else {
367 self.insert_failed_check_result(func_name.clone(), fn_span, idx + 1, sp);
368 }
369 }
370
371 pub fn check_contract(
372 &mut self,
373 arg: usize,
374 args: &[Spanned<Operand>],
375 contract: PropertyContract<'tcx>,
376 generic_mapping: &FxHashMap<String, Ty<'tcx>>,
377 func_name: String,
378 fn_span: Span,
379 idx: usize,
380 ) -> bool {
381 match contract {
382 PropertyContract::Align(ty) => {
383 let contract_required_ty = reflect_generic(generic_mapping, ty);
384 rap_debug!(
385 "peel generic ty for {:?}, actual_ty is {:?}",
386 func_name.clone(),
387 contract_required_ty
388 );
389 if !self.check_align(arg, contract_required_ty) {
390 self.insert_checking_result("Align", false, func_name, fn_span, idx);
391 } else {
392 rap_debug!("Checking Align passed for {func_name} in {:?}!", fn_span);
393 self.insert_checking_result("Align", true, func_name, fn_span, idx);
394 }
395 }
396 PropertyContract::InBound(ty, contract_len) => {
397 let contract_ty = reflect_generic(generic_mapping, ty);
398 if let CisRangeItem::Var(base, len_fields) = contract_len {
399 let base_tuple = get_arg_place(&args[base - 1].node);
400 let length_arg = self
401 .chains
402 .find_var_id_with_fields_seq(base_tuple.1, len_fields);
403 if !self.check_inbound(arg, length_arg, contract_ty) {
404 self.insert_checking_result("InBound", false, func_name, fn_span, idx);
405 } else {
406 rap_info!("Checking InBound passed for {func_name} in {:?}!", fn_span);
407 self.insert_checking_result("InBound", true, func_name, fn_span, idx);
408 }
409 } else {
410 rap_error!("Wrong arg {:?} in Inbound safety check!", contract_len);
411 }
412 }
413 PropertyContract::NonNull => {
414 self.check_non_null(arg);
415 }
416 _ => {}
417 }
418 true
419 }
420
421 pub fn check_align(&self, arg: usize, contract_required_ty: Ty<'tcx>) -> bool {
425 let var = self.chains.get_var_node(arg).unwrap();
429 let required_ty = self.visit_ty_and_get_layout(contract_required_ty);
430 for cis in &var.cis.contracts {
431 if let PropertyContract::Align(cis_ty) = cis {
432 let ori_ty = self.visit_ty_and_get_layout(*cis_ty);
433 return AlignState::Cast(ori_ty, required_ty).check();
434 }
435 }
436 let mem = self.chains.get_obj_ty_through_chain(arg);
438 let mem_ty = self.visit_ty_and_get_layout(mem.unwrap());
439 let cur_ty = self.visit_ty_and_get_layout(var.ty.unwrap());
440 let point_to_id = self.chains.get_point_to_id(arg);
441 let var_ty = self.chains.get_var_node(point_to_id);
442 return AlignState::Cast(mem_ty, cur_ty).check() && var_ty.unwrap().ots.align;
445 }
446
447 pub fn check_non_zst(&self, arg: usize) -> bool {
448 let obj_ty = self.chains.get_obj_ty_through_chain(arg);
449 if obj_ty.is_none() {
450 self.show_error_info(arg);
451 }
452 let ori_ty = self.visit_ty_and_get_layout(obj_ty.unwrap());
453 match ori_ty {
454 PlaceTy::Ty(_align, size) => size == 0,
455 PlaceTy::GenericTy(_, _, tys) => {
456 if tys.is_empty() {
457 return false;
458 }
459 for (_, size) in tys {
460 if size != 0 {
461 return false;
462 }
463 }
464 true
465 }
466 _ => false,
467 }
468 }
469
470 pub fn check_typed(&self, arg: usize) -> bool {
472 let obj_ty = self.chains.get_obj_ty_through_chain(arg).unwrap();
473 let var = self.chains.get_var_node(arg);
474 let var_ty = var.unwrap().ty.unwrap();
476 if obj_ty != var_ty && is_strict_ty_convert(self.tcx, obj_ty, var_ty) {
477 return false;
478 }
479 self.check_init(arg)
480 }
481
482 pub fn check_non_null(&self, arg: usize) -> bool {
483 let point_to_id = self.chains.get_point_to_id(arg);
484 let var_ty = self.chains.get_var_node(point_to_id);
485 if var_ty.is_none() {
486 self.show_error_info(arg);
487 }
488 var_ty.unwrap().ots.nonnull
489 }
490
491 pub fn check_init(&self, arg: usize) -> bool {
494 let point_to_id = self.chains.get_point_to_id(arg);
495 let var = self.chains.get_var_node(point_to_id);
496 if var.unwrap().field.is_empty() {
498 let mut init_flag = true;
499 for field in &var.unwrap().field {
500 init_flag &= self.check_init(*field.1);
501 }
502 init_flag
503 } else {
504 var.unwrap().ots.init
505 }
506 }
507
508 pub fn check_allocator_consistency(&self, _func_name: String, _arg: usize) -> bool {
509 true
510 }
511
512 pub fn check_allocated(&self, _arg: usize) -> bool {
513 true
514 }
515
516 pub fn check_inbound(&self, arg: usize, length_arg: usize, contract_ty: Ty<'tcx>) -> bool {
517 let mem_arg = self.chains.get_point_to_id(arg);
519 let mem_var = self.chains.get_var_node(mem_arg).unwrap();
520 for cis in &mem_var.cis.contracts {
521 if let PropertyContract::InBound(cis_ty, len) = cis {
522 return self.check_le_op(&contract_ty, length_arg, cis_ty, len);
524 }
525 }
526 false
527 }
528
529 fn check_le_op(
531 &self,
532 left_ty: &Ty<'tcx>,
533 left_arg: usize,
534 right_ty: &Ty<'tcx>,
535 right_len: &CisRangeItem,
536 ) -> bool {
537 if left_ty == right_ty {
541 return self
542 .compare_patial_order_of_two_args(left_arg, right_len.get_var_base().unwrap());
543 }
544 let left_layout = self.visit_ty_and_get_layout(*left_ty);
546 let right_layout = self.visit_ty_and_get_layout(*right_ty);
547 let get_size_range = |layout: &PlaceTy<'tcx>| -> Option<(u128, u128)> {
548 match layout {
549 PlaceTy::Ty(_, size) => Some((*size as u128, *size as u128)),
550 PlaceTy::GenericTy(_, _, layouts) if !layouts.is_empty() => {
551 let sizes: Vec<u128> = layouts.iter().map(|(_, s)| *s as u128).collect();
552 let min = *sizes.iter().min().unwrap();
553 let max = *sizes.iter().max().unwrap();
554 Some((min, max))
555 }
556 _ => None,
557 }
558 };
559 let (left_min_size, left_max_size) = match get_size_range(&left_layout) {
560 Some(range) => range,
561 None => return false, };
563 let (right_min_size, right_max_size) = match get_size_range(&right_layout) {
564 Some(range) => range,
565 None => return false, };
567 false
570 }
571
572 fn compare_patial_order_of_two_args(&self, left: usize, right: usize) -> bool {
574 let mut dataflow_analyzer = DataFlowAnalyzer::new(self.tcx, false);
576 dataflow_analyzer.build_graph(self.def_id);
577 let left_local = rustc_middle::mir::Local::from(left);
578 let right_local = rustc_middle::mir::Local::from(right);
579 let left_local_set = dataflow_analyzer.collect_equivalent_locals(self.def_id, left_local);
580 let right_local_set = dataflow_analyzer.collect_equivalent_locals(self.def_id, right_local);
581 if right_local_set.contains(&rustc_middle::mir::Local::from(left)) {
583 return true;
584 }
585 for left_local_item in left_local_set {
593 let left_var = self.chains.get_var_node(left_local_item.as_usize());
594 if left_var.is_none() {
595 continue;
596 }
597 for cis in &left_var.unwrap().cis.contracts {
598 if let PropertyContract::ValidNum(cis_range) = cis {
599 let cis_len = &cis_range.range;
600 match cis_range.bin_op {
601 BinOp::Le | BinOp::Lt | BinOp::Eq => {
602 return cis_len.get_var_base().is_some()
603 && right_local_set.contains(&rustc_middle::mir::Local::from(
604 cis_len.get_var_base().unwrap(),
605 ));
606 }
607 _ => {}
608 }
609 }
610 }
611 }
612 false
613 }
614
615 pub fn check_valid_string(&self, _arg: usize) -> bool {
620 true
621 }
622
623 pub fn check_valid_cstr(&self, _arg: usize) -> bool {
624 true
625 }
626
627 pub fn check_valid_num(&self, _arg: usize) -> bool {
628 true
629 }
630
631 pub fn check_alias(&self, _arg: usize) -> bool {
632 true
633 }
634
635 pub fn check_valid_ptr(&self, arg: usize) -> bool {
637 !self.check_non_zst(arg) || (self.check_non_zst(arg) && self.check_deref(arg))
638 }
639
640 pub fn check_deref(&self, arg: usize) -> bool {
641 self.check_allocated(arg)
642 }
644
645 pub fn check_ref_to_ptr(&self, arg: usize) -> bool {
646 self.check_deref(arg)
647 && self.check_init(arg)
648 && self.check_alias(arg)
650 }
651
652 pub fn show_error_info(&self, arg: usize) {
653 rap_warn!(
654 "In func {:?}, visitor checker error! Can't get {arg} in chain!",
655 get_cleaned_def_path_name(self.tcx, self.def_id)
656 );
657 display_hashmap(&self.chains.variables, 1);
658 }
659}