1#[allow(unused)]
2use crate::analysis::senryx::contracts::property::PropertyContract;
3use crate::analysis::senryx::matcher::parse_unsafe_api;
4use crate::analysis::unsafety_isolation::generate_dot::NodeType;
5use crate::rap_debug;
6use crate::rap_warn;
7use rustc_data_structures::fx::FxHashMap;
8use rustc_hir::def::DefKind;
9use rustc_hir::def_id::DefId;
10use rustc_hir::Attribute;
11use rustc_hir::ImplItemKind;
12use rustc_middle::mir::BinOp;
13use rustc_middle::mir::Local;
14use rustc_middle::mir::{BasicBlock, Terminator};
15use rustc_middle::ty::{AssocKind, Mutability, Ty, TyCtxt, TyKind};
16use rustc_middle::{
17 mir::{Operand, TerminatorKind},
18 ty,
19};
20use rustc_span::def_id::LocalDefId;
21use rustc_span::kw;
22use rustc_span::sym;
23use std::collections::HashMap;
24use std::collections::HashSet;
25use std::fmt::Debug;
26use std::hash::Hash;
27use syn::Expr;
28
29pub fn generate_node_ty(tcx: TyCtxt, def_id: DefId) -> NodeType {
30 (def_id, check_safety(tcx, def_id), get_type(tcx, def_id))
31}
32
33pub fn check_visibility(tcx: TyCtxt, func_defid: DefId) -> bool {
34 if !tcx.visibility(func_defid).is_public() {
35 return false;
36 }
37 true
48}
49
50pub fn is_re_exported(tcx: TyCtxt, target_defid: DefId, module_defid: LocalDefId) -> bool {
51 for child in tcx.module_children_local(module_defid) {
52 if child.vis.is_public() {
53 if let Some(def_id) = child.res.opt_def_id() {
54 if def_id == target_defid {
55 return true;
56 }
57 }
58 }
59 }
60 false
61}
62
63pub fn print_hashset<T: std::fmt::Debug>(set: &HashSet<T>) {
64 for item in set {
65 println!("{:?}", item);
66 }
67 println!("---------------");
68}
69
70pub fn get_cleaned_def_path_name(tcx: TyCtxt, def_id: DefId) -> String {
71 let def_id_str = format!("{:?}", def_id);
72 let mut parts: Vec<&str> = def_id_str.split("::").collect();
73
74 let mut remove_first = false;
75 if let Some(first_part) = parts.get_mut(0) {
76 if first_part.contains("core") {
77 *first_part = "core";
78 } else if first_part.contains("std") {
79 *first_part = "std";
80 } else if first_part.contains("alloc") {
81 *first_part = "alloc";
82 } else {
83 remove_first = true;
84 }
85 }
86 if remove_first && !parts.is_empty() {
87 parts.remove(0);
88 }
89
90 let new_parts: Vec<String> = parts
91 .into_iter()
92 .filter_map(|s| {
93 if s.contains("{") {
94 if remove_first {
95 get_struct_name(tcx, def_id)
96 } else {
97 None
98 }
99 } else {
100 Some(s.to_string())
101 }
102 })
103 .collect();
104
105 let mut cleaned_path = new_parts.join("::");
106 cleaned_path = cleaned_path.trim_end_matches(')').to_string();
107 cleaned_path
108}
109
110pub fn get_sp_json() -> serde_json::Value {
111 let json_data: serde_json::Value =
112 serde_json::from_str(include_str!("../unsafety_isolation/data/std_sps.json"))
113 .expect("Unable to parse JSON");
114 json_data
115}
116
117pub fn get_std_api_signature_json() -> serde_json::Value {
118 let json_data: serde_json::Value =
119 serde_json::from_str(include_str!("../unsafety_isolation/data/std_sig.json"))
120 .expect("Unable to parse JSON");
121 json_data
122}
123
124pub fn get_sp(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<String> {
125 let cleaned_path_name = get_cleaned_def_path_name(tcx, def_id);
126 let json_data: serde_json::Value = get_sp_json();
127
128 if let Some(function_info) = json_data.get(&cleaned_path_name) {
129 if let Some(sp_list) = function_info.get("0") {
130 let mut result = HashSet::new();
131 if let Some(sp_array) = sp_list.as_array() {
132 for sp in sp_array {
133 if let Some(sp_name) = sp.as_str() {
134 result.insert(sp_name.to_string());
135 }
136 }
137 }
138 return result;
139 }
140 }
141 HashSet::new()
142}
143
144pub fn get_struct_name(tcx: TyCtxt<'_>, def_id: DefId) -> Option<String> {
145 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
146 if let Some(impl_id) = assoc_item.impl_container(tcx) {
147 let ty = tcx.type_of(impl_id).skip_binder();
148 let type_name = ty.to_string();
149 let struct_name = type_name
150 .split('<')
151 .next()
152 .unwrap_or("")
153 .split("::")
154 .last()
155 .unwrap_or("")
156 .to_string();
157
158 return Some(struct_name);
159 }
160 }
161 None
162}
163
164pub fn check_safety(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
165 let poly_fn_sig = tcx.fn_sig(def_id);
166 let fn_sig = poly_fn_sig.skip_binder();
167 fn_sig.safety() == rustc_hir::Safety::Unsafe
168}
169
170pub fn get_type(tcx: TyCtxt<'_>, def_id: DefId) -> usize {
172 let mut node_type = 2;
173 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
174 match assoc_item.kind {
175 AssocKind::Fn { has_self, .. } => {
176 if has_self {
177 node_type = 1;
178 } else {
179 let fn_sig = tcx.fn_sig(def_id).skip_binder();
180 let output = fn_sig.output().skip_binder();
181 if output.is_param(0) {
183 node_type = 0;
184 }
185 if let Some(impl_id) = assoc_item.impl_container(tcx) {
187 let ty = tcx.type_of(impl_id).skip_binder();
188 if output == ty {
189 node_type = 0;
190 }
191 }
192 match output.kind() {
193 TyKind::Ref(_, ref_ty, _) => {
194 if ref_ty.is_param(0) {
195 node_type = 0;
196 }
197 if let Some(impl_id) = assoc_item.impl_container(tcx) {
198 let ty = tcx.type_of(impl_id).skip_binder();
199 if *ref_ty == ty {
200 node_type = 0;
201 }
202 }
203 }
204 TyKind::Adt(adt_def, substs) => {
205 if adt_def.is_enum()
206 && (tcx.is_diagnostic_item(sym::Option, adt_def.did())
207 || tcx.is_diagnostic_item(sym::Result, adt_def.did())
208 || tcx.is_diagnostic_item(kw::Box, adt_def.did()))
209 {
210 let inner_ty = substs.type_at(0);
211 if inner_ty.is_param(0) {
212 node_type = 0;
213 }
214 if let Some(impl_id) = assoc_item.impl_container(tcx) {
215 let ty_impl = tcx.type_of(impl_id).skip_binder();
216 if inner_ty == ty_impl {
217 node_type = 0;
218 }
219 }
220 }
221 }
222 _ => {}
223 }
224 }
225 }
226 _ => todo!(),
227 }
228 }
229 node_type
230}
231
232pub fn get_adt_ty(tcx: TyCtxt, def_id: DefId) -> Option<Ty> {
233 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
234 if let Some(impl_id) = assoc_item.impl_container(tcx) {
235 return Some(tcx.type_of(impl_id).skip_binder());
236 }
237 }
238 None
239}
240
241pub fn get_cons(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<NodeType> {
242 let mut cons = Vec::new();
243 if tcx.def_kind(def_id) == DefKind::Fn || get_type(tcx, def_id) == 0 {
244 return cons;
245 }
246 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
247 if let Some(impl_id) = assoc_item.impl_container(tcx) {
248 let ty = tcx.type_of(impl_id).skip_binder();
250 if let Some(adt_def) = ty.ty_adt_def() {
251 let adt_def_id = adt_def.did();
252 let impls = tcx.inherent_impls(adt_def_id);
253 for impl_def_id in impls {
254 for item in tcx.associated_item_def_ids(impl_def_id) {
255 if (tcx.def_kind(item) == DefKind::Fn
256 || tcx.def_kind(item) == DefKind::AssocFn)
257 && get_type(tcx, *item) == 0
258 {
259 cons.push(generate_node_ty(tcx, *item));
260 }
261 }
262 }
263 }
264 }
265 }
266 cons
267}
268
269pub fn get_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
270 let mut callees = HashSet::new();
271 if tcx.is_mir_available(def_id) {
272 let body = tcx.optimized_mir(def_id);
273 for bb in body.basic_blocks.iter() {
274 if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
275 if let Operand::Constant(func_constant) = func {
276 if let ty::FnDef(ref callee_def_id, _) = func_constant.const_.ty().kind() {
277 if check_safety(tcx, *callee_def_id)
278 {
280 let sp_set = get_sp(tcx, *callee_def_id);
281 if sp_set.len() != 0 {
282 callees.insert(*callee_def_id);
283 }
284 }
285 }
286 }
287 }
288 }
289 }
290 callees
291}
292
293pub fn get_impls_for_struct(tcx: TyCtxt<'_>, struct_def_id: DefId) -> Vec<DefId> {
295 let mut impls = Vec::new();
296 for impl_item_id in tcx.hir_crate_items(()).impl_items() {
297 let impl_item = tcx.hir_impl_item(impl_item_id);
298 match impl_item.kind {
299 ImplItemKind::Type(ty) => {
300 if let rustc_hir::TyKind::Path(ref qpath) = ty.kind {
301 if let rustc_hir::QPath::Resolved(_, path) = qpath {
302 if let rustc_hir::def::Res::Def(_, ref def_id) = path.res {
303 if *def_id == struct_def_id {
304 impls.push(impl_item.owner_id.to_def_id());
305 }
306 }
307 }
308 }
309 }
310 _ => (),
311 }
312 }
313 impls
314}
315
316pub fn get_adt_def_id_by_adt_method(tcx: TyCtxt<'_>, def_id: DefId) -> Option<DefId> {
317 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
318 if let Some(impl_id) = assoc_item.impl_container(tcx) {
319 let ty = tcx.type_of(impl_id).skip_binder();
321 if let Some(adt_def) = ty.ty_adt_def() {
322 return Some(adt_def.did());
323 }
324 }
325 }
326 None
327}
328
329pub fn get_pointee(matched_ty: Ty<'_>) -> Ty<'_> {
331 let pointee = if let ty::RawPtr(ty_mut, _) = matched_ty.kind() {
333 get_pointee(*ty_mut)
334 } else if let ty::Ref(_, referred_ty, _) = matched_ty.kind() {
335 get_pointee(*referred_ty)
336 } else {
337 matched_ty
338 };
339 pointee
340}
341
342pub fn is_ptr(matched_ty: Ty<'_>) -> bool {
343 if let ty::RawPtr(_, _) = matched_ty.kind() {
344 return true;
345 }
346 false
347}
348
349pub fn is_ref(matched_ty: Ty<'_>) -> bool {
350 if let ty::Ref(_, _, _) = matched_ty.kind() {
351 return true;
352 }
353 false
354}
355
356pub fn is_slice(matched_ty: Ty) -> Option<Ty> {
357 if let ty::Slice(inner) = matched_ty.kind() {
358 return Some(*inner);
359 }
360 None
361}
362
363pub fn has_mut_self_param(tcx: TyCtxt, def_id: DefId) -> bool {
364 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
365 match assoc_item.kind {
366 AssocKind::Fn { has_self, .. } => {
367 if has_self {
368 let body = tcx.optimized_mir(def_id);
369 let fst_arg = body.local_decls[Local::from_usize(1)].clone();
370 let ty = fst_arg.ty;
371 let is_mut_ref =
372 matches!(ty.kind(), ty::Ref(_, _, mutbl) if *mutbl == Mutability::Mut);
373 return fst_arg.mutability.is_mut() || is_mut_ref;
374 }
375 }
376 _ => (),
377 }
378 }
379 false
380}
381
382pub fn get_all_mutable_methods(tcx: TyCtxt, def_id: DefId) -> HashMap<DefId, HashSet<usize>> {
385 let mut results = HashMap::new();
386 let adt_def = get_adt_def_id_by_adt_method(tcx, def_id);
387 let public_fields = adt_def.map_or_else(HashSet::new, |def| get_public_fields(tcx, def));
388 let impl_vec = adt_def.map_or_else(Vec::new, |def| get_impls_for_struct(tcx, def));
389 for impl_id in impl_vec {
390 let associated_items = tcx.associated_items(impl_id);
391 for item in associated_items.in_definition_order() {
392 if let AssocKind::Fn {
393 name: _,
394 has_self: _,
395 } = item.kind
396 {
397 let item_def_id = item.def_id;
398 if has_mut_self_param(tcx, item_def_id) {
399 let modified_fields = public_fields.clone();
401 results.insert(item_def_id, modified_fields);
402 }
403 }
404 }
405 }
406 results
407}
408
409pub fn get_public_fields(tcx: TyCtxt, def_id: DefId) -> HashSet<usize> {
411 let adt_def = tcx.adt_def(def_id);
412 adt_def
413 .all_fields()
414 .enumerate()
415 .filter_map(|(index, field_def)| tcx.visibility(field_def.did).is_public().then_some(index))
416 .collect()
417}
418
419pub fn display_hashmap<K, V>(map: &HashMap<K, V>, level: usize)
421where
422 K: Ord + Debug + Hash,
423 V: Debug,
424{
425 let indent = " ".repeat(level);
426 let mut sorted_keys: Vec<_> = map.keys().collect();
427 sorted_keys.sort();
428
429 for key in sorted_keys {
430 if let Some(value) = map.get(key) {
431 println!("{}{:?}: {:?}", indent, key, value);
432 }
433 }
434}
435
436pub fn get_all_std_unsafe_callees(tcx: TyCtxt, def_id: DefId) -> Vec<String> {
437 let mut results = Vec::new();
438 let body = tcx.optimized_mir(def_id);
439 let bb_len = body.basic_blocks.len();
440 for i in 0..bb_len {
441 let callees = match_std_unsafe_callee(
442 tcx,
443 body.basic_blocks[BasicBlock::from_usize(i)]
444 .clone()
445 .terminator(),
446 );
447 results.extend(callees);
448 }
449 results
450}
451
452pub fn get_all_std_unsafe_callees_block_id(tcx: TyCtxt, def_id: DefId) -> Vec<usize> {
453 let mut results = Vec::new();
454 let body = tcx.optimized_mir(def_id);
455 let bb_len = body.basic_blocks.len();
456 for i in 0..bb_len {
457 if match_std_unsafe_callee(
458 tcx,
459 body.basic_blocks[BasicBlock::from_usize(i)]
460 .clone()
461 .terminator(),
462 )
463 .is_empty()
464 {
465 results.push(i);
466 }
467 }
468 results
469}
470
471pub fn match_std_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
472 let mut results = Vec::new();
473 if let TerminatorKind::Call { func, .. } = &terminator.kind {
474 if let Operand::Constant(func_constant) = func {
475 if let ty::FnDef(ref callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
476 let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
477 if parse_unsafe_api(&func_name).is_some() {
478 results.push(func_name);
479 }
480 }
481 }
482 }
483 results
484}
485
486pub fn is_strict_ty_convert<'tcx>(tcx: TyCtxt<'tcx>, src_ty: Ty<'tcx>, dst_ty: Ty<'tcx>) -> bool {
489 (is_strict_ty(tcx, src_ty) && dst_ty.is_mutable_ptr()) || is_strict_ty(tcx, dst_ty)
490}
491
492pub fn is_strict_ty<'tcx>(tcx: TyCtxt<'tcx>, ori_ty: Ty<'tcx>) -> bool {
494 let ty = get_pointee(ori_ty);
495 let mut flag = false;
496 if let TyKind::Adt(adt_def, substs) = ty.kind() {
497 if adt_def.is_struct() {
498 for field_def in adt_def.all_fields() {
499 flag |= is_strict_ty(tcx, field_def.ty(tcx, substs))
500 }
501 }
502 }
503 ty.is_bool() || ty.is_str() || flag
504}
505
506pub fn reverse_op(op: BinOp) -> BinOp {
507 match op {
508 BinOp::Lt => BinOp::Ge,
509 BinOp::Ge => BinOp::Lt,
510 BinOp::Le => BinOp::Gt,
511 BinOp::Gt => BinOp::Le,
512 BinOp::Eq => BinOp::Eq,
513 BinOp::Ne => BinOp::Ne,
514 _ => op,
515 }
516}
517
518pub fn generate_contract_from_annotation_without_field_types(
520 tcx: TyCtxt,
521 def_id: DefId,
522) -> Vec<(usize, Vec<usize>, PropertyContract)> {
523 let contracts_with_ty = generate_contract_from_annotation(tcx, def_id);
524
525 contracts_with_ty
526 .into_iter()
527 .map(|(local_id, fields_with_ty, contract)| {
528 let fields: Vec<usize> = fields_with_ty
529 .into_iter()
530 .map(|(field_idx, _)| field_idx)
531 .collect();
532 (local_id, fields, contract)
533 })
534 .collect()
535}
536
537pub fn is_verify_target_func(tcx: TyCtxt, def_id: DefId) -> bool {
539 const REGISTER_TOOL: &str = "rapx";
540 for attr in tcx.get_all_attrs(def_id).into_iter() {
541 if let Attribute::Unparsed(tool_attr) = attr {
542 if tool_attr.path.segments[0].as_str() == REGISTER_TOOL
543 && tool_attr.path.segments[1].as_str() == "proof"
544 {
545 return true;
546 }
547 }
548 }
549 false
550}
551
552pub fn generate_contract_from_annotation(
557 tcx: TyCtxt,
558 def_id: DefId,
559) -> Vec<(usize, Vec<(usize, Ty)>, PropertyContract)> {
560 const REGISTER_TOOL: &str = "rapx";
561 let tool_attrs = tcx.get_all_attrs(def_id).into_iter().filter(|attr| {
562 if let Attribute::Unparsed(tool_attr) = attr {
563 if tool_attr.path.segments[0].as_str() == REGISTER_TOOL
564 && tool_attr.path.segments[1].as_str() != "proof"
565 {
566 return true;
567 }
568 }
569 false
570 });
571 let mut results = Vec::new();
572 for attr in tool_attrs {
573 let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
574 let safety_attr =
575 safety_parser::property_attr::parse_inner_attr_from_str(attr_str.as_str()).unwrap();
576 let attr_name = safety_attr.name;
577 let attr_kind = safety_attr.kind;
578 let contract = PropertyContract::new(tcx, def_id, attr_kind, attr_name, &safety_attr.expr);
579 let (local, fields) = parse_cis_local(tcx, def_id, safety_attr.expr);
580 results.push((local, fields, contract));
581 }
582 results
586}
587
588pub fn parse_cis_local(tcx: TyCtxt, def_id: DefId, expr: Vec<Expr>) -> (usize, Vec<(usize, Ty)>) {
604 for e in expr {
606 if let Some((base, fields, _ty)) = parse_expr_into_local_and_ty(tcx, def_id, &e) {
607 return (base, fields);
608 }
609 }
610 (0, Vec::new())
611}
612
613pub fn parse_expr_into_local_and_ty<'tcx>(
615 tcx: TyCtxt<'tcx>,
616 def_id: DefId,
617 expr: &Expr,
618) -> Option<(usize, Vec<(usize, Ty<'tcx>)>, Ty<'tcx>)> {
619 if let Some((base_ident, fields)) = access_ident_recursive(&expr) {
620 let (param_names, param_tys) = parse_signature(tcx, def_id);
621 if param_names[0] == "0".to_string() {
622 return None;
623 }
624 if let Some(param_index) = param_names.iter().position(|name| name == &base_ident) {
625 let mut current_ty = param_tys[param_index];
626 let mut field_indices = Vec::new();
627 for field_name in fields {
628 let peeled_ty = current_ty.peel_refs();
630 if let rustc_middle::ty::TyKind::Adt(adt_def, arg_list) = *peeled_ty.kind() {
631 let variant = adt_def.non_enum_variant();
632 if let Ok(field_idx) = field_name.parse::<usize>() {
634 if field_idx < variant.fields.len() {
635 current_ty = variant.fields[rustc_abi::FieldIdx::from_usize(field_idx)]
636 .ty(tcx, arg_list);
637 field_indices.push((field_idx, current_ty));
638 continue;
639 }
640 }
641 if let Some((idx, _)) = variant
643 .fields
644 .iter()
645 .enumerate()
646 .find(|(_, f)| f.ident(tcx).name.to_string() == field_name.clone())
647 {
648 current_ty =
649 variant.fields[rustc_abi::FieldIdx::from_usize(idx)].ty(tcx, arg_list);
650 field_indices.push((idx, current_ty));
651 }
652 else {
654 break; }
656 }
657 else {
659 break; }
661 }
662 return Some((param_index + 1, field_indices, current_ty));
665 }
666 }
667 None
668}
669
670pub fn parse_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
673 if def_id.as_local().is_some() {
675 return parse_local_signature(tcx, def_id);
676 } else {
677 rap_debug!("{:?} is not local def id.", def_id);
678 return parse_outside_signature(tcx, def_id);
679 };
680}
681
682fn parse_outside_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
684 let sig = tcx.fn_sig(def_id).skip_binder();
685 let param_tys: Vec<Ty<'tcx>> = sig.inputs().skip_binder().iter().copied().collect();
686
687 if let Some(args_name) = get_known_std_names(tcx, def_id) {
689 return (args_name, param_tys);
696 }
697
698 let args_name = (0..param_tys.len()).map(|i| format!("{}", i)).collect();
700 rap_debug!(
701 "function {:?} has arg: {:?}, arg types: {:?}",
702 def_id,
703 args_name,
704 param_tys
705 );
706 return (args_name, param_tys);
707}
708
709fn get_known_std_names<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Vec<String>> {
713 let std_func_name = get_cleaned_def_path_name(tcx, def_id);
714 let json_data: serde_json::Value = get_std_api_signature_json();
715
716 if let Some(arg_info) = json_data.get(&std_func_name) {
717 if let Some(args_name) = arg_info.as_array() {
718 if args_name.len() == 0 {
720 return Some(vec!["0".to_string()]);
721 }
722 let mut result = Vec::new();
724 for arg in args_name {
725 if let Some(sp_name) = arg.as_str() {
726 result.push(sp_name.to_string());
727 }
728 }
729 return Some(result);
730 }
731 }
732 None
733}
734
735pub fn parse_local_signature(tcx: TyCtxt, def_id: DefId) -> (Vec<String>, Vec<Ty>) {
737 let local_def_id = def_id.as_local().unwrap();
739 let hir_body = tcx.hir_body_owned_by(local_def_id);
740 if hir_body.params.len() == 0 {
741 return (vec!["0".to_string()], Vec::new());
742 }
743 let params = hir_body.params;
745 let typeck_results = tcx.typeck_body(hir_body.id());
746 let mut param_names = Vec::new();
747 let mut param_tys = Vec::new();
748 for param in params {
749 match param.pat.kind {
750 rustc_hir::PatKind::Binding(_, _, ident, _) => {
751 param_names.push(ident.name.to_string());
752 let ty = typeck_results.pat_ty(param.pat);
753 param_tys.push(ty);
754 }
755 _ => {
756 param_names.push(String::new());
757 param_tys.push(typeck_results.pat_ty(param.pat));
758 }
759 }
760 }
761 (param_names, param_tys)
762}
763
764pub fn access_ident_recursive(expr: &Expr) -> Option<(String, Vec<String>)> {
771 match expr {
772 Expr::Path(syn::ExprPath { path, .. }) => {
773 if path.segments.len() == 1 {
774 rap_debug!("expr2 {:?}", expr);
775 let ident = path.segments[0].ident.to_string();
776 Some((ident, Vec::new()))
777 } else {
778 None
779 }
780 }
781 Expr::Field(syn::ExprField { base, member, .. }) => {
783 let (base_ident, mut fields) =
784 if let Some((base_ident, fields)) = access_ident_recursive(base) {
785 (base_ident, fields)
786 } else {
787 return None;
788 };
789 let field_name = match member {
790 syn::Member::Named(ident) => ident.to_string(),
791 syn::Member::Unnamed(index) => index.index.to_string(),
792 };
793 fields.push(field_name);
794 Some((base_ident, fields))
795 }
796 _ => None,
797 }
798}
799
800pub fn parse_expr_into_number(expr: &Expr) -> Option<usize> {
802 if let Expr::Lit(expr_lit) = expr {
803 if let syn::Lit::Int(lit_int) = &expr_lit.lit {
804 return lit_int.base10_parse::<usize>().ok();
805 }
806 }
807 None
808}
809
810pub fn match_ty_with_ident(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
827 if let Some(primitive_ty) = match_primitive_type(tcx, type_ident.clone()) {
829 return Some(primitive_ty);
830 }
831 return find_generic_param(tcx, def_id, type_ident.clone());
833 }
836
837fn match_primitive_type(tcx: TyCtxt, type_ident: String) -> Option<Ty> {
839 match type_ident.as_str() {
840 "i8" => Some(tcx.types.i8),
841 "i16" => Some(tcx.types.i16),
842 "i32" => Some(tcx.types.i32),
843 "i64" => Some(tcx.types.i64),
844 "i128" => Some(tcx.types.i128),
845 "isize" => Some(tcx.types.isize),
846 "u8" => Some(tcx.types.u8),
847 "u16" => Some(tcx.types.u16),
848 "u32" => Some(tcx.types.u32),
849 "u64" => Some(tcx.types.u64),
850 "u128" => Some(tcx.types.u128),
851 "usize" => Some(tcx.types.usize),
852 "f16" => Some(tcx.types.f16),
853 "f32" => Some(tcx.types.f32),
854 "f64" => Some(tcx.types.f64),
855 "f128" => Some(tcx.types.f128),
856 "bool" => Some(tcx.types.bool),
857 "char" => Some(tcx.types.char),
858 "str" => Some(tcx.types.str_),
859 _ => None,
860 }
861}
862
863fn find_generic_param(tcx: TyCtxt, def_id: DefId, type_ident: String) -> Option<Ty> {
865 rap_debug!(
866 "Searching for generic param: {} in {:?}",
867 type_ident,
868 def_id
869 );
870 let (_, param_tys) = parse_signature(tcx, def_id);
871 rap_debug!("Function parameter types: {:?} of {:?}", param_tys, def_id);
872 for &ty in ¶m_tys {
874 if let Some(found) = find_generic_in_ty(tcx, ty, &type_ident) {
875 return Some(found);
876 }
877 }
878
879 None
880}
881
882fn find_generic_in_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>, type_ident: &str) -> Option<Ty<'tcx>> {
884 match ty.kind() {
885 TyKind::Param(param_ty) => {
886 if param_ty.name.as_str() == type_ident {
887 return Some(ty);
888 }
889 }
890 TyKind::RawPtr(ty, _)
891 | TyKind::Ref(_, ty, _)
892 | TyKind::Slice(ty)
893 | TyKind::Array(ty, _) => {
894 if let Some(found) = find_generic_in_ty(tcx, *ty, type_ident) {
895 return Some(found);
896 }
897 }
898 TyKind::Tuple(tys) => {
899 for tuple_ty in tys.iter() {
900 if let Some(found) = find_generic_in_ty(tcx, tuple_ty, type_ident) {
901 return Some(found);
902 }
903 }
904 }
905 TyKind::Adt(adt_def, substs) => {
906 let name = tcx.item_name(adt_def.did()).to_string();
907 if name == type_ident {
908 return Some(ty);
909 }
910 for field in adt_def.all_fields() {
911 let field_ty = field.ty(tcx, substs);
912 if let Some(found) = find_generic_in_ty(tcx, field_ty, type_ident) {
913 return Some(found);
914 }
915 }
916 }
917 _ => {}
918 }
919 None
920}
921
922pub fn reflect_generic<'tcx>(
943 generic_mapping: &FxHashMap<String, Ty<'tcx>>,
944 ty: Ty<'tcx>,
945) -> Ty<'tcx> {
946 match ty.kind() {
947 TyKind::Param(param_ty) => {
948 let generic_name = param_ty.name.to_string();
949 if let Some(actual_ty) = generic_mapping.get(&generic_name) {
950 return *actual_ty;
951 }
952 }
953 _ => {}
954 }
955 ty
956}
957
958