Skip to content

Commit 074488b

Browse files
committed
Properly infer types with type casts
1 parent 75ac37f commit 074488b

File tree

5 files changed

+112
-28
lines changed

5 files changed

+112
-28
lines changed

crates/hir-ty/src/infer.rs

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,15 @@
1313
//! to certain types. To record this, we use the union-find implementation from
1414
//! the `ena` crate, which is extracted from rustc.
1515
16+
mod cast;
17+
pub(crate) mod closure;
18+
mod coerce;
19+
mod expr;
20+
mod mutability;
21+
mod pat;
22+
mod path;
23+
pub(crate) mod unify;
24+
1625
use std::{convert::identity, ops::Index};
1726

1827
use chalk_ir::{
@@ -60,15 +69,8 @@ pub use coerce::could_coerce;
6069
#[allow(unreachable_pub)]
6170
pub use unify::could_unify;
6271

63-
pub(crate) use self::closure::{CaptureKind, CapturedItem, CapturedItemWithoutTy};
64-
65-
pub(crate) mod unify;
66-
mod path;
67-
mod expr;
68-
mod pat;
69-
mod coerce;
70-
pub(crate) mod closure;
71-
mod mutability;
72+
use cast::CastCheck;
73+
pub(crate) use closure::{CaptureKind, CapturedItem, CapturedItemWithoutTy};
7274

7375
/// The entry point of type inference.
7476
pub(crate) fn infer_query(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<InferenceResult> {
@@ -508,6 +510,8 @@ pub(crate) struct InferenceContext<'a> {
508510
diverges: Diverges,
509511
breakables: Vec<BreakableContext>,
510512

513+
deferred_cast_checks: Vec<CastCheck>,
514+
511515
// fields related to closure capture
512516
current_captures: Vec<CapturedItemWithoutTy>,
513517
current_closure: Option<ClosureId>,
@@ -582,7 +586,8 @@ impl<'a> InferenceContext<'a> {
582586
resolver,
583587
diverges: Diverges::Maybe,
584588
breakables: Vec::new(),
585-
current_captures: vec![],
589+
deferred_cast_checks: Vec::new(),
590+
current_captures: Vec::new(),
586591
current_closure: None,
587592
deferred_closures: FxHashMap::default(),
588593
closure_dependencies: FxHashMap::default(),
@@ -594,7 +599,7 @@ impl<'a> InferenceContext<'a> {
594599
// used this function for another workaround, mention it here. If you really need this function and believe that
595600
// there is no problem in it being `pub(crate)`, remove this comment.
596601
pub(crate) fn resolve_all(self) -> InferenceResult {
597-
let InferenceContext { mut table, mut result, .. } = self;
602+
let InferenceContext { mut table, mut result, deferred_cast_checks, .. } = self;
598603
// Destructure every single field so whenever new fields are added to `InferenceResult` we
599604
// don't forget to handle them here.
600605
let InferenceResult {
@@ -622,6 +627,13 @@ impl<'a> InferenceContext<'a> {
622627

623628
table.fallback_if_possible();
624629

630+
// Comment from rustc:
631+
// Even though coercion casts provide type hints, we check casts after fallback for
632+
// backwards compatibility. This makes fallback a stronger type hint than a cast coercion.
633+
for cast in deferred_cast_checks {
634+
cast.check(&mut table);
635+
}
636+
625637
// FIXME resolve obligations as well (use Guidance if necessary)
626638
table.resolve_obligations_as_possible();
627639

crates/hir-ty/src/infer/cast.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//! Type cast logic. Basically coercion + additional casts.
2+
3+
use crate::{infer::unify::InferenceTable, Interner, Ty, TyExt, TyKind};
4+
5+
#[derive(Clone, Debug)]
6+
pub(super) struct CastCheck {
7+
expr_ty: Ty,
8+
cast_ty: Ty,
9+
}
10+
11+
impl CastCheck {
12+
pub(super) fn new(expr_ty: Ty, cast_ty: Ty) -> Self {
13+
Self { expr_ty, cast_ty }
14+
}
15+
16+
pub(super) fn check(self, table: &mut InferenceTable<'_>) {
17+
// FIXME: This function currently only implements the bits that influence the type
18+
// inference. We should return the adjustments on success and report diagnostics on error.
19+
let expr_ty = table.resolve_ty_shallow(&self.expr_ty);
20+
let cast_ty = table.resolve_ty_shallow(&self.cast_ty);
21+
22+
if expr_ty.contains_unknown() || cast_ty.contains_unknown() {
23+
return;
24+
}
25+
26+
if table.coerce(&expr_ty, &cast_ty).is_ok() {
27+
return;
28+
}
29+
30+
if check_ref_to_ptr_cast(expr_ty, cast_ty, table) {
31+
// Note that this type of cast is actually split into a coercion to a
32+
// pointer type and a cast:
33+
// &[T; N] -> *[T; N] -> *T
34+
return;
35+
}
36+
37+
// FIXME: Check other kinds of non-coercion casts and report error if any?
38+
}
39+
}
40+
41+
fn check_ref_to_ptr_cast(expr_ty: Ty, cast_ty: Ty, table: &mut InferenceTable<'_>) -> bool {
42+
let Some((expr_inner_ty, _, _)) = expr_ty.as_reference() else { return false; };
43+
let Some((cast_inner_ty, _)) = cast_ty.as_raw_ptr() else { return false; };
44+
let TyKind::Array(expr_elt_ty, _) = expr_inner_ty.kind(Interner) else { return false; };
45+
table.coerce(expr_elt_ty, cast_inner_ty).is_ok()
46+
}

crates/hir-ty/src/infer/expr.rs

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ use crate::{
4646
};
4747

4848
use super::{
49-
coerce::auto_deref_adjust_steps, find_breakable, BreakableContext, Diverges, Expectation,
50-
InferenceContext, InferenceDiagnostic, TypeMismatch,
49+
cast::CastCheck, coerce::auto_deref_adjust_steps, find_breakable, BreakableContext, Diverges,
50+
Expectation, InferenceContext, InferenceDiagnostic, TypeMismatch,
5151
};
5252

5353
impl InferenceContext<'_> {
@@ -574,16 +574,8 @@ impl InferenceContext<'_> {
574574
}
575575
Expr::Cast { expr, type_ref } => {
576576
let cast_ty = self.make_ty(type_ref);
577-
// FIXME: propagate the "castable to" expectation
578-
let inner_ty = self.infer_expr_no_expect(*expr);
579-
match (inner_ty.kind(Interner), cast_ty.kind(Interner)) {
580-
(TyKind::Ref(_, _, inner), TyKind::Raw(_, cast)) => {
581-
// FIXME: record invalid cast diagnostic in case of mismatch
582-
self.unify(inner, cast);
583-
}
584-
// FIXME check the other kinds of cast...
585-
_ => (),
586-
}
577+
let expr_ty = self.infer_expr(*expr, &Expectation::Castable(cast_ty.clone()));
578+
self.deferred_cast_checks.push(CastCheck::new(expr_ty, cast_ty.clone()));
587579
cast_ty
588580
}
589581
Expr::Ref { expr, rawness, mutability } => {
@@ -1592,7 +1584,7 @@ impl InferenceContext<'_> {
15921584
output: Ty,
15931585
inputs: Vec<Ty>,
15941586
) -> Vec<Ty> {
1595-
if let Some(expected_ty) = expected_output.to_option(&mut self.table) {
1587+
if let Some(expected_ty) = expected_output.only_has_type(&mut self.table) {
15961588
self.table.fudge_inference(|table| {
15971589
if table.try_unify(&expected_ty, &output).is_ok() {
15981590
table.resolve_with_fallback(inputs, &|var, kind, _, _| match kind {

crates/hir-ty/src/tests/regression.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1978,3 +1978,23 @@ fn x(a: [i32; 4]) {
19781978
"#,
19791979
);
19801980
}
1981+
1982+
#[test]
1983+
fn dont_unify_on_casts() {
1984+
// #15246
1985+
check_types(
1986+
r#"
1987+
fn unify(_: [bool; 1]) {}
1988+
fn casted(_: *const bool) {}
1989+
fn default<T>() -> T { loop {} }
1990+
1991+
fn test() {
1992+
let foo = default();
1993+
//^^^ [bool; 1]
1994+
1995+
casted(&foo as *const _);
1996+
unify(foo);
1997+
}
1998+
"#,
1999+
);
2000+
}

crates/hir-ty/src/tests/simple.rs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3513,7 +3513,6 @@ fn func() {
35133513
);
35143514
}
35153515

3516-
// FIXME
35173516
#[test]
35183517
fn castable_to() {
35193518
check_infer(
@@ -3538,10 +3537,10 @@ fn func() {
35383537
120..122 '{}': ()
35393538
138..184 '{ ...0]>; }': ()
35403539
148..149 'x': Box<[i32; 0]>
3541-
152..160 'Box::new': fn new<[{unknown}; 0]>([{unknown}; 0]) -> Box<[{unknown}; 0]>
3542-
152..164 'Box::new([])': Box<[{unknown}; 0]>
3540+
152..160 'Box::new': fn new<[i32; 0]>([i32; 0]) -> Box<[i32; 0]>
3541+
152..164 'Box::new([])': Box<[i32; 0]>
35433542
152..181 'Box::n...2; 0]>': Box<[i32; 0]>
3544-
161..163 '[]': [{unknown}; 0]
3543+
161..163 '[]': [i32; 0]
35453544
"#]],
35463545
);
35473546
}
@@ -3577,6 +3576,21 @@ fn f<T>(t: Ark<T>) {
35773576
);
35783577
}
35793578

3579+
#[test]
3580+
fn ref_to_array_to_ptr_cast() {
3581+
check_types(
3582+
r#"
3583+
fn default<T>() -> T { loop {} }
3584+
fn foo() {
3585+
let arr = [default()];
3586+
//^^^ [i32; 1]
3587+
let ref_to_arr = &arr;
3588+
let casted = ref_to_arr as *const i32;
3589+
}
3590+
"#,
3591+
);
3592+
}
3593+
35803594
#[test]
35813595
fn const_dependent_on_local() {
35823596
check_types(

0 commit comments

Comments
 (0)