Skip to content

Commit

Permalink
Merge pull request #82 from angr/fix/arm_lr
Browse files Browse the repository at this point in the history
Fix/arm lr
  • Loading branch information
Kyle-Kyle authored Feb 14, 2024
2 parents f32174e + b252761 commit 496229a
Show file tree
Hide file tree
Showing 9 changed files with 359 additions and 69 deletions.
21 changes: 21 additions & 0 deletions angrop/chain_builder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .func_caller import FuncCaller
from .sys_caller import SysCaller
from .pivot import Pivot
from .shifter import Shifter
from .. import rop_utils

l = logging.getLogger("angrop.chain_builder")
Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(self, project, rop_gadgets, pivot_gadgets, syscall_gadgets, arch, b
if not SysCaller.supported_os(self.project.loader.main_object.os):
l.warning("%s is not a fully supported OS, SysCaller may not work on this OS",
self.project.loader.main_object.os)
self._shifter = Shifter(self)

def set_regs(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -133,6 +135,24 @@ def execve(self, path=None, path_addr=None):
return None
return self._sys_caller.execve(path=path, path_addr=path_addr)

def shift(self, length, preserve_regs=None):
"""
build a rop chain to shift the stack to a specific value
:param length: the length of sp you want to shift
:param preserve_regs: set of registers to preserve, e.g. ('eax', 'ebx')
"""
return self._shifter.shift(length, preserve_regs=preserve_regs)

def retsled(self, size, preserve_regs=None):
"""
create a ret-sled ROP chain where if the control flow falls into any point of the chain,
the control flow will be captured and maintained.
for example, a series of ret gadgets in x86/x86_64
:param size: the size of the retsled chain
:param preserve_regs: set of registers to preserve, e.g. ('eax', 'ebx')
"""
return self._shifter.retsled(size, preserve_regs=preserve_regs)

def set_badbytes(self, badbytes):
self.badbytes = badbytes

Expand All @@ -148,6 +168,7 @@ def update(self):
if self._sys_caller:
self._sys_caller.update()
self._pivot.update()
self._shifter.update()

# should also be able to do execve by providing writable memory
# todo pass values to setregs as symbolic variables
4 changes: 4 additions & 0 deletions angrop/chain_builder/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def _build_reg_setting_chain(self, gadgets, modifiable_memory_range, register_di
badbytes=self.badbytes)

# iterate through the stack values that need to be in the chain
# HACK: handle jump register separately because of angrop's broken
# assumptions on x86's ret behavior
if gadgets[-1].transit_type == 'jmp_reg':
stack_change += arch_bytes
for i in range(stack_change // bytes_per_pop):
sym_word = test_symbolic_state.memory.load(sp + bytes_per_pop*i, bytes_per_pop,
endness=self.project.arch.memory_endness)
Expand Down
83 changes: 41 additions & 42 deletions angrop/chain_builder/func_caller.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import logging

import angr
from angr.calling_conventions import SimRegArg, SimStackArg

from .builder import Builder
from .. import rop_utils
from ..errors import RopException
from ..rop_gadget import RopGadget

Expand All @@ -14,13 +16,23 @@ class FuncCaller(Builder):
calling convention
"""

def _func_call(self, func_gadget, cc, args, extra_regs=None, modifiable_memory_range=None, preserve_regs=None,
use_partial_controllers=False, needs_return=True):
def _func_call(self, func_gadget, cc, args, extra_regs=None, preserve_regs=None,
needs_return=True, **kwargs):
"""
func_gadget: the address of the function to invoke
cc: calling convention
args: the arguments to the function
extra_regs: what extra registers to set besides the function arguments, useful for invoking system calls
preserve_res: what registers preserve
needs_return: whether we need to cleanup stack after the function invocation,
setting this to False will result in a shorter chain
"""
assert type(args) in [list, tuple], "function arguments must be a list or tuple!"
if kwargs:
l.warning("passing deprecated arguments %s to angrop.chain_builder.FuncCaller", kwargs)

preserve_regs = set(preserve_regs) if preserve_regs else set()
arch_bytes = self.project.arch.bytes
registers = {} if extra_regs is None else extra_regs
if preserve_regs is None:
preserve_regs = []

# distinguish register and stack arguments
register_arguments = args
Expand All @@ -30,55 +42,42 @@ def _func_call(self, func_gadget, cc, args, extra_regs=None, modifiable_memory_r
stack_arguments = args[len(cc.ARG_REGS):]

# set register arguments
registers = {} if extra_regs is None else extra_regs
for arg, reg in zip(register_arguments, cc.ARG_REGS):
registers[reg] = arg
for reg in preserve_regs:
registers.pop(reg, None)
chain = self.chain_builder.set_regs(modifiable_memory_range=modifiable_memory_range,
use_partial_controllers=use_partial_controllers,
**registers)
chain = self.chain_builder.set_regs(**registers)

# invoke the function
chain.add_gadget(func_gadget)
for _ in range(func_gadget.stack_change//arch_bytes-1):
chain.add_value(self._get_fill_val())

# we are done here if there is no stack arguments
if not stack_arguments:
# we are done here if we don't need to return
if not needs_return:
return chain

# handle stack arguments:
# 1. we need to pop the arguments after use
# 2. push the stack arguments

# step 1: find a stack cleaner (a gadget that can pop all the stack args)
# with the smallest stack change
stack_cleaner = None
if needs_return:
for g in self.chain_builder.gadgets:
# just pop plz
if g.mem_reads or g.mem_writes or g.mem_changes:
continue
# at least can pop all the args
if g.stack_change < arch_bytes * (len(stack_arguments)+1):
continue

if stack_cleaner is None or g.stack_change < stack_cleaner.stack_change:
stack_cleaner = g

if stack_cleaner is None:
raise RopException(f"Fail to find a stack cleaner that can pop {len(stack_arguments)} words!")

# in case we can't find a stack_cleaner and we don't need to return
if stack_cleaner is None:
stack_cleaner = RopGadget(self._get_fill_val())
stack_cleaner.stack_change = arch_bytes * (len(stack_arguments)+1)

chain.add_gadget(stack_cleaner)
stack_arguments += [self._get_fill_val()]*(stack_cleaner.stack_change//arch_bytes - len(stack_arguments)-1)
for arg in stack_arguments:
chain.add_value(arg)

# now we need to cleanly finish the calling convention
# 1. handle stack arguments
# 2. handle function return address to maintain the control flow
if stack_arguments:
cleaner = self.chain_builder.shift((len(stack_arguments)+1)*arch_bytes) # +1 for itself
chain.add_gadget(cleaner._gadgets[0])
for arg in stack_arguments:
chain.add_value(arg)

# handle return address
if not isinstance(cc.RETURN_ADDR, (SimStackArg, SimRegArg)):
raise RopException(f"What is the calling convention {cc} I'm dealing with?")
if isinstance(cc.RETURN_ADDR, SimRegArg) and cc.RETURN_ADDR.reg_name != 'ip_at_syscall':
# now we know this function will return to a specific register
# so we need to set the return address before invoking the function
reg_name = cc.RETURN_ADDR.reg_name
shifter = self.chain_builder._shifter.shift(self.project.arch.bytes)
next_ip = rop_utils.cast_rop_value(shifter._gadgets[0].addr, self.project)
pre_chain = self.chain_builder.set_regs(**{reg_name: next_ip})
chain = pre_chain + chain
return chain

def func_call(self, address, args, **kwargs):
Expand Down
25 changes: 14 additions & 11 deletions angrop/chain_builder/reg_setter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,18 @@ def verify(self, chain, preserve_regs, registers):
return False
return True

def _maybe_fix_jump_chain(self, chain, preserve_regs):
all_changed_regs = set()
for g in chain._gadgets[:-1]:
all_changed_regs.update(g.changed_regs)
jump_reg = chain._gadgets[-1].jump_reg
if jump_reg in all_changed_regs:
return chain
shifter = self.chain_builder._shifter.shift(self.project.arch.bytes)
next_ip = rop_utils.cast_rop_value(shifter._gadgets[0].addr, self.project)
new = self.run(preserve_regs=preserve_regs, **{jump_reg: next_ip})
return new + chain

def run(self, modifiable_memory_range=None, use_partial_controllers=False, preserve_regs=None, **registers):
if len(registers) == 0:
return RopChain(self.project, None, badbytes=self.badbytes)
Expand Down Expand Up @@ -94,6 +106,8 @@ def run(self, modifiable_memory_range=None, use_partial_controllers=False, pres
chain = self._build_reg_setting_chain(gadgets, modifiable_memory_range,
registers, stack_change)
chain._concretize_chain_values()
if chain._gadgets[-1].transit_type == 'jmp_reg':
chain = self._maybe_fix_jump_chain(chain, preserve_regs)
if self.verify(chain, preserve_regs, registers):
#self._chain_cache[reg_tuple].append(gadgets)
return chain
Expand Down Expand Up @@ -487,14 +501,3 @@ def _check_if_sufficient_partial_control(self, gadget, reg, value):
return False
return True
return False

def _get_single_ret(self):
# start with a ret instruction
ret_addr = None
for g in self._reg_setting_gadgets:
if len(g.changed_regs) == 0 and len(g.mem_writes) == 0 and \
len(g.mem_reads) == 0 and len(g.mem_changes) == 0 and \
g.stack_change == self.project.arch.bytes:
ret_addr = g.addr
break
return ret_addr
140 changes: 140 additions & 0 deletions angrop/chain_builder/shifter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import logging
from collections import defaultdict

from .builder import Builder
from ..rop_chain import RopChain
from ..errors import RopException

l = logging.getLogger(__name__)

class Shifter(Builder):
"""
A class to find stack shifting gadgets, like add rsp; ret or pop chains
"""
def __init__(self, chain_builder):
super().__init__(chain_builder)

self.shift_gadgets = None
self.update()

def update(self):
self.shift_gadgets = self._filter_gadgets(self.chain_builder.gadgets)

def verify_shift(self, chain, length, preserve_regs):
arch_bytes = self.project.arch.bytes
init_sp = chain._blank_state.regs.sp.concrete_value - len(chain._values) * arch_bytes
state = chain.exec()
if state.regs.sp.concrete_value != init_sp + length + arch_bytes:
return False
for act in state.history.actions:
if act.type != 'reg' or act.action != 'write':
continue
reg_name = self.project.arch.register_size_names[act.offset, self.project.arch.bytes]
if reg_name in preserve_regs:
chain_str = '\n-----\n'.join([str(self.project.factory.block(g.addr).capstone)for g in chain._gadgets])
l.exception("Somehow angrop thinks \n%s\n can be used for the chain generation.", chain_str)
return False
return True

def verify_retsled(self, chain, size, preserve_regs):
if len(chain.payload_str()) != size:
return False
state = chain.exec()
for act in state.history.actions:
if act.type != 'reg' or act.action != 'write':
continue
reg_name = self.project.arch.register_size_names[act.offset, self.project.arch.bytes]
if reg_name == self.arch.stack_pointer:
continue
if reg_name in preserve_regs:
chain_str = '\n-----\n'.join([str(self.project.factory.block(g.addr).capstone)for g in chain._gadgets])
l.exception("Somehow angrop thinks \n%s\n can be used for the chain generation.", chain_str)
return False
return True

@staticmethod
def same_effect(g1, g2):
if g1.stack_change != g2.stack_change:
return False
if g1.transit_type != g2.transit_type:
return False
return True

def shift(self, length, preserve_regs=None):
preserve_regs = set(preserve_regs) if preserve_regs else set()
arch_bytes = self.project.arch.bytes

if length % arch_bytes != 0:
raise RopException("Currently, we do not support shifting misaligned sp change")
if length not in self.shift_gadgets or \
all(preserve_regs.intersection(x.changed_regs) for x in self.shift_gadgets[length]):
raise RopException("Encounter a shifting request that requires chaining multiple shifting gadgets " +
"together which is not support atm. Plz create an issue on GitHub " +
"so we can add the support!")
for g in self.shift_gadgets[length]:
if preserve_regs.intersection(g.changed_regs):
continue
try:
chain = RopChain(self.project, self.chain_builder)
chain.add_gadget(g)
for _ in range(g.stack_change//arch_bytes-1):
chain.add_value(self._get_fill_val())
if self.verify_shift(chain, length, preserve_regs):
return chain
except RopException:
continue

raise RopException(f"Failed to shift sp for {length:#x} bytes while preserving {preserve_regs}")

def retsled(self, size, preserve_regs=None):
preserve_regs = set(preserve_regs) if preserve_regs else set()
arch_bytes = self.project.arch.bytes

if size % arch_bytes != 0:
raise RopException("the size of a retsled must be word aligned")
if not self.shift_gadgets[arch_bytes]:
raise RopException("fail to find a ret-equivalent gadget in this binary!")
for g in self.shift_gadgets[arch_bytes]:
try:
chain = RopChain(self.project, self.chain_builder)
for _ in range(size//arch_bytes):
chain.add_gadget(g)
if self.verify_retsled(chain, size, preserve_regs):
return chain
except RopException:
continue

raise RopException(f"Failed to create a ret-sled sp for {size:#x} bytes while preserving {preserve_regs}")

def better_than(self, g1, g2):
if not self.same_effect(g1, g2):
return False
return g1.changed_regs.issubset(g2.changed_regs)

def _filter_gadgets(self, gadgets):
"""
filter gadgets having the same effect
"""
# we don't like gadgets with any memory accesses or jump gadgets
gadgets = [x for x in gadgets if x.num_mem_access == 0 and x.transit_type != 'jmp_reg']

# now do the standard filtering
gadgets = set(gadgets)
skip = set({})
while True:
to_remove = set({})
for g in gadgets-skip:
to_remove.update({x for x in gadgets-{g} if self.better_than(g, x)})
if to_remove:
break
skip.add(g)
if not to_remove:
break
gadgets -= to_remove

d = defaultdict(list)
for g in gadgets:
d[g.stack_change].append(g)
for x in d:
d[x] = sorted(d[x], key=lambda g: len(g.changed_regs))
return d
9 changes: 4 additions & 5 deletions angrop/rop_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,14 @@ def payload_code(self, constraints=None, print_instructions=True):
payload = ""
payload += 'chain = b""\n'

gadget_dict = {g.addr:g for g in self._gadgets}
concrete_vals = self._concretize_chain_values(constraints)
for value, rebased in concrete_vals:

instruction_code = ""
if print_instructions:
value_in_gadget = value
if value_in_gadget in gadget_dict:
asmstring = rop_utils.gadget_to_asmstring(self._p, gadget_dict[value_in_gadget])
if print_instructions :
sec = self._p.loader.find_section_containing(value)
if sec and sec.is_executable:
asmstring = rop_utils.addr_to_asmstring(self._p, value)
if asmstring != "":
instruction_code = "\t# " + asmstring

Expand Down
2 changes: 2 additions & 0 deletions angrop/rop_gadget.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def reg_set_same_effect(self, other):
return False
if self.reg_dependencies != other.reg_dependencies:
return False
if self.transit_type != other.transit_type:
return False
return True

def reg_set_better_than(self, other):
Expand Down
Loading

0 comments on commit 496229a

Please sign in to comment.