Skip to content

Commit d36f43b

Browse files
committed
Allow cross references in collective extensions
1 parent b85bd2d commit d36f43b

File tree

5 files changed

+84
-36
lines changed

5 files changed

+84
-36
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 59 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -860,20 +860,21 @@ object desugar {
860860
*
861861
* where every definition in `body` is expanded to an extension method
862862
* taking type parameters `tparams` and a leading paramter `(x: T)`.
863-
* See: makeExtensionDef
863+
* See: collectiveExtensionBody
864864
*/
865865
def moduleDef(mdef: ModuleDef)(implicit ctx: Context): Tree = {
866866
val impl = mdef.impl
867867
val mods = mdef.mods
868868
impl.constr match {
869-
case DefDef(_, tparams, (vparams @ (vparam :: Nil)) :: givenParamss, _, _) =>
869+
case DefDef(_, tparams, vparamss @ (vparam :: Nil) :: givenParamss, _, _) =>
870+
// Transform collective extension
870871
assert(mods.is(Given))
871872
return moduleDef(
872873
cpy.ModuleDef(mdef)(
873874
mdef.name,
874875
cpy.Template(impl)(
875876
constr = emptyConstructor,
876-
body = impl.body.map(makeExtensionDef(_, tparams, vparams, givenParamss)))))
877+
body = collectiveExtensionBody(impl.body, tparams, vparamss))))
877878
case _ =>
878879
}
879880

@@ -916,38 +917,67 @@ object desugar {
916917
}
917918
}
918919

919-
/** Given tpe parameters `Ts` (possibly empty) and a leading value parameter `(x: T)`,
920-
* map a method definition
920+
/** Transform the statements of a collective extension
921+
* @param stats the original statements as they were parsed
922+
* @param tparams the collective type parameters
923+
* @param vparamss the collective value parameters, consisting
924+
* of a single leading value parameter, followed by
925+
* zero or more context parameter clauses
921926
*
922-
* def foo [Us] paramss ...
927+
* Note: It is already assured by Parser.checkExtensionMethod that all
928+
* statements conform to requirements.
923929
*
924-
* to
930+
* Each method in stats is transformed into an extension method. Furthermore,
931+
* identifier references to other methods are turned into selections on the common
932+
* parameter.
933+
*
934+
* Example:
925935
*
926-
* <extension> def foo[Ts ++ Us](x: T) parammss ...
936+
* extension on [Ts](x: T)(using C):
937+
* def f(y: T) = ???
938+
* def g(z: T) = f(z)
927939
*
928-
* If the given member `mdef` is not of this form, flag it as an error.
940+
* is turned into
941+
*
942+
* extension:
943+
* <extension> def f[Ts](x: T)(using C)(y: T) = ???
944+
* <extension> def g[Ts](x: T)(using C)(z: T) = x.f(z)
929945
*/
930-
931-
def makeExtensionDef(mdef: Tree, tparams: List[TypeDef], leadingParams: List[ValDef],
932-
givenParamss: List[List[ValDef]])(using ctx: Context): Tree = {
933-
val allowed = "allowed here, since collective parameters are given"
934-
mdef match {
935-
case mdef: DefDef =>
936-
if (mdef.mods.is(Extension)) {
937-
ctx.error(em"No extension method $allowed", mdef.sourcePos)
946+
def collectiveExtensionBody(stats: List[Tree],
947+
tparams: List[TypeDef], vparamss: List[List[ValDef]])(using ctx: Context): List[Tree] =
948+
val methodNames: Set[Name] =
949+
stats.collect { case stat: DefDef => stat.name }.toSet
950+
951+
object linkMethods extends UntypedTreeMap:
952+
private val paramName = vparamss.head.head.name
953+
private var prefixName = paramName
954+
955+
override def transform(tree: Tree)(using Context): Tree = tree match
956+
case tree: NamedDefTree if tree.name == paramName =>
957+
prefixName = UniqueName.fresh()
958+
super.transform(tree)
959+
case tree: Ident if methodNames.contains(tree.name) =>
960+
cpy.Select(tree)(Ident(prefixName), tree.name)
961+
case _ =>
962+
super.transform(tree)
963+
964+
def apply(rhs: Tree): Tree =
965+
val rhs1 = transform(rhs)
966+
if prefixName == paramName then rhs1
967+
else Block(ValDef(prefixName, TypeTree(), Ident(paramName)), rhs1)
968+
end linkMethods
969+
970+
for stat <- stats yield
971+
stat match
972+
case mdef: DefDef =>
973+
cpy.DefDef(mdef)(
974+
tparams = tparams ++ mdef.tparams,
975+
vparamss = vparamss ::: mdef.vparamss,
976+
rhs = linkMethods(mdef.rhs)
977+
).withMods(mdef.mods | Extension)
978+
case mdef =>
938979
mdef
939-
}
940-
else cpy.DefDef(mdef)(
941-
tparams = tparams ++ mdef.tparams,
942-
vparamss = leadingParams :: givenParamss ::: mdef.vparamss
943-
).withMods(mdef.mods | Extension)
944-
case mdef: Import =>
945-
mdef
946-
case mdef =>
947-
ctx.error(em"Only methods $allowed", mdef.sourcePos)
948-
mdef
949-
}
950-
}
980+
end collectiveExtensionBody
951981

952982
/** Transforms
953983
*

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3484,6 +3484,8 @@ object Parsers {
34843484
else if tparams.nonEmpty && stat.tparams.nonEmpty then
34853485
syntaxError(i"extension method cannot have type parameters since some were already given previously",
34863486
stat.tparams.head.span)
3487+
else if stat.rhs.isEmpty then
3488+
syntaxError(i"extension method cannot be abstract", stat.span)
34873489
case stat =>
34883490
syntaxError(i"extension clause can only define methods", stat.span)
34893491
}

compiler/src/dotty/tools/dotc/typer/Nullables.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ object Nullables:
170170
case info :: infos1 =>
171171
if info.asserted.contains(ref) then true
172172
else if info.retracted.contains(ref) then false
173-
else impliesNotNull(infos1)(ref)
173+
else infos1.impliesNotNull(ref)
174174
case _ =>
175175
false
176176

library/src/scala/tasty/Reflection.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -566,12 +566,12 @@ class Reflection(private[scala] val internal: CompilerInterface) { self =>
566566
Some((cdef.name, cdef.constructor, cdef.parents, cdef.derived, cdef.self, cdef.body))
567567
}
568568

569-
extension ClassDefOps on (self: ClassDef) {
570-
def constructor(using ctx: Context): DefDef = internal.ClassDef_constructor(self)
571-
def parents(using ctx: Context): List[Tree /* Term | TypeTree */] = internal.ClassDef_parents(self)
572-
def derived(using ctx: Context): List[TypeTree] = internal.ClassDef_derived(self)
573-
def self(using ctx: Context): Option[ValDef] = internal.ClassDef_self(self)
574-
def body(using ctx: Context): List[Statement] = internal.ClassDef_body(self)
569+
extension ClassDefOps on (_this: ClassDef) {
570+
def constructor(using ctx: Context): DefDef = internal.ClassDef_constructor(_this)
571+
def parents(using ctx: Context): List[Tree /* Term | TypeTree */] = internal.ClassDef_parents(_this)
572+
def derived(using ctx: Context): List[TypeTree] = internal.ClassDef_derived(_this)
573+
def self(using ctx: Context): Option[ValDef] = internal.ClassDef_self(_this)
574+
def body(using ctx: Context): List[Statement] = internal.ClassDef_body(_this)
575575
}
576576

577577
// DefDef

tests/run/collective-extensions.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
extension on (x: String):
2+
def foo(y: String): String = x ++ y
3+
def bar(y: String): String = foo(y)
4+
def baz(y: String): String =
5+
val x = y
6+
bar(x)
7+
def bam(y: String): String = this.baz(x)(y)
8+
def app(n: Int, suffix: String): String =
9+
if n == 0 then x ++ suffix
10+
else app(n - 1, suffix ++ suffix)
11+
12+
@main def Test =
13+
assert("abc".bar("def") == "abcdef")
14+
assert("abc".baz("def") == "abcdef")
15+
assert("abc".bam("def") == "abcdef")
16+
assert("abc".app(3, "!") == "abc!!!!!!!!")

0 commit comments

Comments
 (0)