Skip to content

Add class script wrapper #2033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions modules/build/src/main/scala/scala/build/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import scala.build.errors.*
import scala.build.input.VirtualScript.VirtualScriptNameRegex
import scala.build.input.*
import scala.build.internal.resource.ResourceMapper
import scala.build.internal.{Constants, CustomCodeWrapper, MainClass, Util}
import scala.build.internal.{Constants, MainClass, Util}
import scala.build.options.ScalaVersionUtil.asVersion
import scala.build.options.*
import scala.build.options.validation.ValidationException
Expand Down Expand Up @@ -227,7 +227,6 @@ object Build {
CrossSources.forInputs(
inputs,
Sources.defaultPreprocessors(
options.scriptOptions.codeWrapper.getOrElse(CustomCodeWrapper),
options.archiveCache,
options.internal.javaClassNameVersionOpt,
() => options.javaHome().value.javaCommand
Expand Down Expand Up @@ -266,8 +265,11 @@ object Build {
overrideOptions: BuildOptions
): Either[BuildException, NonCrossBuilds] = either {

val baseOptions = overrideOptions.orElse(sharedOptions)
val scopedSources = value(crossSources.scopedSources(baseOptions))
val baseOptions = overrideOptions.orElse(sharedOptions)

val wrappedScriptsSources = crossSources.withWrappedScripts(baseOptions)

val scopedSources = value(wrappedScriptsSources.scopedSources(baseOptions))

val mainSources = scopedSources.sources(Scope.Main, baseOptions)
val mainOptions = mainSources.buildOptions
Expand Down
112 changes: 103 additions & 9 deletions modules/build/src/main/scala/scala/build/CrossSources.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,94 @@ import scala.build.testrunner.DynamicTestRunner.globPattern
import scala.util.Try
import scala.util.chaining.*

final case class CrossSources(
/** CrossSources with unwrapped scripts, use [[withWrappedScripts]] to wrap them and obtain an
* instance of CrossSources
*
* See [[CrossSources]] for more information
*
* @param paths
* paths and realtive paths to sources on disk, wrapped in their build requirements
* @param inMemory
* in memory sources (e.g. snippets) wrapped in their build requirements
* @param defaultMainClass
* @param resourceDirs
* @param buildOptions
* build options from sources
* @param unwrappedScripts
* in memory script sources, their code must be wrapped before compiling
*/
sealed class UnwrappedCrossSources(
paths: Seq[WithBuildRequirements[(os.Path, os.RelPath)]],
inMemory: Seq[WithBuildRequirements[Sources.InMemory]],
defaultMainClass: Option[String],
resourceDirs: Seq[WithBuildRequirements[os.Path]],
buildOptions: Seq[WithBuildRequirements[BuildOptions]]
buildOptions: Seq[WithBuildRequirements[BuildOptions]],
unwrappedScripts: Seq[WithBuildRequirements[Sources.UnwrappedScript]]
) {

/** For all unwrapped script sources contained in this object wrap them according to provided
* BuildOptions
*
* @param buildOptions
* options used to choose the script wrapper
* @return
* CrossSources with all the scripts wrapped
*/
def withWrappedScripts(buildOptions: BuildOptions): CrossSources = {
val codeWrapper = ScriptPreprocessor.getScriptWrapper(buildOptions)

val wrappedScripts = unwrappedScripts.map { unwrapppedWithRequirements =>
unwrapppedWithRequirements.map(_.wrap(codeWrapper))
}

CrossSources(
paths,
inMemory ++ wrappedScripts,
defaultMainClass,
resourceDirs,
this.buildOptions
)
}

def sharedOptions(baseOptions: BuildOptions): BuildOptions =
buildOptions
.filter(_.requirements.isEmpty)
.map(_.value)
.foldLeft(baseOptions)(_ orElse _)

private def needsScalaVersion =
protected def needsScalaVersion =
paths.exists(_.needsScalaVersion) ||
inMemory.exists(_.needsScalaVersion) ||
resourceDirs.exists(_.needsScalaVersion) ||
buildOptions.exists(_.needsScalaVersion)
}

/** Information gathered from preprocessing command inputs - sources and build options from using
* directives
*
* @param paths
* paths and realtive paths to sources on disk, wrapped in their build requirements
* @param inMemory
* in memory sources (e.g. snippets and wrapped scripts) wrapped in their build requirements
* @param defaultMainClass
* @param resourceDirs
* @param buildOptions
* build options from sources
*/
final case class CrossSources(
paths: Seq[WithBuildRequirements[(os.Path, os.RelPath)]],
inMemory: Seq[WithBuildRequirements[Sources.InMemory]],
defaultMainClass: Option[String],
resourceDirs: Seq[WithBuildRequirements[os.Path]],
buildOptions: Seq[WithBuildRequirements[BuildOptions]]
) extends UnwrappedCrossSources(
paths,
inMemory,
defaultMainClass,
resourceDirs,
buildOptions,
Nil
) {
def scopedSources(baseOptions: BuildOptions): Either[BuildException, ScopedSources] = either {

val sharedOptions0 = sharedOptions(baseOptions)
Expand Down Expand Up @@ -114,7 +182,6 @@ final case class CrossSources(
crossSources0.buildOptions.map(_.scopedValue(defaultScope))
)
}

}

object CrossSources {
Expand All @@ -141,7 +208,7 @@ object CrossSources {
suppressWarningOptions: SuppressWarningOptions,
exclude: Seq[Positioned[String]] = Nil,
maybeRecoverOnError: BuildException => Option[BuildException] = e => Some(e)
)(using ScalaCliInvokeData): Either[BuildException, (CrossSources, Inputs)] = either {
)(using ScalaCliInvokeData): Either[BuildException, (UnwrappedCrossSources, Inputs)] = either {

def preprocessSources(elems: Seq[SingleElement])
: Either[BuildException, Seq[PreprocessedSource]] =
Expand Down Expand Up @@ -262,6 +329,16 @@ object CrossSources {
Sources.InMemory(m.originalPath, m.relPath, m.code, m.ignoreLen)
) -> m.directivesPositions
}
val unwrappedScriptsWithDirectivePositions
: Seq[(WithBuildRequirements[Sources.UnwrappedScript], Option[DirectivesPositions])] =
preprocessedSources.collect {
case m: PreprocessedSource.UnwrappedScript =>
val baseReqs0 = baseReqs(m.scopePath)
WithBuildRequirements(
m.requirements.fold(baseReqs0)(_ orElse baseReqs0),
Sources.UnwrappedScript(m.originalPath, m.relPath, m.wrapScriptFun)
) -> m.directivesPositions
}

val resourceDirs: Seq[WithBuildRequirements[os.Path]] = allInputs.elements.collect {
case r: ResourceDirectory =>
Expand All @@ -271,14 +348,20 @@ object CrossSources {
)

lazy val allPathsWithDirectivesByScope: Map[Scope, Seq[(os.Path, DirectivesPositions)]] =
(pathsWithDirectivePositions ++ inMemoryWithDirectivePositions)
(pathsWithDirectivePositions ++
inMemoryWithDirectivePositions ++
unwrappedScriptsWithDirectivePositions)
.flatMap { (withBuildRequirements, directivesPositions) =>
val scope = withBuildRequirements.scopedValue(Scope.Main).scope
val path: os.Path = withBuildRequirements.value match
case im: Sources.InMemory =>
im.originalPath match
case Right((_, p: os.Path)) => p
case _ => inputs.workspace / im.generatedRelPath
case us: Sources.UnwrappedScript =>
us.originalPath match
case Right((_, p: os.Path)) => p
case _ => inputs.workspace / us.generatedRelPath
case (p: os.Path, _) => p
directivesPositions.map((path, scope, _))
}
Expand Down Expand Up @@ -306,9 +389,20 @@ object CrossSources {
}
}

val paths = pathsWithDirectivePositions.map(_._1)
val inMemory = inMemoryWithDirectivePositions.map(_._1)
(CrossSources(paths, inMemory, defaultMainClassOpt, resourceDirs, buildOptions), allInputs)
val paths = pathsWithDirectivePositions.map(_._1)
val inMemory = inMemoryWithDirectivePositions.map(_._1)
val unwrappedScripts = unwrappedScriptsWithDirectivePositions.map(_._1)
(
UnwrappedCrossSources(
paths,
inMemory,
defaultMainClassOpt,
resourceDirs,
buildOptions,
unwrappedScripts
),
allInputs
)
}

private def resolveInputsFromSources(sources: Seq[Positioned[os.Path]], enableMarkdown: Boolean) =
Expand Down
14 changes: 12 additions & 2 deletions modules/build/src/main/scala/scala/build/Sources.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,17 @@ object Sources {
topWrapperLen: Int
)

final case class UnwrappedScript(
originalPath: Either[String, (os.SubPath, os.Path)],
generatedRelPath: os.RelPath,
wrapScriptFun: CodeWrapper => (String, Int)
) {
def wrap(wrapper: CodeWrapper): InMemory = {
val (content, topWrapperLen) = wrapScriptFun(wrapper)
InMemory(originalPath, generatedRelPath, content, topWrapperLen)
}
}

/** The default preprocessor list.
*
* @param codeWrapper
Expand All @@ -86,13 +97,12 @@ object Sources {
* @return
*/
def defaultPreprocessors(
codeWrapper: CodeWrapper,
archiveCache: ArchiveCache[Task],
javaClassNameVersionOpt: Option[String],
javaCommand: () => String
): Seq[Preprocessor] =
Seq(
ScriptPreprocessor(codeWrapper),
ScriptPreprocessor,
MarkdownPreprocessor,
JavaPreprocessor(archiveCache, javaClassNameVersionOpt, javaCommand),
ScalaPreprocessor,
Expand Down
13 changes: 12 additions & 1 deletion modules/build/src/main/scala/scala/build/bsp/BspClient.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package scala.build.bsp

import ch.epfl.scala.{bsp4j => b}
import ch.epfl.scala.bsp4j as b

import java.lang.Boolean as JBoolean
import java.net.URI
Expand All @@ -10,6 +10,7 @@ import java.util.concurrent.{ConcurrentHashMap, ExecutorService}
import scala.build.Position.File
import scala.build.bsp.protocol.TextEdit
import scala.build.errors.{BuildException, CompositeBuildException, Diagnostic, Severity}
import scala.build.internal.util.WarningMessages
import scala.build.postprocessing.LineConversion
import scala.build.{BloopBuildClient, GeneratedSource, Logger}
import scala.jdk.CollectionConverters.*
Expand Down Expand Up @@ -48,6 +49,16 @@ class BspClient(
val diag0 = diag.duplicate()
diag0.getRange.getStart.setLine(startLine)
diag0.getRange.getEnd.setLine(endLine)

if (
diag0.getMessage.contains(
"cannot be a main method since it cannot be accessed statically"
)
)
diag0.setMessage(
WarningMessages.mainAnnotationNotSupported( /* annotationIgnored */ false)
)

diag0
}
updatedDiagOpt.getOrElse(diag)
Expand Down
16 changes: 10 additions & 6 deletions modules/build/src/main/scala/scala/build/bsp/BspImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.build.errors.{
ParsingInputsException
}
import scala.build.input.{Inputs, ScalaCliInvokeData}
import scala.build.internal.{Constants, CustomCodeWrapper}
import scala.build.internal.Constants
import scala.build.options.{BuildOptions, Scope}
import scala.collection.mutable.ListBuffer
import scala.concurrent.duration.DurationInt
Expand Down Expand Up @@ -101,7 +101,6 @@ final class BspImpl(
CrossSources.forInputs(
inputs = inputs,
preprocessors = Sources.defaultPreprocessors(
buildOptions.scriptOptions.codeWrapper.getOrElse(CustomCodeWrapper),
buildOptions.archiveCache,
buildOptions.internal.javaClassNameVersionOpt,
() => buildOptions.javaHome().value.javaCommand
Expand All @@ -113,16 +112,21 @@ final class BspImpl(
).left.map((_, Scope.Main))
}

val wrappedScriptsSources = crossSources.withWrappedScripts(buildOptions)

if (verbosity >= 3)
pprint.err.log(crossSources)
pprint.err.log(wrappedScriptsSources)

val scopedSources = value(crossSources.scopedSources(buildOptions).left.map((_, Scope.Main)))
val scopedSources =
value(wrappedScriptsSources.scopedSources(buildOptions).left.map((_, Scope.Main)))

if (verbosity >= 3)
pprint.err.log(scopedSources)

val sourcesMain = scopedSources.sources(Scope.Main, crossSources.sharedOptions(buildOptions))
val sourcesTest = scopedSources.sources(Scope.Test, crossSources.sharedOptions(buildOptions))
val sourcesMain =
scopedSources.sources(Scope.Main, wrappedScriptsSources.sharedOptions(buildOptions))
val sourcesTest =
scopedSources.sources(Scope.Test, wrappedScriptsSources.sharedOptions(buildOptions))

if (verbosity >= 3)
pprint.err.log(sourcesMain)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package scala.build.internal

/** Script code wrapper that solves problem of deadlocks when using threads. The code is placed in a
* class instance constructor, the created object is kept in 'mainObjectCode'.script to support
* running interconnected scripts using Scala CLI <br> <br> Incompatible with Scala 2 - it uses
* Scala 3 feature 'export'<br> Incompatible with native JS members - the wrapper is a class
*/
case object ClassCodeWrapper extends CodeWrapper {
private val userCodeNestingLevel = 1
def apply(
code: String,
pkgName: Seq[Name],
indexedWrapperName: Name,
extraCode: String,
scriptPath: String
) = {
val name = CodeWrapper.mainClassObject(indexedWrapperName).backticked
val wrapperClassName = Name(indexedWrapperName.raw ++ "$_").backticked
val mainObjectCode =
AmmUtil.normalizeNewlines(s"""|object $name {
| private var args$$opt0 = Option.empty[Array[String]]
| def args$$set(args: Array[String]): Unit = {
| args$$opt0 = Some(args)
| }
| def args$$opt: Option[Array[String]] = args$$opt0
| def args$$: Array[String] = args$$opt.getOrElse {
| sys.error("No arguments passed to this script")
| }
|
| lazy val script = new $wrapperClassName
|
| def main(args: Array[String]): Unit = {
| args$$set(args)
| script.hashCode() // hashCode to clear scalac warning about pure expression in statement position
| }
|}
|
|export $name.script as ${indexedWrapperName.backticked}
|""".stripMargin)

val packageDirective =
if (pkgName.isEmpty) "" else s"package ${AmmUtil.encodeScalaSourcePath(pkgName)}" + "\n"

// indentation is important in the generated code, so we don't want scalafmt to touch that
// format: off
val top = AmmUtil.normalizeNewlines(s"""
$packageDirective


final class $wrapperClassName {
def args = $name.args$$
def scriptPath = \"\"\"$scriptPath\"\"\"
""")
val bottom = AmmUtil.normalizeNewlines(s"""
$extraCode
}

$mainObjectCode
""")
// format: on

(top, bottom, userCodeNestingLevel)
}
}
Loading