Skip to content

Commit

Permalink
enrich-kafka: authenticate with Event Hubs using OAuth2 (close #863)
Browse files Browse the repository at this point in the history
  • Loading branch information
spenes committed Jan 31, 2024
1 parent 164a68a commit 7f32979
Show file tree
Hide file tree
Showing 8 changed files with 187 additions and 23 deletions.
9 changes: 9 additions & 0 deletions integration-tests/enrich-kafka/config/enrich-kafka.hocon
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
"consumerConf": {
"enable.auto.commit": "false"
"auto.offset.reset" : "earliest"
"security.protocol": "PLAINTEXT"
"sasl.mechanism": "GSSAPI"
}
}

Expand All @@ -20,12 +22,19 @@
"bootstrapServers": "broker:29092"
"partitionKey": "app_id"
"headers": ["app_id"]
"producerConf": {
"acks": "all"
"security.protocol": "PLAINTEXT"
"sasl.mechanism": "GSSAPI"
}
}

"bad": {
"type": "Kafka"
"topicName": "it-enrich-kinesis-bad"
"bootstrapServers": "broker:29092"
"security.protocol": "PLAINTEXT"
"sasl.mechanism": "GSSAPI"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import com.snowplowanalytics.snowplow.enrich.common.fs2.config.io.Output.{Kafka

import com.snowplowanalytics.snowplow.enrich.common.fs2.test.CollectorPayloadGen

import com.snowplowanalytics.snowplow.enrich.kafka.{Sink, Source}
import com.snowplowanalytics.snowplow.enrich.kafka._

class EnrichKafkaSpec extends Specification with CatsEffect {

Expand Down Expand Up @@ -67,7 +67,7 @@ class EnrichKafkaSpec extends Specification with CatsEffect {

def run(): IO[Aggregates] = {

val resources = Sink.init[IO](OutKafka(collectorPayloadsStream, bootstrapServers, "", Set.empty, producerConf))
val resources = Sink.init[IO](OutKafka(collectorPayloadsStream, bootstrapServers, "", Set.empty, producerConf), classOf[SourceAuthHandler].getName)

resources.use { sink =>
val generate =
Expand All @@ -79,10 +79,10 @@ class EnrichKafkaSpec extends Specification with CatsEffect {
consumeGood(refGood).merge(consumeBad(refBad))

def consumeGood(ref: Ref[IO, AggregateGood]): Stream[IO, Unit] =
Source.init[IO](InKafka(enrichedStream, bootstrapServers, consumerConf)).map(_.record.value).evalMap(aggregateGood(_, ref))
Source.init[IO](InKafka(enrichedStream, bootstrapServers, consumerConf), classOf[GoodSinkAuthHandler].getName).map(_.record.value).evalMap(aggregateGood(_, ref))

def consumeBad(ref: Ref[IO, AggregateBad]): Stream[IO, Unit] =
Source.init[IO](InKafka(badRowsStream, bootstrapServers, consumerConf)).map(_.record.value).evalMap(aggregateBad(_, ref))
Source.init[IO](InKafka(badRowsStream, bootstrapServers, consumerConf), classOf[BadSinkAuthHandler].getName).map(_.record.value).evalMap(aggregateBad(_, ref))

def aggregateGood(r: Array[Byte], ref: Ref[IO, AggregateGood]): IO[Unit] =
for {
Expand Down
12 changes: 12 additions & 0 deletions modules/kafka/src/main/resources/application.conf
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
"enable.auto.commit": "false"
"auto.offset.reset" : "earliest"
"group.id": "enrich"
"security.protocol": "SASL_SSL"
"sasl.mechanism": "OAUTHBEARER"
"sasl.jaas.config": "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
}
}

Expand All @@ -13,6 +16,9 @@
"type": "Kafka"
"producerConf": {
"acks": "all"
"security.protocol": "SASL_SSL"
"sasl.mechanism": "OAUTHBEARER"
"sasl.jaas.config": "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
}
"partitionKey": ""
"headers": []
Expand All @@ -24,6 +30,9 @@
"bootstrapServers": ""
"producerConf": {
"acks": "all"
"security.protocol": "SASL_SSL"
"sasl.mechanism": "OAUTHBEARER"
"sasl.jaas.config": "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
}
"partitionKey": ""
"headers": []
Expand All @@ -33,6 +42,9 @@
"type": "Kafka"
"producerConf": {
"acks": "all"
"security.protocol": "SASL_SSL"
"sasl.mechanism": "OAUTHBEARER"
"sasl.jaas.config": "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
}
"partitionKey": ""
"headers": []
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Copyright (c) 2023-present Snowplow Analytics Ltd.
* All rights reserved.
*
* This software is made available by Snowplow Analytics, Ltd.,
* under the terms of the Snowplow Limited Use License Agreement, Version 1.0
* located at https://docs.snowplow.io/limited-use-license-1.0
* BY INSTALLING, DOWNLOADING, ACCESSING, USING OR DISTRIBUTING ANY PORTION
* OF THE SOFTWARE, YOU AGREE TO THE TERMS OF SUCH LICENSE AGREEMENT.
*/
package com.snowplowanalytics.snowplow.enrich.kafka

import java.net.URI
import java.{lang, util}

import com.nimbusds.jwt.JWTParser

import javax.security.auth.callback.Callback
import javax.security.auth.callback.UnsupportedCallbackException
import javax.security.auth.login.AppConfigurationEntry

import org.apache.kafka.clients.CommonClientConfigs
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback

import com.azure.identity.DefaultAzureCredentialBuilder
import com.azure.core.credential.TokenRequestContext

// We need separate instances of callback handler with separate source and
// sinks because they need different tokens to authenticate. However we are
// only giving class name to Kafka and it initializes the class itself and if
// we pass same class name for all source and sinks, Kafka initializes and uses
// only one instance of the callback handler. To create separate instances, we
// created multiple different classes and pass their names to respective sink
// and source properties. With this way, all the source and sinks will have their
// own callback handler instance.

class SourceAuthHandler extends AzureAuthenticationCallbackHandler

class GoodSinkAuthHandler extends AzureAuthenticationCallbackHandler

class BadSinkAuthHandler extends AzureAuthenticationCallbackHandler

class PiiSinkAuthHandler extends AzureAuthenticationCallbackHandler

class AzureAuthenticationCallbackHandler extends AuthenticateCallbackHandler {

val credentials = new DefaultAzureCredentialBuilder().build()

var sbUri: String = ""

override def configure(
configs: util.Map[String, _],
saslMechanism: String,
jaasConfigEntries: util.List[AppConfigurationEntry]
): Unit = {
val bootstrapServer =
configs
.get(CommonClientConfigs.BOOTSTRAP_SERVERS_CONFIG)
.toString
.replaceAll("\\[|\\]", "")
.split(",")
.toList
.headOption match {
case Some(s) => s
case None => throw new Exception("Empty bootstrap servers list")
}
val uri = URI.create("https://" + bootstrapServer)
// Workload identity works with '.default' scope
this.sbUri = s"${uri.getScheme}://${uri.getHost}/.default"
}

override def handle(callbacks: Array[Callback]): Unit =
callbacks.foreach {
case callback: OAuthBearerTokenCallback =>
val token = getOAuthBearerToken()
callback.token(token)
case callback => throw new UnsupportedCallbackException(callback)
}

def getOAuthBearerToken(): OAuthBearerToken = {
val reqContext = new TokenRequestContext()
reqContext.addScopes(sbUri)
val accessToken = credentials.getTokenSync(reqContext).getToken
val jwt = JWTParser.parse(accessToken)
val claims = jwt.getJWTClaimsSet

new OAuthBearerToken {
override def value(): String = accessToken

override def lifetimeMs(): Long = claims.getExpirationTime.getTime

override def scope(): util.Set[String] = null

override def principalName(): String = null

override def startTimeMs(): lang.Long = null
}
}

override def close(): Unit = ()
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ object Main extends IOApp {
BuildInfo.version,
BuildInfo.description,
cliConfig => IO.pure(cliConfig),
(input, _) => Source.init[IO](input),
out => Sink.initAttributed(out),
out => Sink.initAttributed(out),
out => Sink.init(out),
(input, _) => Source.init[IO](input, classOf[SourceAuthHandler].getName),
out => Sink.initAttributed(out, classOf[GoodSinkAuthHandler].getName),
out => Sink.initAttributed(out, classOf[PiiSinkAuthHandler].getName),
out => Sink.init(out, classOf[BadSinkAuthHandler].getName),
checkpoint,
createBlobStorageClient,
_.record.value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,20 @@ import com.snowplowanalytics.snowplow.enrich.common.fs2.config.io.Output
object Sink {

def init[F[_]: Async: Parallel](
output: Output
output: Output,
authCallbackClass: String
): Resource[F, ByteSink[F]] =
for {
sink <- initAttributed(output)
sink <- initAttributed(output, authCallbackClass)
} yield (records: List[Array[Byte]]) => sink(records.map(AttributedData(_, UUID.randomUUID().toString, Map.empty)))

def initAttributed[F[_]: Async: Parallel](
output: Output
output: Output,
authCallbackClass: String
): Resource[F, AttributedByteSink[F]] =
output match {
case k: Output.Kafka =>
mkProducer(k).map { producer => records =>
mkProducer(k, authCallbackClass).map { producer => records =>
records.parTraverse_ { record =>
producer
.produceOne_(toProducerRecord(k.topicName, record))
Expand All @@ -49,11 +51,14 @@ object Sink {
}

private def mkProducer[F[_]: Async](
output: Output.Kafka
output: Output.Kafka,
authCallbackClass: String
): Resource[F, KafkaProducer[F, String, Array[Byte]]] = {
val producerSettings =
ProducerSettings[F, String, Array[Byte]]
.withBootstrapServers(output.bootstrapServers)
// set before user-provided config to make it possible to override it via config
.withProperty("sasl.login.callback.handler.class", authCallbackClass)
.withProperties(output.producerConf)
.withProperties(
("key.serializer", "org.apache.kafka.common.serialization.StringSerializer"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,23 @@ import com.snowplowanalytics.snowplow.enrich.common.fs2.config.io.Input
object Source {

def init[F[_]: Async](
input: Input
input: Input,
authCallbackClass: String
): Stream[F, CommittableConsumerRecord[F, String, Array[Byte]]] =
input match {
case k: Input.Kafka => kafka(k)
case k: Input.Kafka => kafka(k, authCallbackClass)
case i => Stream.raiseError[F](new IllegalArgumentException(s"Input $i is not Kafka"))
}

def kafka[F[_]: Async](
input: Input.Kafka
input: Input.Kafka,
authCallbackClass: String
): Stream[F, CommittableConsumerRecord[F, String, Array[Byte]]] = {
val consumerSettings =
ConsumerSettings[F, String, Array[Byte]]
.withBootstrapServers(input.bootstrapServers)
// set before user-provided config to make it possible to override it via config
.withProperty("sasl.login.callback.handler.class", authCallbackClass)
.withProperties(input.consumerConf)
.withEnableAutoCommit(false) // prevent enabling auto-commits by setting this after user-provided config
.withProperties(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ class ConfigSpec extends Specification with CatsEffect {
"auto.offset.reset" -> "earliest",
"session.timeout.ms" -> "45000",
"enable.auto.commit" -> "false",
"group.id" -> "enrich"
"group.id" -> "enrich",
"security.protocol" -> "SASL_SSL",
"sasl.mechanism" -> "OAUTHBEARER",
"sasl.jaas.config" -> "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
)
),
io.Outputs(
Expand All @@ -55,23 +58,38 @@ class ConfigSpec extends Specification with CatsEffect {
"localhost:9092",
"app_id",
Set("app_id"),
Map("acks" -> "all")
Map(
"acks" -> "all",
"security.protocol" -> "SASL_SSL",
"sasl.mechanism" -> "OAUTHBEARER",
"sasl.jaas.config" -> "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
)
),
Some(
io.Output.Kafka(
"pii",
"localhost:9092",
"app_id",
Set("app_id"),
Map("acks" -> "all")
Map(
"acks" -> "all",
"security.protocol" -> "SASL_SSL",
"sasl.mechanism" -> "OAUTHBEARER",
"sasl.jaas.config" -> "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
)
)
),
io.Output.Kafka(
"bad",
"localhost:9092",
"",
Set(),
Map("acks" -> "all")
Map(
"acks" -> "all",
"security.protocol" -> "SASL_SSL",
"sasl.mechanism" -> "OAUTHBEARER",
"sasl.jaas.config" -> "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
)
)
),
io.Concurrency(256, 1),
Expand Down Expand Up @@ -151,7 +169,10 @@ class ConfigSpec extends Specification with CatsEffect {
Map(
"auto.offset.reset" -> "earliest",
"enable.auto.commit" -> "false",
"group.id" -> "enrich"
"group.id" -> "enrich",
"security.protocol" -> "SASL_SSL",
"sasl.mechanism" -> "OAUTHBEARER",
"sasl.jaas.config" -> "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
)
),
io.Outputs(
Expand All @@ -160,15 +181,25 @@ class ConfigSpec extends Specification with CatsEffect {
"localhost:9092",
"",
Set(),
Map("acks" -> "all")
Map(
"acks" -> "all",
"security.protocol" -> "SASL_SSL",
"sasl.mechanism" -> "OAUTHBEARER",
"sasl.jaas.config" -> "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
)
),
None,
io.Output.Kafka(
"bad",
"localhost:9092",
"",
Set(),
Map("acks" -> "all")
Map(
"acks" -> "all",
"security.protocol" -> "SASL_SSL",
"sasl.mechanism" -> "OAUTHBEARER",
"sasl.jaas.config" -> "org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule required;"
)
)
),
io.Concurrency(256, 1),
Expand Down

0 comments on commit 7f32979

Please sign in to comment.