|
| 1 | +/* |
| 2 | + * Use of this source code is governed by the MIT license that can be |
| 3 | + * found in the LICENSE file. |
| 4 | + */ |
| 5 | + |
| 6 | +package org.rust.ide.injected |
| 7 | + |
| 8 | +import com.intellij.injected.editor.VirtualFileWindow |
| 9 | +import com.intellij.lang.injection.MultiHostInjector |
| 10 | +import com.intellij.lang.injection.MultiHostRegistrar |
| 11 | +import com.intellij.openapi.project.Project |
| 12 | +import com.intellij.openapi.util.TextRange |
| 13 | +import com.intellij.openapi.vfs.VirtualFile |
| 14 | +import com.intellij.psi.PsiElement |
| 15 | +import com.intellij.psi.tree.IElementType |
| 16 | +import com.intellij.util.text.CharArrayUtil |
| 17 | +import org.rust.cargo.project.settings.rustSettings |
| 18 | +import org.rust.cargo.project.workspace.PackageOrigin |
| 19 | +import org.rust.cargo.util.AutoInjectedCrates |
| 20 | +import org.rust.lang.RsLanguage |
| 21 | +import org.rust.lang.core.psi.RS_DOC_COMMENTS |
| 22 | +import org.rust.lang.core.psi.RsDocCommentImpl |
| 23 | +import org.rust.lang.core.psi.RsFile |
| 24 | +import org.rust.lang.core.psi.ext.* |
| 25 | +import org.rust.lang.doc.psi.RsDocKind |
| 26 | +import org.rust.openapiext.toPsiFile |
| 27 | +import org.rust.stdext.nextOrNull |
| 28 | +import java.util.regex.Pattern |
| 29 | + |
| 30 | +// See https://github.com/rust-lang/rust/blob/5182cc1c/src/librustdoc/html/markdown.rs#L646 |
| 31 | +private val LANG_SPLIT_REGEX = Pattern.compile("[^\\w-]+", Pattern.UNICODE_CHARACTER_CLASS) |
| 32 | +private val RUST_LANG_ALIASES = listOf( |
| 33 | + "rust", |
| 34 | + "allow_fail", |
| 35 | + "should_panic", |
| 36 | + "no_run", |
| 37 | + "test_harness", |
| 38 | +// "compile_fail", // don't highlight code that is expected to contain errors |
| 39 | + "edition2018", |
| 40 | + "edition2015" |
| 41 | +) |
| 42 | + |
| 43 | +class RsDoctestLanguageInjector : MultiHostInjector { |
| 44 | + private data class CodeRange(val start: Int, val end: Int, val codeStart: Int) { |
| 45 | + fun isCodeNotEmpty(): Boolean = codeStart + 1 < end |
| 46 | + |
| 47 | + val indent: Int = codeStart - start |
| 48 | + |
| 49 | + fun offsetIndent(indent: Int): CodeRange? = |
| 50 | + if (start + indent < end) CodeRange(start + indent, end, codeStart) else null |
| 51 | + } |
| 52 | + |
| 53 | + override fun elementsToInjectIn(): List<Class<out PsiElement>> = |
| 54 | + listOf(RsDocCommentImpl::class.java) |
| 55 | + |
| 56 | + override fun getLanguagesToInject(registrar: MultiHostRegistrar, context: PsiElement) { |
| 57 | + if (context !is RsDocCommentImpl) return |
| 58 | + if (!context.isValidHost || context.elementType !in RS_DOC_COMMENTS) return |
| 59 | + if (!context.project.rustSettings.doctestInjectionEnabled) return |
| 60 | + |
| 61 | + val rsElement = context.ancestorStrict<RsElement>() ?: return |
| 62 | + val cargoTarget = rsElement.containingCargoTarget ?: return |
| 63 | + if (!cargoTarget.isLib) return // only library targets can have doctests |
| 64 | + val crateName = cargoTarget.normName |
| 65 | + val text = context.text |
| 66 | + |
| 67 | + findDoctestInjectableRanges(text, context.elementType).map { ranges -> |
| 68 | + ranges.map { |
| 69 | + CodeRange( |
| 70 | + it.startOffset, |
| 71 | + it.endOffset, |
| 72 | + CharArrayUtil.shiftForward(text, it.startOffset, it.endOffset, " \t") |
| 73 | + ) |
| 74 | + } |
| 75 | + }.map { ranges -> |
| 76 | + val commonIndent = ranges.filter { it.isCodeNotEmpty() }.map { it.indent }.min() |
| 77 | + val indentedRanges = if (commonIndent != null) ranges.mapNotNull { it.offsetIndent(commonIndent) } else ranges |
| 78 | + |
| 79 | + indentedRanges.map { (start, end, codeStart) -> |
| 80 | + // `cargo doc` has special rules for code lines which start with `#`: |
| 81 | + // * `# ` prefix is used to mark lines that should be skipped in rendered documentation. |
| 82 | + // * `##` prefix is converted to `#` |
| 83 | + // See https://github.com/rust-lang/rust/blob/5182cc1c/src/librustdoc/html/markdown.rs#L114 |
| 84 | + when { |
| 85 | + text.startsWith("##", codeStart) -> TextRange(codeStart + 1, end) |
| 86 | + text.startsWith("# ", codeStart) -> TextRange(codeStart + 2, end) |
| 87 | + else -> TextRange(start, end) |
| 88 | + } |
| 89 | + } |
| 90 | + }.forEach { ranges -> |
| 91 | + val inj = registrar.startInjecting(RsLanguage) |
| 92 | + |
| 93 | + ranges.forEachIndexed { index, range -> |
| 94 | + val isFirstIteration = index == 0 |
| 95 | + val isLastIteration = index == ranges.size - 1 |
| 96 | + |
| 97 | + val prefix = if (isFirstIteration) { |
| 98 | + buildString { |
| 99 | + // Yes, we want to skip the only "std" crate. Not core/alloc/etc, the "std" only |
| 100 | + val isStdCrate = crateName == AutoInjectedCrates.STD && |
| 101 | + cargoTarget.pkg.origin == PackageOrigin.STDLIB |
| 102 | + if (!isStdCrate) { |
| 103 | + append("extern crate ") |
| 104 | + append(crateName) |
| 105 | + append("; ") |
| 106 | + } |
| 107 | + append("fn main() {") |
| 108 | + } |
| 109 | + } else { |
| 110 | + null |
| 111 | + } |
| 112 | + val suffix = if (isLastIteration) "}" else null |
| 113 | + |
| 114 | + inj.addPlace(prefix, suffix, context, range) |
| 115 | + } |
| 116 | + |
| 117 | + inj.doneInjecting() |
| 118 | + } |
| 119 | + } |
| 120 | +} |
| 121 | + |
| 122 | +fun findDoctestInjectableRanges(comment: RsDocCommentImpl): Sequence<List<TextRange>> = |
| 123 | + findDoctestInjectableRanges(comment.text, comment.elementType) |
| 124 | + |
| 125 | +private fun findDoctestInjectableRanges(text: String, elementType: IElementType): Sequence<List<TextRange>> { |
| 126 | + // TODO use markdown parser |
| 127 | + val tripleBacktickIndices = text.indicesOf("```").toList() |
| 128 | + if (tripleBacktickIndices.size < 2) return emptySequence() // no code blocks in the comment |
| 129 | + |
| 130 | + val infix = RsDocKind.of(elementType).infix |
| 131 | + |
| 132 | + return tripleBacktickIndices.asSequence().chunked(2).mapNotNull { idx -> |
| 133 | + // Contains code lines inside backticks including `///` at the start and `\n` at the end. |
| 134 | + // It doesn't contain the last line with /// ``` |
| 135 | + val lines = run { |
| 136 | + val codeBlockStart = idx[0] + 3 // skip ``` |
| 137 | + val codeBlockEnd = idx.getOrNull(1) ?: return@mapNotNull null |
| 138 | + generateSequence(codeBlockStart) { text.indexOf("\n", it) + 1 } |
| 139 | + .takeWhile { it != 0 && it < codeBlockEnd } |
| 140 | + .zipWithNext() |
| 141 | + .iterator() |
| 142 | + } |
| 143 | + |
| 144 | + // ```rust, should_panic, edition2018 |
| 145 | + // ^ this text |
| 146 | + val lang = lines.nextOrNull()?.let { text.substring(it.first, it.second - 1) } ?: return@mapNotNull null |
| 147 | + if (lang.isNotEmpty()) { |
| 148 | + val parts = lang.split(LANG_SPLIT_REGEX).filter { it.isNotBlank() } |
| 149 | + if (parts.any { it !in RUST_LANG_ALIASES }) return@mapNotNull null |
| 150 | + } |
| 151 | + |
| 152 | + // skip doc comment infix (`///`, `//!` or ` * `) |
| 153 | + val ranges = lines.asSequence().mapNotNull { (lineStart, lineEnd) -> |
| 154 | + val index = text.indexOf(infix, lineStart) |
| 155 | + if (index != -1 && index < lineEnd) { |
| 156 | + val start = index + infix.length |
| 157 | + TextRange(start, lineEnd) |
| 158 | + } else { |
| 159 | + null |
| 160 | + } |
| 161 | + }.toList() |
| 162 | + |
| 163 | + if (ranges.isEmpty()) return@mapNotNull null |
| 164 | + ranges |
| 165 | + } |
| 166 | +} |
| 167 | + |
| 168 | +private fun String.indicesOf(s: String): Sequence<Int> = |
| 169 | + generateSequence(indexOf(s)) { indexOf(s, it + s.length) }.takeWhile { it != -1 } |
| 170 | + |
| 171 | +fun VirtualFile.isDoctestInjection(project: Project): Boolean { |
| 172 | + val virtualFileWindow = this as? VirtualFileWindow ?: return false |
| 173 | + val hostFile = virtualFileWindow.delegate.toPsiFile(project) as? RsFile ?: return false |
| 174 | + val hostElement = hostFile.findElementAt(virtualFileWindow.documentWindow.injectedToHost(0)) ?: return false |
| 175 | + return hostElement.elementType in RS_DOC_COMMENTS |
| 176 | +} |
| 177 | + |
| 178 | +val RsFile.isDoctestInjection: Boolean |
| 179 | + get() = virtualFile?.isDoctestInjection(project) == true |
| 180 | + |
| 181 | +val RsElement.isDoctestInjection: Boolean |
| 182 | + get() = (contextualFile as? RsFile)?.isDoctestInjection == true |
0 commit comments