Skip to content

Commit

Permalink
Use AES key per stream
Browse files Browse the repository at this point in the history
Switched to use AES key for each stream which brings key auto-rotation
  • Loading branch information
willyborankin committed Aug 18, 2023
1 parent 65b8297 commit 14dd0e1
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,6 @@ public class EncryptedRepository extends BlobStoreRepository {

private final EncryptedRepositorySettings encryptedRepositorySettings;

private final Cache<String, EncryptionData> encryptionDataCache;

private final EncryptionDataGenerator encryptionDataGenerator;

public EncryptedRepository(final RepositoryMetadata metadata,
Expand All @@ -71,27 +69,11 @@ public EncryptedRepository(final RepositoryMetadata metadata,
final NamedXContentRegistry namedXContentRegistry,
final ClusterService clusterService,
final RecoverySettings recoverySettings) {
this(metadata, encryptedRepositorySettings,
blobStorageRepositoryType, blobStorageRepository,
namedXContentRegistry, clusterService,
CacheBuilder.<String, EncryptionData>builder().build(),
recoverySettings);
}

public EncryptedRepository(final RepositoryMetadata metadata,
final EncryptedRepositorySettings encryptedRepositorySettings,
final String blobStorageRepositoryType,
final BlobStoreRepository blobStorageRepository,
final NamedXContentRegistry namedXContentRegistry,
final ClusterService clusterService,
final Cache<String, EncryptionData> encryptionDataCache,
final RecoverySettings recoverySettings) {
super(metadata, COMPRESS_SETTING.get(metadata.settings()),
namedXContentRegistry, clusterService, recoverySettings);
this.encryptedRepositorySettings = encryptedRepositorySettings;
this.blobStorageRepositoryType = blobStorageRepositoryType;
this.blobStorageRepository = blobStorageRepository;
this.encryptionDataCache = encryptionDataCache;
this.encryptionDataGenerator = new EncryptionDataGenerator();
}

Expand Down Expand Up @@ -119,14 +101,12 @@ protected void doStart() {
@Override
protected void doStop() {
super.doStop();
encryptionDataCache.invalidateAll();
blobStorageRepository.stop();
}

@Override
protected void doClose() {
super.doClose();
encryptionDataCache.invalidateAll();
blobStorageRepository.close();
}

Expand All @@ -135,10 +115,8 @@ protected BlobStore createBlobStore() throws Exception {
return new EncryptedBlobStore(
blobStorageRepository.blobStore(),
new CryptoIO(
encryptionDataCache.computeIfAbsent(
settingsKey(metadata.settings()),
this::createOrRestoreEncryptionData
)
new EncryptionDataSerializer(encryptedRepositorySettings.rsaKeyPair(settingsKey(metadata.settings()))),
encryptionDataGenerator
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import java.util.Map;
import java.util.Set;

class EncryptedRepositorySettings {
public class EncryptedRepositorySettings {

private static final Logger LOGGER = LogManager.getLogger(EncryptedRepositorySettings.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import javax.crypto.Cipher;
import javax.crypto.CipherInputStream;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
Expand All @@ -26,34 +25,29 @@ public class CryptoIO implements Encryptor, Decryptor {

public static final int GCM_ENCRYPTED_BLOCK_LENGTH = 128;

public static final int GCM_IV_LENGTH = 12;

public static final String CIPHER_TRANSFORMATION = "AES/GCM/NoPadding";

private final SecretKey secretKey;

private final byte[] aad;
private final EncryptionDataSerializer encryptionDataSerializer;

private final SecureRandom secureRandom;
private final EncryptionDataGenerator encryptionDataGenerator;

public CryptoIO(final EncryptionData encryptionData) {
this.secretKey = encryptionData.encryptionKey();
this.aad = encryptionData.aad();
this.secureRandom = new SecureRandom();
public CryptoIO(final EncryptionDataSerializer encryptionDataSerializer,
final EncryptionDataGenerator encryptionDataGenerator) {
this.encryptionDataSerializer = encryptionDataSerializer;
this.encryptionDataGenerator = encryptionDataGenerator;
}

public InputStream encrypt(final InputStream in) throws IOException {
return Permissions.doPrivileged(() -> {
final byte[] iv = new byte[GCM_IV_LENGTH];
secureRandom.nextBytes(iv);
final var encryptionData = encryptionDataGenerator.generate();
final Cipher cipher = createEncryptingCipher(
secretKey,
new GCMParameterSpec(GCM_ENCRYPTED_BLOCK_LENGTH, iv),
encryptionData.encryptionKey(),
new GCMParameterSpec(GCM_ENCRYPTED_BLOCK_LENGTH, encryptionData.iv()),
CIPHER_TRANSFORMATION);
cipher.updateAAD(aad);
cipher.updateAAD(encryptionData.aad());
return new BufferedInputStream(
new SequenceInputStream(
new ByteArrayInputStream(iv),
new ByteArrayInputStream(encryptionDataSerializer.serialize(encryptionData)),
new CipherInputStream(in, cipher)
), BUFFER_SIZE
);
Expand All @@ -62,17 +56,18 @@ public InputStream encrypt(final InputStream in) throws IOException {

public InputStream decrypt(final InputStream in) throws IOException {
return Permissions.doPrivileged(() -> {
final var encryptionData = encryptionDataSerializer.deserialize(in.readNBytes(EncryptionDataSerializer.ENC_DATA_SIZE));
final Cipher cipher = createDecryptingCipher(
secretKey,
new GCMParameterSpec(GCM_ENCRYPTED_BLOCK_LENGTH, in.readNBytes(GCM_IV_LENGTH)),
encryptionData.encryptionKey(),
new GCMParameterSpec(GCM_ENCRYPTED_BLOCK_LENGTH, encryptionData.iv()),
CIPHER_TRANSFORMATION);
cipher.updateAAD(aad);
cipher.updateAAD(encryptionData.aad());
return new BufferedInputStream(new CipherInputStream(in, cipher), BUFFER_SIZE);
});
}

public long encryptedStreamSize(final long originSize) {
return originSize + GCM_TAG_LENGTH + GCM_IV_LENGTH;
return originSize + GCM_TAG_LENGTH + EncryptionDataSerializer.ENC_DATA_SIZE;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ public final class EncryptionData {

private final byte[] aad;

public EncryptionData(final SecretKey encryptionKey, final byte[] aad) {
private final byte[] iv;

public EncryptionData(final SecretKey encryptionKey, final byte[] aad, final byte[] iv) {
this.encryptionKey = encryptionKey;
this.aad = aad;
this.iv = iv;
}

public SecretKey encryptionKey() {
Expand All @@ -28,18 +31,23 @@ public byte[] aad() {
return aad;
}

public byte[] iv() {
return iv;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final EncryptionData that = (EncryptionData) o;
return Objects.equals(encryptionKey, that.encryptionKey) && Arrays.equals(aad, that.aad);
EncryptionData that = (EncryptionData) o;
return Objects.equals(encryptionKey, that.encryptionKey) && Arrays.equals(aad, that.aad) && Arrays.equals(iv, that.iv);
}

@Override
public int hashCode() {
int result = Objects.hash(encryptionKey);
result = 31 * result + Arrays.hashCode(aad);
result = 31 * result + Arrays.hashCode(iv);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ public final class EncryptionDataGenerator {

private static final int AAD_SIZE = 32;

public static final int GCM_IV_LENGTH = 12;

private final KeyGenerator aesKeyGenerator;

private final SecureRandom random;
private final SecureRandom random = new SecureRandom();

public EncryptionDataGenerator() {
try {
this.aesKeyGenerator = KeyGenerator.getInstance("AES");
this.random = new SecureRandom();
this.aesKeyGenerator.init(KEY_SIZE, this.random);
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("Couldn't create AES key generator", e);
Expand All @@ -31,8 +32,10 @@ public EncryptionDataGenerator() {

public EncryptionData generate() {
final byte[] aad = new byte[AAD_SIZE];
final byte[] iv = new byte[GCM_IV_LENGTH];
random.nextBytes(aad);
return new EncryptionData(aesKeyGenerator.generateKey(), aad);
random.nextBytes(iv);
return new EncryptionData(aesKeyGenerator.generateKey(), aad, iv);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class EncryptionDataSerializer implements Encryptor, Decryptor {

public static final int ENCRYPTED_AAD_SIZE = 256;

public static final int ENC_DATA_SIZE = ENCRYPTED_KEY_SIZE + ENCRYPTED_AAD_SIZE + SIGNATURE_SIZE + Integer.BYTES;
public static final int ENC_DATA_SIZE = EncryptionDataGenerator.GCM_IV_LENGTH + ENCRYPTED_KEY_SIZE + ENCRYPTED_AAD_SIZE + SIGNATURE_SIZE + Integer.BYTES;

public EncryptionDataSerializer(final KeyPair rsaKeyPair) {
this.rsaKeyPair = rsaKeyPair;
Expand All @@ -52,15 +52,18 @@ public byte[] serialize(final EncryptionData encryptionData) throws IOException
}
final byte[] key = encryptionData.encryptionKey().getEncoded();
final byte[] aad = encryptionData.aad();
final byte[] iv = encryptionData.iv();
final byte[] signature = sign(
ByteBuffer.allocate(key.length + aad.length)
ByteBuffer.allocate(key.length + aad.length + iv.length)
.put(key)
.put(aad)
.put(iv)
.array()
);
final byte[] encryptedKey = encrypt(key, "Couldn't encrypt " + KEY_ALGORITHM + " key");
final byte[] encryptedAad = encrypt(aad, "Couldn't encrypt AAD");
return ByteBuffer.allocate(ENC_DATA_SIZE)
.put(iv)
.put(encryptedKey)
.put(encryptedAad)
.put(signature)
Expand All @@ -75,6 +78,8 @@ public EncryptionData deserialize(final byte[] metadata) throws IOException {
final byte[] encryptedKey = new byte[256];
final byte[] encryptedAad = new byte[256];
final byte[] signature = new byte[256];
final byte[] iv = new byte[EncryptionDataGenerator.GCM_IV_LENGTH];
buffer.get(iv);
buffer.get(encryptedKey);
buffer.get(encryptedAad);
buffer.get(signature);
Expand All @@ -83,12 +88,13 @@ public EncryptionData deserialize(final byte[] metadata) throws IOException {
final byte[] decryptedAdd = decrypt(encryptedAad, "Couldn't decrypt AAD");
verifySignature(
signature,
ByteBuffer.allocate(decryptedKey.length + decryptedAdd.length)
ByteBuffer.allocate(decryptedKey.length + decryptedAdd.length + iv.length)
.put(decryptedKey)
.put(decryptedAdd)
.put(iv)
.array()
);
return new EncryptionData(new SecretKeySpec(decryptedKey, KEY_ALGORITHM), decryptedAdd);
return new EncryptionData(new SecretKeySpec(decryptedKey, KEY_ALGORITHM), decryptedAdd, iv);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@
import org.opensearch.cluster.metadata.RepositoryMetadata;
import org.opensearch.cluster.service.ClusterApplierService;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.cache.Cache;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.indices.recovery.RecoverySettings;
import org.opensearch.repositories.blobstore.BlobStoreRepository;
import org.opensearch.repository.encrypted.security.EncryptionData;
import org.opensearch.test.OpenSearchTestCase;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;

Expand All @@ -35,9 +32,6 @@ public class EncryptedRepositoryTests extends OpenSearchTestCase {

final RecoverySettings mockedRecoverySettings = mock(RecoverySettings.class);

final Cache<String, EncryptionData> mockedCache =
(Cache<String, EncryptionData>) mock(Cache.class);

@Before
public void setupMocks() throws Exception {
when(mockedClusterService.getClusterApplierService()).thenReturn(mock(ClusterApplierService.class));
Expand All @@ -52,21 +46,16 @@ public void testBlobStorageLifecycle() throws Exception {
mockedBlobStoreRepository,
mockedNamedXContentRegistry,
mockedClusterService,
mockedCache,
mockedRecoverySettings);

repository.start();
verify(mockedBlobStoreRepository).start();

repository.stop();
verify(mockedBlobStoreRepository).stop();
verify(mockedCache).invalidateAll();

reset(mockedCache);

repository.close();
verify(mockedBlobStoreRepository).close();
verify(mockedCache).invalidateAll();

repository.stats();
verify(mockedBlobStoreRepository).stats();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,26 @@

import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.BeforeClass;
import org.opensearch.repository.encrypted.RsaKeyAwareTest;
import org.opensearch.test.OpenSearchTestCase;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.Security;

public class CryptoIOTests extends OpenSearchTestCase {
public class CryptoIOTests extends RsaKeyAwareTest {

private static final int MAX_BYES_SIZE = 18_192;
private static final int MAX_BYES_SIZE = 1_800_192;

private final EncryptionData encData = new EncryptionDataGenerator().generate();
private final EncryptionDataGenerator encryptionDataGenerator = new EncryptionDataGenerator();

@BeforeClass
static void setupProvider() {
static {
Security.addProvider(new BouncyCastleProvider());
}

public void testEncryptAndDecrypt() throws IOException {
final CryptoIO cryptoIo = new CryptoIO(encData);
final CryptoIO cryptoIo = new CryptoIO(new EncryptionDataSerializer(rsaKeyPair), encryptionDataGenerator);
final byte [] sequence = randomByteArrayOfLength(randomInt(MAX_BYES_SIZE));

try (InputStream encIn = cryptoIo.encrypt(new ByteArrayInputStream(sequence))) {
Expand All @@ -35,11 +35,10 @@ public void testEncryptAndDecrypt() throws IOException {
assertArrayEquals(sequence, decIn.readAllBytes());
}
}

}

public void testEncryptedStreamSize() throws IOException {
final CryptoIO cryptoIo = new CryptoIO(encData);
final CryptoIO cryptoIo = new CryptoIO(new EncryptionDataSerializer(rsaKeyPair), encryptionDataGenerator);
final byte [] sequence = randomByteArrayOfLength(randomInt(MAX_BYES_SIZE));

try (InputStream encIn = cryptoIo.encrypt(new ByteArrayInputStream(sequence))) {
Expand Down

0 comments on commit 14dd0e1

Please sign in to comment.