Skip to content

Commit

Permalink
Updated checker and added tests for strict mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr3zee committed Dec 3, 2024
1 parent 1ed006e commit 9847059
Show file tree
Hide file tree
Showing 6 changed files with 871 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ fun AbstractCliOption.processAsStrictModeOption(value: String, configuration: Co
return false
}

private fun String.toStrictMode(): StrictMode? {
fun String.toStrictMode(): StrictMode? {
return when (lowercase()) {
"none" -> StrictMode.NONE
"warning" -> StrictMode.WARNING
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@ import org.jetbrains.kotlin.fir.analysis.checkers.MppCheckerKind
import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
import org.jetbrains.kotlin.fir.analysis.checkers.declaration.FirClassChecker
import org.jetbrains.kotlin.fir.analysis.checkers.expression.FirFunctionCallChecker
import org.jetbrains.kotlin.fir.analysis.checkers.extractArgumentsTypeRefAndSource
import org.jetbrains.kotlin.fir.analysis.checkers.toClassLikeSymbol
import org.jetbrains.kotlin.fir.declarations.FirClass
import org.jetbrains.kotlin.fir.declarations.FirProperty
import org.jetbrains.kotlin.fir.declarations.FirSimpleFunction
import org.jetbrains.kotlin.fir.declarations.utils.isSuspend
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.extensions.predicateBasedProvider
import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
import org.jetbrains.kotlin.fir.scopes.impl.toConeType
import org.jetbrains.kotlin.fir.symbols.impl.FirClassLikeSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
import org.jetbrains.kotlin.fir.types.FirTypeRef
import org.jetbrains.kotlin.fir.types.coneType
import org.jetbrains.kotlin.utils.memoryOptimizedMap
import org.jetbrains.kotlin.utils.memoryOptimizedPlus
Expand Down Expand Up @@ -101,16 +106,14 @@ class FirRpcStrictModeClassChecker(private val ctx: FirCheckersContext) : FirCla

val types = function.valueParameters.memoryOptimizedMap { parameter ->
parameter.source to vsApi {
parameter.returnTypeRef.coneType.toClassSymbolVS(context.session)
parameter.returnTypeRef
}
} memoryOptimizedPlus (function.returnTypeRef.source to returnClassSymbol)
} memoryOptimizedPlus (function.returnTypeRef.source to function.returnTypeRef)

types.filter { (_, symbol) ->
symbol != null
}.forEach { (source, symbol) ->
checkSerializableTypes<FirClassSymbol<*>>(
types.forEach { (source, symbol) ->
checkSerializableTypes<FirClassLikeSymbol<*>>(
context = context,
clazz = symbol!!,
typeRef = symbol,
serializablePropertiesProvider = serializablePropertiesProvider,
) { symbol, parents ->
when (symbol.classId) {
Expand Down Expand Up @@ -142,28 +145,54 @@ class FirRpcStrictModeClassChecker(private val ctx: FirCheckersContext) : FirCla

private fun <ContextElement> checkSerializableTypes(
context: CheckerContext,
clazz: FirClassSymbol<*>,
typeRef: FirTypeRef,
serializablePropertiesProvider: FirSerializablePropertiesProvider,
parentContext: List<ContextElement> = emptyList(),
checker: (FirClassSymbol<*>, List<ContextElement>) -> ContextElement?,
checker: (FirClassLikeSymbol<*>, List<ContextElement>) -> ContextElement?,
) {
val newElement = checker(clazz, parentContext)
val symbol = typeRef.toClassLikeSymbol(context.session) ?: return
val newElement = checker(symbol, parentContext)
val nextContext = if (newElement != null) {
parentContext memoryOptimizedPlus newElement
} else {
parentContext
}

serializablePropertiesProvider.getSerializablePropertiesForClass(clazz)
if (symbol !is FirClassSymbol<*>) {
return
}

val extracted = extractArgumentsTypeRefAndSource(typeRef)
.orEmpty()
.withIndex()
.associate { (i, refSource) ->
symbol.typeParameterSymbols[i].toConeType() to refSource.typeRef
}

val flowProps: List<FirTypeRef> = if (symbol.classId == RpcClassId.flow) {
listOf<FirTypeRef>(extracted.values.toList()[0]!!)
} else {
emptyList()
}

serializablePropertiesProvider.getSerializablePropertiesForClass(symbol)
.serializableProperties
.mapNotNull { property ->
vsApi {
property.propertySymbol.resolvedReturnType.toClassSymbolVS(context.session)
val resolvedTypeRef = property.propertySymbol.resolvedReturnTypeRef
val result = if (resolvedTypeRef.toClassLikeSymbol(context.session) != null) {
resolvedTypeRef
} else {
extracted[property.propertySymbol.resolvedReturnType]
}
if (result == null) {
print(1)
}
}.forEach { symbol ->
result
}.memoryOptimizedPlus(flowProps)
.forEach { symbol ->
checkSerializableTypes(
context = context,
clazz = symbol,
typeRef = symbol,
serializablePropertiesProvider = serializablePropertiesProvider,
parentContext = nextContext,
checker = checker,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,10 @@ public void testCheckedAnnotation() {
public void testRpcChecked() {
runTest("src/testData/diagnostics/rpcChecked.kt");
}

@Test
@TestMetadata("strictMode.kt")
public void testStrictMode() {
runTest("src/testData/diagnostics/strictMode.kt");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,46 @@

package kotlinx.rpc.codegen.test.services

import kotlinx.rpc.codegen.StrictMode
import kotlinx.rpc.codegen.StrictModeConfigurationKeys
import kotlinx.rpc.codegen.registerRpcExtensions
import kotlinx.rpc.codegen.toStrictMode
import org.jetbrains.kotlin.compiler.plugin.CompilerPluginRegistrar
import org.jetbrains.kotlin.config.CompilerConfiguration
import org.jetbrains.kotlin.test.directives.model.DirectiveApplicability
import org.jetbrains.kotlin.test.directives.model.DirectivesContainer
import org.jetbrains.kotlin.test.directives.model.SimpleDirectivesContainer
import org.jetbrains.kotlin.test.model.TestModule
import org.jetbrains.kotlin.test.services.EnvironmentConfigurator
import org.jetbrains.kotlin.test.services.TestServices
import org.jetbrains.kotlinx.serialization.compiler.extensions.SerializationComponentRegistrar

class ExtensionRegistrarConfigurator(testServices: TestServices) : EnvironmentConfigurator(testServices) {
override val directiveContainers: List<DirectivesContainer> = listOf(RpcDirectives)

override fun CompilerPluginRegistrar.ExtensionStorage.registerCompilerExtensions(
module: TestModule,
configuration: CompilerConfiguration
) {
val strictMode = module.directives[RpcDirectives.RPC_STRICT_MODE]
if (strictMode.isNotEmpty()) {
val mode = strictMode.single().toStrictMode() ?: StrictMode.WARNING
configuration.put(StrictModeConfigurationKeys.STATE_FLOW, mode)
configuration.put(StrictModeConfigurationKeys.SHARED_FLOW, mode)
configuration.put(StrictModeConfigurationKeys.NESTED_FLOW, mode)
configuration.put(StrictModeConfigurationKeys.STREAM_SCOPED_FUNCTIONS, mode)
configuration.put(StrictModeConfigurationKeys.SUSPENDING_SERVER_STREAMING, mode)
configuration.put(StrictModeConfigurationKeys.NOT_TOP_LEVEL_SERVER_FLOW, mode)
configuration.put(StrictModeConfigurationKeys.FIELDS, mode)
}

registerRpcExtensions(configuration)

// libs
SerializationComponentRegistrar.registerExtensions(this)
}
}

object RpcDirectives : SimpleDirectivesContainer() {
val RPC_STRICT_MODE by stringDirective("none, warning or error", DirectiveApplicability.Module)
}
Loading

0 comments on commit 9847059

Please sign in to comment.