Skip to content

Commit

Permalink
fixes for insert_slice
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlimb committed Dec 13, 2024
1 parent f474abb commit b7a2259
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 18 deletions.
34 changes: 24 additions & 10 deletions lib/Dialect/NDArray/Extensions/MeshShardingExtensions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ static std::array<Value, 2> getShardSliceOffAndSz(
ValueRange myIdx, int64_t dim, ArrayRef<int64_t> meshShape,
ArrayRef<MeshAxesAttr> splitAxes, Value targetOffs,
ArrayRef<int64_t> srcShape, const SmallVector<OpFoldResult> &slcOffs,
const SmallVector<OpFoldResult> &slcSizes,
const SmallVector<OpFoldResult> &slcStrides,
const SmallVector<OpFoldResult> &haloSizes, const EasyI64 &zero,
const EasyI64 &one, OpBuilder &builder, Location loc) {
Expand All @@ -214,8 +215,12 @@ static std::array<Value, 2> getShardSliceOffAndSz(
std::tie(myOff, mySize) =
getOffsetAndSize(myID, zero, one, targetOffs, currPos, builder, loc);
} else {
myOff = getBaseShardDimOff(myID, numShards, extend, zero).get();
mySize = getBaseShardDimSize(myID, numShards, extend, one, zero).get();
auto myOff_ = getBaseShardDimOff(myID, numShards, extend, zero);
auto mySize_ = getBaseShardDimSize(myID, numShards, extend, one, zero);
auto slcSz = easyI64(loc, builder, slcSizes[dim]);
mySize_ = zero.max(slcSz - myOff_).min(mySize_);
myOff = myOff_.get();
mySize = mySize_.get();
}

// the global offset of the local shard is slice offset plus the computed
Expand Down Expand Up @@ -290,7 +295,7 @@ getLocalOffSzAndStrFromSlice(OP op, ArrayRef<int64_t> srcShape,
} else {
auto offAndSz = getShardSliceOffAndSz(
myIdx, dim, mesh.getShape(), splitAxes, targetOffs, srcShape, slcOffs,
slcStrides, haloSizes, zero, one, builder, loc);
slcSizes, slcStrides, haloSizes, zero, one, builder, loc);
lShardOffs.emplace_back(offAndSz[0]);
lShardSizes.emplace_back(offAndSz[1]);
}
Expand Down Expand Up @@ -439,6 +444,7 @@ struct InsertSliceShardingInterface
}
auto dstSharding = mlir::mesh::MeshSharding::get(shardingOption.mesh, res);
maybeInsertSourceShardingAnnotation(dstSharding, op->getOpOperand(0), b);
maybeInsertTargetShardingAnnotation(dstSharding, op->getResult(0), b);

return success();
}
Expand All @@ -449,7 +455,8 @@ struct InsertSliceShardingInterface
IRMapping &spmdizationMap,
SymbolTableCollection &symbolTableCollection,
OpBuilder &builder) const {
if (resultShardings.size() != 0) {
if (resultShardings.size() != 1 || operandShardings.size() < 2 ||
resultShardings[0] != operandShardings[0]) {
return failure();
}

Expand Down Expand Up @@ -493,22 +500,29 @@ struct InsertSliceShardingInterface
}

scf::IfOp ifOp = builder.create<scf::IfOp>(
loc, hasSize.get(), [&](OpBuilder &b, Location loc) {
(void)b.create<imex::ndarray::InsertSliceOp>(
loc, hasSize.get(),
[&](OpBuilder &b, Location loc) {
auto res = b.create<imex::ndarray::InsertSliceOp>(
loc, spmdizedOperands[0], spmdizedOperands[1], lShardOffs,
lShardSizes, lShardStrides);
b.create<scf::YieldOp>(loc);
b.create<scf::YieldOp>(loc, res.getResult());
},
[&](OpBuilder &b, Location loc) {
b.create<scf::YieldOp>(loc, spmdizedOperands[0]);
});
spmdizationMap.map(op, ifOp.getOperation());

builder.create<mlir::mesh::UpdateHaloOp>(
loc, spmdizedOperands[0].getType(), spmdizedOperands[0],
auto res = builder.create<mlir::mesh::UpdateHaloOp>(
loc, spmdizedOperands[0].getType(), ifOp.getResult(0),
dstSharding.getMeshAttr(),
mlir::mesh::MeshAxesArrayAttr::get(op->getContext(),
dstSharding.getSplitAxes()),
dstSharding.getDynamicHaloSizes(),
DenseI64ArrayAttr::get(op->getContext(),
dstSharding.getStaticHaloSizes()));

spmdizationMap.map(op->getResult(0), res->getResult(0));
spmdizationMap.map(op, res.getOperation());

return success();
}
};
Expand Down
19 changes: 14 additions & 5 deletions lib/Dialect/NDArray/IR/InsertSliceOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,31 @@ class InsertSliceOpConstantArgumentFolder final
return mlir::failure();

auto sourceType = insertSliceOp.getSourceType();
auto dstTnsrType = insertSliceOp.getDestinationType(); //.getTensorType();
auto dstTnsrType = insertSliceOp.getDestinationType();

// Create the new op in canonical form.
auto sourceTnsrType =
mlir::tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
insertSliceOp.getSourceType().getRank(), dstTnsrType, mixedOffsets,
mixedSizes, mixedStrides);
auto newSourceType = sourceType.cloneWith(sourceTnsrType.getShape(),
sourceTnsrType.getElementType());

mlir::Value toInsert = insertSliceOp.getSource();
if (newSourceType != sourceType) {
if (newSourceType.getRank() != sourceType.getRank())
if (sourceType.getRank() == 0) {
if (newSourceType.getRank() > 1) {
return mlir::failure();
}
} else if (newSourceType.getRank() != sourceType.getRank()) {
return mlir::failure();
mlir::OpBuilder::InsertionGuard g(rewriter);
toInsert = rewriter.create<mlir::tensor::CastOp>(insertSliceOp.getLoc(),
newSourceType, toInsert);
} else {
mlir::OpBuilder::InsertionGuard g(rewriter);
toInsert = rewriter.create<mlir::tensor::CastOp>(
insertSliceOp.getLoc(), newSourceType, toInsert);
}
}

rewriter.replaceOpWithNewOp<InsertOpTy>(
insertSliceOp, insertSliceOp.getDestination(), toInsert, mixedOffsets,
mixedSizes, mixedStrides);
Expand Down
9 changes: 6 additions & 3 deletions lib/Dialect/NDArray/Transforms/CoalesceShardOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ struct CoalesceShardOpsPass
return defOp;
} else if (auto op = ::mlir::dyn_cast<::mlir::DestinationStyleOpInterface>(
defOp)) {
return op.getNumDpsInputs() == 1 ? op.getDpsInits()[0].getDefiningOp()
: defOp;
return op.getNumDpsInits() == 1 ? getBaseArray(op.getDpsInits()[0])
: defOp;
} else if (auto op = ::mlir::dyn_cast<::imex::ndarray::SubviewOp>(defOp)) {
return getBaseArray(op.getSource());
} else if (auto op =
Expand Down Expand Up @@ -479,7 +479,10 @@ struct CoalesceShardOpsPass

// update shardOps of dependent Subview/InsertSliceOps
for (auto svShardOp : shardOps) {
svShardOp.getSrcMutable().assign(newShardOp.getResult());
assert(svShardOp->hasOneUse());
if (mlir::isa<::imex::ndarray::SubviewOp>(*svShardOp->user_begin())) {
svShardOp.getSrcMutable().assign(newShardOp.getResult());
}
svShardOp.getShardingMutable().assign(newSharding);
}
// barriers/halo-updates get inserted when InsertSliceOps (or other write
Expand Down

0 comments on commit b7a2259

Please sign in to comment.