Skip to content

Commit 3e7d5ae

Browse files
noti0na1olhotakevangirardin
committed
Add experimental flexible types feature on top of explicit nulls
Enabled by -Yflexible-types with -Yexplicit-nulls. A flexible type T! is a non-denotable type such that T <: T! <: T|Null and T|Null <: T! <: T. Here we patch return types and parameter types of Java methods and fields to use flexible types. This is unsound and kills subtyping transitivity but makes interop with Java play more nicely with the explicit nulls experimental feature (i.e. fewer nullability casts). Co-authored-by: Ondřej Lhoták <[email protected]> Co-authored-by: Evan Girardin <[email protected]>
1 parent c88c0fe commit 3e7d5ae

File tree

57 files changed

+650
-37
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+650
-37
lines changed

compiler/src/dotty/tools/dotc/config/ScalaSettings.scala

+1
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ private sealed trait YSettings:
410410
// Experimental language features
411411
val YnoKindPolymorphism: Setting[Boolean] = BooleanSetting("-Yno-kind-polymorphism", "Disable kind polymorphism.")
412412
val YexplicitNulls: Setting[Boolean] = BooleanSetting("-Yexplicit-nulls", "Make reference types non-nullable. Nullable types can be expressed with unions: e.g. String|Null.")
413+
val YflexibleTypes: Setting[Boolean] = BooleanSetting("-Yflexible-types", "Make Java return types and parameter types use flexible types, which have a nullable lower bound and non-null upper bound.")
413414
val YcheckInit: Setting[Boolean] = BooleanSetting("-Ysafe-init", "Ensure safe initialization of objects.")
414415
val YcheckInitGlobal: Setting[Boolean] = BooleanSetting("-Ysafe-init-global", "Check safe initialization of global objects.")
415416
val YrequireTargetName: Setting[Boolean] = BooleanSetting("-Yrequire-targetName", "Warn if an operator is defined without a @targetName annotation.")

compiler/src/dotty/tools/dotc/core/Contexts.scala

+3
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,9 @@ object Contexts {
472472
/** Is the explicit nulls option set? */
473473
def explicitNulls: Boolean = base.settings.YexplicitNulls.value
474474

475+
/** Is the flexible types option set? */
476+
def flexibleTypes: Boolean = base.settings.YexplicitNulls.value && base.settings.YflexibleTypes.value
477+
475478
/** A fresh clone of this context embedded in this context. */
476479
def fresh: FreshContext = freshOver(this)
477480

compiler/src/dotty/tools/dotc/core/JavaNullInterop.scala

+17-13
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ object JavaNullInterop {
7878
* but the result type is not nullable.
7979
*/
8080
private def nullifyExceptReturnType(tp: Type)(using Context): Type =
81-
new JavaNullMap(true)(tp)
81+
new JavaNullMap(outermostLevelAlreadyNullable = true)(tp)
8282

8383
/** Nullifies a Java type by adding `| Null` in the relevant places. */
8484
private def nullifyType(tp: Type)(using Context): Type =
85-
new JavaNullMap(false)(tp)
85+
new JavaNullMap(outermostLevelAlreadyNullable = false)(tp)
8686

8787
/** A type map that implements the nullification function on types. Given a Java-sourced type, this adds `| Null`
8888
* in the right places to make the nulls explicit in Scala.
@@ -96,25 +96,29 @@ object JavaNullInterop {
9696
* to `(A & B) | Null`, instead of `(A | Null & B | Null) | Null`.
9797
*/
9898
private class JavaNullMap(var outermostLevelAlreadyNullable: Boolean)(using Context) extends TypeMap {
99+
def nullify(tp: Type): Type = if ctx.flexibleTypes then FlexibleType(tp) else OrNull(tp)
100+
99101
/** Should we nullify `tp` at the outermost level? */
100102
def needsNull(tp: Type): Boolean =
101-
!outermostLevelAlreadyNullable && (tp match {
103+
!(outermostLevelAlreadyNullable || (tp match {
102104
case tp: TypeRef =>
103105
// We don't modify value types because they're non-nullable even in Java.
104-
!tp.symbol.isValueClass &&
106+
tp.symbol.isValueClass
107+
// We don't modify unit types.
108+
|| tp.isRef(defn.UnitClass)
105109
// We don't modify `Any` because it's already nullable.
106-
!tp.isRef(defn.AnyClass) &&
110+
|| tp.isRef(defn.AnyClass)
107111
// We don't nullify Java varargs at the top level.
108112
// Example: if `setNames` is a Java method with signature `void setNames(String... names)`,
109113
// then its Scala signature will be `def setNames(names: (String|Null)*): Unit`.
110114
// This is because `setNames(null)` passes as argument a single-element array containing the value `null`,
111115
// and not a `null` array.
112-
!tp.isRef(defn.RepeatedParamClass)
113-
case _ => true
114-
})
116+
|| !ctx.flexibleTypes && tp.isRef(defn.RepeatedParamClass)
117+
case _ => false
118+
}))
115119

116120
override def apply(tp: Type): Type = tp match {
117-
case tp: TypeRef if needsNull(tp) => OrNull(tp)
121+
case tp: TypeRef if needsNull(tp) => nullify(tp)
118122
case appTp @ AppliedType(tycon, targs) =>
119123
val oldOutermostNullable = outermostLevelAlreadyNullable
120124
// We don't make the outmost levels of type arguments nullable if tycon is Java-defined.
@@ -124,7 +128,7 @@ object JavaNullInterop {
124128
val targs2 = targs map this
125129
outermostLevelAlreadyNullable = oldOutermostNullable
126130
val appTp2 = derivedAppliedType(appTp, tycon, targs2)
127-
if needsNull(tycon) then OrNull(appTp2) else appTp2
131+
if needsNull(tycon) then nullify(appTp2) else appTp2
128132
case ptp: PolyType =>
129133
derivedLambdaType(ptp)(ptp.paramInfos, this(ptp.resType))
130134
case mtp: MethodType =>
@@ -138,12 +142,12 @@ object JavaNullInterop {
138142
// nullify(A & B) = (nullify(A) & nullify(B)) | Null, but take care not to add
139143
// duplicate `Null`s at the outermost level inside `A` and `B`.
140144
outermostLevelAlreadyNullable = true
141-
OrNull(derivedAndType(tp, this(tp.tp1), this(tp.tp2)))
142-
case tp: TypeParamRef if needsNull(tp) => OrNull(tp)
145+
nullify(derivedAndType(tp, this(tp.tp1), this(tp.tp2)))
146+
case tp: TypeParamRef if needsNull(tp) => nullify(tp)
143147
// In all other cases, return the type unchanged.
144148
// In particular, if the type is a ConstantType, then we don't nullify it because it is the
145149
// type of a final non-nullable field.
146150
case _ => tp
147151
}
148152
}
149-
}
153+
}

compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala

+7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ import Types.*
88
object NullOpsDecorator:
99

1010
extension (self: Type)
11+
def stripFlexible(using Context): Type = {
12+
self match {
13+
case FlexibleType(tp) => tp
14+
case _ => self
15+
}
16+
}
1117
/** Syntactically strips the nullability from this type.
1218
* If the type is `T1 | ... | Tn`, and `Ti` references to `Null`,
1319
* then return `T1 | ... | Ti-1 | Ti+1 | ... | Tn`.
@@ -33,6 +39,7 @@ object NullOpsDecorator:
3339
if (tp1s ne tp1) && (tp2s ne tp2) then
3440
tp.derivedAndType(tp1s, tp2s)
3541
else tp
42+
case tp: FlexibleType => tp.hi
3643
case tp @ TypeBounds(lo, hi) =>
3744
tp.derivedTypeBounds(strip(lo), strip(hi))
3845
case tp => tp

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

+3
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
564564
case CapturingType(parent, refs) =>
565565
val parent1 = recur(parent)
566566
if parent1 ne parent then tp.derivedCapturingType(parent1, refs) else tp
567+
case tp: FlexibleType =>
568+
val underlying = recur(tp.underlying)
569+
if underlying ne tp.underlying then tp.derivedFlexibleType(underlying) else tp
567570
case tp: AnnotatedType =>
568571
val parent1 = recur(tp.parent)
569572
if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Contexts.ctx
1010
import dotty.tools.dotc.reporting.trace
1111
import config.Feature.migrateTo3
1212
import config.Printers.*
13+
import dotty.tools.dotc.core.NullOpsDecorator.stripFlexible
1314

1415
trait PatternTypeConstrainer { self: TypeComparer =>
1516

@@ -163,7 +164,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
163164
}
164165
}
165166

166-
def dealiasDropNonmoduleRefs(tp: Type) = tp.dealias match {
167+
def dealiasDropNonmoduleRefs(tp: Type): Type = tp.dealias match {
167168
case tp: TermRef =>
168169
// we drop TermRefs that don't have a class symbol, as they can't
169170
// meaningfully participate in GADT reasoning and just get in the way.
@@ -172,6 +173,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
172173
// additional trait - argument-less enum cases desugar to vals.
173174
// See run/enum-Tree.scala.
174175
if tp.classSymbol.exists then tp else tp.info
176+
case FlexibleType(tp) => dealiasDropNonmoduleRefs(tp)
175177
case tp => tp
176178
}
177179

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

+19-10
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import reporting.trace
2323
import annotation.constructorOnly
2424
import cc.*
2525
import NameKinds.WildcardParamName
26+
import NullOpsDecorator.stripFlexible
2627

2728
/** Provides methods to compare types.
2829
*/
@@ -524,7 +525,6 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
524525
constraint = constraint.hardenTypeVars(tp2)
525526

526527
res
527-
528528
case tp1 @ CapturingType(parent1, refs1) =>
529529
def compareCapturing =
530530
if tp2.isAny then true
@@ -863,6 +863,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
863863
false
864864
}
865865
compareClassInfo
866+
case tp2: FlexibleType =>
867+
recur(tp1, tp2.lo)
866868
case _ =>
867869
fourthTry
868870
}
@@ -1058,6 +1060,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
10581060
case tp1: ExprType if ctx.phaseId > gettersPhase.id =>
10591061
// getters might have converted T to => T, need to compensate.
10601062
recur(tp1.widenExpr, tp2)
1063+
case tp1: FlexibleType =>
1064+
recur(tp1.hi, tp2)
10611065
case _ =>
10621066
false
10631067
}
@@ -2499,15 +2503,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
24992503
NoType
25002504
}
25012505

2502-
private def andTypeGen(tp1: Type, tp2: Type, op: (Type, Type) => Type,
2503-
original: (Type, Type) => Type = _ & _, isErased: Boolean = ctx.erasedTypes): Type = trace(s"andTypeGen(${tp1.show}, ${tp2.show})", subtyping, show = true) {
2504-
val t1 = distributeAnd(tp1, tp2)
2505-
if (t1.exists) t1
2506-
else {
2507-
val t2 = distributeAnd(tp2, tp1)
2508-
if (t2.exists) t2
2509-
else if (isErased) erasedGlb(tp1, tp2)
2510-
else liftIfHK(tp1, tp2, op, original, _ | _)
2506+
private def andTypeGen(tp1orig: Type, tp2orig: Type, op: (Type, Type) => Type,
2507+
original: (Type, Type) => Type = _ & _, isErased: Boolean = ctx.erasedTypes): Type = trace(s"andTypeGen(${tp1orig.show}, ${tp2orig.show})", subtyping, show = true) {
2508+
val tp1 = tp1orig.stripFlexible
2509+
val tp2 = tp2orig.stripFlexible
2510+
val ret = {
2511+
val t1 = distributeAnd(tp1, tp2)
2512+
if (t1.exists) t1
2513+
else {
2514+
val t2 = distributeAnd(tp2, tp1)
2515+
if (t2.exists) t2
2516+
else if (isErased) erasedGlb(tp1, tp2)
2517+
else liftIfHK(tp1, tp2, op, original, _ | _)
25112518
// The ` | ` on variances is needed since variances are associated with bounds
25122519
// not lambdas. Example:
25132520
//
@@ -2517,7 +2524,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
25172524
//
25182525
// Here, `F` is treated as bivariant in `O`. That is, only bivariant implementation
25192526
// of `F` are allowed. See neg/hk-variance2s.scala test.
2527+
}
25202528
}
2529+
if(tp1orig.isInstanceOf[FlexibleType] && tp2orig.isInstanceOf[FlexibleType]) FlexibleType(ret) else ret
25212530
}
25222531

25232532
/** Form a normalized conjunction of two types.

compiler/src/dotty/tools/dotc/core/Types.scala

+54
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ object Types extends TypeUtils {
343343
/** Is this type guaranteed not to have `null` as a value? */
344344
final def isNotNull(using Context): Boolean = this match {
345345
case tp: ConstantType => tp.value.value != null
346+
case tp: FlexibleType => false
346347
case tp: ClassInfo => !tp.cls.isNullableClass && tp.cls != defn.NothingClass
347348
case tp: AppliedType => tp.superType.isNotNull
348349
case tp: TypeBounds => tp.lo.isNotNull
@@ -372,6 +373,7 @@ object Types extends TypeUtils {
372373
case AppliedType(tycon, args) => tycon.unusableForInference || args.exists(_.unusableForInference)
373374
case RefinedType(parent, _, rinfo) => parent.unusableForInference || rinfo.unusableForInference
374375
case TypeBounds(lo, hi) => lo.unusableForInference || hi.unusableForInference
376+
case FlexibleType(underlying) => underlying.unusableForInference
375377
case tp: AndOrType => tp.tp1.unusableForInference || tp.tp2.unusableForInference
376378
case tp: LambdaType => tp.resultType.unusableForInference || tp.paramInfos.exists(_.unusableForInference)
377379
case WildcardType(optBounds) => optBounds.unusableForInference
@@ -3396,6 +3398,40 @@ object Types extends TypeUtils {
33963398
}
33973399
}
33983400

3401+
// --- FlexibleType -----------------------------------------------------------------
3402+
3403+
/* Represents a nullable type coming from Java code in a similar way to Platform Types
3404+
* in Kotlin. A FlexibleType(T) generally behaves like an abstract type with bad bounds
3405+
* T|Null .. T, so that T|Null <: FlexibleType(T) <: T.
3406+
*/
3407+
case class FlexibleType(original: Type, lo: Type, hi: Type) extends CachedProxyType with ValueType {
3408+
def underlying(using Context): Type = original
3409+
3410+
override def superType(using Context): Type = hi
3411+
3412+
def derivedFlexibleType(original: Type)(using Context): Type =
3413+
if this.original eq original then this else FlexibleType(original)
3414+
3415+
override def computeHash(bs: Binders): Int = doHash(bs, original)
3416+
3417+
override final def baseClasses(using Context): List[ClassSymbol] = original.baseClasses
3418+
}
3419+
3420+
object FlexibleType {
3421+
def apply(original: Type)(using Context): FlexibleType = original match {
3422+
case ft: FlexibleType => ft
3423+
case _ =>
3424+
val hi = original.stripNull
3425+
val lo = if hi eq original then OrNull(hi) else original
3426+
new FlexibleType(original, lo, hi)
3427+
}
3428+
3429+
def unapply(tp: Type)(using Context): Option[Type] = tp match {
3430+
case ft: FlexibleType => Some(ft.original)
3431+
case _ => None
3432+
}
3433+
}
3434+
33993435
// --- AndType/OrType ---------------------------------------------------------------
34003436

34013437
abstract class AndOrType extends CachedGroundType with ValueType {
@@ -5694,6 +5730,8 @@ object Types extends TypeUtils {
56945730
samClass(tp.underlying)
56955731
case tp: AnnotatedType =>
56965732
samClass(tp.underlying)
5733+
case tp: FlexibleType =>
5734+
samClass(tp.superType)
56975735
case _ =>
56985736
NoSymbol
56995737

@@ -5824,6 +5862,8 @@ object Types extends TypeUtils {
58245862
tp.derivedJavaArrayType(elemtp)
58255863
protected def derivedExprType(tp: ExprType, restpe: Type): Type =
58265864
tp.derivedExprType(restpe)
5865+
protected def derivedFlexibleType(tp: FlexibleType, under: Type): Type =
5866+
tp.derivedFlexibleType(under)
58275867
// note: currying needed because Scala2 does not support param-dependencies
58285868
protected def derivedLambdaType(tp: LambdaType)(formals: List[tp.PInfo], restpe: Type): Type =
58295869
tp.derivedLambdaType(tp.paramNames, formals, restpe)
@@ -5947,6 +5987,9 @@ object Types extends TypeUtils {
59475987
case tp: OrType =>
59485988
derivedOrType(tp, this(tp.tp1), this(tp.tp2))
59495989

5990+
case tp: FlexibleType =>
5991+
derivedFlexibleType(tp, this(tp.underlying))
5992+
59505993
case tp: MatchType =>
59515994
val bound1 = this(tp.bound)
59525995
val scrut1 = atVariance(0)(this(tp.scrutinee))
@@ -6234,6 +6277,14 @@ object Types extends TypeUtils {
62346277
if (underlying.isExactlyNothing) underlying
62356278
else tp.derivedAnnotatedType(underlying, annot)
62366279
}
6280+
override protected def derivedFlexibleType(tp: FlexibleType, underlying: Type): Type =
6281+
underlying match {
6282+
case Range(lo, hi) =>
6283+
range(tp.derivedFlexibleType(lo), tp.derivedFlexibleType(hi))
6284+
case _ =>
6285+
if (underlying.isExactlyNothing) underlying
6286+
else tp.derivedFlexibleType(underlying)
6287+
}
62376288
override protected def derivedCapturingType(tp: Type, parent: Type, refs: CaptureSet): Type =
62386289
parent match // TODO ^^^ handle ranges in capture sets as well
62396290
case Range(lo, hi) =>
@@ -6375,6 +6426,9 @@ object Types extends TypeUtils {
63756426
case tp: TypeVar =>
63766427
this(x, tp.underlying)
63776428

6429+
case tp: FlexibleType =>
6430+
this(x, tp.underlying)
6431+
63786432
case ExprType(restpe) =>
63796433
this(x, restpe)
63806434

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

+3
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,9 @@ class TreePickler(pickler: TastyPickler, attributes: Attributes) {
269269
case tpe: OrType =>
270270
writeByte(ORtype)
271271
withLength { pickleType(tpe.tp1, richTypes); pickleType(tpe.tp2, richTypes) }
272+
case tpe: FlexibleType =>
273+
writeByte(FLEXIBLEtype)
274+
withLength { pickleType(tpe.underlying, richTypes) }
272275
case tpe: ExprType =>
273276
writeByte(BYNAMEtype)
274277
pickleType(tpe.underlying)

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

+2
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,8 @@ class TreeUnpickler(reader: TastyReader,
430430
readTypeRef() match {
431431
case binder: LambdaType => binder.paramRefs(readNat())
432432
}
433+
case FLEXIBLEtype =>
434+
FlexibleType(readType())
433435
}
434436
assert(currentAddr == end, s"$start $currentAddr $end ${astTagToString(tag)}")
435437
result

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

+2
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ class PlainPrinter(_ctx: Context) extends Printer {
287287
case AnnotatedType(tpe, annot) =>
288288
if annot.symbol == defn.InlineParamAnnot || annot.symbol == defn.ErasedParamAnnot then toText(tpe)
289289
else toTextLocal(tpe) ~ " " ~ toText(annot)
290+
case FlexibleType(tpe) =>
291+
"FlexibleType(" ~ toText(tpe) ~ ")"
290292
case tp: TypeVar =>
291293
def toTextCaret(tp: Type) = if printDebug then toTextLocal(tp) ~ Str("^") else toText(tp)
292294
if (tp.isInstantiated)

compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala

+2
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,8 @@ private class ExtractAPICollector(using Context) extends ThunkHolder {
565565
case tp: OrType =>
566566
val s = combineApiTypes(apiType(tp.tp1), apiType(tp.tp2))
567567
withMarker(s, orMarker)
568+
case tp: FlexibleType =>
569+
apiType(tp.underlying)
568570
case ExprType(resultType) =>
569571
withMarker(apiType(resultType), byNameMarker)
570572
case MatchType(bound, scrut, cases) =>

compiler/src/dotty/tools/dotc/typer/Applications.scala

+3-2
Original file line numberDiff line numberDiff line change
@@ -643,14 +643,15 @@ trait Applications extends Compatibility {
643643
missingArg(n)
644644
}
645645

646-
if (formal.isRepeatedParam)
646+
val formal1 = formal.stripFlexible
647+
if (formal1.isRepeatedParam)
647648
args match {
648649
case arg :: Nil if isVarArg(arg) =>
649650
addTyped(arg)
650651
case (arg @ Typed(Literal(Constant(null)), _)) :: Nil if ctx.isAfterTyper =>
651652
addTyped(arg)
652653
case _ =>
653-
val elemFormal = formal.widenExpr.argTypesLo.head
654+
val elemFormal = formal1.widenExpr.argTypesLo.head
654655
val typedArgs =
655656
harmonic(harmonizeArgs, elemFormal) {
656657
args.map { arg =>

0 commit comments

Comments
 (0)