rustc_mir_transform/shim/
async_destructor_ctor.rs

1use rustc_hir::def_id::DefId;
2use rustc_hir::lang_items::LangItem;
3use rustc_hir::{CoroutineDesugaring, CoroutineKind, CoroutineSource, Safety};
4use rustc_index::{Idx, IndexVec};
5use rustc_middle::mir::{
6    BasicBlock, BasicBlockData, Body, Local, LocalDecl, MirSource, Operand, Place, Rvalue,
7    SourceInfo, Statement, StatementKind, Terminator, TerminatorKind,
8};
9use rustc_middle::ty::{self, EarlyBinder, Ty, TyCtxt, TypeVisitableExt};
10
11use super::*;
12use crate::patch::MirPatch;
13
14pub(super) fn build_async_destructor_ctor_shim<'tcx>(
15    tcx: TyCtxt<'tcx>,
16    def_id: DefId,
17    ty: Ty<'tcx>,
18) -> Body<'tcx> {
19    debug!("build_async_destructor_ctor_shim(def_id={:?}, ty={:?})", def_id, ty);
20    debug_assert_eq!(Some(def_id), tcx.lang_items().async_drop_in_place_fn());
21    let generic_body = tcx.optimized_mir(def_id);
22    let args = tcx.mk_args(&[ty.into()]);
23    let mut body = EarlyBinder::bind(generic_body.clone()).instantiate(tcx, args);
24
25    // Minimal shim passes except MentionedItems,
26    // it causes error "mentioned_items for DefId(...async_drop_in_place...) have already been set
27    pm::run_passes(
28        tcx,
29        &mut body,
30        &[
31            &simplify::SimplifyCfg::MakeShim,
32            &abort_unwinding_calls::AbortUnwindingCalls,
33            &add_call_guards::CriticalCallEdges,
34        ],
35        None,
36        pm::Optimizations::Allowed,
37    );
38    body
39}
40
41// build_drop_shim analog for async drop glue (for generated coroutine poll function)
42pub(super) fn build_async_drop_shim<'tcx>(
43    tcx: TyCtxt<'tcx>,
44    def_id: DefId,
45    ty: Ty<'tcx>,
46) -> Body<'tcx> {
47    debug!("build_async_drop_shim(def_id={:?}, ty={:?})", def_id, ty);
48    let ty::Coroutine(_, parent_args) = ty.kind() else {
49        bug!();
50    };
51    let typing_env = ty::TypingEnv::fully_monomorphized();
52
53    let drop_ty = parent_args.first().unwrap().expect_ty();
54    let drop_ptr_ty = Ty::new_mut_ptr(tcx, drop_ty);
55
56    assert!(tcx.is_coroutine(def_id));
57    let coroutine_kind = tcx.coroutine_kind(def_id).unwrap();
58
59    assert!(matches!(
60        coroutine_kind,
61        CoroutineKind::Desugared(CoroutineDesugaring::Async, CoroutineSource::Fn)
62    ));
63
64    let needs_async_drop = drop_ty.needs_async_drop(tcx, typing_env);
65    let needs_sync_drop = !needs_async_drop && drop_ty.needs_drop(tcx, typing_env);
66
67    let resume_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, None));
68    let resume_ty = Ty::new_adt(tcx, resume_adt, ty::List::empty());
69
70    let fn_sig = ty::Binder::dummy(tcx.mk_fn_sig(
71        [ty, resume_ty],
72        tcx.types.unit,
73        false,
74        Safety::Safe,
75        ExternAbi::Rust,
76    ));
77    let sig = tcx.instantiate_bound_regions_with_erased(fn_sig);
78
79    assert!(!drop_ty.is_coroutine());
80    let span = tcx.def_span(def_id);
81    let source_info = SourceInfo::outermost(span);
82
83    // The first argument (index 0), but add 1 for the return value.
84    let coroutine_layout = Place::from(Local::new(1 + 0));
85    let coroutine_layout_dropee =
86        tcx.mk_place_field(coroutine_layout, FieldIdx::new(0), drop_ptr_ty);
87
88    let return_block = BasicBlock::new(1);
89    let mut blocks = IndexVec::with_capacity(2);
90    let block = |blocks: &mut IndexVec<_, _>, kind| {
91        blocks.push(BasicBlockData {
92            statements: vec![],
93            terminator: Some(Terminator { source_info, kind }),
94            is_cleanup: false,
95        })
96    };
97    block(
98        &mut blocks,
99        if needs_sync_drop {
100            TerminatorKind::Drop {
101                place: tcx.mk_place_deref(coroutine_layout_dropee),
102                target: return_block,
103                unwind: UnwindAction::Continue,
104                replace: false,
105                drop: None,
106                async_fut: None,
107            }
108        } else {
109            TerminatorKind::Goto { target: return_block }
110        },
111    );
112    block(&mut blocks, TerminatorKind::Return);
113
114    let source = MirSource::from_instance(ty::InstanceKind::AsyncDropGlue(def_id, ty));
115    let mut body =
116        new_body(source, blocks, local_decls_for_sig(&sig, span), sig.inputs().len(), span);
117
118    body.coroutine = Some(Box::new(CoroutineInfo::initial(
119        coroutine_kind,
120        parent_args.as_coroutine().yield_ty(),
121        parent_args.as_coroutine().resume_ty(),
122    )));
123    body.phase = MirPhase::Runtime(RuntimePhase::Initial);
124    if !needs_async_drop || drop_ty.references_error() {
125        // Returning noop body for types without `need async drop`
126        // (or sync Drop in case of !`need async drop` && `need drop`).
127        // And also for error types.
128        return body;
129    }
130
131    let mut dropee_ptr = Place::from(body.local_decls.push(LocalDecl::new(drop_ptr_ty, span)));
132    let st_kind = StatementKind::Assign(Box::new((
133        dropee_ptr,
134        Rvalue::Use(Operand::Move(coroutine_layout_dropee)),
135    )));
136    body.basic_blocks_mut()[START_BLOCK].statements.push(Statement { source_info, kind: st_kind });
137    dropee_ptr = dropee_emit_retag(tcx, &mut body, dropee_ptr, span);
138
139    let dropline = body.basic_blocks.last_index();
140
141    let patch = {
142        let mut elaborator = DropShimElaborator {
143            body: &body,
144            patch: MirPatch::new(&body),
145            tcx,
146            typing_env,
147            produce_async_drops: true,
148        };
149        let dropee = tcx.mk_place_deref(dropee_ptr);
150        let resume_block = elaborator.patch.resume_block();
151        elaborate_drop(
152            &mut elaborator,
153            source_info,
154            dropee,
155            (),
156            return_block,
157            Unwind::To(resume_block),
158            START_BLOCK,
159            dropline,
160        );
161        elaborator.patch
162    };
163    patch.apply(&mut body);
164
165    body
166}
167
168// * For async drop a "normal" coroutine:
169// `async_drop_in_place<T>::{closure}.poll()` is converted into `T.future_drop_poll()`.
170// Every coroutine has its `poll` (calculate yourself a little further)
171// and its `future_drop_poll` (drop yourself a little further).
172//
173// * For async drop of "async drop coroutine" (`async_drop_in_place<T>::{closure}`):
174// Correct drop of such coroutine means normal execution of nested async drop.
175// async_drop(async_drop(T))::future_drop_poll() => async_drop(T)::poll().
176pub(super) fn build_future_drop_poll_shim<'tcx>(
177    tcx: TyCtxt<'tcx>,
178    def_id: DefId,
179    proxy_ty: Ty<'tcx>,
180    impl_ty: Ty<'tcx>,
181) -> Body<'tcx> {
182    let instance = ty::InstanceKind::FutureDropPollShim(def_id, proxy_ty, impl_ty);
183    let ty::Coroutine(coroutine_def_id, _) = impl_ty.kind() else {
184        bug!("build_future_drop_poll_shim not for coroutine impl type: ({:?})", instance);
185    };
186
187    let span = tcx.def_span(def_id);
188
189    if tcx.is_async_drop_in_place_coroutine(*coroutine_def_id) {
190        build_adrop_for_adrop_shim(tcx, proxy_ty, impl_ty, span, instance)
191    } else {
192        build_adrop_for_coroutine_shim(tcx, proxy_ty, impl_ty, span, instance)
193    }
194}
195
196// For async drop a "normal" coroutine:
197// `async_drop_in_place<T>::{closure}.poll()` is converted into `T.future_drop_poll()`.
198// Every coroutine has its `poll` (calculate yourself a little further)
199// and its `future_drop_poll` (drop yourself a little further).
200fn build_adrop_for_coroutine_shim<'tcx>(
201    tcx: TyCtxt<'tcx>,
202    proxy_ty: Ty<'tcx>,
203    impl_ty: Ty<'tcx>,
204    span: Span,
205    instance: ty::InstanceKind<'tcx>,
206) -> Body<'tcx> {
207    let ty::Coroutine(coroutine_def_id, impl_args) = impl_ty.kind() else {
208        bug!("build_adrop_for_coroutine_shim not for coroutine impl type: ({:?})", instance);
209    };
210    let proxy_ref = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, proxy_ty);
211    // taking _1.0 (impl from Pin)
212    let pin_proxy_layout_local = Local::new(1);
213    let source_info = SourceInfo::outermost(span);
214    // converting `(_1: Pin<&mut CorLayout>, _2: &mut Context<'_>) -> Poll<()>`
215    // into `(_1: Pin<&mut ProxyLayout>, _2: &mut Context<'_>) -> Poll<()>`
216    // let mut _x: &mut CorLayout = &*_1.0.0;
217    // Replace old _1.0 accesses into _x accesses;
218    let body = tcx.optimized_mir(*coroutine_def_id).future_drop_poll().unwrap();
219    let mut body: Body<'tcx> = EarlyBinder::bind(body.clone()).instantiate(tcx, impl_args);
220    body.source.instance = instance;
221    body.phase = MirPhase::Runtime(RuntimePhase::Initial);
222    body.var_debug_info.clear();
223    let pin_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Pin, Some(span)));
224    let args = tcx.mk_args(&[proxy_ref.into()]);
225    let pin_proxy_ref = Ty::new_adt(tcx, pin_adt_ref, args);
226
227    let cor_ref = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, impl_ty);
228
229    let proxy_ref_local = body.local_decls.push(LocalDecl::new(proxy_ref, span));
230    let cor_ref_local = body.local_decls.push(LocalDecl::new(cor_ref, span));
231
232    FixProxyFutureDropVisitor { tcx, replace_to: cor_ref_local }.visit_body(&mut body);
233    // Now changing first arg from Pin<&mut ImplCoroutine> to Pin<&mut ProxyCoroutine>
234    body.local_decls[pin_proxy_layout_local] = LocalDecl::new(pin_proxy_ref, span);
235
236    {
237        let mut idx: usize = 0;
238        // _proxy = _1.0 : Pin<&ProxyLayout> ==> &ProxyLayout
239        let proxy_ref_place = Place::from(pin_proxy_layout_local)
240            .project_deeper(&[PlaceElem::Field(FieldIdx::ZERO, proxy_ref)], tcx);
241        body.basic_blocks_mut()[START_BLOCK].statements.insert(
242            idx,
243            Statement {
244                source_info,
245                kind: StatementKind::Assign(Box::new((
246                    Place::from(proxy_ref_local),
247                    Rvalue::CopyForDeref(proxy_ref_place),
248                ))),
249            },
250        );
251        idx += 1;
252        let mut cor_ptr_local = proxy_ref_local;
253        proxy_ty.find_async_drop_impl_coroutine(tcx, |ty| {
254            if ty != proxy_ty {
255                let ty_ptr = Ty::new_mut_ptr(tcx, ty);
256                let impl_ptr_place = Place::from(cor_ptr_local).project_deeper(
257                    &[PlaceElem::Deref, PlaceElem::Field(FieldIdx::ZERO, ty_ptr)],
258                    tcx,
259                );
260                cor_ptr_local = body.local_decls.push(LocalDecl::new(ty_ptr, span));
261                // _cor_ptr = _proxy.0.0 (... .0)
262                body.basic_blocks_mut()[START_BLOCK].statements.insert(
263                    idx,
264                    Statement {
265                        source_info,
266                        kind: StatementKind::Assign(Box::new((
267                            Place::from(cor_ptr_local),
268                            Rvalue::CopyForDeref(impl_ptr_place),
269                        ))),
270                    },
271                );
272                idx += 1;
273            }
274        });
275
276        // _cor_ref = &*cor_ptr
277        let reborrow = Rvalue::Ref(
278            tcx.lifetimes.re_erased,
279            BorrowKind::Mut { kind: MutBorrowKind::Default },
280            tcx.mk_place_deref(Place::from(cor_ptr_local)),
281        );
282        body.basic_blocks_mut()[START_BLOCK].statements.insert(
283            idx,
284            Statement {
285                source_info,
286                kind: StatementKind::Assign(Box::new((Place::from(cor_ref_local), reborrow))),
287            },
288        );
289    }
290    body
291}
292
293// When dropping async drop coroutine, we continue its execution.
294// async_drop(async_drop(T))::future_drop_poll() => async_drop(T)::poll()
295fn build_adrop_for_adrop_shim<'tcx>(
296    tcx: TyCtxt<'tcx>,
297    proxy_ty: Ty<'tcx>,
298    impl_ty: Ty<'tcx>,
299    span: Span,
300    instance: ty::InstanceKind<'tcx>,
301) -> Body<'tcx> {
302    let source_info = SourceInfo::outermost(span);
303    let proxy_ref = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, proxy_ty);
304    // taking _1.0 (impl from Pin)
305    let pin_proxy_layout_local = Local::new(1);
306    let proxy_ref_place = Place::from(pin_proxy_layout_local)
307        .project_deeper(&[PlaceElem::Field(FieldIdx::ZERO, proxy_ref)], tcx);
308    let cor_ref = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, impl_ty);
309
310    // ret_ty = `Poll<()>`
311    let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, None));
312    let ret_ty = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
313    // env_ty = `Pin<&mut proxy_ty>`
314    let pin_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Pin, None));
315    let env_ty = Ty::new_adt(tcx, pin_adt_ref, tcx.mk_args(&[proxy_ref.into()]));
316    // sig = `fn (Pin<&mut proxy_ty>, &mut Context) -> Poll<()>`
317    let sig = tcx.mk_fn_sig(
318        [env_ty, Ty::new_task_context(tcx)],
319        ret_ty,
320        false,
321        hir::Safety::Safe,
322        ExternAbi::Rust,
323    );
324    // This function will be called with pinned proxy coroutine layout.
325    // We need to extract `Arg0.0` to get proxy layout, and then get `.0`
326    // further to receive impl coroutine (may be needed)
327    let mut locals = local_decls_for_sig(&sig, span);
328    let mut blocks = IndexVec::with_capacity(3);
329
330    let proxy_ref_local = locals.push(LocalDecl::new(proxy_ref, span));
331
332    let call_bb = BasicBlock::new(1);
333    let return_bb = BasicBlock::new(2);
334
335    let mut statements = Vec::new();
336
337    statements.push(Statement {
338        source_info,
339        kind: StatementKind::Assign(Box::new((
340            Place::from(proxy_ref_local),
341            Rvalue::CopyForDeref(proxy_ref_place),
342        ))),
343    });
344
345    let mut cor_ptr_local = proxy_ref_local;
346    proxy_ty.find_async_drop_impl_coroutine(tcx, |ty| {
347        if ty != proxy_ty {
348            let ty_ptr = Ty::new_mut_ptr(tcx, ty);
349            let impl_ptr_place = Place::from(cor_ptr_local)
350                .project_deeper(&[PlaceElem::Deref, PlaceElem::Field(FieldIdx::ZERO, ty_ptr)], tcx);
351            cor_ptr_local = locals.push(LocalDecl::new(ty_ptr, span));
352            // _cor_ptr = _proxy.0.0 (... .0)
353            statements.push(Statement {
354                source_info,
355                kind: StatementKind::Assign(Box::new((
356                    Place::from(cor_ptr_local),
357                    Rvalue::CopyForDeref(impl_ptr_place),
358                ))),
359            });
360        }
361    });
362
363    // convert impl coroutine ptr into ref
364    let reborrow = Rvalue::Ref(
365        tcx.lifetimes.re_erased,
366        BorrowKind::Mut { kind: MutBorrowKind::Default },
367        tcx.mk_place_deref(Place::from(cor_ptr_local)),
368    );
369    let cor_ref_place = Place::from(locals.push(LocalDecl::new(cor_ref, span)));
370    statements.push(Statement {
371        source_info,
372        kind: StatementKind::Assign(Box::new((cor_ref_place, reborrow))),
373    });
374
375    // cor_pin_ty = `Pin<&mut cor_ref>`
376    let cor_pin_ty = Ty::new_adt(tcx, pin_adt_ref, tcx.mk_args(&[cor_ref.into()]));
377    let cor_pin_place = Place::from(locals.push(LocalDecl::new(cor_pin_ty, span)));
378
379    let pin_fn = tcx.require_lang_item(LangItem::PinNewUnchecked, Some(span));
380    // call Pin<FutTy>::new_unchecked(&mut impl_cor)
381    blocks.push(BasicBlockData {
382        statements,
383        terminator: Some(Terminator {
384            source_info,
385            kind: TerminatorKind::Call {
386                func: Operand::function_handle(tcx, pin_fn, [cor_ref.into()], span),
387                args: [dummy_spanned(Operand::Move(cor_ref_place))].into(),
388                destination: cor_pin_place,
389                target: Some(call_bb),
390                unwind: UnwindAction::Continue,
391                call_source: CallSource::Misc,
392                fn_span: span,
393            },
394        }),
395        is_cleanup: false,
396    });
397    // When dropping async drop coroutine, we continue its execution:
398    // we call impl::poll (impl_layout, ctx)
399    let poll_fn = tcx.require_lang_item(LangItem::FuturePoll, None);
400    let resume_ctx = Place::from(Local::new(2));
401    blocks.push(BasicBlockData {
402        statements: vec![],
403        terminator: Some(Terminator {
404            source_info,
405            kind: TerminatorKind::Call {
406                func: Operand::function_handle(tcx, poll_fn, [impl_ty.into()], span),
407                args: [
408                    dummy_spanned(Operand::Move(cor_pin_place)),
409                    dummy_spanned(Operand::Move(resume_ctx)),
410                ]
411                .into(),
412                destination: Place::return_place(),
413                target: Some(return_bb),
414                unwind: UnwindAction::Continue,
415                call_source: CallSource::Misc,
416                fn_span: span,
417            },
418        }),
419        is_cleanup: false,
420    });
421    blocks.push(BasicBlockData {
422        statements: vec![],
423        terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }),
424        is_cleanup: false,
425    });
426
427    let source = MirSource::from_instance(instance);
428    let mut body = new_body(source, blocks, locals, sig.inputs().len(), span);
429    body.phase = MirPhase::Runtime(RuntimePhase::Initial);
430    return body;
431}