Skip to content

Add extension instances #8318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 3, 2020
66 changes: 32 additions & 34 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -859,21 +859,22 @@ object desugar {
* given object name extends parents { self => body' }
*
* where every definition in `body` is expanded to an extension method
* taking type parameters `tparams` and a leading parameter `(x: T)`.
* See: makeExtensionDef
* taking type parameters `tparams` and a leading paramter `(x: T)`.
* See: collectiveExtensionBody
*/
def moduleDef(mdef: ModuleDef)(implicit ctx: Context): Tree = {
val impl = mdef.impl
val mods = mdef.mods
impl.constr match {
case DefDef(_, tparams, (vparams @ (vparam :: Nil)) :: givenParamss, _, _) =>
case DefDef(_, tparams, vparamss @ (vparam :: Nil) :: givenParamss, _, _) =>
// Transform collective extension
assert(mods.is(Given))
return moduleDef(
cpy.ModuleDef(mdef)(
mdef.name,
cpy.Template(impl)(
constr = emptyConstructor,
body = impl.body.map(makeExtensionDef(_, tparams, vparams, givenParamss)))))
body = collectiveExtensionBody(impl.body, tparams, vparamss))))
case _ =>
}

Expand Down Expand Up @@ -916,43 +917,40 @@ object desugar {
}
}

/** Given tpe parameters `Ts` (possibly empty) and a leading value parameter `(x: T)`,
* map a method definition
/** Transform the statements of a collective extension
* @param stats the original statements as they were parsed
* @param tparams the collective type parameters
* @param vparamss the collective value parameters, consisting
* of a single leading value parameter, followed by
* zero or more context parameter clauses
*
* def foo [Us] paramss ...
* Note: It is already assured by Parser.checkExtensionMethod that all
* statements conform to requirements.
*
* to
* Each method in stats is transformed into an extension method. Example:
*
* extension on [Ts](x: T)(using C):
* def f(y: T) = ???
* def g(z: T) = f(z)
*
* <extension> def foo[Ts ++ Us](x: T) parammss ...
* is turned into
*
* If the given member `mdef` is not of this form, flag it as an error.
* extension:
* <extension> def f[Ts](x: T)(using C)(y: T) = ???
* <extension> def g[Ts](x: T)(using C)(z: T) = f(z)
*/

def makeExtensionDef(mdef: Tree, tparams: List[TypeDef], leadingParams: List[ValDef],
givenParamss: List[List[ValDef]])(using ctx: Context): Tree = {
mdef match {
case mdef: DefDef =>
if (mdef.mods.is(Extension)) {
ctx.error(NoExtensionMethodAllowed(mdef), mdef.sourcePos)
mdef
} else {
if (tparams.nonEmpty && mdef.tparams.nonEmpty) then
ctx.error(ExtensionMethodCannotHaveTypeParams(mdef), mdef.tparams.head.sourcePos)
mdef
else cpy.DefDef(mdef)(
def collectiveExtensionBody(stats: List[Tree],
tparams: List[TypeDef], vparamss: List[List[ValDef]])(using ctx: Context): List[Tree] =
for stat <- stats yield
stat match
case mdef: DefDef =>
cpy.DefDef(mdef)(
tparams = tparams ++ mdef.tparams,
vparamss = leadingParams :: givenParamss ::: mdef.vparamss
vparamss = vparamss ::: mdef.vparamss,
).withMods(mdef.mods | Extension)
}
case mdef: Import =>
mdef
case mdef if !mdef.isEmpty => {
ctx.error(ExtensionCanOnlyHaveDefs(mdef), mdef.sourcePos)
mdef
}
case mdef => mdef
}
}
case mdef =>
mdef
end collectiveExtensionBody

/** Transforms
*
Expand Down
6 changes: 3 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ object Flags {
/** A method that has default params */
val (_, DefaultParameterized @ _, _) = newFlags(27, "<defaultparam>")

/** An extension method */
/** An extension method, or a collective extension instance */
val (_, Extension @ _, _) = newFlags(28, "<extension>")

/** An inferable (`given`) parameter */
Expand Down Expand Up @@ -499,14 +499,14 @@ object Flags {

/** Flags that can apply to a module val */
val RetainedModuleValFlags: FlagSet = RetainedModuleValAndClassFlags |
Override | Final | Method | Implicit | Given | Lazy |
Override | Final | Method | Implicit | Given | Lazy | Extension |
Accessor | AbsOverride | StableRealizable | Captured | Synchronized | Erased

/** Flags that can apply to a module class */
val RetainedModuleClassFlags: FlagSet = RetainedModuleValAndClassFlags | Enum

/** Flags retained in export forwarders */
val RetainedExportFlags = Given | Implicit | Extension | Inline
val RetainedExportFlags = Given | Implicit | Inline

/** Flags that apply only to classes */
val ClassOnlyFlags = Sealed | Open | Abstract.toTypeFlags
Expand Down
42 changes: 34 additions & 8 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,11 @@ object Parsers {
lookahead.nextToken()
if lookahead.isIdent && !lookahead.isIdent(nme.on) then
lookahead.nextToken()
if lookahead.isNewLine then
lookahead.nextToken()
lookahead.isIdent(nme.on)
|| lookahead.token == LBRACE
|| lookahead.token == COLON

/* --------- OPERAND/OPERATOR STACK --------------------------------------- */

Expand Down Expand Up @@ -3470,6 +3474,23 @@ object Parsers {
Template(constr, parents, Nil, EmptyValDef, Nil)
}

def checkExtensionMethod(tparams: List[Tree],
vparamss: List[List[Tree]], stat: Tree): Unit = stat match {
case stat: DefDef =>
if stat.mods.is(Extension) && vparamss.nonEmpty then
syntaxError(i"no extension method allowed here since leading parameter was already given", stat.span)
else if !stat.mods.is(Extension) && vparamss.isEmpty then
syntaxError(i"an extension method is required here", stat.span)
else if tparams.nonEmpty && stat.tparams.nonEmpty then
syntaxError(i"extension method cannot have type parameters since some were already given previously",
stat.tparams.head.span)
else if stat.rhs.isEmpty then
syntaxError(i"extension method cannot be abstract", stat.span)
case EmptyTree =>
case stat =>
syntaxError(i"extension clause can only define methods", stat.span)
}

/** GivenDef ::= [GivenSig] [‘_’ ‘<:’] Type ‘=’ Expr
* | [GivenSig] ConstrApps [TemplateBody]
* GivenSig ::= [id] [DefTypeParamClause] {UsingParamClauses} ‘as’
Expand Down Expand Up @@ -3516,22 +3537,27 @@ object Parsers {
finalizeDef(gdef, mods1, start)
}

/** ExtensionDef ::= [id] ‘on’ ExtParamClause {UsingParamClause} ExtMethods
/** ExtensionDef ::= [id] [‘on’ ExtParamClause {UsingParamClause}] TemplateBody
*/
def extensionDef(start: Offset, mods: Modifiers): ModuleDef =
in.nextToken()
val name = if isIdent && !isIdent(nme.on) then ident() else EmptyTermName
in.endMarkerScope(if name.isEmpty then nme.extension else name) {
if !isIdent(nme.on) then syntaxErrorOrIncomplete("`on` expected")
if isIdent(nme.on) then in.nextToken()
val tparams = typeParamClauseOpt(ParamOwner.Def)
val extParams = paramClause(0, prefix = true)
val givenParamss = paramClauses(givenOnly = true)
val (tparams, vparamss, extensionFlag) =
if isIdent(nme.on) then
in.nextToken()
val tparams = typeParamClauseOpt(ParamOwner.Def)
val extParams = paramClause(0, prefix = true)
val givenParamss = paramClauses(givenOnly = true)
(tparams, extParams :: givenParamss, Extension)
else
(Nil, Nil, EmptyFlags)
possibleTemplateStart()
if !in.isNestedStart then syntaxError("Extension without extension methods")
val templ = templateBodyOpt(makeConstructor(tparams, extParams :: givenParamss), Nil, Nil)
val templ = templateBodyOpt(makeConstructor(tparams, vparamss), Nil, Nil)
templ.body.foreach(checkExtensionMethod(tparams, vparamss, _))
val edef = ModuleDef(name, templ)
finalizeDef(edef, addFlag(mods, Given), start)
finalizeDef(edef, addFlag(mods, Given | extensionFlag), start)
}

/* -------- TEMPLATES ------------------------------------------- */
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,4 +220,7 @@ class SymUtils(val self: Symbol) extends AnyVal {
/** Is symbol a splice operation? */
def isSplice(implicit ctx: Context): Boolean =
self == defn.InternalQuoted_exprSplice || self == defn.QuotedType_splice

def isCollectiveExtensionClass(using Context): Boolean =
self.is(ModuleClass) && self.sourceModule.is(Extension, butNot = Method)
}
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ trait ImportSuggestions:
.alternatives
.map(mbr => TermRef(site, mbr.symbol))
.filter(ref =>
ref.symbol.is(Extension)
ref.symbol.isAllOf(ExtensionMethod)
&& isApplicableMethodRef(ref, argType :: Nil, WildcardType))
.headOption

Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,8 @@ class Namer { typer: Typer =>
(StableRealizable, ExprType(path.tpe.select(sym)))
else
(EmptyFlags, mbr.info.ensureMethodic)
val mbrFlags = Exported | Method | Final | maybeStable | sym.flags & RetainedExportFlags
var mbrFlags = Exported | Method | Final | maybeStable | sym.flags & RetainedExportFlags
if sym.isAllOf(ExtensionMethod) then mbrFlags |= Extension
val forwarderName = checkNoConflict(alias, isPrivate = false, span)
ctx.newSymbol(cls, forwarderName, mbrFlags, mbrInfo, coord = span)
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Nullables.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ object Nullables:
case info :: infos1 =>
if info.asserted.contains(ref) then true
else if info.retracted.contains(ref) then false
else impliesNotNull(infos1)(ref)
else infos1.impliesNotNull(ref)
case _ =>
false

Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/RefChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,9 @@ object RefChecks {
overrideError("is erased, cannot override non-erased member")
else if (other.is(Erased) && !member.isOneOf(Erased | Inline)) // (1.9.1)
overrideError("is not erased, cannot override erased member")
else if (member.is(Extension) && !other.is(Extension)) // (1.9.2)
else if (member.isAllOf(ExtensionMethod) && !other.isAllOf(ExtensionMethod)) // (1.9.2)
overrideError("is an extension method, cannot override a normal method")
else if (other.is(Extension) && !member.is(Extension)) // (1.9.2)
else if (other.isAllOf(ExtensionMethod) && !member.isAllOf(ExtensionMethod)) // (1.9.2)
overrideError("is a normal method, cannot override an extension method")
else if ((member.isInlineMethod || member.isScala2Macro) && other.is(Deferred) &&
member.extendedOverriddenSymbols.forall(_.is(Deferred))) // (1.10)
Expand Down
23 changes: 21 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,19 @@ class Typer extends Namer
if (name == nme.ROOTPKG)
return tree.withType(defn.RootPackage.termRef)

/** Convert a reference `f` to an extension method in a collective extension
* on parameter `x` to `x.f`
*/
def extensionMethodSelect(xmethod: Symbol): untpd.Tree =
val leadParamName = xmethod.info.paramNamess.head.head
def isLeadParam(sym: Symbol) =
sym.is(Param) && sym.owner.owner == xmethod.owner && sym.name == leadParamName
def leadParam(ctx: Context): Symbol =
ctx.scope.lookupAll(leadParamName).find(isLeadParam) match
case Some(param) => param
case None => leadParam(ctx.outersIterator.dropWhile(_.scope eq ctx.scope).next)
untpd.cpy.Select(tree)(untpd.ref(leadParam(ctx).termRef), name)

val rawType = {
val saved1 = unimported
val saved2 = foundUnderScala2
Expand Down Expand Up @@ -441,8 +454,14 @@ class Typer extends Namer
errorType(new MissingIdent(tree, kind, name.show), tree.sourcePos)

val tree1 = ownType match {
case ownType: NamedType if !prefixIsElidable(ownType) =>
ref(ownType).withSpan(tree.span)
case ownType: NamedType =>
val sym = ownType.symbol
if sym.isAllOf(ExtensionMethod)
&& sym.owner.isCollectiveExtensionClass
&& ctx.owner.isContainedIn(sym.owner)
then typed(extensionMethodSelect(sym), pt)
else if prefixIsElidable(ownType) then tree.withType(ownType)
else ref(ownType).withSpan(tree.span)
case _ =>
tree.withType(ownType)
}
Expand Down
68 changes: 0 additions & 68 deletions compiler/test/dotty/tools/dotc/reporting/ErrorMessagesTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1840,74 +1840,6 @@ class ErrorMessagesTests extends ErrorMessagesTest {
assertEquals("given x @ String", x.show)
}

@Test def extensionMethodsNotAllowed =
checkMessagesAfter(RefChecks.name) {
"""object Test {
| extension on[T] (t: T) {
| def (c: T).f: T = ???
| }
|}
""".stripMargin
}
.expect { (ictx, messages) ⇒
implicit val ctx: Context = ictx
assertMessageCount(1, messages)
val errorMsg = messages.head.msg
val NoExtensionMethodAllowed(x) :: Nil = messages
assertEquals("No extension method allowed here, since collective parameters are given", errorMsg)
assertEquals("def (c: T) f: T = ???", x.show)
}

@Test def extensionMethodTypeParamsNotAllowed =
checkMessagesAfter(RefChecks.name) {
"""object Test {
| extension on[T] (t: T) {
| def f[U](u: U): T = ???
| }
|}
""".stripMargin
}
.expect { (ictx, messages) ⇒
implicit val ctx: Context = ictx
assertMessageCount(1, messages)
val errorMsg = messages.head.msg
val ExtensionMethodCannotHaveTypeParams(x) :: Nil = messages
assertEquals("Extension method cannot have type parameters since some were already given previously", errorMsg)
assertEquals("def f[U](u: U): T = ???", x.show)
}

@Test def extensionMethodCanOnlyHaveDefs =
checkMessagesAfter(RefChecks.name) {
"""object Test {
| extension on[T] (t: T) {
| val v: T = t
| }
|}
""".stripMargin
}
.expect { (ictx, messages) ⇒
implicit val ctx: Context = ictx
assertMessageCount(1, messages)
val errorMsg = messages.head.msg
val ExtensionCanOnlyHaveDefs(x) :: Nil = messages
assertEquals("Only methods allowed here, since collective parameters are given", errorMsg)
assertEquals("val v: T = t", x.show)
}

@Test def anonymousInstanceMustImplementAType =
checkMessagesAfter(RefChecks.name) {
"""object Test {
| extension on[T] (t: T) { }
|}
""".stripMargin
}
.expect { (ictx, messages) ⇒
implicit val ctx: Context = ictx
assertMessageCount(1, messages)
val errorMsg = messages.head.msg
assertEquals("anonymous instance must implement a type or have at least one extension method", errorMsg)
}

@Test def typeSplicesInValPatterns =
checkMessagesAfter(RefChecks.name) {
s"""import scala.quoted._
Expand Down
3 changes: 2 additions & 1 deletion docs/docs/internals/syntax.md
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ EnumDef ::= id ClassConstr InheritClauses EnumBody
GivenDef ::= [GivenSig] [‘_’ ‘<:’] Type ‘=’ Expr
| [GivenSig] ConstrApps [TemplateBody]
GivenSig ::= [id] [DefTypeParamClause] {UsingParamClause} ‘as’
ExtensionDef ::= [id] ‘on’ ExtParamClause {WithParamsOrTypes} ExtMethods
ExtensionDef ::= [id] [‘on’ ExtParamClause {UsingParamClause}]
TemplateBody
ExtMethods ::= [nl] ‘{’ ‘def’ DefDef {semi ‘def’ DefDef} ‘}’
ExtParamClause ::= [DefTypeParamClause] ‘(’ DefParam ‘)’
Template ::= InheritClauses [TemplateBody] Template(constr, parents, self, stats)
Expand Down
Loading