Skip to content

Commit b448db2

Browse files
Merge pull request #8318 from dotty-staging/add-extension-instance
Add extension instances
2 parents aea5f3c + e732ee5 commit b448db2

20 files changed

+234
-166
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+32-34
Original file line numberDiff line numberDiff line change
@@ -859,21 +859,22 @@ object desugar {
859859
* given object name extends parents { self => body' }
860860
*
861861
* where every definition in `body` is expanded to an extension method
862-
* taking type parameters `tparams` and a leading parameter `(x: T)`.
863-
* See: makeExtensionDef
862+
* taking type parameters `tparams` and a leading paramter `(x: T)`.
863+
* See: collectiveExtensionBody
864864
*/
865865
def moduleDef(mdef: ModuleDef)(implicit ctx: Context): Tree = {
866866
val impl = mdef.impl
867867
val mods = mdef.mods
868868
impl.constr match {
869-
case DefDef(_, tparams, (vparams @ (vparam :: Nil)) :: givenParamss, _, _) =>
869+
case DefDef(_, tparams, vparamss @ (vparam :: Nil) :: givenParamss, _, _) =>
870+
// Transform collective extension
870871
assert(mods.is(Given))
871872
return moduleDef(
872873
cpy.ModuleDef(mdef)(
873874
mdef.name,
874875
cpy.Template(impl)(
875876
constr = emptyConstructor,
876-
body = impl.body.map(makeExtensionDef(_, tparams, vparams, givenParamss)))))
877+
body = collectiveExtensionBody(impl.body, tparams, vparamss))))
877878
case _ =>
878879
}
879880

@@ -916,43 +917,40 @@ object desugar {
916917
}
917918
}
918919

919-
/** Given tpe parameters `Ts` (possibly empty) and a leading value parameter `(x: T)`,
920-
* map a method definition
920+
/** Transform the statements of a collective extension
921+
* @param stats the original statements as they were parsed
922+
* @param tparams the collective type parameters
923+
* @param vparamss the collective value parameters, consisting
924+
* of a single leading value parameter, followed by
925+
* zero or more context parameter clauses
921926
*
922-
* def foo [Us] paramss ...
927+
* Note: It is already assured by Parser.checkExtensionMethod that all
928+
* statements conform to requirements.
923929
*
924-
* to
930+
* Each method in stats is transformed into an extension method. Example:
931+
*
932+
* extension on [Ts](x: T)(using C):
933+
* def f(y: T) = ???
934+
* def g(z: T) = f(z)
925935
*
926-
* <extension> def foo[Ts ++ Us](x: T) parammss ...
936+
* is turned into
927937
*
928-
* If the given member `mdef` is not of this form, flag it as an error.
938+
* extension:
939+
* <extension> def f[Ts](x: T)(using C)(y: T) = ???
940+
* <extension> def g[Ts](x: T)(using C)(z: T) = f(z)
929941
*/
930-
931-
def makeExtensionDef(mdef: Tree, tparams: List[TypeDef], leadingParams: List[ValDef],
932-
givenParamss: List[List[ValDef]])(using ctx: Context): Tree = {
933-
mdef match {
934-
case mdef: DefDef =>
935-
if (mdef.mods.is(Extension)) {
936-
ctx.error(NoExtensionMethodAllowed(mdef), mdef.sourcePos)
937-
mdef
938-
} else {
939-
if (tparams.nonEmpty && mdef.tparams.nonEmpty) then
940-
ctx.error(ExtensionMethodCannotHaveTypeParams(mdef), mdef.tparams.head.sourcePos)
941-
mdef
942-
else cpy.DefDef(mdef)(
942+
def collectiveExtensionBody(stats: List[Tree],
943+
tparams: List[TypeDef], vparamss: List[List[ValDef]])(using ctx: Context): List[Tree] =
944+
for stat <- stats yield
945+
stat match
946+
case mdef: DefDef =>
947+
cpy.DefDef(mdef)(
943948
tparams = tparams ++ mdef.tparams,
944-
vparamss = leadingParams :: givenParamss ::: mdef.vparamss
949+
vparamss = vparamss ::: mdef.vparamss,
945950
).withMods(mdef.mods | Extension)
946-
}
947-
case mdef: Import =>
948-
mdef
949-
case mdef if !mdef.isEmpty => {
950-
ctx.error(ExtensionCanOnlyHaveDefs(mdef), mdef.sourcePos)
951-
mdef
952-
}
953-
case mdef => mdef
954-
}
955-
}
951+
case mdef =>
952+
mdef
953+
end collectiveExtensionBody
956954

957955
/** Transforms
958956
*

compiler/src/dotty/tools/dotc/core/Flags.scala

+3-3
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ object Flags {
301301
/** A method that has default params */
302302
val (_, DefaultParameterized @ _, _) = newFlags(27, "<defaultparam>")
303303

304-
/** An extension method */
304+
/** An extension method, or a collective extension instance */
305305
val (_, Extension @ _, _) = newFlags(28, "<extension>")
306306

307307
/** An inferable (`given`) parameter */
@@ -499,14 +499,14 @@ object Flags {
499499

500500
/** Flags that can apply to a module val */
501501
val RetainedModuleValFlags: FlagSet = RetainedModuleValAndClassFlags |
502-
Override | Final | Method | Implicit | Given | Lazy |
502+
Override | Final | Method | Implicit | Given | Lazy | Extension |
503503
Accessor | AbsOverride | StableRealizable | Captured | Synchronized | Erased
504504

505505
/** Flags that can apply to a module class */
506506
val RetainedModuleClassFlags: FlagSet = RetainedModuleValAndClassFlags | Enum
507507

508508
/** Flags retained in export forwarders */
509-
val RetainedExportFlags = Given | Implicit | Extension | Inline
509+
val RetainedExportFlags = Given | Implicit | Inline
510510

511511
/** Flags that apply only to classes */
512512
val ClassOnlyFlags = Sealed | Open | Abstract.toTypeFlags

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+34-8
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,11 @@ object Parsers {
931931
lookahead.nextToken()
932932
if lookahead.isIdent && !lookahead.isIdent(nme.on) then
933933
lookahead.nextToken()
934+
if lookahead.isNewLine then
935+
lookahead.nextToken()
934936
lookahead.isIdent(nme.on)
937+
|| lookahead.token == LBRACE
938+
|| lookahead.token == COLON
935939

936940
/* --------- OPERAND/OPERATOR STACK --------------------------------------- */
937941

@@ -3471,6 +3475,23 @@ object Parsers {
34713475
Template(constr, parents, Nil, EmptyValDef, Nil)
34723476
}
34733477

3478+
def checkExtensionMethod(tparams: List[Tree],
3479+
vparamss: List[List[Tree]], stat: Tree): Unit = stat match {
3480+
case stat: DefDef =>
3481+
if stat.mods.is(Extension) && vparamss.nonEmpty then
3482+
syntaxError(i"no extension method allowed here since leading parameter was already given", stat.span)
3483+
else if !stat.mods.is(Extension) && vparamss.isEmpty then
3484+
syntaxError(i"an extension method is required here", stat.span)
3485+
else if tparams.nonEmpty && stat.tparams.nonEmpty then
3486+
syntaxError(i"extension method cannot have type parameters since some were already given previously",
3487+
stat.tparams.head.span)
3488+
else if stat.rhs.isEmpty then
3489+
syntaxError(i"extension method cannot be abstract", stat.span)
3490+
case EmptyTree =>
3491+
case stat =>
3492+
syntaxError(i"extension clause can only define methods", stat.span)
3493+
}
3494+
34743495
/** GivenDef ::= [GivenSig] [‘_’ ‘<:’] Type ‘=’ Expr
34753496
* | [GivenSig] ConstrApps [TemplateBody]
34763497
* GivenSig ::= [id] [DefTypeParamClause] {UsingParamClauses} ‘as’
@@ -3517,22 +3538,27 @@ object Parsers {
35173538
finalizeDef(gdef, mods1, start)
35183539
}
35193540

3520-
/** ExtensionDef ::= [id] ‘on’ ExtParamClause {UsingParamClause} ExtMethods
3541+
/** ExtensionDef ::= [id] [‘on’ ExtParamClause {UsingParamClause}] TemplateBody
35213542
*/
35223543
def extensionDef(start: Offset, mods: Modifiers): ModuleDef =
35233544
in.nextToken()
35243545
val name = if isIdent && !isIdent(nme.on) then ident() else EmptyTermName
35253546
in.endMarkerScope(if name.isEmpty then nme.extension else name) {
3526-
if !isIdent(nme.on) then syntaxErrorOrIncomplete("`on` expected")
3527-
if isIdent(nme.on) then in.nextToken()
3528-
val tparams = typeParamClauseOpt(ParamOwner.Def)
3529-
val extParams = paramClause(0, prefix = true)
3530-
val givenParamss = paramClauses(givenOnly = true)
3547+
val (tparams, vparamss, extensionFlag) =
3548+
if isIdent(nme.on) then
3549+
in.nextToken()
3550+
val tparams = typeParamClauseOpt(ParamOwner.Def)
3551+
val extParams = paramClause(0, prefix = true)
3552+
val givenParamss = paramClauses(givenOnly = true)
3553+
(tparams, extParams :: givenParamss, Extension)
3554+
else
3555+
(Nil, Nil, EmptyFlags)
35313556
possibleTemplateStart()
35323557
if !in.isNestedStart then syntaxError("Extension without extension methods")
3533-
val templ = templateBodyOpt(makeConstructor(tparams, extParams :: givenParamss), Nil, Nil)
3558+
val templ = templateBodyOpt(makeConstructor(tparams, vparamss), Nil, Nil)
3559+
templ.body.foreach(checkExtensionMethod(tparams, vparamss, _))
35343560
val edef = ModuleDef(name, templ)
3535-
finalizeDef(edef, addFlag(mods, Given), start)
3561+
finalizeDef(edef, addFlag(mods, Given | extensionFlag), start)
35363562
}
35373563

35383564
/* -------- TEMPLATES ------------------------------------------- */

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

+3
Original file line numberDiff line numberDiff line change
@@ -220,4 +220,7 @@ class SymUtils(val self: Symbol) extends AnyVal {
220220
/** Is symbol a splice operation? */
221221
def isSplice(implicit ctx: Context): Boolean =
222222
self == defn.InternalQuoted_exprSplice || self == defn.QuotedType_splice
223+
224+
def isCollectiveExtensionClass(using Context): Boolean =
225+
self.is(ModuleClass) && self.sourceModule.is(Extension, butNot = Method)
223226
}

compiler/src/dotty/tools/dotc/typer/ImportSuggestions.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ trait ImportSuggestions:
205205
.alternatives
206206
.map(mbr => TermRef(site, mbr.symbol))
207207
.filter(ref =>
208-
ref.symbol.is(Extension)
208+
ref.symbol.isAllOf(ExtensionMethod)
209209
&& isApplicableMethodRef(ref, argType :: Nil, WildcardType))
210210
.headOption
211211

compiler/src/dotty/tools/dotc/typer/Namer.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -1096,7 +1096,8 @@ class Namer { typer: Typer =>
10961096
(StableRealizable, ExprType(path.tpe.select(sym)))
10971097
else
10981098
(EmptyFlags, mbr.info.ensureMethodic)
1099-
val mbrFlags = Exported | Method | Final | maybeStable | sym.flags & RetainedExportFlags
1099+
var mbrFlags = Exported | Method | Final | maybeStable | sym.flags & RetainedExportFlags
1100+
if sym.isAllOf(ExtensionMethod) then mbrFlags |= Extension
11001101
val forwarderName = checkNoConflict(alias, isPrivate = false, span)
11011102
ctx.newSymbol(cls, forwarderName, mbrFlags, mbrInfo, coord = span)
11021103
}

compiler/src/dotty/tools/dotc/typer/Nullables.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ object Nullables:
170170
case info :: infos1 =>
171171
if info.asserted.contains(ref) then true
172172
else if info.retracted.contains(ref) then false
173-
else impliesNotNull(infos1)(ref)
173+
else infos1.impliesNotNull(ref)
174174
case _ =>
175175
false
176176

compiler/src/dotty/tools/dotc/typer/RefChecks.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,9 @@ object RefChecks {
394394
overrideError("is erased, cannot override non-erased member")
395395
else if (other.is(Erased) && !member.isOneOf(Erased | Inline)) // (1.9.1)
396396
overrideError("is not erased, cannot override erased member")
397-
else if (member.is(Extension) && !other.is(Extension)) // (1.9.2)
397+
else if (member.isAllOf(ExtensionMethod) && !other.isAllOf(ExtensionMethod)) // (1.9.2)
398398
overrideError("is an extension method, cannot override a normal method")
399-
else if (other.is(Extension) && !member.is(Extension)) // (1.9.2)
399+
else if (other.isAllOf(ExtensionMethod) && !member.isAllOf(ExtensionMethod)) // (1.9.2)
400400
overrideError("is a normal method, cannot override an extension method")
401401
else if ((member.isInlineMethod || member.isScala2Macro) && other.is(Deferred) &&
402402
member.extendedOverriddenSymbols.forall(_.is(Deferred))) // (1.10)

compiler/src/dotty/tools/dotc/typer/Typer.scala

+21-2
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,19 @@ class Typer extends Namer
401401
if (name == nme.ROOTPKG)
402402
return tree.withType(defn.RootPackage.termRef)
403403

404+
/** Convert a reference `f` to an extension method in a collective extension
405+
* on parameter `x` to `x.f`
406+
*/
407+
def extensionMethodSelect(xmethod: Symbol): untpd.Tree =
408+
val leadParamName = xmethod.info.paramNamess.head.head
409+
def isLeadParam(sym: Symbol) =
410+
sym.is(Param) && sym.owner.owner == xmethod.owner && sym.name == leadParamName
411+
def leadParam(ctx: Context): Symbol =
412+
ctx.scope.lookupAll(leadParamName).find(isLeadParam) match
413+
case Some(param) => param
414+
case None => leadParam(ctx.outersIterator.dropWhile(_.scope eq ctx.scope).next)
415+
untpd.cpy.Select(tree)(untpd.ref(leadParam(ctx).termRef), name)
416+
404417
val rawType = {
405418
val saved1 = unimported
406419
val saved2 = foundUnderScala2
@@ -441,8 +454,14 @@ class Typer extends Namer
441454
errorType(new MissingIdent(tree, kind, name.show), tree.sourcePos)
442455

443456
val tree1 = ownType match {
444-
case ownType: NamedType if !prefixIsElidable(ownType) =>
445-
ref(ownType).withSpan(tree.span)
457+
case ownType: NamedType =>
458+
val sym = ownType.symbol
459+
if sym.isAllOf(ExtensionMethod)
460+
&& sym.owner.isCollectiveExtensionClass
461+
&& ctx.owner.isContainedIn(sym.owner)
462+
then typed(extensionMethodSelect(sym), pt)
463+
else if prefixIsElidable(ownType) then tree.withType(ownType)
464+
else ref(ownType).withSpan(tree.span)
446465
case _ =>
447466
tree.withType(ownType)
448467
}

compiler/test/dotty/tools/dotc/reporting/ErrorMessagesTests.scala

-68
Original file line numberDiff line numberDiff line change
@@ -1840,74 +1840,6 @@ class ErrorMessagesTests extends ErrorMessagesTest {
18401840
assertEquals("given x @ String", x.show)
18411841
}
18421842

1843-
@Test def extensionMethodsNotAllowed =
1844-
checkMessagesAfter(RefChecks.name) {
1845-
"""object Test {
1846-
| extension on[T] (t: T) {
1847-
| def (c: T).f: T = ???
1848-
| }
1849-
|}
1850-
""".stripMargin
1851-
}
1852-
.expect { (ictx, messages)
1853-
implicit val ctx: Context = ictx
1854-
assertMessageCount(1, messages)
1855-
val errorMsg = messages.head.msg
1856-
val NoExtensionMethodAllowed(x) :: Nil = messages
1857-
assertEquals("No extension method allowed here, since collective parameters are given", errorMsg)
1858-
assertEquals("def (c: T) f: T = ???", x.show)
1859-
}
1860-
1861-
@Test def extensionMethodTypeParamsNotAllowed =
1862-
checkMessagesAfter(RefChecks.name) {
1863-
"""object Test {
1864-
| extension on[T] (t: T) {
1865-
| def f[U](u: U): T = ???
1866-
| }
1867-
|}
1868-
""".stripMargin
1869-
}
1870-
.expect { (ictx, messages)
1871-
implicit val ctx: Context = ictx
1872-
assertMessageCount(1, messages)
1873-
val errorMsg = messages.head.msg
1874-
val ExtensionMethodCannotHaveTypeParams(x) :: Nil = messages
1875-
assertEquals("Extension method cannot have type parameters since some were already given previously", errorMsg)
1876-
assertEquals("def f[U](u: U): T = ???", x.show)
1877-
}
1878-
1879-
@Test def extensionMethodCanOnlyHaveDefs =
1880-
checkMessagesAfter(RefChecks.name) {
1881-
"""object Test {
1882-
| extension on[T] (t: T) {
1883-
| val v: T = t
1884-
| }
1885-
|}
1886-
""".stripMargin
1887-
}
1888-
.expect { (ictx, messages)
1889-
implicit val ctx: Context = ictx
1890-
assertMessageCount(1, messages)
1891-
val errorMsg = messages.head.msg
1892-
val ExtensionCanOnlyHaveDefs(x) :: Nil = messages
1893-
assertEquals("Only methods allowed here, since collective parameters are given", errorMsg)
1894-
assertEquals("val v: T = t", x.show)
1895-
}
1896-
1897-
@Test def anonymousInstanceMustImplementAType =
1898-
checkMessagesAfter(RefChecks.name) {
1899-
"""object Test {
1900-
| extension on[T] (t: T) { }
1901-
|}
1902-
""".stripMargin
1903-
}
1904-
.expect { (ictx, messages)
1905-
implicit val ctx: Context = ictx
1906-
assertMessageCount(1, messages)
1907-
val errorMsg = messages.head.msg
1908-
assertEquals("anonymous instance must implement a type or have at least one extension method", errorMsg)
1909-
}
1910-
19111843
@Test def typeSplicesInValPatterns =
19121844
checkMessagesAfter(RefChecks.name) {
19131845
s"""import scala.quoted._

docs/docs/internals/syntax.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,8 @@ EnumDef ::= id ClassConstr InheritClauses EnumBody
386386
GivenDef ::= [GivenSig] [‘_’ ‘<:’] Type ‘=’ Expr
387387
| [GivenSig] ConstrApps [TemplateBody]
388388
GivenSig ::= [id] [DefTypeParamClause] {UsingParamClause} ‘as’
389-
ExtensionDef ::= [id] ‘on’ ExtParamClause {WithParamsOrTypes} ExtMethods
389+
ExtensionDef ::= [id] [‘on’ ExtParamClause {UsingParamClause}]
390+
TemplateBody
390391
ExtMethods ::= [nl] ‘{’ ‘def’ DefDef {semi ‘def’ DefDef} ‘}’
391392
ExtParamClause ::= [DefTypeParamClause] ‘(’ DefParam ‘)’
392393
Template ::= InheritClauses [TemplateBody] Template(constr, parents, self, stats)

0 commit comments

Comments
 (0)