@@ -860,20 +860,21 @@ object desugar {
860
860
*
861
861
* where every definition in `body` is expanded to an extension method
862
862
* taking type parameters `tparams` and a leading paramter `(x: T)`.
863
- * See: makeExtensionDef
863
+ * See: collectiveExtensionBody
864
864
*/
865
865
def moduleDef (mdef : ModuleDef )(implicit ctx : Context ): Tree = {
866
866
val impl = mdef.impl
867
867
val mods = mdef.mods
868
868
impl.constr match {
869
- case DefDef (_, tparams, (vparams @ (vparam :: Nil )) :: givenParamss, _, _) =>
869
+ case DefDef (_, tparams, vparamss @ (vparam :: Nil ) :: givenParamss, _, _) =>
870
+ // Transform collective extension
870
871
assert(mods.is(Given ))
871
872
return moduleDef(
872
873
cpy.ModuleDef (mdef)(
873
874
mdef.name,
874
875
cpy.Template (impl)(
875
876
constr = emptyConstructor,
876
- body = impl.body.map(makeExtensionDef(_ , tparams, vparams, givenParamss) ))))
877
+ body = collectiveExtensionBody( impl.body, tparams, vparamss ))))
877
878
case _ =>
878
879
}
879
880
@@ -916,38 +917,67 @@ object desugar {
916
917
}
917
918
}
918
919
919
- /** Given tpe parameters `Ts` (possibly empty) and a leading value parameter `(x: T)`,
920
- * map a method definition
920
+ /** Transform the statements of a collective extension
921
+ * @param stats the original statements as they were parsed
922
+ * @param tparams the collective type parameters
923
+ * @param vparamss the collective value parameters, consisting
924
+ * of a single leading value parameter, followed by
925
+ * zero or more context parameter clauses
921
926
*
922
- * def foo [Us] paramss ...
927
+ * Note: It is already assured by Parser.checkExtensionMethod that all
928
+ * statements conform to requirements.
923
929
*
924
- * to
930
+ * Each method in stats is transformed into an extension method. Furthermore,
931
+ * identifier references to other methods are turned into selections on the common
932
+ * parameter.
933
+ *
934
+ * Example:
925
935
*
926
- * <extension> def foo[Ts ++ Us](x: T) parammss ...
936
+ * extension on [Ts](x: T)(using C):
937
+ * def f(y: T) = ???
938
+ * def g(z: T) = f(z)
927
939
*
928
- * If the given member `mdef` is not of this form, flag it as an error.
940
+ * is turned into
941
+ *
942
+ * extension:
943
+ * <extension> def f[Ts](x: T)(using C)(y: T) = ???
944
+ * <extension> def g[Ts](x: T)(using C)(z: T) = x.f(z)
929
945
*/
930
-
931
- def makeExtensionDef (mdef : Tree , tparams : List [TypeDef ], leadingParams : List [ValDef ],
932
- givenParamss : List [List [ValDef ]])(using ctx : Context ): Tree = {
933
- val allowed = " allowed here, since collective parameters are given"
934
- mdef match {
935
- case mdef : DefDef =>
936
- if (mdef.mods.is(Extension )) {
937
- ctx.error(em " No extension method $allowed" , mdef.sourcePos)
946
+ def collectiveExtensionBody (stats : List [Tree ],
947
+ tparams : List [TypeDef ], vparamss : List [List [ValDef ]])(using ctx : Context ): List [Tree ] =
948
+ val methodNames : Set [Name ] =
949
+ stats.collect { case stat : DefDef => stat.name }.toSet
950
+
951
+ object linkMethods extends UntypedTreeMap :
952
+ private val paramName = vparamss.head.head.name
953
+ private var prefixName = paramName
954
+
955
+ override def transform (tree : Tree )(using Context ): Tree = tree match
956
+ case tree : NamedDefTree if tree.name == paramName =>
957
+ prefixName = UniqueName .fresh()
958
+ super .transform(tree)
959
+ case tree : Ident if methodNames.contains(tree.name) =>
960
+ cpy.Select (tree)(Ident (prefixName), tree.name)
961
+ case _ =>
962
+ super .transform(tree)
963
+
964
+ def apply (rhs : Tree ): Tree =
965
+ val rhs1 = transform(rhs)
966
+ if prefixName == paramName then rhs1
967
+ else Block (ValDef (prefixName, TypeTree (), Ident (paramName)), rhs1)
968
+ end linkMethods
969
+
970
+ for stat <- stats yield
971
+ stat match
972
+ case mdef : DefDef =>
973
+ cpy.DefDef (mdef)(
974
+ tparams = tparams ++ mdef.tparams,
975
+ vparamss = vparamss ::: mdef.vparamss,
976
+ rhs = linkMethods(mdef.rhs)
977
+ ).withMods(mdef.mods | Extension )
978
+ case mdef =>
938
979
mdef
939
- }
940
- else cpy.DefDef (mdef)(
941
- tparams = tparams ++ mdef.tparams,
942
- vparamss = leadingParams :: givenParamss ::: mdef.vparamss
943
- ).withMods(mdef.mods | Extension )
944
- case mdef : Import =>
945
- mdef
946
- case mdef =>
947
- ctx.error(em " Only methods $allowed" , mdef.sourcePos)
948
- mdef
949
- }
950
- }
980
+ end collectiveExtensionBody
951
981
952
982
/** Transforms
953
983
*
0 commit comments