1#[allow(unused)]
2use crate::analysis::senryx::contracts::property::PropertyContract;
3use crate::analysis::senryx::matcher::parse_unsafe_api;
4use crate::analysis::unsafety_isolation::generate_dot::NodeType;
5use crate::rap_debug;
6use crate::rap_warn;
7use rustc_data_structures::fx::FxHashMap;
8use rustc_hir::def::DefKind;
9use rustc_hir::def_id::DefId;
10use rustc_hir::Attribute;
11use rustc_hir::ImplItemKind;
12use rustc_middle::mir::BinOp;
13use rustc_middle::mir::Local;
14use rustc_middle::mir::{BasicBlock, Terminator};
15use rustc_middle::ty::{AssocKind, Mutability, Ty, TyCtxt, TyKind};
16use rustc_middle::{
17 mir::{Operand, TerminatorKind},
18 ty,
19};
20use rustc_span::def_id::LocalDefId;
21use rustc_span::kw;
22use rustc_span::sym;
23use std::collections::HashMap;
24use std::collections::HashSet;
25use std::fmt::Debug;
26use std::hash::Hash;
27use syn::Expr;
28
29pub fn generate_node_ty(tcx: TyCtxt<'_>, def_id: DefId) -> NodeType {
30 (def_id, check_safety(tcx, def_id), get_type(tcx, def_id))
31}
32
33pub fn check_visibility(tcx: TyCtxt<'_>, func_defid: DefId) -> bool {
34 if !tcx.visibility(func_defid).is_public() {
35 return false;
36 }
37 true
48}
49
50pub fn is_re_exported(tcx: TyCtxt<'_>, target_defid: DefId, module_defid: LocalDefId) -> bool {
51 for child in tcx.module_children_local(module_defid) {
52 if child.vis.is_public() {
53 if let Some(def_id) = child.res.opt_def_id() {
54 if def_id == target_defid {
55 return true;
56 }
57 }
58 }
59 }
60 false
61}
62
63pub fn print_hashset<T: std::fmt::Debug>(set: &HashSet<T>) {
64 for item in set {
65 println!("{:?}", item);
66 }
67 println!("---------------");
68}
69
70pub fn get_cleaned_def_path_name(tcx: TyCtxt<'_>, def_id: DefId) -> String {
71 let def_id_str = format!("{:?}", def_id);
72 let mut parts: Vec<&str> = def_id_str
73 .split("::")
74 .collect();
76
77 let mut remove_first = false;
78 if let Some(first_part) = parts.get_mut(0) {
79 if first_part.contains("core") {
80 *first_part = "core";
81 } else if first_part.contains("std") {
82 *first_part = "std";
83 } else if first_part.contains("alloc") {
84 *first_part = "alloc";
85 } else {
86 remove_first = true;
87 }
88 }
89 if remove_first && !parts.is_empty() {
90 parts.remove(0);
91 }
92
93 let new_parts: Vec<String> = parts
94 .into_iter()
95 .filter_map(|s| {
96 if s.contains("{") {
97 if remove_first {
98 get_struct_name(tcx, def_id)
99 } else {
100 None
101 }
102 } else {
103 Some(s.to_string())
104 }
105 })
106 .collect();
107
108 let mut cleaned_path = new_parts.join("::");
109 cleaned_path = cleaned_path.trim_end_matches(')').to_string();
110 cleaned_path
111}
112
113pub fn get_sp_json() -> serde_json::Value {
114 let json_data: serde_json::Value =
115 serde_json::from_str(include_str!("../unsafety_isolation/data/std_sps.json"))
116 .expect("Unable to parse JSON");
117 json_data
118}
119
120pub fn get_std_api_signature_json() -> serde_json::Value {
121 let json_data: serde_json::Value =
122 serde_json::from_str(include_str!("../unsafety_isolation/data/std_sig.json"))
123 .expect("Unable to parse JSON");
124 json_data
125}
126
127pub fn get_sp(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<String> {
128 let cleaned_path_name = get_cleaned_def_path_name(tcx, def_id);
129 let json_data: serde_json::Value = get_sp_json();
130
131 if let Some(function_info) = json_data.get(&cleaned_path_name) {
132 if let Some(sp_list) = function_info.get("0") {
133 let mut result = HashSet::new();
134 if let Some(sp_array) = sp_list.as_array() {
135 for sp in sp_array {
136 if let Some(sp_name) = sp.as_str() {
137 result.insert(sp_name.to_string());
138 }
139 }
140 }
141 return result;
142 }
143 }
144 HashSet::new()
145}
146
147pub fn get_struct_name(tcx: TyCtxt<'_>, def_id: DefId) -> Option<String> {
148 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
149 if let Some(impl_id) = assoc_item.impl_container(tcx) {
150 let ty = tcx.type_of(impl_id).skip_binder();
151 let type_name = ty.to_string();
152 let struct_name = type_name
153 .split('<')
154 .next()
155 .unwrap_or("")
156 .split("::")
157 .last()
158 .unwrap_or("")
159 .to_string();
160
161 return Some(struct_name);
162 }
163 }
164 None
165}
166
167pub fn check_safety(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
168 let poly_fn_sig = tcx.fn_sig(def_id);
169 let fn_sig = poly_fn_sig.skip_binder();
170 fn_sig.safety() == rustc_hir::Safety::Unsafe
171}
172
173pub fn get_type(tcx: TyCtxt<'_>, def_id: DefId) -> usize {
175 let mut node_type = 2;
176 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
177 match assoc_item.kind {
178 AssocKind::Fn { has_self, .. } => {
179 if has_self {
180 node_type = 1;
181 } else {
182 let fn_sig = tcx.fn_sig(def_id).skip_binder();
183 let output = fn_sig.output().skip_binder();
184 if output.is_param(0) {
186 node_type = 0;
187 }
188 if let Some(impl_id) = assoc_item.impl_container(tcx) {
190 let ty = tcx.type_of(impl_id).skip_binder();
191 if output == ty {
192 node_type = 0;
193 }
194 }
195 match output.kind() {
196 TyKind::Ref(_, ref_ty, _) => {
197 if ref_ty.is_param(0) {
198 node_type = 0;
199 }
200 if let Some(impl_id) = assoc_item.impl_container(tcx) {
201 let ty = tcx.type_of(impl_id).skip_binder();
202 if *ref_ty == ty {
203 node_type = 0;
204 }
205 }
206 }
207 TyKind::Adt(adt_def, substs) => {
208 if adt_def.is_enum()
209 && (tcx.is_diagnostic_item(sym::Option, adt_def.did())
210 || tcx.is_diagnostic_item(sym::Result, adt_def.did())
211 || tcx.is_diagnostic_item(kw::Box, adt_def.did()))
212 {
213 let inner_ty = substs.type_at(0);
214 if inner_ty.is_param(0) {
215 node_type = 0;
216 }
217 if let Some(impl_id) = assoc_item.impl_container(tcx) {
218 let ty_impl = tcx.type_of(impl_id).skip_binder();
219 if inner_ty == ty_impl {
220 node_type = 0;
221 }
222 }
223 }
224 }
225 _ => {}
226 }
227 }
228 }
229 _ => todo!(),
230 }
231 }
232 node_type
233}
234
235pub fn get_adt_ty(tcx: TyCtxt<'_>, def_id: DefId) -> Option<Ty> {
236 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
237 if let Some(impl_id) = assoc_item.impl_container(tcx) {
238 return Some(tcx.type_of(impl_id).skip_binder());
239 }
240 }
241 None
242}
243
244pub fn get_cons(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<NodeType> {
245 let mut cons = Vec::new();
246 if tcx.def_kind(def_id) == DefKind::Fn || get_type(tcx, def_id) == 0 {
247 return cons;
248 }
249 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
250 if let Some(impl_id) = assoc_item.impl_container(tcx) {
251 let ty = tcx.type_of(impl_id).skip_binder();
253 if let Some(adt_def) = ty.ty_adt_def() {
254 let adt_def_id = adt_def.did();
255 let impls = tcx.inherent_impls(adt_def_id);
256 for impl_def_id in impls {
257 for item in tcx.associated_item_def_ids(impl_def_id) {
258 if (tcx.def_kind(item) == DefKind::Fn
259 || tcx.def_kind(item) == DefKind::AssocFn)
260 && get_type(tcx, *item) == 0
261 {
262 cons.push(generate_node_ty(tcx, *item));
263 }
264 }
265 }
266 }
267 }
268 }
269 cons
270}
271
272pub fn get_callees(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<DefId> {
273 let mut callees = HashSet::new();
274 if tcx.is_mir_available(def_id) {
275 let body = tcx.optimized_mir(def_id);
276 for bb in body.basic_blocks.iter() {
277 if let TerminatorKind::Call { func, .. } = &bb.terminator().kind {
278 if let Operand::Constant(func_constant) = func {
279 if let ty::FnDef(ref callee_def_id, _) = func_constant.const_.ty().kind() {
280 if check_safety(tcx, *callee_def_id)
281 {
283 let sp_set = get_sp(tcx, *callee_def_id);
284 if sp_set.len() != 0 {
285 callees.insert(*callee_def_id);
286 }
287 }
288 }
289 }
290 }
291 }
292 }
293 callees
294}
295
296pub fn get_impls_for_struct(tcx: TyCtxt<'_>, struct_def_id: DefId) -> Vec<DefId> {
298 let mut impls = Vec::new();
299 for impl_item_id in tcx.hir_crate_items(()).impl_items() {
300 let impl_item = tcx.hir_impl_item(impl_item_id);
301 match impl_item.kind {
302 ImplItemKind::Type(ty) => {
303 if let rustc_hir::TyKind::Path(ref qpath) = ty.kind {
304 if let rustc_hir::QPath::Resolved(_, path) = qpath {
305 if let rustc_hir::def::Res::Def(_, ref def_id) = path.res {
306 if *def_id == struct_def_id {
307 impls.push(impl_item.owner_id.to_def_id());
308 }
309 }
310 }
311 }
312 }
313 _ => (),
314 }
315 }
316 impls
317}
318
319pub fn get_adt_def_id_by_adt_method(tcx: TyCtxt<'_>, def_id: DefId) -> Option<DefId> {
320 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
321 if let Some(impl_id) = assoc_item.impl_container(tcx) {
322 let ty = tcx.type_of(impl_id).skip_binder();
324 if let Some(adt_def) = ty.ty_adt_def() {
325 return Some(adt_def.did());
326 }
327 }
328 }
329 None
330}
331
332pub fn get_pointee(matched_ty: Ty<'_>) -> Ty<'_> {
334 let pointee = if let ty::RawPtr(ty_mut, _) = matched_ty.kind() {
336 get_pointee(*ty_mut)
337 } else if let ty::Ref(_, referred_ty, _) = matched_ty.kind() {
338 get_pointee(*referred_ty)
339 } else {
340 matched_ty
341 };
342 pointee
343}
344
345pub fn is_ptr(matched_ty: Ty<'_>) -> bool {
346 if let ty::RawPtr(_, _) = matched_ty.kind() {
347 return true;
348 }
349 false
350}
351
352pub fn is_ref(matched_ty: Ty<'_>) -> bool {
353 if let ty::Ref(_, _, _) = matched_ty.kind() {
354 return true;
355 }
356 false
357}
358
359pub fn is_slice(matched_ty: Ty<'_>) -> Option<Ty> {
360 if let ty::Slice(inner) = matched_ty.kind() {
361 return Some(*inner);
362 }
363 None
364}
365
366pub fn has_mut_self_param(tcx: TyCtxt<'_>, def_id: DefId) -> bool {
367 if let Some(assoc_item) = tcx.opt_associated_item(def_id) {
368 match assoc_item.kind {
369 AssocKind::Fn { has_self, .. } => {
370 if has_self {
371 let body = tcx.optimized_mir(def_id);
372 let fst_arg = body.local_decls[Local::from_usize(1)].clone();
373 let ty = fst_arg.ty;
374 let is_mut_ref =
375 matches!(ty.kind(), ty::Ref(_, _, mutbl) if *mutbl == Mutability::Mut);
376 return fst_arg.mutability.is_mut() || is_mut_ref;
377 }
378 }
379 _ => (),
380 }
381 }
382 false
383}
384
385pub fn get_all_mutable_methods(tcx: TyCtxt<'_>, def_id: DefId) -> HashMap<DefId, HashSet<usize>> {
388 let mut results = HashMap::new();
389 let adt_def = get_adt_def_id_by_adt_method(tcx, def_id);
390 let public_fields = adt_def.map_or_else(HashSet::new, |def| get_public_fields(tcx, def));
391 let impl_vec = adt_def.map_or_else(Vec::new, |def| get_impls_for_struct(tcx, def));
392 for impl_id in impl_vec {
393 let associated_items = tcx.associated_items(impl_id);
394 for item in associated_items.in_definition_order() {
395 if let AssocKind::Fn {
396 name: _,
397 has_self: _,
398 } = item.kind
399 {
400 let item_def_id = item.def_id;
401 if has_mut_self_param(tcx, item_def_id) {
402 let modified_fields = public_fields.clone();
404 results.insert(item_def_id, modified_fields);
405 }
406 }
407 }
408 }
409 results
410}
411
412pub fn get_public_fields(tcx: TyCtxt<'_>, def_id: DefId) -> HashSet<usize> {
414 let adt_def = tcx.adt_def(def_id);
415 adt_def
416 .all_fields()
417 .enumerate()
418 .filter_map(|(index, field_def)| tcx.visibility(field_def.did).is_public().then_some(index))
419 .collect()
420}
421
422pub fn display_hashmap<K, V>(map: &HashMap<K, V>, level: usize)
424where
425 K: Ord + Debug + Hash,
426 V: Debug,
427{
428 let indent = " ".repeat(level);
429 let mut sorted_keys: Vec<_> = map.keys().collect();
430 sorted_keys.sort();
431
432 for key in sorted_keys {
433 if let Some(value) = map.get(key) {
434 println!("{}{:?}: {:?}", indent, key, value);
435 }
436 }
437}
438
439pub fn get_all_std_unsafe_callees(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<String> {
440 let mut results = Vec::new();
441 let body = tcx.optimized_mir(def_id);
442 let bb_len = body.basic_blocks.len();
443 for i in 0..bb_len {
444 let callees = match_std_unsafe_callee(
445 tcx,
446 body.basic_blocks[BasicBlock::from_usize(i)]
447 .clone()
448 .terminator(),
449 );
450 results.extend(callees);
451 }
452 results
453}
454
455pub fn get_all_std_unsafe_callees_block_id(tcx: TyCtxt<'_>, def_id: DefId) -> Vec<usize> {
456 let mut results = Vec::new();
457 let body = tcx.optimized_mir(def_id);
458 let bb_len = body.basic_blocks.len();
459 for i in 0..bb_len {
460 if match_std_unsafe_callee(
461 tcx,
462 body.basic_blocks[BasicBlock::from_usize(i)]
463 .clone()
464 .terminator(),
465 )
466 .is_empty()
467 {
468 results.push(i);
469 }
470 }
471 results
472}
473
474pub fn match_std_unsafe_callee(tcx: TyCtxt<'_>, terminator: &Terminator<'_>) -> Vec<String> {
475 let mut results = Vec::new();
476 if let TerminatorKind::Call { func, .. } = &terminator.kind {
477 if let Operand::Constant(func_constant) = func {
478 if let ty::FnDef(ref callee_def_id, _raw_list) = func_constant.const_.ty().kind() {
479 let func_name = get_cleaned_def_path_name(tcx, *callee_def_id);
480 if parse_unsafe_api(&func_name).is_some() {
481 results.push(func_name);
482 }
483 }
484 }
485 }
486 results
487}
488
489pub fn is_strict_ty_convert<'tcx>(tcx: TyCtxt<'tcx>, src_ty: Ty<'tcx>, dst_ty: Ty<'tcx>) -> bool {
492 (is_strict_ty(tcx, src_ty) && dst_ty.is_mutable_ptr()) || is_strict_ty(tcx, dst_ty)
493}
494
495pub fn is_strict_ty<'tcx>(tcx: TyCtxt<'tcx>, ori_ty: Ty<'tcx>) -> bool {
497 let ty = get_pointee(ori_ty);
498 let mut flag = false;
499 if let TyKind::Adt(adt_def, substs) = ty.kind() {
500 if adt_def.is_struct() {
501 for field_def in adt_def.all_fields() {
502 flag |= is_strict_ty(tcx, field_def.ty(tcx, substs))
503 }
504 }
505 }
506 ty.is_bool() || ty.is_str() || flag
507}
508
509pub fn reverse_op(op: BinOp) -> BinOp {
510 match op {
511 BinOp::Lt => BinOp::Ge,
512 BinOp::Ge => BinOp::Lt,
513 BinOp::Le => BinOp::Gt,
514 BinOp::Gt => BinOp::Le,
515 BinOp::Eq => BinOp::Eq,
516 BinOp::Ne => BinOp::Ne,
517 _ => op,
518 }
519}
520
521pub fn generate_contract_from_annotation_without_field_types(
523 tcx: TyCtxt<'_>,
524 def_id: DefId,
525) -> Vec<(usize, Vec<usize>, PropertyContract)> {
526 let contracts_with_ty = generate_contract_from_annotation(tcx, def_id);
527
528 contracts_with_ty
529 .into_iter()
530 .map(|(local_id, fields_with_ty, contract)| {
531 let fields: Vec<usize> = fields_with_ty
532 .into_iter()
533 .map(|(field_idx, _)| field_idx)
534 .collect();
535 (local_id, fields, contract)
536 })
537 .collect()
538}
539
540pub fn generate_contract_from_annotation(
545 tcx: TyCtxt<'_>,
546 def_id: DefId,
547) -> Vec<(usize, Vec<(usize, Ty)>, PropertyContract)> {
548 const REGISTER_TOOL: &str = "rapx";
549 let tool_attrs = tcx.get_all_attrs(def_id).filter(|attr| {
550 if let Attribute::Unparsed(tool_attr) = attr
551 && tool_attr.path.segments[0].as_str() == REGISTER_TOOL
552 {
553 return true;
554 }
555 false
556 });
557 let mut results = Vec::new();
558 for attr in tool_attrs {
559 let attr_str = rustc_hir_pretty::attribute_to_string(&tcx, attr);
560 let safety_attr =
561 safety_parser::property_attr::parse_inner_attr_from_str(attr_str.as_str()).unwrap();
562 let attr_name = safety_attr.name;
563 let attr_kind = safety_attr.kind;
564 let contract = PropertyContract::new(tcx, def_id, attr_kind, attr_name, &safety_attr.expr);
565 let (local, fields) = parse_cis_local(tcx, def_id, safety_attr.expr);
566 results.push((local, fields, contract));
567 }
568 results
572}
573
574pub fn parse_cis_local(
590 tcx: TyCtxt<'_>,
591 def_id: DefId,
592 expr: Vec<Expr>,
593) -> (usize, Vec<(usize, Ty<'_>)>) {
594 for e in expr {
596 if let Some((base, fields, _ty)) = parse_expr_into_local_and_ty(tcx, def_id, &e) {
597 return (base, fields);
598 }
599 }
600 (0, Vec::new())
601}
602
603pub fn parse_expr_into_local_and_ty<'tcx>(
605 tcx: TyCtxt<'tcx>,
606 def_id: DefId,
607 expr: &Expr,
608) -> Option<(usize, Vec<(usize, Ty<'tcx>)>, Ty<'tcx>)> {
609 if let Some((base_ident, fields)) = access_ident_recursive(&expr) {
610 let (param_names, param_tys) = parse_signature(tcx, def_id);
611 if param_names[0] == "0".to_string() {
612 return None;
613 }
614 if let Some(param_index) = param_names.iter().position(|name| name == &base_ident) {
615 let mut current_ty = param_tys[param_index];
616 let mut field_indices = Vec::new();
617 for field_name in fields {
618 let peeled_ty = current_ty.peel_refs();
620 if let rustc_middle::ty::TyKind::Adt(adt_def, arg_list) = *peeled_ty.kind() {
621 let variant = adt_def.non_enum_variant();
622 if let Ok(field_idx) = field_name.parse::<usize>() {
624 if field_idx < variant.fields.len() {
625 current_ty = variant.fields[rustc_abi::FieldIdx::from_usize(field_idx)]
626 .ty(tcx, arg_list);
627 field_indices.push((field_idx, current_ty));
628 continue;
629 }
630 }
631 if let Some((idx, _)) = variant
633 .fields
634 .iter()
635 .enumerate()
636 .find(|(_, f)| f.ident(tcx).name.to_string() == field_name.clone())
637 {
638 current_ty =
639 variant.fields[rustc_abi::FieldIdx::from_usize(idx)].ty(tcx, arg_list);
640 field_indices.push((idx, current_ty));
641 }
642 else {
644 break; }
646 }
647 else {
649 break; }
651 }
652 return Some((param_index + 1, field_indices, current_ty));
655 }
656 }
657 None
658}
659
660pub fn parse_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
663 if def_id.as_local().is_some() {
665 return parse_local_signature(tcx, def_id);
666 } else {
667 rap_debug!("{:?} is not local def id.", def_id);
668 return parse_outside_signature(tcx, def_id);
669 };
670}
671
672fn parse_outside_signature<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> (Vec<String>, Vec<Ty<'tcx>>) {
674 let sig = tcx.fn_sig(def_id).skip_binder();
675 let param_tys: Vec<Ty<'tcx>> = sig.inputs().skip_binder().iter().copied().collect();
676
677 if let Some(args_name) = get_known_std_names(tcx, def_id) {
679 return (args_name, param_tys);
686 }
687
688 let args_name = (0..param_tys.len()).map(|i| format!("{}", i)).collect();
690 rap_debug!(
691 "function {:?} has arg: {:?}, arg types: {:?}",
692 def_id,
693 args_name,
694 param_tys
695 );
696 return (args_name, param_tys);
697}
698
699fn get_known_std_names<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Option<Vec<String>> {
703 let std_func_name = get_cleaned_def_path_name(tcx, def_id);
704 let json_data: serde_json::Value = get_std_api_signature_json();
705
706 if let Some(arg_info) = json_data.get(&std_func_name) {
707 if let Some(args_name) = arg_info.as_array() {
708 if args_name.len() == 0 {
710 return Some(vec!["0".to_string()]);
711 }
712 let mut result = Vec::new();
714 for arg in args_name {
715 if let Some(sp_name) = arg.as_str() {
716 result.push(sp_name.to_string());
717 }
718 }
719 return Some(result);
720 }
721 }
722 None
723}
724
725pub fn parse_local_signature(tcx: TyCtxt<'_>, def_id: DefId) -> (Vec<String>, Vec<Ty>) {
727 let local_def_id = def_id.as_local().unwrap();
729 let hir_body = tcx.hir_body_owned_by(local_def_id);
730 if hir_body.params.len() == 0 {
731 return (vec!["0".to_string()], Vec::new());
732 }
733 let params = hir_body.params;
735 let typeck_results = tcx.typeck_body(hir_body.id());
736 let mut param_names = Vec::new();
737 let mut param_tys = Vec::new();
738 for param in params {
739 match param.pat.kind {
740 rustc_hir::PatKind::Binding(_, _, ident, _) => {
741 param_names.push(ident.name.to_string());
742 let ty = typeck_results.pat_ty(param.pat);
743 param_tys.push(ty);
744 }
745 _ => {
746 param_names.push(String::new());
747 param_tys.push(typeck_results.pat_ty(param.pat));
748 }
749 }
750 }
751 (param_names, param_tys)
752}
753
754pub fn access_ident_recursive(expr: &Expr) -> Option<(String, Vec<String>)> {
761 match expr {
762 Expr::Path(syn::ExprPath { path, .. }) => {
763 if path.segments.len() == 1 {
764 rap_debug!("expr2 {:?}", expr);
765 let ident = path.segments[0].ident.to_string();
766 Some((ident, Vec::new()))
767 } else {
768 None
769 }
770 }
771 Expr::Field(syn::ExprField { base, member, .. }) => {
773 let (base_ident, mut fields) =
774 if let Some((base_ident, fields)) = access_ident_recursive(base) {
775 (base_ident, fields)
776 } else {
777 return None;
778 };
779 let field_name = match member {
780 syn::Member::Named(ident) => ident.to_string(),
781 syn::Member::Unnamed(index) => index.index.to_string(),
782 };
783 fields.push(field_name);
784 Some((base_ident, fields))
785 }
786 _ => None,
787 }
788}
789
790pub fn parse_expr_into_number(expr: &Expr) -> Option<usize> {
792 if let Expr::Lit(expr_lit) = expr {
793 if let syn::Lit::Int(lit_int) = &expr_lit.lit {
794 return lit_int.base10_parse::<usize>().ok();
795 }
796 }
797 None
798}
799
800pub fn match_ty_with_ident(tcx: TyCtxt<'_>, def_id: DefId, type_ident: String) -> Option<Ty> {
817 if let Some(primitive_ty) = match_primitive_type(tcx, type_ident.clone()) {
819 return Some(primitive_ty);
820 }
821 return find_generic_param(tcx, def_id, type_ident.clone());
823 }
826
827fn match_primitive_type(tcx: TyCtxt<'_>, type_ident: String) -> Option<Ty> {
829 match type_ident.as_str() {
830 "i8" => Some(tcx.types.i8),
831 "i16" => Some(tcx.types.i16),
832 "i32" => Some(tcx.types.i32),
833 "i64" => Some(tcx.types.i64),
834 "i128" => Some(tcx.types.i128),
835 "isize" => Some(tcx.types.isize),
836 "u8" => Some(tcx.types.u8),
837 "u16" => Some(tcx.types.u16),
838 "u32" => Some(tcx.types.u32),
839 "u64" => Some(tcx.types.u64),
840 "u128" => Some(tcx.types.u128),
841 "usize" => Some(tcx.types.usize),
842 "f16" => Some(tcx.types.f16),
843 "f32" => Some(tcx.types.f32),
844 "f64" => Some(tcx.types.f64),
845 "f128" => Some(tcx.types.f128),
846 "bool" => Some(tcx.types.bool),
847 "char" => Some(tcx.types.char),
848 "str" => Some(tcx.types.str_),
849 _ => None,
850 }
851}
852
853fn find_generic_param(tcx: TyCtxt<'_>, def_id: DefId, type_ident: String) -> Option<Ty> {
855 rap_debug!(
856 "Searching for generic param: {} in {:?}",
857 type_ident,
858 def_id
859 );
860 let (_, param_tys) = parse_signature(tcx, def_id);
861 rap_debug!("Function parameter types: {:?} of {:?}", param_tys, def_id);
862 for &ty in ¶m_tys {
864 if let Some(found) = find_generic_in_ty(tcx, ty, &type_ident) {
865 return Some(found);
866 }
867 }
868
869 None
870}
871
872fn find_generic_in_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>, type_ident: &str) -> Option<Ty<'tcx>> {
874 match ty.kind() {
875 TyKind::Param(param_ty) => {
876 if param_ty.name.as_str() == type_ident {
877 return Some(ty);
878 }
879 }
880 TyKind::RawPtr(ty, _)
881 | TyKind::Ref(_, ty, _)
882 | TyKind::Slice(ty)
883 | TyKind::Array(ty, _) => {
884 if let Some(found) = find_generic_in_ty(tcx, *ty, type_ident) {
885 return Some(found);
886 }
887 }
888 TyKind::Tuple(tys) => {
889 for tuple_ty in tys.iter() {
890 if let Some(found) = find_generic_in_ty(tcx, tuple_ty, type_ident) {
891 return Some(found);
892 }
893 }
894 }
895 TyKind::Adt(adt_def, substs) => {
896 let name = tcx.item_name(adt_def.did()).to_string();
897 if name == type_ident {
898 return Some(ty);
899 }
900 for field in adt_def.all_fields() {
901 let field_ty = field.ty(tcx, substs);
902 if let Some(found) = find_generic_in_ty(tcx, field_ty, type_ident) {
903 return Some(found);
904 }
905 }
906 }
907 _ => {}
908 }
909 None
910}
911
912pub fn reflect_generic<'tcx>(
933 generic_mapping: &FxHashMap<String, Ty<'tcx>>,
934 ty: Ty<'tcx>,
935) -> Ty<'tcx> {
936 match ty.kind() {
937 TyKind::Param(param_ty) => {
938 let generic_name = param_ty.name.to_string();
939 if let Some(actual_ty) = generic_mapping.get(&generic_name) {
940 return *actual_ty;
941 }
942 }
943 _ => {}
944 }
945 ty
946}
947
948