Skip to content

(some normalization improvements) #104133

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/rustc_hir_analysis/src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub fn provide(providers: &mut Providers) {
*providers = Providers {
opt_const_param_of: type_of::opt_const_param_of,
type_of: type_of::type_of,
fully_revealed_type_of: type_of::fully_revealed_type_of,
item_bounds: item_bounds::item_bounds,
explicit_item_bounds: item_bounds::explicit_item_bounds,
generics_of: generics_of::generics_of,
Expand Down
33 changes: 32 additions & 1 deletion compiler/rustc_hir_analysis/src/collect/type_of.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ use rustc_hir::{HirId, Node};
use rustc_middle::hir::nested_filter;
use rustc_middle::ty::subst::InternalSubsts;
use rustc_middle::ty::util::IntTypeExt;
use rustc_middle::ty::{self, DefIdTree, Ty, TyCtxt, TypeFolder, TypeSuperFoldable, TypeVisitable};
use rustc_middle::ty::{
self, DefIdTree, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitable,
};
use rustc_span::symbol::Ident;
use rustc_span::{Span, DUMMY_SP};

Expand Down Expand Up @@ -538,6 +540,35 @@ pub(super) fn type_of(tcx: TyCtxt<'_>, def_id: DefId) -> Ty<'_> {
}
}

pub fn fully_revealed_type_of<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> Ty<'tcx> {
let ty = tcx.type_of(def_id);
if ty.has_opaque_types() { ty.fold_with(&mut DeeperTypeFolder { tcx }) } else { ty }
}

struct DeeperTypeFolder<'tcx> {
tcx: TyCtxt<'tcx>,
}

impl<'tcx> TypeFolder<'tcx> for DeeperTypeFolder<'tcx> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}

fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
if !ty.has_opaque_types() {
return ty;
}

let ty = ty.super_fold_with(self);

if let ty::Opaque(def_id, substs) = *ty.kind() {
self.tcx.bound_fully_revealed_type_of(def_id).subst(self.tcx, substs)
} else {
ty
}
}
}

#[instrument(skip(tcx), level = "debug")]
/// Checks "defining uses" of opaque `impl Trait` types to ensure that they meet the restrictions
/// laid for "higher-order pattern unification".
Expand Down
5 changes: 5 additions & 0 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,11 @@ impl<'tcx> InferCtxtBuilder<'tcx> {
self
}

pub fn considering_regions(mut self, c: bool) -> Self {
self.considering_regions = c;
self
}

pub fn with_normalize_fn_sig_for_diagnostic(
mut self,
fun: Lrc<dyn Fn(&InferCtxt<'tcx>, ty::PolyFnSig<'tcx>) -> ty::PolyFnSig<'tcx>>,
Expand Down
19 changes: 17 additions & 2 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,21 @@ rustc_queries! {
separate_provide_extern
}

query fully_revealed_type_of(key: DefId) -> Ty<'tcx> {
desc { |tcx|
"fully {action} `{path}`",
action = {
use rustc_hir::def::DefKind;
match tcx.def_kind(key) {
DefKind::TyAlias => "expanding type alias",
DefKind::TraitAlias => "expanding trait alias",
_ => "computing type of",
}
},
path = tcx.def_path_str(key),
}
}

query collect_trait_impl_trait_tys(key: DefId)
-> Result<&'tcx FxHashMap<DefId, Ty<'tcx>>, ErrorGuaranteed>
{
Expand Down Expand Up @@ -1873,12 +1888,12 @@ rustc_queries! {

/// Do not call this query directly: invoke `normalize` instead.
query normalize_projection_ty(
goal: CanonicalProjectionGoal<'tcx>
key: CanonicalProjectionGoal<'tcx>
) -> Result<
&'tcx Canonical<'tcx, canonical::QueryResponse<'tcx, NormalizationResult<'tcx>>>,
NoSolution,
> {
desc { "normalizing `{}`", goal.value.value }
desc { "normalizing `{}` {}", key.value.value.projection_ty, if key.value.value.considering_regions { "considering regions" } else { "modulo regions" } }
remap_env_constness
}

Expand Down
9 changes: 8 additions & 1 deletion compiler/rustc_middle/src/traits/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,14 @@ pub mod type_op {
}

pub type CanonicalProjectionGoal<'tcx> =
Canonical<'tcx, ty::ParamEnvAnd<'tcx, ty::ProjectionTy<'tcx>>>;
Canonical<'tcx, ty::ParamEnvAnd<'tcx, ProjectionGoal<'tcx>>>;

#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, HashStable, Lift)]
#[derive(TypeFoldable, TypeVisitable)]
pub struct ProjectionGoal<'tcx> {
pub projection_ty: ty::ProjectionTy<'tcx>,
pub considering_regions: bool,
}

pub type CanonicalTyGoal<'tcx> = Canonical<'tcx, ty::ParamEnvAnd<'tcx, Ty<'tcx>>>;

Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_middle/src/ty/normalize_erasing_regions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ impl<'tcx> TyCtxt<'tcx> {
let value = self.erase_regions(value);
debug!(?value);

if !value.has_projections() {
if !value.needs_normalization(param_env.reveal()) {
value
} else {
value.fold_with(&mut NormalizeAfterErasingRegionsFolder { tcx: self, param_env })
Expand Down Expand Up @@ -84,7 +84,7 @@ impl<'tcx> TyCtxt<'tcx> {
let value = self.erase_regions(value);
debug!(?value);

if !value.has_projections() {
if !value.needs_normalization(param_env.reveal()) {
Ok(value)
} else {
let mut folder = TryNormalizeAfterErasingRegionsFolder::new(self, param_env);
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_middle/src/ty/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,10 @@ impl<'tcx> TyCtxt<'tcx> {
ty::EarlyBinder(self.type_of(def_id))
}

pub fn bound_fully_revealed_type_of(self, def_id: DefId) -> ty::EarlyBinder<Ty<'tcx>> {
ty::EarlyBinder(self.fully_revealed_type_of(def_id))
}

pub fn bound_trait_impl_trait_tys(
self,
def_id: DefId,
Expand Down
15 changes: 14 additions & 1 deletion compiler/rustc_middle/src/ty/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,19 @@ pub trait TypeVisitable<'tcx>: fmt::Debug + Clone {
fn still_further_specializable(&self) -> bool {
self.has_type_flags(TypeFlags::STILL_FURTHER_SPECIALIZABLE)
}

fn needs_normalization(&self, reveal: ty::Reveal) -> bool {
match reveal {
ty::Reveal::UserFacing => {
self.has_type_flags(TypeFlags::HAS_TY_PROJECTION | TypeFlags::HAS_CT_PROJECTION)
}
ty::Reveal::All => self.has_type_flags(
TypeFlags::HAS_TY_PROJECTION
| TypeFlags::HAS_TY_OPAQUE
| TypeFlags::HAS_CT_PROJECTION,
),
}
}
}

pub trait TypeSuperVisitable<'tcx>: TypeVisitable<'tcx> {
Expand Down Expand Up @@ -537,7 +550,7 @@ struct FoundFlags;

// FIXME: Optimize for checking for infer flags
struct HasTypeFlagsVisitor {
flags: ty::TypeFlags,
flags: TypeFlags,
}

impl std::fmt::Debug for HasTypeFlagsVisitor {
Expand Down
2 changes: 1 addition & 1 deletion compiler/rustc_trait_selection/src/traits/fulfill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ impl<'a, 'tcx> ObligationProcessor for FulfillProcessor<'a, 'tcx> {

let infcx = self.selcx.infcx();

if obligation.predicate.has_projections() {
if obligation.predicate.needs_normalization(obligation.param_env.reveal()) {
let mut obligations = Vec::new();
let predicate = crate::traits::project::try_normalize_with_depth_to(
&mut self.selcx,
Expand Down
26 changes: 9 additions & 17 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,18 +381,6 @@ where
result
}

pub(crate) fn needs_normalization<'tcx, T: TypeVisitable<'tcx>>(value: &T, reveal: Reveal) -> bool {
match reveal {
Reveal::UserFacing => value
.has_type_flags(ty::TypeFlags::HAS_TY_PROJECTION | ty::TypeFlags::HAS_CT_PROJECTION),
Reveal::All => value.has_type_flags(
ty::TypeFlags::HAS_TY_PROJECTION
| ty::TypeFlags::HAS_TY_OPAQUE
| ty::TypeFlags::HAS_CT_PROJECTION,
),
}
}

struct AssocTypeNormalizer<'a, 'b, 'tcx> {
selcx: &'a mut SelectionContext<'b, 'tcx>,
param_env: ty::ParamEnv<'tcx>,
Expand Down Expand Up @@ -453,7 +441,7 @@ impl<'a, 'b, 'tcx> AssocTypeNormalizer<'a, 'b, 'tcx> {
value
);

if !needs_normalization(&value, self.param_env.reveal()) {
if !value.needs_normalization(self.param_env.reveal()) {
value
} else {
value.fold_with(self)
Expand All @@ -477,7 +465,7 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
}

fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
if !needs_normalization(&ty, self.param_env.reveal()) {
if !ty.needs_normalization(self.param_env.reveal()) {
return ty;
}

Expand Down Expand Up @@ -526,7 +514,7 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
}

let substs = substs.fold_with(self);
let generic_ty = self.tcx().bound_type_of(def_id);
let generic_ty = self.tcx().bound_fully_revealed_type_of(def_id);
let concrete_ty = generic_ty.subst(self.tcx(), substs);
self.depth += 1;
let folded_ty = self.fold_ty(concrete_ty);
Expand Down Expand Up @@ -650,6 +638,10 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
if tcx.lazy_normalization() {
constant
} else {
if !constant.needs_normalization(self.param_env.reveal()) {
return constant;
}

let constant = constant.super_fold_with(self);
debug!(?constant, ?self.param_env);
with_replaced_escaping_bound_vars(
Expand All @@ -663,7 +655,7 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {

#[inline]
fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if p.allow_normalization() && needs_normalization(&p, self.param_env.reveal()) {
if p.allow_normalization() && p.needs_normalization(self.param_env.reveal()) {
p.super_fold_with(self)
} else {
p
Expand Down Expand Up @@ -1124,7 +1116,7 @@ fn opt_normalize_projection_type<'a, 'b, 'tcx>(

let projected_term = selcx.infcx().resolve_vars_if_possible(projected_term);

let mut result = if projected_term.has_projections() {
let mut result = if projected_term.needs_normalization(param_env.reveal()) {
let mut normalizer = AssocTypeNormalizer::new(
selcx,
param_env,
Expand Down
25 changes: 18 additions & 7 deletions compiler/rustc_trait_selection/src/traits/query/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::infer::at::At;
use crate::infer::canonical::OriginalQueryValues;
use crate::infer::{InferCtxt, InferOk};
use crate::traits::error_reporting::TypeErrCtxtExt;
use crate::traits::project::{needs_normalization, BoundVarReplacer, PlaceholderReplacer};
use crate::traits::project::{BoundVarReplacer, PlaceholderReplacer};
use crate::traits::{Obligation, ObligationCause, PredicateObligation, Reveal};
use rustc_data_structures::sso::SsoHashMap;
use rustc_data_structures::stack::ensure_sufficient_stack;
Expand All @@ -21,6 +21,7 @@ use std::ops::ControlFlow;
use super::NoSolution;

pub use rustc_middle::traits::query::NormalizationResult;
use rustc_middle::traits::query::ProjectionGoal;

pub trait AtExt<'tcx> {
fn normalize<T>(&self, value: T) -> Result<Normalized<'tcx, T>, NoSolution>
Expand Down Expand Up @@ -53,7 +54,7 @@ impl<'cx, 'tcx> AtExt<'tcx> for At<'cx, 'tcx> {
self.param_env,
self.cause,
);
if !needs_normalization(&value, self.param_env.reveal()) {
if !value.needs_normalization(self.param_env.reveal()) {
return Ok(Normalized { value, obligations: vec![] });
}

Expand Down Expand Up @@ -182,7 +183,7 @@ impl<'cx, 'tcx> FallibleTypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {

#[instrument(level = "debug", skip(self))]
fn try_fold_ty(&mut self, ty: Ty<'tcx>) -> Result<Ty<'tcx>, Self::Error> {
if !needs_normalization(&ty, self.param_env.reveal()) {
if !ty.needs_normalization(self.param_env.reveal()) {
return Ok(ty);
}

Expand Down Expand Up @@ -216,7 +217,7 @@ impl<'cx, 'tcx> FallibleTypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {
self.infcx.err_ctxt().report_overflow_error(&obligation, true);
}

let generic_ty = self.tcx().bound_type_of(def_id);
let generic_ty = self.tcx().bound_fully_revealed_type_of(def_id);
let concrete_ty = generic_ty.subst(self.tcx(), substs);
self.anon_depth += 1;
if concrete_ty == ty {
Expand All @@ -241,7 +242,10 @@ impl<'cx, 'tcx> FallibleTypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {
// we don't need to replace them with placeholders (see branch below).

let tcx = self.infcx.tcx;
let data = data.try_fold_with(self)?;
let data = ProjectionGoal {
projection_ty: data.try_fold_with(self)?,
considering_regions: self.infcx.considering_regions,
};

let mut orig_values = OriginalQueryValues::default();
// HACK(matthewjasper) `'static` is special-cased in selection,
Expand Down Expand Up @@ -292,7 +296,10 @@ impl<'cx, 'tcx> FallibleTypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {
let infcx = self.infcx;
let (data, mapped_regions, mapped_types, mapped_consts) =
BoundVarReplacer::replace_bound_vars(infcx, &mut self.universes, data);
let data = data.try_fold_with(self)?;
let data = ProjectionGoal {
projection_ty: data.try_fold_with(self)?,
considering_regions: self.infcx.considering_regions,
};

let mut orig_values = OriginalQueryValues::default();
// HACK(matthewjasper) `'static` is special-cased in selection,
Expand Down Expand Up @@ -353,6 +360,10 @@ impl<'cx, 'tcx> FallibleTypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {
&mut self,
constant: ty::Const<'tcx>,
) -> Result<ty::Const<'tcx>, Self::Error> {
if !constant.needs_normalization(self.param_env.reveal()) {
return Ok(constant);
}

let constant = constant.try_super_fold_with(self)?;
debug!(?constant, ?self.param_env);
Ok(crate::traits::project::with_replaced_escaping_bound_vars(
Expand All @@ -368,7 +379,7 @@ impl<'cx, 'tcx> FallibleTypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {
&mut self,
p: ty::Predicate<'tcx>,
) -> Result<ty::Predicate<'tcx>, Self::Error> {
if p.allow_normalization() && needs_normalization(&p, self.param_env.reveal()) {
if p.allow_normalization() && p.needs_normalization(self.param_env.reveal()) {
p.try_super_fold_with(self)
} else {
Ok(p)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@ where
type QueryResponse = T;

fn try_fast_path(_tcx: TyCtxt<'tcx>, key: &ParamEnvAnd<'tcx, Self>) -> Option<T> {
if !key.value.value.has_projections() { Some(key.value.value) } else { None }
if !key.value.value.needs_normalization(key.param_env.reveal()) {
Some(key.value.value)
} else {
None
}
}

fn perform_query(
Expand Down
Loading