Skip to content

Commit

Permalink
Rewrite hash calculation code in rust
Browse files Browse the repository at this point in the history
  • Loading branch information
twizmwazin committed Jun 24, 2024
1 parent 7a4eb8a commit 4f1c6e0
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 142 deletions.
150 changes: 12 additions & 138 deletions claripy/ast/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import builtins
import itertools
import logging
import math
Expand All @@ -17,17 +18,6 @@
if TYPE_CHECKING:
from claripy.annotation import Annotation

try:
import _pickle as pickle
except ImportError:
import pickle

try:
# Python's build-in MD5 is about 2x faster than hashlib.md5 on short bytestrings
import _md5 as md5
except ImportError:
import hashlib as md5

l = logging.getLogger("claripy.ast")

WORKER = bool(os.environ.get("WORKER", False))
Expand Down Expand Up @@ -69,7 +59,7 @@ def _make_name(name: str, size: int, explicit_name: bool = False, prefix: str =
return name


def _d(h, cls, state):
def _unpickle(h, cls, state):
"""
This function is the deserializer for ASTs.
It exists to work around the fact that pickle will (normally) call __new__() with no arguments during
Expand Down Expand Up @@ -132,9 +122,6 @@ def __new__(cls, op, args, add_variables=None, hash=None, **kwargs): # pylint:d
:param annotations: A frozenset of annotations applied onto this AST.
"""

# if any(isinstance(a, BackendObject) for a in args):
# raise Exception('asdf')

a_args = args if type(args) is tuple else tuple(args)

# initialize the following properties: symbolic, variables and errored
Expand Down Expand Up @@ -252,17 +239,17 @@ def __new__(cls, op, args, add_variables=None, hash=None, **kwargs): # pylint:d
elif op in {"BVS", "BVV", "BoolS", "BoolV", "FPS", "FPV"} and not annotations:
if op == "FPV" and a_args[0] == 0.0 and math.copysign(1, a_args[0]) < 0:
# Python does not distinguish between +0.0 and -0.0 so we add sign to tuple to distinguish
h = (op, kwargs.get("length", None), ("-", *a_args))
h = builtins.hash((op, kwargs.get("length", None), ("-", *a_args)))
elif op == "FPV" and math.isnan(a_args[0]):
# cannot compare nans
h = (op, kwargs.get("length", None), ("nan",) + a_args[1:])
h = builtins.hash((op, kwargs.get("length", None), ("nan",) + a_args[1:]))
else:
h = (op, kwargs.get("length", None), a_args)
h = builtins.hash((op, kwargs.get("length", None), a_args))

cache = cls._leaf_cache
else:
h = Base._calc_hash(op, a_args, kwargs) if hash is None else hash
self = cache.get(h, None)
self = cache.get(h & 0x7FFF_FFFF_FFFF_FFFF, None)
if self is None:
self = super().__new__(
cls,
Expand All @@ -282,8 +269,8 @@ def __new__(cls, op, args, add_variables=None, hash=None, **kwargs): # pylint:d
relocatable_annotations=relocatable_annotations,
**kwargs,
)
self._hash = h
cache[h] = self
self._hash = h & 0x7FFF_FFFF_FFFF_FFFF
cache[self._hash] = self
# else:
# if self.args != a_args or self.op != op or self.variables != kwargs['variables']:
# raise Exception("CRAP -- hash collision")
Expand All @@ -296,7 +283,7 @@ def __init_with_annotations__(
):
cache = cls._hash_cache
h = Base._calc_hash(op, a_args, kwargs)
self = cache.get(h, None)
self = cache.get(h & 0x7FFF_FFFF_FFFF_FFFF, None)
if self is not None:
return self

Expand All @@ -318,15 +305,15 @@ def __init_with_annotations__(
**kwargs,
)

self._hash = h
cache[h] = self
self._hash = h & 0x7FFF_FFFF_FFFF_FFFF
cache[self._hash] = self

return self

def __reduce__(self):
# HASHCONS: these attributes key the cache
# BEFORE CHANGING THIS, SEE ALL OTHER INSTANCES OF "HASHCONS" IN THIS FILE
return _d, (
return _unpickle, (
self._hash,
self.__class__,
(self.op, self.args, self.length, self.variables, self.symbolic, self.annotations),
Expand All @@ -335,113 +322,6 @@ def __reduce__(self):
def __init__(self, *args, **kwargs):
pass

@staticmethod
def _calc_hash(op, args, keywords):
"""
Calculates the hash of an AST, given the operation, args, and kwargs.
:param op: The operation.
:param args: The arguments to the operation.
:param keywords: A dict including the 'symbolic', 'variables', and 'length' items.
:returns: a hash.
We do it using md5 to avoid hash collisions.
(hash(-1) == hash(-2), for example)
"""
args_tup = tuple(a if type(a) in (int, float) else getattr(a, "_hash", hash(a)) for a in args)
# HASHCONS: these attributes key the cache
# BEFORE CHANGING THIS, SEE ALL OTHER INSTANCES OF "HASHCONS" IN THIS FILE

to_hash = Base._ast_serialize(op, args_tup, keywords)
if to_hash is None:
# fall back to pickle.dumps
to_hash = (
op,
args_tup,
str(keywords.get("length", None)),
hash(keywords["variables"]),
keywords["symbolic"],
hash(keywords.get("annotations", None)),
)
to_hash = pickle.dumps(to_hash, -1)

# Why do we use md5 when it's broken? Because speed is more important
# than cryptographic integrity here. Then again, look at all those
# allocations we're doing here... fast python is painful.
hd = md5.md5(to_hash).digest()
return md5_unpacker.unpack(hd)[0] # 64 bits

@staticmethod
def _arg_serialize(arg) -> bytes | None:
if arg is None:
return b"\x0f"
elif arg is True:
return b"\x1f"
elif arg is False:
return b"\x2e"
elif isinstance(arg, int):
if arg < 0:
if arg >= -0x7FFF:
return b"-" + struct.pack("<h", arg)
elif arg >= -0x7FFF_FFFF:
return b"-" + struct.pack("<i", arg)
elif arg >= -0x7FFF_FFFF_FFFF_FFFF:
return b"-" + struct.pack("<q", arg)
return None
else:
if arg <= 0xFFFF:
return struct.pack("<H", arg)
elif arg <= 0xFFFF_FFFF:
return struct.pack("<I", arg)
elif arg <= 0xFFFF_FFFF_FFFF_FFFF:
return struct.pack("<Q", arg)
return None
elif isinstance(arg, str):
return arg.encode()
elif isinstance(arg, float):
return struct.pack("f", arg)
elif isinstance(arg, tuple):
arr = []
for elem in arg:
b = Base._arg_serialize(elem)
if b is None:
return None
arr.append(b)
return b"".join(arr)

return None

@staticmethod
def _ast_serialize(op: str, args_tup, keywords) -> bytes | None:
"""
Serialize the AST and get a bytestring for hashing.
:param op: The operator.
:param args_tup: A tuple of arguments.
:param keywords: A dict of keywords.
:return: The serialized bytestring.
"""

serialized_args = Base._arg_serialize(args_tup)
if serialized_args is None:
return None

if "length" in keywords:
length = Base._arg_serialize(keywords["length"])
if length is None:
return None
else:
length = b"none"

variables = struct.pack("<Q", hash(keywords["variables"]) & 0xFFFF_FFFF_FFFF_FFFF)
symbolic = b"\x01" if keywords["symbolic"] else b"\x00"
if "annotations" in keywords:
annotations = struct.pack("<Q", hash(keywords["annotations"]) & 0xFFFF_FFFF_FFFF_FFFF)
else:
annotations = b"\xf9"

return op.encode() + serialized_args + length + variables + symbolic + annotations

# pylint:disable=attribute-defined-outside-init
def __a_init__(
self,
Expand Down Expand Up @@ -523,12 +403,6 @@ def _encoded_name(self):
# Collapsing and simplification
#

# def _models_for(self, backend):
# for a in self.args:
# backend.convert_expr(a)
# else:
# yield backend.convert(a)

def make_like(self: T, op: str, args: Iterable, **kwargs) -> T:
# Try to simplify the expression again
simplified = simplifications.simpleton.simplify(op, args) if kwargs.pop("simplify", False) is True else None
Expand Down
Loading

0 comments on commit 4f1c6e0

Please sign in to comment.