From 55c1228e376556642e1388084b04907149ef25a4 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Mon, 14 Oct 2024 09:05:47 +0000 Subject: [PATCH] refactor: compatability with go1.21 --- crypto/smtongsuo.go | 5 +- crypto/smtongsuo_test.go | 16 +++- crypto/x509.go | 166 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 179 insertions(+), 8 deletions(-) diff --git a/crypto/smtongsuo.go b/crypto/smtongsuo.go index c70da99..c7564ff 100644 --- a/crypto/smtongsuo.go +++ b/crypto/smtongsuo.go @@ -261,7 +261,10 @@ func (t *Tongsuo) ShowCertInfo(ctx context.Context, return "", nil, errors.Wrap(err, "parse policy") } - cert.Policies = append(cert.Policies, oid) + // cert.Policies = append(cert.Policies, oid) + if ansiOid, ok := oid.toASN1OID(); ok { + cert.PolicyIdentifiers = append(cert.PolicyIdentifiers, ansiOid) + } } } diff --git a/crypto/smtongsuo_test.go b/crypto/smtongsuo_test.go index ac60597..d212a55 100644 --- a/crypto/smtongsuo_test.go +++ b/crypto/smtongsuo_test.go @@ -249,10 +249,14 @@ func TestTongsuo_NewPrikeyAndCert(t *testing.T) { require.True(t, cert.IsCA) oid, err := OidAsn2X509(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 59936, 1, 1, 3}) require.NoError(t, err) - require.Contains(t, cert.Policies, oid) + ansiOid, ok := oid.toASN1OID() + require.True(t, ok) + require.Contains(t, cert.PolicyIdentifiers, ansiOid) oid, err = OidAsn2X509(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 59936, 1, 2, 3}) require.NoError(t, err) - require.Contains(t, cert.Policies, oid) + ansiOid, ok = oid.toASN1OID() + require.True(t, ok) + require.Contains(t, cert.PolicyIdentifiers, ansiOid) }) t.Run("not ca", func(t *testing.T) { @@ -289,10 +293,14 @@ func TestTongsuo_NewPrikeyAndCert(t *testing.T) { require.Equal(t, "test-common-name", cert.Subject.CommonName) oid, err := OidAsn2X509(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 59936, 1, 1, 3}) require.NoError(t, err) - require.Contains(t, cert.Policies, oid) + ansiOid, ok := oid.toASN1OID() + require.True(t, ok) + require.Contains(t, cert.PolicyIdentifiers, ansiOid) oid, err = OidAsn2X509(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 59936, 1, 1, 4}) require.NoError(t, err) - require.Contains(t, cert.Policies, oid) + ansiOid, ok = oid.toASN1OID() + require.True(t, ok) + require.Contains(t, cert.PolicyIdentifiers, ansiOid) }) } diff --git a/crypto/x509.go b/crypto/x509.go index 094fd09..9977eac 100644 --- a/crypto/x509.go +++ b/crypto/x509.go @@ -11,7 +11,9 @@ import ( "crypto/x509/pkix" "encoding/asn1" "fmt" + "math" "math/big" + "math/bits" "net" "net/mail" "net/url" @@ -1569,18 +1571,176 @@ func X509CertSubjectKeyID(pubkey crypto.PublicKey) ([]byte, error) { return hasher.Sum(nil), nil } +// An OID represents an ASN.1 OBJECT IDENTIFIER. +type OID struct { + der []byte +} + +// EqualASN1OID returns whether an OID equals an asn1.ObjectIdentifier. If +// asn1.ObjectIdentifier cannot represent the OID specified by oid, because +// a component of OID requires more than 31 bits, it returns false. +func (oid OID) EqualASN1OID(other asn1.ObjectIdentifier) bool { + if len(other) < 2 { + return false + } + v, offset, failed := parseBase128Int(oid.der, 0) + if failed { + // This should never happen, since we've already parsed the OID, + // but just in case. + return false + } + if v < 80 { + a, b := v/40, v%40 + if other[0] != a || other[1] != b { + return false + } + } else { + a, b := 2, v-80 + if other[0] != a || other[1] != b { + return false + } + } + + i := 2 + for ; offset < len(oid.der); i++ { + v, offset, failed = parseBase128Int(oid.der, offset) + if failed { + // Again, shouldn't happen, since we've already parsed + // the OID, but better safe than sorry. + return false + } + if i >= len(other) || v != other[i] { + return false + } + } + + return i == len(other) +} + +func (oid OID) toASN1OID() (asn1.ObjectIdentifier, bool) { + out := make([]int, 0, len(oid.der)+1) + + const ( + valSize = 31 // amount of usable bits of val for OIDs. + bitsPerByte = 7 + maxValSafeShift = (1 << (valSize - bitsPerByte)) - 1 + ) + + val := 0 + + for _, v := range oid.der { + if val > maxValSafeShift { + return nil, false + } + + val <<= bitsPerByte + val |= int(v & 0x7F) + + if v&0x80 == 0 { + if len(out) == 0 { + if val < 80 { + out = append(out, val/40) + out = append(out, val%40) + } else { + out = append(out, 2) + out = append(out, val-80) + } + val = 0 + continue + } + out = append(out, val) + val = 0 + } + } + + return out, true +} + +var errInvalidOID = errors.New("invalid oid") + +func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, failed bool) { + offset = initOffset + var ret64 int64 + for shifted := 0; offset < len(bytes); shifted++ { + // 5 * 7 bits per byte == 35 bits of data + // Thus the representation is either non-minimal or too large for an int32 + if shifted == 5 { + failed = true + return + } + ret64 <<= 7 + b := bytes[offset] + // integers should be minimally encoded, so the leading octet should + // never be 0x80 + if shifted == 0 && b == 0x80 { + failed = true + return + } + ret64 |= int64(b & 0x7f) + offset++ + if b&0x80 == 0 { + ret = int(ret64) + // Ensure that the returned value fits in an int on all platforms + if ret64 > math.MaxInt32 { + failed = true + } + return + } + } + failed = true + return +} + +func appendBase128Int(dst []byte, n uint64) []byte { + for i := base128IntLength(n) - 1; i >= 0; i-- { + o := byte(n >> uint(i*7)) + o &= 0x7f + if i != 0 { + o |= 0x80 + } + dst = append(dst, o) + } + return dst +} + +func base128IntLength(n uint64) int { + if n == 0 { + return 1 + } + return (bits.Len64(n) + 6) / 7 +} + +// OIDFromInts creates a new OID using ints, each integer is a separate component. +func OIDFromInts(oid []uint64) (OID, error) { + if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) { + return OID{}, errInvalidOID + } + + length := base128IntLength(oid[0]*40 + oid[1]) + for _, v := range oid[2:] { + length += base128IntLength(v) + } + + der := make([]byte, 0, length) + der = appendBase128Int(der, oid[0]*40+oid[1]) + for _, v := range oid[2:] { + der = appendBase128Int(der, v) + } + return OID{der}, nil +} + // OidAsn2X509 convert asn1 object identifier to x509 object identifier -func OidAsn2X509(oid asn1.ObjectIdentifier) (x509.OID, error) { +func OidAsn2X509(oid asn1.ObjectIdentifier) (OID, error) { oids := make([]uint64, 0, len(oid)) for i := range oid { oids = append(oids, uint64(oid[i])) } - return x509.OIDFromInts(oids) + return OIDFromInts(oids) } // OidFromString convert string to x509 object identifier -func OidFromString(val string) (x509Oid x509.OID, err error) { +func OidFromString(val string) (x509Oid OID, err error) { asnOid, err := gutils.ParseObjectIdentifier(val) if err != nil { return x509Oid, errors.Wrapf(err, "parse oid %s", val)