Skip to content

Commit

Permalink
[xla:emitters] support CPU in common lowering passes
Browse files Browse the repository at this point in the history
This paves the way for sharing these passes with CPU.

Note that lower_tensors requires more work before CPU tests can pass.

PiperOrigin-RevId: 716807708
  • Loading branch information
cota authored and Google-ML-Automation committed Jan 18, 2025
1 parent e141532 commit 2668841
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 24 deletions.
46 changes: 34 additions & 12 deletions xla/codegen/emitters/transforms/lower_tensors.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,20 @@ namespace scf = ::mlir::scf;
namespace ml = ::mlir::LLVM;
namespace vector = ::mlir::vector;

bool IsAMD(const se::DeviceDescription& device_description) {
struct CpuDeviceDescription {};
using DeviceDescription =
std::variant<se::DeviceDescription, CpuDeviceDescription>;

bool IsCpu(const DeviceDescription& device_description) {
return std::holds_alternative<CpuDeviceDescription>(device_description);
}

bool IsAMD(const DeviceDescription& device_description) {
if (IsCpu(device_description)) return false;
const auto& gpu_device_description =
std::get<se::DeviceDescription>(device_description);
return std::holds_alternative<se::RocmComputeCapability>(
device_description.gpu_compute_capability());
gpu_device_description.gpu_compute_capability());
}

Value GetDestinationBuffer(Value dest) {
Expand Down Expand Up @@ -755,7 +766,7 @@ Value CreateBitcast(mlir::ImplicitLocOpBuilder& b, mlir::Operation* op,
class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
public:
RewriteAtomicRMW(mlir::MLIRContext* context,
const se::DeviceDescription* device_description)
const DeviceDescription& device_description)
: OpRewritePattern<AtomicRMWOp>(context),
device_description_(device_description) {}

Expand All @@ -764,8 +775,10 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
auto modifier_parameters = GetAtomicModifierParameters(op);
if (modifier_parameters.has_value()) {
if (mlir::isa<mlir::VectorType>(modifier_parameters->first.getType()) &&
(IsAMD(*device_description_) ||
!device_description_->cuda_compute_capability().IsAtLeastHopper())) {
(IsCpu(device_description_) || IsAMD(device_description_) ||
!std::get<se::DeviceDescription>(device_description_)
.cuda_compute_capability()
.IsAtLeastHopper())) {
return rewriter.notifyMatchFailure(
op,
"atomic vectorization currently only supported on Hopper or later");
Expand All @@ -790,12 +803,14 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
std::optional<std::pair<mlir::Value, ml::AtomicBinOp>>
modifier_parameters,
mlir::PatternRewriter& rewriter) const {
CHECK(!IsCpu(device_description_)) << "Unimplemented";

Value modifier_arg = modifier_parameters->first;
Type element_type = modifier_arg.getType();
ml::AtomicBinOp atomic_bin_op = modifier_parameters->second;

Location loc = op.getLoc();
bool is_amd = IsAMD(*device_description_);
bool is_amd = IsAMD(device_description_);
llvm::StringRef sync_scope = is_amd ? "agent" : "";
mlir::ImplicitLocOpBuilder b(loc, rewriter);
Value addr = CreateGep(op.getInput(), op.getIndices(), b);
Expand All @@ -821,13 +836,15 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
}
case ml::AtomicBinOp::fadd: {
// TODO(b/336367154): Introduce an atomic_rmw op with the binOp attr.
const auto& gpu_device_description =
std::get<se::DeviceDescription>(device_description_);
return is_amd ? emitAMDAtomicFAdd(
loc, modifier_arg, addr, sync_scope,
device_description_->rocm_compute_capability(),
gpu_device_description.rocm_compute_capability(),
rewriter)
: emitNVidiaAtomicFAdd(
loc, modifier_arg, addr, sync_scope,
device_description_->cuda_compute_capability(),
gpu_device_description.cuda_compute_capability(),
rewriter, op);
}
case ml::AtomicBinOp::fmax: {
Expand Down Expand Up @@ -1148,7 +1165,7 @@ class RewriteAtomicRMW : public OpRewritePattern<AtomicRMWOp> {
});
}

const se::DeviceDescription* device_description_;
const DeviceDescription& device_description_;
};

class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
Expand All @@ -1161,15 +1178,19 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {

void runOnOperation() override {
if (!gpu_device_info_.empty()) {
CHECK(!is_cpu_target_);
se::GpuDeviceInfoProto device_info;
CHECK(tsl::protobuf::TextFormat::ParseFromString(gpu_device_info_,
&device_info));
device_description_ = se::DeviceDescription(device_info);
} else if (is_cpu_target_) {
device_description_ = CpuDeviceDescription{};
}

MLIRContext* mlir_context = &getContext();
mlir::RewritePatternSet tensor_patterns(mlir_context);

tensor_patterns.add<RewriteAtomicRMW>(mlir_context, &device_description_);
tensor_patterns.add<RewriteAtomicRMW>(mlir_context, device_description_);
tensor_patterns
.add<RewriteAllocateShared, RewriteNonScalarConstants,
RewriteSyncThreads, RewriteTensorExtract, RewriteTransferRead,
Expand Down Expand Up @@ -1220,15 +1241,16 @@ class LowerTensorsPass : public impl::LowerTensorsPassBase<LowerTensorsPass> {
signalPassFailure();
});
}
se::DeviceDescription device_description_;
DeviceDescription device_description_;
};

} // namespace

std::unique_ptr<::mlir::Pass> CreateLowerTensorsPass(
const std::string& gpu_device_info) {
const std::string& gpu_device_info, bool is_cpu_target) {
LowerTensorsPassOptions options;
options.gpu_device_info_ = gpu_device_info;
options.is_cpu_target_ = is_cpu_target;
return std::make_unique<LowerTensorsPass>(options);
}

Expand Down
32 changes: 22 additions & 10 deletions xla/codegen/emitters/transforms/lower_to_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/

#include <memory>
#include <string>
#include <utility>
#include <variant>

Expand Down Expand Up @@ -56,6 +57,10 @@ namespace se = ::stream_executor;
#define GEN_PASS_DEF_LOWERTOLLVMPASS
#include "xla/codegen/emitters/transforms/passes.h.inc"

struct CpuDeviceDescription {};
using DeviceDescription =
std::variant<se::DeviceDescription, CpuDeviceDescription>;

class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
public:
explicit LowerToLLVMPass(const LowerToLLVMPassOptions& options)
Expand All @@ -66,10 +71,13 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {

void runOnOperation() override {
if (!gpu_device_info_.empty()) {
CHECK(!is_cpu_target_);
se::GpuDeviceInfoProto device_info;
CHECK(tsl::protobuf::TextFormat::ParseFromString(gpu_device_info_,
&device_info));
device_description_ = se::DeviceDescription(device_info);
} else if (is_cpu_target_) {
device_description_ = CpuDeviceDescription{};
}
// Populate type conversions.
mlir::LowerToLLVMOptions llvm_opts(&getContext(),
Expand All @@ -83,14 +91,17 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
mlir::arith::populateArithExpandOpsPatterns(patterns);
mlir::arith::populateArithToLLVMConversionPatterns(type_converter,
patterns);
if (std::holds_alternative<se::RocmComputeCapability>(
device_description_.gpu_compute_capability())) {
mlir::populateGpuToROCDLConversionPatterns(
type_converter, patterns, mlir::gpu::amd::Runtime::Unknown);
mlir::configureGpuToROCDLConversionLegality(target);
} else {
mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns);
mlir::configureGpuToNVVMConversionLegality(target);
if (const auto* gpu_description =
std::get_if<se::DeviceDescription>(&device_description_)) {
if (std::holds_alternative<se::RocmComputeCapability>(
gpu_description->gpu_compute_capability())) {
mlir::populateGpuToROCDLConversionPatterns(
type_converter, patterns, mlir::gpu::amd::Runtime::Unknown);
mlir::configureGpuToROCDLConversionLegality(target);
} else {
mlir::populateGpuToNVVMConversionPatterns(type_converter, patterns);
mlir::configureGpuToNVVMConversionLegality(target);
}
}
mlir::populateFuncToLLVMConversionPatterns(type_converter, patterns);
mlir::populateVectorToLLVMConversionPatterns(type_converter, patterns);
Expand Down Expand Up @@ -122,15 +133,16 @@ class LowerToLLVMPass : public impl::LowerToLLVMPassBase<LowerToLLVMPass> {
}

private:
se::DeviceDescription device_description_;
DeviceDescription device_description_;
};

} // namespace

std::unique_ptr<::mlir::Pass> CreateLowerToLLVMPass(
const std::string& gpu_device_info) {
const std::string& gpu_device_info, bool is_cpu_target) {
LowerToLLVMPassOptions options;
options.gpu_device_info_ = gpu_device_info;
options.is_cpu_target_ = is_cpu_target;
return std::make_unique<LowerToLLVMPass>(options);
}

Expand Down
4 changes: 2 additions & 2 deletions xla/codegen/emitters/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@ namespace emitters {
std::unique_ptr<mlir::Pass> CreateExpandFloatOpsPass();
std::unique_ptr<mlir::Pass> CreateFlattenTensorsPass();
std::unique_ptr<mlir::Pass> CreateLowerTensorsPass(
const std::string& gpu_device_info = "");
const std::string& gpu_device_info = "", bool is_cpu_target_ = false);
std::unique_ptr<mlir::Pass> CreateLowerTensorsPass(
const stream_executor::DeviceDescription& device_description);
std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass(
const std::string& gpu_device_info = "");
const std::string& gpu_device_info = "", bool is_cpu_target_ = false);
std::unique_ptr<mlir::Pass> CreateLowerToLLVMPass(
const stream_executor::DeviceDescription& device_description);

Expand Down
4 changes: 4 additions & 0 deletions xla/codegen/emitters/transforms/passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def LowerTensorsPass : Pass<"xla-lower-tensors", "mlir::ModuleOp"> {
let options = [
Option<"gpu_device_info_", "gpu_device_info", "std::string", /*default=*/"",
"Serialized stream_executor::GPUDeviceInfo proto.">,
Option<"is_cpu_target_", "is_cpu_target", "bool", /*default=*/"false",
"Whether the pass applies to a CPU pipeline. If true, gpu_device_info_ must be empty.">,
];
let constructor = "CreateLowerTensorsPass()";
}
Expand Down Expand Up @@ -94,6 +96,8 @@ def LowerToLLVMPass :
let options = [
Option<"gpu_device_info_", "gpu_device_info", "std::string", /*default=*/"",
"Serialized stream_executor::GPUDeviceInfo proto.">,
Option<"is_cpu_target_", "is_cpu_target", "bool", /*default=*/"false",
"Whether the pass applies to a CPU pipeline. If true, gpu_device_info_ must be empty.">,
];
let constructor = "CreateLowerToLLVMPass()";
}
Expand Down

0 comments on commit 2668841

Please sign in to comment.