Skip to content

Commit

Permalink
refactor: compatability with go1.21
Browse files Browse the repository at this point in the history
  • Loading branch information
Laisky committed Oct 14, 2024
1 parent 29ccb0e commit 55c1228
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 8 deletions.
5 changes: 4 additions & 1 deletion crypto/smtongsuo.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,10 @@ func (t *Tongsuo) ShowCertInfo(ctx context.Context,
return "", nil, errors.Wrap(err, "parse policy")
}

cert.Policies = append(cert.Policies, oid)
// cert.Policies = append(cert.Policies, oid)
if ansiOid, ok := oid.toASN1OID(); ok {
cert.PolicyIdentifiers = append(cert.PolicyIdentifiers, ansiOid)
}
}
}

Expand Down
16 changes: 12 additions & 4 deletions crypto/smtongsuo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,14 @@ func TestTongsuo_NewPrikeyAndCert(t *testing.T) {
require.True(t, cert.IsCA)
oid, err := OidAsn2X509(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 59936, 1, 1, 3})
require.NoError(t, err)
require.Contains(t, cert.Policies, oid)
ansiOid, ok := oid.toASN1OID()
require.True(t, ok)
require.Contains(t, cert.PolicyIdentifiers, ansiOid)
oid, err = OidAsn2X509(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 59936, 1, 2, 3})
require.NoError(t, err)
require.Contains(t, cert.Policies, oid)
ansiOid, ok = oid.toASN1OID()
require.True(t, ok)
require.Contains(t, cert.PolicyIdentifiers, ansiOid)
})

t.Run("not ca", func(t *testing.T) {
Expand Down Expand Up @@ -289,10 +293,14 @@ func TestTongsuo_NewPrikeyAndCert(t *testing.T) {
require.Equal(t, "test-common-name", cert.Subject.CommonName)
oid, err := OidAsn2X509(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 59936, 1, 1, 3})
require.NoError(t, err)
require.Contains(t, cert.Policies, oid)
ansiOid, ok := oid.toASN1OID()
require.True(t, ok)
require.Contains(t, cert.PolicyIdentifiers, ansiOid)
oid, err = OidAsn2X509(asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 59936, 1, 1, 4})
require.NoError(t, err)
require.Contains(t, cert.Policies, oid)
ansiOid, ok = oid.toASN1OID()
require.True(t, ok)
require.Contains(t, cert.PolicyIdentifiers, ansiOid)
})
}

Expand Down
166 changes: 163 additions & 3 deletions crypto/x509.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import (
"crypto/x509/pkix"
"encoding/asn1"
"fmt"
"math"
"math/big"
"math/bits"
"net"
"net/mail"
"net/url"
Expand Down Expand Up @@ -1569,18 +1571,176 @@ func X509CertSubjectKeyID(pubkey crypto.PublicKey) ([]byte, error) {
return hasher.Sum(nil), nil
}

// An OID represents an ASN.1 OBJECT IDENTIFIER.
type OID struct {
der []byte
}

// EqualASN1OID returns whether an OID equals an asn1.ObjectIdentifier. If
// asn1.ObjectIdentifier cannot represent the OID specified by oid, because
// a component of OID requires more than 31 bits, it returns false.
func (oid OID) EqualASN1OID(other asn1.ObjectIdentifier) bool {
if len(other) < 2 {
return false
}
v, offset, failed := parseBase128Int(oid.der, 0)
if failed {
// This should never happen, since we've already parsed the OID,
// but just in case.
return false
}
if v < 80 {
a, b := v/40, v%40
if other[0] != a || other[1] != b {
return false
}
} else {
a, b := 2, v-80
if other[0] != a || other[1] != b {
return false
}
}

i := 2
for ; offset < len(oid.der); i++ {
v, offset, failed = parseBase128Int(oid.der, offset)
if failed {
// Again, shouldn't happen, since we've already parsed
// the OID, but better safe than sorry.
return false
}
if i >= len(other) || v != other[i] {
return false
}
}

return i == len(other)
}

func (oid OID) toASN1OID() (asn1.ObjectIdentifier, bool) {
out := make([]int, 0, len(oid.der)+1)

const (
valSize = 31 // amount of usable bits of val for OIDs.
bitsPerByte = 7
maxValSafeShift = (1 << (valSize - bitsPerByte)) - 1
)

val := 0

for _, v := range oid.der {
if val > maxValSafeShift {
return nil, false
}

val <<= bitsPerByte
val |= int(v & 0x7F)

if v&0x80 == 0 {
if len(out) == 0 {
if val < 80 {
out = append(out, val/40)
out = append(out, val%40)
} else {
out = append(out, 2)
out = append(out, val-80)
}
val = 0
continue
}
out = append(out, val)
val = 0
}
}

return out, true
}

var errInvalidOID = errors.New("invalid oid")

func parseBase128Int(bytes []byte, initOffset int) (ret, offset int, failed bool) {
offset = initOffset
var ret64 int64
for shifted := 0; offset < len(bytes); shifted++ {
// 5 * 7 bits per byte == 35 bits of data
// Thus the representation is either non-minimal or too large for an int32
if shifted == 5 {
failed = true
return
}
ret64 <<= 7
b := bytes[offset]
// integers should be minimally encoded, so the leading octet should
// never be 0x80
if shifted == 0 && b == 0x80 {
failed = true
return
}
ret64 |= int64(b & 0x7f)
offset++
if b&0x80 == 0 {
ret = int(ret64)
// Ensure that the returned value fits in an int on all platforms
if ret64 > math.MaxInt32 {
failed = true
}
return
}
}
failed = true
return
}

func appendBase128Int(dst []byte, n uint64) []byte {
for i := base128IntLength(n) - 1; i >= 0; i-- {
o := byte(n >> uint(i*7))
o &= 0x7f
if i != 0 {
o |= 0x80
}
dst = append(dst, o)
}
return dst
}

func base128IntLength(n uint64) int {
if n == 0 {
return 1
}
return (bits.Len64(n) + 6) / 7
}

// OIDFromInts creates a new OID using ints, each integer is a separate component.
func OIDFromInts(oid []uint64) (OID, error) {
if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
return OID{}, errInvalidOID
}

length := base128IntLength(oid[0]*40 + oid[1])
for _, v := range oid[2:] {
length += base128IntLength(v)
}

der := make([]byte, 0, length)
der = appendBase128Int(der, oid[0]*40+oid[1])
for _, v := range oid[2:] {
der = appendBase128Int(der, v)
}
return OID{der}, nil
}

// OidAsn2X509 convert asn1 object identifier to x509 object identifier
func OidAsn2X509(oid asn1.ObjectIdentifier) (x509.OID, error) {
func OidAsn2X509(oid asn1.ObjectIdentifier) (OID, error) {
oids := make([]uint64, 0, len(oid))
for i := range oid {
oids = append(oids, uint64(oid[i]))
}

return x509.OIDFromInts(oids)
return OIDFromInts(oids)
}

// OidFromString convert string to x509 object identifier
func OidFromString(val string) (x509Oid x509.OID, err error) {
func OidFromString(val string) (x509Oid OID, err error) {
asnOid, err := gutils.ParseObjectIdentifier(val)
if err != nil {
return x509Oid, errors.Wrapf(err, "parse oid %s", val)
Expand Down

0 comments on commit 55c1228

Please sign in to comment.