Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add APX register support #261

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 32 additions & 23 deletions gematria/datasets/block_wrapper.S
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
// See the "WrappedFunc" typedef for the function signature this code has. Since
// it doesn't return we make no guarantees about preserving registers / stack
// frame, but we do use the normal calling convention for input parameters.
// TODO(orodley): Update to support r16-r31.
gematria_prologue:
mov rax, [rdi] // rax = vector_reg_width
mov rbx, [rdi + 8] // rbx = uses_upper_vector_regs
mov rcx, [rdi + 16] // rcx = uses_apx_regs
cmp rax, 0
je set_int_registers
cmp rax, 1
Expand All @@ -40,55 +40,64 @@ gematria_prologue:

set_xmm_registers:
.irp n, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
vpbroadcastq xmm\n, [rdi + 0x90 + (8 * \n)]
vpbroadcastq xmm\n, [rdi + 0x118 + (8 * \n)]
.endr
cmp rbx, 0
je set_int_registers
set_upper_xmm_registers:
.irp n, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
vpbroadcastq xmm\n, [rdi + 0x90 + (8 * \n)]
vpbroadcastq xmm\n, [rdi + 0x118 + (8 * \n)]
.endr
jmp set_int_registers

set_ymm_registers:
.irp n, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
vpbroadcastq ymm\n, [rdi + 0x90 + (8 * \n)]
vpbroadcastq ymm\n, [rdi + 0x118 + (8 * \n)]
.endr
cmp rbx, 0
je set_int_registers
.irp n, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
vpbroadcastq ymm\n, [rdi + 0x90 + (8 * \n)]
vpbroadcastq ymm\n, [rdi + 0x118 + (8 * \n)]
.endr
jmp set_int_registers

set_zmm_registers:
.irp n, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
vpbroadcastq zmm\n, [rdi + 0x90 + (8 * \n)]
vpbroadcastq zmm\n, [rdi + 0x118 + (8 * \n)]
.endr
cmp rbx, 0
je set_int_registers
.irp n, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
vpbroadcastq zmm\n, [rdi + 0x90 + (8 * \n)]
vpbroadcastq zmm\n, [rdi + 0x118 + (8 * \n)]
.endr

set_int_registers:
mov r15, rdi
mov rax, [r15 + 0x10]
mov rbx, [r15 + 0x18]
mov rcx, [r15 + 0x20]
mov rdx, [r15 + 0x28]
mov rsi, [r15 + 0x30]
mov rdi, [r15 + 0x38]
mov rsp, [r15 + 0x40]
mov rbp, [r15 + 0x48]
mov r8, [r15 + 0x50]
mov r9, [r15 + 0x58]
mov r10, [r15 + 0x60]
mov r11, [r15 + 0x68]
mov r12, [r15 + 0x70]
mov r13, [r15 + 0x78]
mov r14, [r15 + 0x80]
mov r15, [r15 + 0x88]
// Set APX registers first if necessary, as we need to overwrite RDI as part
// of setting the base registers.
cmp rcx, 0
je set_base_registers
.irp n, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31
mov r\n, [rdi + 0x98 + (8 * \n)]
.endr

set_base_registers:
mov rax, [r15 + 0x18]
mov rbx, [r15 + 0x20]
mov rcx, [r15 + 0x28]
mov rdx, [r15 + 0x30]
mov rsi, [r15 + 0x38]
mov rdi, [r15 + 0x40]
mov rsp, [r15 + 0x48]
mov rbp, [r15 + 0x50]
mov r8, [r15 + 0x58]
mov r9, [r15 + 0x60]
mov r10, [r15 + 0x68]
mov r11, [r15 + 0x70]
mov r12, [r15 + 0x78]
mov r13, [r15 + 0x80]
mov r14, [r15 + 0x88]
mov r15, [r15 + 0x90]

_gematria_prologue_size = . - gematria_prologue
.size gematria_prologue, _gematria_prologue_size
Expand Down
6 changes: 5 additions & 1 deletion gematria/datasets/find_accessed_addrs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ RawX64Regs ToRawRegs(
RawX64Regs raw_regs;
raw_regs.max_vector_reg_width = VectorRegWidth::NONE;
raw_regs.uses_upper_vector_regs = 0;
raw_regs.uses_apx_regs = 0;

for (const RegisterAndValue& reg_and_value : regs) {
if (reg_and_value.register_name() == "RAX") {
Expand Down Expand Up @@ -229,7 +230,10 @@ RawX64Regs ToRawRegs(
}

VectorRegWidth vector_width = VectorRegWidth::NONE;
if (absl::StartsWith(reg_and_value.register_name(), "XMM")) {
if (reg_and_value.register_name()[0] == 'R') {
raw_regs.apx_regs[number_suffix - 16] = reg_and_value.register_value();
raw_regs.uses_apx_regs = 1;
} else if (absl::StartsWith(reg_and_value.register_name(), "XMM")) {
vector_width = VectorRegWidth::XMM;
raw_regs.vector_regs[number_suffix] = reg_and_value.register_value();
} else if (absl::StartsWith(reg_and_value.register_name(), "YMM")) {
Expand Down
52 changes: 33 additions & 19 deletions gematria/datasets/find_accessed_addrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,26 +39,40 @@ enum class VectorRegWidth : uint64_t {
// bytes large, so that there will be no padding and calculating offsets by hand
// is easy (as is required in our assembly prologue code).
struct RawX64Regs {
VectorRegWidth max_vector_reg_width;
uint64_t uses_upper_vector_regs;
int64_t rax;
int64_t rbx;
int64_t rcx;
int64_t rdx;
int64_t rsi;
int64_t rdi;
int64_t rsp;
int64_t rbp;
int64_t r8;
int64_t r9;
int64_t r10;
int64_t r11;
int64_t r12;
int64_t r13;
int64_t r14;
int64_t r15;
VectorRegWidth max_vector_reg_width; // offset 0x0
// If true, the code uses at least one of the 16 extra vector registers
// defined in AVX-512. This is interpreted in combination with the max width.
// For example, if max_vector_reg_width is XMM and uses_upper_vector_regs is
// true, then the code uses XMM0-XMM31 but no YMM or ZMM registers.
//
// If this is false, then the latter 16 elements of vector_regs are unset and
// should be ignored.
uint64_t uses_upper_vector_regs; // offset 0x8
// If true, the code uses at least one of the 16 extra general purpose
// registers defined in APX.
//
// If this is false, then the elements of apx_regs are unset and should be
// ignored.
uint64_t uses_apx_regs; // offset 0x10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed internally: Let's add a note that this is int64_t and not bool to make the layout of the struct simpler.

int64_t rax; // offset 0x18
int64_t rbx; // offset 0x20
int64_t rcx; // offset 0x28
int64_t rdx; // offset 0x30
int64_t rsi; // offset 0x38
int64_t rdi; // offset 0x40
int64_t rsp; // offset 0x48
int64_t rbp; // offset 0x50
int64_t r8; // offset 0x58
int64_t r9; // offset 0x60
int64_t r10; // offset 0x68
int64_t r11; // offset 0x70
int64_t r12; // offset 0x78
int64_t r13; // offset 0x80
int64_t r14; // offset 0x88
int64_t r15; // offset 0x90

int64_t vector_regs[32];
int64_t apx_regs[16]; // offset 0x98
int64_t vector_regs[32]; // offset 0x118
};

// Given a basic block of code, attempt to determine what addresses that code
Expand Down
22 changes: 22 additions & 0 deletions gematria/datasets/find_accessed_addrs_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,5 +289,27 @@ TEST_F(FindAccessedAddrsAvx512Test, UpperZmmRegister) {
)pb"))));
}

class FindAccessedAddrsApxTest : public FindAccessedAddrsTest {
protected:
void SetUp() override {
if (!__builtin_cpu_supports("apxf")) {
GTEST_SKIP() << "Host doesn't support APX";
}

FindAccessedAddrsTest::SetUp();
}
};

TEST_F(FindAccessedAddrsApxTest, UpperGpr) {
EXPECT_THAT(
FindAccessedAddrsAsm(R"asm(
mov r23, [r30]
)asm"),
IsOkAndHolds(Partially(EqualsProto(R"pb(
accessed_blocks: 0x15000
initial_registers: { register_name: "R30" register_value: 0x15000 }
)pb"))));
}

} // namespace
} // namespace gematria
Loading