diff --git a/crypto/aes.go b/crypto/aes.go index 4bf6e1b..2105e94 100644 --- a/crypto/aes.go +++ b/crypto/aes.go @@ -3,6 +3,8 @@ package crypto import ( "crypto/aes" "crypto/cipher" + "crypto/hmac" + "crypto/sha256" "io" "os" "strings" @@ -325,3 +327,13 @@ func AESEncryptFilesInDir(dir string, secret []byte, opts ...AESEncryptFilesInDi return pool.Wait() } + +// HMAC calculate HMAC +func HMAC(key, data []byte) ([]byte, error) { + h := hmac.New(sha256.New, key) + if _, err := h.Write(data); err != nil { + return nil, errors.Wrap(err, "write data") + } + + return h.Sum(nil), nil +} diff --git a/crypto/aes_test.go b/crypto/aes_test.go index 7772df8..857d4c4 100644 --- a/crypto/aes_test.go +++ b/crypto/aes_test.go @@ -241,3 +241,18 @@ func TestGcmIvLength(t *testing.T) { require.Equal(t, AesGcmTagLen, gcm.Overhead()) } } + +func TestHMAC(t *testing.T) { + t.Parallel() + + key := []byte("secret-key") + data := []byte("hello world") + + expected := []byte{0x9, 0x5d, 0x5a, 0x21, 0xfe, 0x6d, 0x6, 0x46, 0xdb, 0x22, 0x3f, 0xdf, 0x3d, 0xe6, 0x43, 0x6b, 0xb8, 0xdf, 0xb2, 0xfa, 0xb0, 0xb5, 0x16, 0x77, 0xec, 0xf6, 0x44, 0x1f, 0xcf, 0x5f, 0x2a, 0x67} + + for i := 0; i < 5; i++ { + result, err := HMAC(key, data) + require.NoError(t, err) + require.Equal(t, expected, result) + } +} diff --git a/crypto/smtongsuo.go b/crypto/smtongsuo.go index 3e140dd..12ae881 100644 --- a/crypto/smtongsuo.go +++ b/crypto/smtongsuo.go @@ -3,6 +3,7 @@ package crypto import ( "bytes" "context" + "encoding/hex" "os" "os/exec" "path/filepath" @@ -253,3 +254,68 @@ func (t *Tongsuo) NewX509CertByCSR(ctx context.Context, return certDer, nil } + +func (t *Tongsuo) EncryptBySm4Baisc(ctx context.Context, key, plaintext, iv []byte) (ciphertext, hmac []byte, err error) { + dir, err := os.MkdirTemp("", "tongsuo*") + if err != nil { + return nil, nil, errors.Wrap(err, "generate tem dir") + } + defer t.removeAll(dir) + + cipherPath := filepath.Join(dir, "cipher") + if _, err = t.runCMD(ctx, []string{ + "enc", "-sm4-cbc", "-e", + "-in", "/dev/stdin", "-out", cipherPath, + "-K", hex.EncodeToString(key), "-iv", hex.EncodeToString(iv), + }, plaintext); err != nil { + return nil, nil, errors.Wrap(err, "encrypt") + } + + if ciphertext, err = os.ReadFile(cipherPath); err != nil { + return nil, nil, errors.Wrap(err, "read cipher") + } + + if hmac, err = HMAC(key, ciphertext); err != nil { + return nil, nil, errors.Wrap(err, "calculate hmac") + } + + return ciphertext, hmac, nil +} + +// DecryptBySm4Baisc decrypt by sm4 +// +// # Args +// - key: sm4 key +// - ciphertext: sm4 encrypted data +// - iv: sm4 iv +// - hmac: if not nil, will check ciphertext's integrity by hmac +func (t *Tongsuo) DecryptBySm4Baisc(ctx context.Context, key, ciphertext, iv, hmac []byte) (plaintext []byte, err error) { + if len(hmac) != 0 { // check hmac + if expectedHmac, err := HMAC(key, ciphertext); err != nil { + return nil, errors.Wrap(err, "calculate hmac") + } else if !bytes.Equal(hmac, expectedHmac) { + return nil, errors.Errorf("hmac not match") + } + } + + dir, err := os.MkdirTemp("", "tongsuo*") + if err != nil { + return nil, errors.Wrap(err, "generate tem dir") + } + defer t.removeAll(dir) + + cipherPath := filepath.Join(dir, "cipher") + if err = os.WriteFile(cipherPath, ciphertext, 0600); err != nil { + return nil, errors.Wrap(err, "write cipher") + } + + if plaintext, err = t.runCMD(ctx, []string{ + "enc", "-sm4-cbc", "-d", + "-in", cipherPath, "-out", "/dev/stdout", + "-K", hex.EncodeToString(key), "-iv", hex.EncodeToString(iv), + }, ciphertext); err != nil { + return nil, errors.Wrap(err, "decrypt") + } + + return plaintext, nil +} diff --git a/crypto/smtongsuo_test.go b/crypto/smtongsuo_test.go index de47d6e..57924d6 100644 --- a/crypto/smtongsuo_test.go +++ b/crypto/smtongsuo_test.go @@ -155,3 +155,55 @@ func TestTongsuo_NewIntermediaCaByCsr(t *testing.T) { require.Contains(t, certinfo, "Issuer: CN = test-common-name") }) } + +func TestTongsuo_EncryptBySm4Baisc(t *testing.T) { + t.Parallel() + if testSkipSmTongsuo(t) { + return + } + + ctx := context.Background() + ins, err := NewTongsuo("/usr/local/bin/tongsuo") + require.NoError(t, err) + + key, err := Salt(16) + require.NoError(t, err) + incorrectKey, err := Salt(16) + require.NoError(t, err) + plaintext := []byte("Hello, World!") + iv, err := Salt(16) + require.NoError(t, err) + + t.Run("correct passphare", func(t *testing.T) { + t.Parallel() + + ciphertext, tag, err := ins.EncryptBySm4Baisc(ctx, key, plaintext, iv) + require.NoError(t, err) + require.NotNil(t, ciphertext) + + // Decrypt the ciphertext to verify the encryption + decrypted, err := ins.DecryptBySm4Baisc(ctx, key, ciphertext, iv, tag) + require.NoError(t, err) + require.Equal(t, plaintext, decrypted) + }) + + t.Run("incorrect passphare", func(t *testing.T) { + t.Parallel() + + ciphertext, tag, err := ins.EncryptBySm4Baisc(ctx, key, plaintext, iv) + require.NoError(t, err) + require.NotNil(t, ciphertext) + + // Decrypt the ciphertext with incorrect key + _, err = ins.DecryptBySm4Baisc(ctx, []byte("incorrect key"), ciphertext, iv, tag) + require.ErrorContains(t, err, "hmac not match") + + // Decrypt the ciphertext with incorrect tag + _, err = ins.DecryptBySm4Baisc(ctx, key, ciphertext, iv, []byte("incorrect tag")) + require.ErrorContains(t, err, "hmac not match") + + // Decrypt the ciphertext with incorrect key and empty tag + _, err = ins.DecryptBySm4Baisc(ctx, incorrectKey, ciphertext, iv, nil) + require.ErrorContains(t, err, "got bad decrypt") + }) +}