From d4899f7b9badb3a622d8e79405339d8f0796e149 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Tue, 10 Dec 2024 13:05:05 -0800 Subject: [PATCH] [jax:custom_partitioning] Make SdyShardingRule a user facing class. Move the parsing of a sharding rule string to a free function str_to_sdy_sharding_rule. Move the building of the MLIR sharding rule to a free function sdy_sharding_rule_to_mlir. PiperOrigin-RevId: 704818640 --- jax/_src/custom_partitioning_sharding_rule.py | 529 ++++++++++-------- .../custom_partitioning_sharding_rule_test.py | 218 +++++--- 2 files changed, 437 insertions(+), 310 deletions(-) diff --git a/jax/_src/custom_partitioning_sharding_rule.py b/jax/_src/custom_partitioning_sharding_rule.py index 5193c9126bb7..1e3e7fe60683 100644 --- a/jax/_src/custom_partitioning_sharding_rule.py +++ b/jax/_src/custom_partitioning_sharding_rule.py @@ -20,16 +20,124 @@ from jax._src.lib.mlir.dialects import sdy -_CompoundFactor = tuple[str, ...] -_DimMapping = tuple[str | _CompoundFactor, ...] - # A single character replacement for ... to simplify parsing. -_ELLIPSIS: str = "…" +BATCHING: str = "…" # A prefix for names of batching dimension factors, used for expanding the # leading ... into factors. _BATCHING_DIM_FACTOR_PREFIX = "?" +def _check_factor(factor:str): + """Validates a factor. + + A factor is a string starting with a letter and containing only letters, + digits, or underscores. + """ + if not factor[0].isalpha(): + raise ValueError(f"Factor names have to start with a letter, but got '{factor[0]}'") + for char in factor[1:]: + if char != "_" and not char.isdigit() and not char.isalpha(): + raise ValueError(f"Unknown character '{char}'") + +class CompoundFactor(tuple): + """Describes the factors for a compound factor. + + A compound factor should contain at least two factors, e.g. + * CompoundFactor('b', 'c'). + """ + def __init__(self, *factors): + if len(factors) < 2: + raise ValueError("A compound factor should contain at least two factors") + for factor in factors: + if not isinstance(factor, str): + raise ValueError(f"Each element of CompoundFactor must be a str, but got {type(factor)}") + if factor == BATCHING: + raise ValueError("Ellipsis can't be used in a compound factor") + else: + _check_factor(factor) + + def __new__(cls, *factors): + return tuple.__new__(CompoundFactor, factors) + + +class ArrayMapping(tuple): + """Describes the factors for an operand or result. + + Each element is either a factor or a CompoundFactor. A leading element can + also be BATCHING, which represents batching dimensions. examples: + * ArrayMapping('a') + * ArrayMapping('b', 'c') + * ArrayMapping(CompoundFactor('b', 'c'), 'd') + * ArrayMapping(BATCHING, CompoundFactor('b', 'c'), 'd') + """ + def __init__(self, *dim_mappings): + for i, d in enumerate(dim_mappings): + if not isinstance(d, str) and not isinstance(d, CompoundFactor): + raise ValueError( + "Each element of ArrayMapping must be a str or CompoundFactor, but" + f" got {type(d)}") + if isinstance(d, str): + if d == BATCHING: + if i != 0: + raise ValueError("Ellipsis can only be used at the beginning of a dimension") + else: + _check_factor(d) + + def __new__(cls, *dim_mappings): + return tuple.__new__(ArrayMapping, dim_mappings) + + +class SdyShardingRule: + """Represents a Shardy sharding rule. + + An SdyShardingRule contains the ArrayMappings for operands and results, and an + optional list of factor sizes. A factor is a name used in the ArrayMappings. + If a factor is only used in CompoundFactors, its size must be specified. + """ + operand_mappings: tuple[ArrayMapping, ...] + result_mappings: tuple[ArrayMapping, ...] + factor_sizes: dict[str, int] + + def __init__(self, operand_mappings: tuple[ArrayMapping, ...], + result_mappings: tuple[ArrayMapping, ...], **factor_sizes): + # Find all factors and mark whether their size can be inferred. + factors_inferrable = dict() + for value in operand_mappings + result_mappings: + for dim in value: + if isinstance(dim, str): + factors_inferrable[dim] = True + else: + for factor in dim: + if factor not in factors_inferrable.keys(): + factors_inferrable[factor] = False + + # Check that factors in factor_sizes are used in the rule. + for factor in factor_sizes: + if factor not in factors_inferrable: + raise ValueError( + f"Factor {factor} is not used in the rule, but size is provided") + + # Check that factors that are used for a whole dimension aren't in + # factor_sizes and factors that are never used for a whole dimension are + # in factor_sizes. + for factor, inferrable in factors_inferrable.items(): + if factor not in factor_sizes and not inferrable: + raise ValueError( + f"Factor {factor} is only used in compound factors; must specify" + " its size") + if factor in factor_sizes and inferrable: + raise ValueError( + f"Factor {factor} represents a whole dimension; do not specify its" + " size") + + self.operand_mappings = operand_mappings + self.result_mappings = result_mappings + self.factor_sizes = factor_sizes + + def __str__(self): + return f"SdyShardingRule({self.operand_mappings}, {self.result_mappings}, {self.factor_sizes})" + + def _get_batching_dim_factor_name(batch_dim_order : int): """Constructs a factor name for a batching dimension. @@ -42,18 +150,18 @@ def _get_batching_dim_factor_name(batch_dim_order : int): def _parse_values( rule: str, -) -> tuple[_DimMapping, ...]: +) -> tuple[ArrayMapping, ...]: """Parses the LHS or RHS of an Einsum notation like string. Converts each operand or result in the Einsum notation like string to a tuple - of _DimMapping. This very closely follows how einops parses their rules in + of ArrayMapping. This very closely follows how einops parses their rules in einops/parsing.py. Args: rule: The Einsum notation for the operands or results of an operation. Returns: - The tuple of values. + The tuple of ArrayMapping. Raises: ValueError: If the rule is not balanced or contains unknown characters. @@ -65,10 +173,10 @@ def _parse_values( # Similar to einops rules, an empty LHS/RHS has a single scalar value. if not rule: - return ((),) + return (ArrayMapping(),) all_values = [] - # Represent all dimensions of an value. When an value[0]==_ELLIPSIS, the + # Represent all dimensions of an value. When an value[0]==BATCHING, the # value may have 0 or more leading dimensions. value = [] current_factor = None @@ -84,12 +192,12 @@ def add_factor(x): current_compound_dim.append(x) for char in rule: - if char == _ELLIPSIS: + if char == BATCHING: if (current_factor is not None or current_compound_dim is not None or value): raise ValueError( "Ellipsis can only be used at the beginning of a dimension") - add_factor(_ELLIPSIS) + add_factor(BATCHING) continue if char in "(), ": if current_factor is not None: @@ -106,10 +214,10 @@ def add_factor(x): raise ValueError("Brackets are not balanced") if len(current_compound_dim) <= 1: raise ValueError("Brackets should contain at least two factors") - value.append(tuple(current_compound_dim)) + value.append(CompoundFactor(*current_compound_dim)) current_compound_dim = None elif char == ",": - all_values.append(tuple(value)) + all_values.append(ArrayMapping(*value)) value = [] elif char == "_" or char.isdigit() or char.isalpha(): if current_factor is None: @@ -125,256 +233,203 @@ def add_factor(x): raise ValueError(f"Brackets are not balanced in rule: '{rule}'") if current_factor is not None: add_factor(current_factor) - all_values.append(tuple(value)) + all_values.append(ArrayMapping(*value)) return tuple(all_values) +def str_to_sdy_sharding_rule(rule: str, **factor_sizes) -> SdyShardingRule: + """Constructs a SdyShardingRule object from the Einsum notation like string. -class SdyShardingRule: - """A representation for Shardy sharding rule. - - A SdyShardingRule includes an Enisum notation like string and an optional - list of factor sizes. A factor is a name in the Einsum notation. If a factor - is only used in compound factors, its size must be specified. + This is done by verifying that the input Einsum notation like string and + with optional factor sizes represents a valid sharding rule and converting + it to an internal representation. - SdyShardingRule examples: + Args: + rule: The Einsum notation like string for an operation. + **factor_sizes: The optional factor sizes. - * Contracting dim matmul AB@BC->AC: SdyShardingRule('i j, j k -> i k') - * Batching matmul: SdyShardingRule('... i j, ... j k -> ... i k') - * A reshape (8,) -> (4, 2): SdyShardingRule('(i j) -> i j') - * Another reshape (4, 2) -> (2, 4): SdyShardingRule('(i j) -> (j i)`, i=4, j=2) - * An elementwise add of any dimensions x + y -> z: SdyShardingRule('..., ... -> ...') + Raises: + ValueError: If there is any problem with the rule or factor_sizes. """ - - def __init__(self, rule: str, **factor_sizes): - """Constructs a SdyShardingRule object from the Einsum notation like string. - - This is done by verifying that the input Einsum notation like string and - with optional factor sizes represents a valid sharding rule and converting - it to an internal representation. - - Args: - rule: The Einsum notation like string for an operation. - **factor_sizes: The optional factor sizes. - - Raises: - ValueError: If there is any problem with the rule or factor_sizes. - """ - if not isinstance(rule, str): - raise TypeError(f"rule must be a str, but got {type(rule)}") - if not all(isinstance(size, int) for size in factor_sizes.values()): - raise TypeError( - f"factor_sizes must be a dict of str to int, but got {factor_sizes}") - - # Replace ... with a single char to simplify parsing. - if _ELLIPSIS in rule: - raise ValueError(f"Unknown character '{_ELLIPSIS}'") + if not isinstance(rule, str): + raise TypeError(f"rule must be a str, but got {type(rule)}") + if not all(isinstance(size, int) for size in factor_sizes.values()): + raise TypeError( + f"factor_sizes must be a dict of str to int, but got {factor_sizes}") + + # Replace ... with a single char to simplify parsing. + if BATCHING in rule: + raise ValueError(f"Unknown character '{BATCHING}'") + if "." in rule: + rule = rule.replace("...", BATCHING) if "." in rule: - rule = rule.replace("...", _ELLIPSIS) - if "." in rule: - raise ValueError("Character '.' must be used inside ellipsis '...'") + raise ValueError("Character '.' must be used inside ellipsis '...'") - try: - operands, results = rule.split("->") - except ValueError as e: - raise ValueError(f"There is no -> in rule: '{rule}'") from e + try: + operands, results = rule.split("->") + except ValueError as e: + raise ValueError(f"There is no -> in rule: '{rule}'") from e - self.operands = _parse_values(operands) - self.results = _parse_values(results) + operand_mappings = _parse_values(operands) + result_mappings = _parse_values(results) - # Find all factors and mark whether their size can be inferred. - factors_inferrable = dict() - for value in self.operands + self.results: - for dim in value: - if dim == _ELLIPSIS: - continue - if isinstance(dim, str): - factors_inferrable[dim] = True - else: - for factor in dim: - if factor not in factors_inferrable.keys(): - factors_inferrable[factor] = False + return SdyShardingRule(operand_mappings, result_mappings, **factor_sizes) - # Check that factors in factor_sizes are used in the rule. - for factor in factor_sizes: - if factor not in factors_inferrable: - raise ValueError( - f"Factor {factor} is not used in the rule, but size is provided") - # Check that factors that are used for a whole dimension aren't in - # factor_sizes and factors that are never used for a whole dimension are - # in factor_sizes. - for factor, inferrable in factors_inferrable.items(): - if factor not in factor_sizes and not inferrable: - raise ValueError( - f"Factor {factor} is only used in compound factors; must specify" - " its size") - if factor in factor_sizes and inferrable: - raise ValueError( - f"Factor {factor} represents a whole dimension; do not specify its" - " size") +def sdy_sharding_rule_to_mlir( + rule: SdyShardingRule, + operand_types: list[ir.Type], + result_types: list[ir.Type],) -> ir.Attribute: + """Builds the MLIR representation for the sharding rule. - self.factor_sizes = factor_sizes + This is done by verifying that the rule is consistent with the types of + the operation and converting the Einsum notation like string to + OpShardingRuleAttr. + """ + if len(rule.operand_mappings) != len(operand_types): + raise ValueError( + f"Sharding rule has {len(rule.operand_mappings)} operands, but the operation" + f" has {len(operand_types)} operands") + if len(rule.result_mappings) != len(result_types): + raise ValueError( + f"Sharding rule has {len(rule.result_mappings)} results, but the operation" + f" has {len(result_types)} results") + + factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict() + types = operand_types + result_types + UNKNOWN = -1 # Representation for unknown factor size or factor index. + + def get_message_for_value(i): + if i >= len(operand_types): + return f"{i - len(operand_types)}th result" + else: + return f"{i}th operand" - def __str__(self): - return f"SdyShardingRule({self.operands}, {self.results}, {self.factor_sizes})" + def get_rank_for_value(i): + return ir.ShapedType(types[i]).rank + + def get_size_for_value_dim(i, j): + return ir.ShapedType(types[i]).shape[j] - def build( - self, - operand_types: list[ir.Type], - result_types: list[ir.Type],) -> ir.Attribute: - """Builds the MLIR representation for the sharding rule. + def add_factor(factor, size): + """Adds a factor to factors_to_indices_sizes. - This is done by verifying that the rule is consistent with the types of - the operation and converting the Einsum notation like string to - OpShardingRuleAttr. + `size` may be a dimensions size, a user specified factor size, or UNKNOWN + if a factor is first used as in a compound factor and then used for a + whole dimension. """ - if len(self.operands) != len(operand_types): + factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN]) + if factor_index != UNKNOWN: + # Not the first time seeing the factor. + if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size: + factor_or_batching_dim = ( + f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor + else f"Batching dimension {factor[1:]}") + raise ValueError( + f"{factor_or_batching_dim} corresponds to two sizes:" + f" {factor_size} and {size}") + if size != UNKNOWN and factor_size == UNKNOWN: + factors_to_indices_sizes[factor] = [factor_index, size] + else: + # First time seeing the factor. + factor_index = len(factors_to_indices_sizes) + factors_to_indices_sizes[factor] = [factor_index, size] + + def add_batching_dim_factor(batch_dim_order, factor_size): + ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order) + add_factor(ellipsis_batch_dim_name, factor_size) + + def build_dim_mapping_for_compound_factors(i, j, factors): + accumulated_size = 1 + all_indices = [] + for factor in factors: + factor_index, factor_size = factors_to_indices_sizes[factor] + accumulated_size *= factor_size + all_indices.append(factor_index) + + dim_size = get_size_for_value_dim(i, j) + if accumulated_size != dim_size: raise ValueError( - f"Sharding rule has {len(self.operands)} operands, but the operation" - f" has {len(operand_types)} operands" - ) - if len(self.results) != len(result_types): + f"{get_message_for_value(i)} actual size {dim_size} doesn't match" + f" the size {accumulated_size} derived from the compound factors" + f" {factors}") + + return sdy.DimMappingAttr.get(factor_indices=all_indices) + + # Add factors and their sizes in the order they appear in the rule, + # including the batching dimensions represented by ellipsis. + ellipsis_rank = None + for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings): + value = tuple(mapping) + if value and value[0] == BATCHING: + has_batching = True + value = value[1:] + else: + has_batching = False + rule_rank = len(value) + op_rank = get_rank_for_value(i) + # The number of dimensions represented by ellipsis. + current_batching_rank = 0 + if has_batching and op_rank >= rule_rank: + current_batching_rank = op_rank - rule_rank + if has_batching: + if ellipsis_rank is None: + ellipsis_rank = current_batching_rank + elif ellipsis_rank != current_batching_rank: + raise ValueError( + "Ellipsis represents different number of leading dimensions" + f" {ellipsis_rank} and {current_batching_rank}") + rule_rank += current_batching_rank + if rule_rank != op_rank: + msg = get_message_for_value(i) raise ValueError( - f"Sharding rule has {len(self.results)} results, but the operation" - f" has {len(result_types)} results" - ) + f"Sharding rule {msg} has rank {rule_rank}, but the operation" + f" {msg} has rank {op_rank}") - factors_to_indices_sizes: OrderedDict[str, list[int]] = OrderedDict() - types = operand_types + result_types - UNKNOWN = -1 # Representation for unknown factor size or factor index. + for j in range(current_batching_rank): + add_batching_dim_factor(j, get_size_for_value_dim(i, j)) - def get_message_for_value(i): - if i >= len(operand_types): - return f"{i - len(operand_types)}th result" + for j, dim in enumerate(value): + if isinstance(dim, str): + add_factor(dim, get_size_for_value_dim(i, j + current_batching_rank)) else: - return f"{i}th operand" - - def get_rank_for_value(i): - return ir.ShapedType(types[i]).rank - - def get_size_for_value_dim(i, j): - return ir.ShapedType(types[i]).shape[j] - - def add_factor(factor, size): - """Adds a factor to factors_to_indices_sizes. - - `size` may be a dimensions size, a user specified factor size, or UNKNOWN - if a factor is first used as in a compound factor and then used for a - whole dimension. - """ - factor_index, factor_size = factors_to_indices_sizes.get(factor, [UNKNOWN, UNKNOWN]) - if factor_index != UNKNOWN: - # Not the first time seeing the factor. - if size != UNKNOWN and factor_size != UNKNOWN and factor_size != size: - factor_or_batching_dim = ( - f"Factor {factor}" if _BATCHING_DIM_FACTOR_PREFIX not in factor - else f"Batching dimension {factor[1:]}") - raise ValueError( - f"{factor_or_batching_dim} corresponds to two sizes:" - f" {factor_size} and {size}") - if size != UNKNOWN and factor_size == UNKNOWN: - factors_to_indices_sizes[factor] = [factor_index, size] - else: - # First time seeing the factor. - factor_index = len(factors_to_indices_sizes) - factors_to_indices_sizes[factor] = [factor_index, size] - - def add_batching_dim_factor(batch_dim_order, factor_size): - ellipsis_batch_dim_name = _get_batching_dim_factor_name(batch_dim_order) - add_factor(ellipsis_batch_dim_name, factor_size) - - def build_dim_mapping_for_compound_factors(i, j, factors): - accumulated_size = 1 - all_indices = [] - for factor in factors: - factor_index, factor_size = factors_to_indices_sizes[factor] - accumulated_size *= factor_size - all_indices.append(factor_index) - - dim_size = get_size_for_value_dim(i, j) - if accumulated_size != dim_size: - raise ValueError( - f"{get_message_for_value(i)} actual size {dim_size} doesn't match" - f" the size {accumulated_size} derived from the compound factors" - f" {factors}") - - return sdy.DimMappingAttr.get(factor_indices=all_indices) - - # Add factors and their sizes in the order they appear in the rule, - # including the batching dimensions represented by ellipsis. - ellipsis_rank = None - for i, value in enumerate(self.operands + self.results): - if value and value[0] == _ELLIPSIS: - has_ellipsis = True - value = value[1:] + for factor in dim: + add_factor(factor, rule.factor_sizes.get(factor, UNKNOWN)) + + # Build the tensor mappings for each operand and result. + tensor_mappings = [] + for i, mapping in enumerate(rule.operand_mappings + rule.result_mappings): + value = tuple(mapping) + dim_mappings = [] + + if value and value[0] == BATCHING: + value = value[1:] + if ellipsis_rank is None: + current_batching_rank = 0 else: - has_ellipsis = False - rule_rank = len(value) - op_rank = get_rank_for_value(i) - # The number of dimensions represented by ellipsis. - current_ellipsis_rank = 0 - if has_ellipsis and op_rank > rule_rank: - current_ellipsis_rank = op_rank - rule_rank - if has_ellipsis: - if ellipsis_rank is None: - ellipsis_rank = current_ellipsis_rank - elif ellipsis_rank != current_ellipsis_rank: - raise ValueError( - "Ellipsis represents different number of leading dimensions" - f" {ellipsis_rank} and {current_ellipsis_rank}") - rule_rank += current_ellipsis_rank - if rule_rank != op_rank: - msg = get_message_for_value(i) - raise ValueError( - f"Sharding rule {msg} has rank {rule_rank}, but the operation" - f" {msg} has rank {op_rank}") - - for j in range(current_ellipsis_rank): - add_batching_dim_factor(j, get_size_for_value_dim(i, j)) - - for j, dim in enumerate(value): - if isinstance(dim, str): - add_factor( - dim, get_size_for_value_dim(i, j + current_ellipsis_rank)) - else: - for factor in dim: - add_factor(factor, self.factor_sizes.get(factor, UNKNOWN)) + current_batching_rank = ellipsis_rank + else: + current_batching_rank = 0 - # Build the tensor mappings for each operand and result. - tensor_mappings = [] - for i, value in enumerate(self.operands + self.results): - dim_mappings = [] + for j in range(current_batching_rank): + dim_mappings.append( + sdy.DimMappingAttr.get(factor_indices=[ + factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]])) - if value and value[0] == _ELLIPSIS: - value = value[1:] - if ellipsis_rank is None: - current_ellipsis_rank = 0 - else: - current_ellipsis_rank = ellipsis_rank + for j, dim in enumerate(value): + if isinstance(dim, str): + dim_mappings.append( + sdy.DimMappingAttr.get( + factor_indices=[factors_to_indices_sizes[dim][0]])) else: - current_ellipsis_rank = 0 - - for j in range(current_ellipsis_rank): dim_mappings.append( - sdy.DimMappingAttr.get(factor_indices=[ - factors_to_indices_sizes[_get_batching_dim_factor_name(j)][0]])) + build_dim_mapping_for_compound_factors( + i, j + current_batching_rank, dim)) - for j, dim in enumerate(value): - if isinstance(dim, str): - dim_mappings.append( - sdy.DimMappingAttr.get( - factor_indices=[factors_to_indices_sizes[dim][0]])) - else: - dim_mappings.append( - build_dim_mapping_for_compound_factors( - i, j + current_ellipsis_rank, dim)) - - tensor_mappings.append( - sdy.TensorMappingAttr.get(dim_mappings=dim_mappings)) - - op_sharding_rule = sdy.OpShardingRuleAttr.get( - factor_sizes=[item[1] for item in factors_to_indices_sizes.values()], - operand_mappings=tensor_mappings[0:len(operand_types)], - result_mappings=tensor_mappings[len(operand_types):]) - return op_sharding_rule + tensor_mappings.append( + sdy.TensorMappingAttr.get(dim_mappings=dim_mappings)) + + return sdy.OpShardingRuleAttr.get( + factor_sizes=[item[1] for item in factors_to_indices_sizes.values()], + operand_mappings=tensor_mappings[0:len(operand_types)], + result_mappings=tensor_mappings[len(operand_types):]) diff --git a/tests/custom_partitioning_sharding_rule_test.py b/tests/custom_partitioning_sharding_rule_test.py index 2aac4e04862f..3aed16510a4f 100644 --- a/tests/custom_partitioning_sharding_rule_test.py +++ b/tests/custom_partitioning_sharding_rule_test.py @@ -16,148 +16,189 @@ from jax._src import test_util as jtu from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import sdy -from jax._src.custom_partitioning_sharding_rule import SdyShardingRule +from jax._src.custom_partitioning_sharding_rule import ArrayMapping, BATCHING, CompoundFactor, sdy_sharding_rule_to_mlir, str_to_sdy_sharding_rule, SdyShardingRule from jax._src.lib.mlir.dialects import hlo as stablehlo class SdyShardingRuleTest(jtu.JaxTestCase): + def test_compound_factor_not_enough_factors(self): + with self.assertRaisesRegex(ValueError, "A compound factor should contain at least two factors"): + CompoundFactor("i") + + def test_compound_factor_batching_now_allowed(self): + with self.assertRaisesRegex(ValueError, "Ellipsis can't be used in a compound factor"): + CompoundFactor(BATCHING, "i") + + def test_compound_factor_element_not_a_str(self): + with self.assertRaisesRegex(ValueError, "Each element of CompoundFactor must be a str"): + CompoundFactor("i", 2) + + def test_compound_factor_str(self): + c = CompoundFactor("i", "j", "k") + self.assertEqual(str(c), "('i', 'j', 'k')") + + def test_value_mapping_element_not_a_str_or_compound_factor(self): + with self.assertRaisesRegex(ValueError, "Each element of ArrayMapping must be a str or CompoundFactor"): + ArrayMapping(CompoundFactor("i", "j"), 3) + + def test_value_mapping_factor_name_not_start_with_letter(self): + with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"): + ArrayMapping("3i", "j") + + def test_value_mapping_ellipsis_not_first(self): + with self.assertRaisesRegex(ValueError, "Ellipsis can only be used at the beginning of a dimension"): + ArrayMapping("i_j", BATCHING) + + def test_value_mapping_str(self): + v = ArrayMapping(BATCHING, "m", CompoundFactor("i", "j"), "k") + self.assertEqual(str(v), f"('{BATCHING}', 'm', ('i', 'j'), 'k')") + + def test_sdy_sharding_rule_factor_size_not_used(self): + with self.assertRaisesRegex(ValueError, "Factor k is not used"): + SdyShardingRule(("i",), ("j",), k=10) + + def test_sdy_sharding_rule_factor_sizes_missing(self): + with self.assertRaisesRegex( + ValueError, + "Factor k is only used in compound factors; must specify its size"): + SdyShardingRule((ArrayMapping("i"), ArrayMapping("j")), + (ArrayMapping(CompoundFactor("j", "k")),)) + + def test_sdy_sharding_rule_factor_size_not_necessary(self): + with self.assertRaisesRegex( + ValueError, + "Factor i represents a whole dimension; do not specify its size"): + SdyShardingRule((ArrayMapping("i"),), (ArrayMapping("j"),), i=10) + + def test_sdy_sharding_rule_compound_factor_size_not_necessary(self): + with self.assertRaisesRegex( + ValueError, + "Factor i represents a whole dimension; do not specify its size"): + SdyShardingRule((ArrayMapping(CompoundFactor("i", "j")),), + (ArrayMapping("i"),), i=10, j=20) + + def test_sdy_sharding_rule_str(self): + r = SdyShardingRule((ArrayMapping("i"), ArrayMapping("j")), + (ArrayMapping(CompoundFactor("j", "k")),), k=10) + self.assertEqual(str(r), "SdyShardingRule((('i',), ('j',)), ((('j', 'k'),),), {'k': 10})") + + +class StrToSdyShardingRuleTest(jtu.JaxTestCase): def test_rule_is_not_a_str(self): with self.assertRaisesRegex(TypeError, "rule must be a str"): - SdyShardingRule(1) + str_to_sdy_sharding_rule(1) def test_factor_sizes_is_not_a_proper_dict(self): with self.assertRaisesRegex( TypeError, "factor_sizes must be a dict of str to int"): - SdyShardingRule("i->j", i="j") + str_to_sdy_sharding_rule("i->j", i="j") def test_sharding_rule_ellipsis_not_complete(self): with self.assertRaisesRegex( ValueError, "Character '.' must be used inside ellipsis '...'"): - SdyShardingRule(".i -> j") + str_to_sdy_sharding_rule(".i -> j") def test_sharding_rule_invalid_factor_name(self): with self.assertRaisesRegex(ValueError, "Factor names have to start with a letter"): - SdyShardingRule("2i -> j") + str_to_sdy_sharding_rule("2i -> j") def test_sharding_rule_missing_results(self): with self.assertRaisesRegex(ValueError, "There is no -> in rule"): - SdyShardingRule("i") + str_to_sdy_sharding_rule("i") def test_sharding_rule_inbalenced_brackets(self): with self.assertRaisesRegex(ValueError, "Brackets are not balanced"): - SdyShardingRule("i j, k)->j") + str_to_sdy_sharding_rule("i j, k)->j") def test_sharding_rule_inbalenced_brackets2(self): with self.assertRaisesRegex(ValueError, "Brackets are not balanced"): - SdyShardingRule("i (j k->j") + str_to_sdy_sharding_rule("i (j k->j") def test_sharding_rule_empty_compound_dim(self): with self.assertRaisesRegex( ValueError, "Brackets should contain at least two factors"): - SdyShardingRule("i ( ) j k->j") + str_to_sdy_sharding_rule("i ( ) j k->j") def test_sharding_rule_one_factorcompound_dim(self): with self.assertRaisesRegex( ValueError, "Brackets should contain at least two factors"): - SdyShardingRule("i (j ) k->j") + str_to_sdy_sharding_rule("i (j ) k->j") def test_sharding_rule_nested_brackets(self): with self.assertRaisesRegex( ValueError, "Compound factors should be one level"): - SdyShardingRule("i (j (k))->j") + str_to_sdy_sharding_rule("i (j (k))->j") def test_sharding_rule_unknown_char(self): with self.assertRaisesRegex(ValueError, "Unknown character"): - SdyShardingRule("i; j->j") + str_to_sdy_sharding_rule("i; j->j") def test_sharding_rule_unknown_single_char_ellipse(self): with self.assertRaisesRegex(ValueError, "Unknown character"): - SdyShardingRule("…j->…j") + str_to_sdy_sharding_rule("…j->…j") def test_sharding_rule_ellipsis_not_leading_dim(self): with self.assertRaisesRegex( ValueError, "Ellipsis can only be used at the beginning of a dimension"): - SdyShardingRule("i ... -> j") + str_to_sdy_sharding_rule("i ... -> j") def test_sharding_rule_ellipsis_inside_compound_dim(self): with self.assertRaisesRegex( ValueError, "Ellipsis can only be used at the beginning of a dimension"): - SdyShardingRule("i, (..., j) -> j") + str_to_sdy_sharding_rule("i, (..., j) -> j") def test_sharding_rule_scalar_operand_scalar_result(self): - rule = SdyShardingRule("->") + rule = str_to_sdy_sharding_rule("->") self.assertEqual(str(rule), "SdyShardingRule(((),), ((),), {})") def test_sharding_rule_one_scalar_operand(self): - rule = SdyShardingRule("i j, , k->j") + rule = str_to_sdy_sharding_rule("i j, , k->j") self.assertEqual( str(rule), "SdyShardingRule((('i', 'j'), (), ('k',)), (('j',),), {})") - def test_sharding_rule_factor_size_not_used(self): - with self.assertRaisesRegex(ValueError, "Factor k is not used"): - SdyShardingRule("i->j", k=10) - - def test_sharding_rule_factor_size_not_necessary(self): - with self.assertRaisesRegex( - ValueError, - "Factor i represents a whole dimension; do not specify its size"): - SdyShardingRule("i->j", i=10) - - def test_sharding_rule_compound_factor_size_not_necessary(self): - with self.assertRaisesRegex( - ValueError, - "Factor i represents a whole dimension; do not specify its size"): - SdyShardingRule("(i j) -> i", i=10, j=20) - - def test_sharding_rule_factor_sizes_missing(self): - with self.assertRaisesRegex( - ValueError, - "Factor k is only used in compound factors; must specify its size"): - SdyShardingRule("i j -> (j k)") - def test_sharding_rule_factor_elementwise_add(self): - rule = SdyShardingRule("... i j, ...i j -> ...i j") + rule = str_to_sdy_sharding_rule("... i j, ...i j -> ...i j") self.assertEqual( str(rule), "SdyShardingRule((('…', 'i', 'j'), ('…', 'i', 'j')), (('…', 'i'," " 'j'),), {})") def test_sharding_rule_factor_vector_scalar_add(self): - rule = SdyShardingRule("...i, -> ...i") + rule = str_to_sdy_sharding_rule("...i, -> ...i") self.assertEqual( str(rule), "SdyShardingRule((('…', 'i'), ()), (('…', 'i'),), {})") def test_sharding_rule_factor_reshape_combining(self): - rule = SdyShardingRule("i j -> (i j)") + rule = str_to_sdy_sharding_rule("i j -> (i j)") self.assertEqual( str(rule), "SdyShardingRule((('i', 'j'),), ((('i', 'j'),),), {})") def test_sharding_rule_factor_reshape_reordering(self): - rule = SdyShardingRule("(j i) -> (i j)", i=10, j=20) + rule = str_to_sdy_sharding_rule("(j i) -> (i j)", i=10, j=20) self.assertEqual( str(rule), "SdyShardingRule(((('j', 'i'),),), ((('i', 'j'),),), {'i': 10, 'j':" " 20})") def test_sharding_rule_factor_compound_then_individual(self): - rule = SdyShardingRule("(i j) (j k) i -> j k") + rule = str_to_sdy_sharding_rule("(i j) (j k) i -> j k") self.assertEqual( str(rule), "SdyShardingRule(((('i', 'j'), ('j', 'k'), 'i'),), (('j', 'k'),), {})") def test_sharding_rule_factor_individual_then_compound(self): - rule = SdyShardingRule("i j k -> (i j) (j k)") + rule = str_to_sdy_sharding_rule("i j k -> (i j) (j k)") self.assertEqual( str(rule), "SdyShardingRule((('i', 'j', 'k'),), ((('i', 'j'), ('j', 'k')),), {})") def test_sharding_rule_factor_infer_k(self): - rule = SdyShardingRule("_i (j k)-> j foo (m bar_24)", k=10, m=10, bar_24=20) + rule = str_to_sdy_sharding_rule("i_ (j k)-> j foo (m bar_24)", k=10, m=10, bar_24=20) self.assertEqual( str(rule), - "SdyShardingRule((('_i', ('j', 'k')),), (('j', 'foo', ('m', 'bar_24'))" + "SdyShardingRule((('i_', ('j', 'k')),), (('j', 'foo', ('m', 'bar_24'))" ",), {'k': 10, 'm': 10, 'bar_24': 20})") @@ -189,11 +230,11 @@ def test_conversion_rule_op_mismatch_in_operands_num(self): results=[self.get_tensor_type((16, 32))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo")),) - rule = SdyShardingRule("i j-> i j") + rule = str_to_sdy_sharding_rule("i j-> i j") with self.assertRaisesRegex( ValueError, "Sharding rule has 1 operands, but the operation has 2 operands"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) @@ -205,12 +246,12 @@ def test_conversion_rule_op_mismatch_in_operands_rank(self): results=[self.get_tensor_type((16, 32))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo")),) - rule = SdyShardingRule("i j, i j k-> i j") + rule = str_to_sdy_sharding_rule("i j, i j k-> i j") with self.assertRaisesRegex( ValueError, "Sharding rule 1th operand has rank 3, but the operation 1th " "operand has rank 2"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) @@ -223,11 +264,11 @@ def test_conversion_rule_op_mismatch_in_results_num(self): operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo")),) - rule = SdyShardingRule("i j, i j -> i j, i j") + rule = str_to_sdy_sharding_rule("i j, i j -> i j, i j") with self.assertRaisesRegex( ValueError, "Sharding rule has 2 results, but the operation has 1 results"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) @@ -239,12 +280,12 @@ def test_conversion_rule_op_mismatch_in_results_dim(self): results=[self.get_tensor_type((16, 32))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("i j, i j -> i j k") + rule = str_to_sdy_sharding_rule("i j, i j -> i j k") with self.assertRaisesRegex( ValueError, "Sharding rule 0th result has rank 3, but the operation 0th " "result has rank 2"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) @@ -256,11 +297,11 @@ def test_conversion_factor_has_two_sizes(self): results=[self.get_tensor_type((16, 64))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("i j, i j -> i j") + rule = str_to_sdy_sharding_rule("i j, i j -> i j") with self.assertRaisesRegex( ValueError, "Factor j corresponds to two sizes: 32 and 64"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) @@ -272,14 +313,30 @@ def test_conversion_batching_dim_has_two_sizes(self): results=[self.get_tensor_type((16, 64))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("..., ... -> ...") + rule = str_to_sdy_sharding_rule("..., ... -> ...") with self.assertRaisesRegex( ValueError, "Batching dimension 1 corresponds to two sizes: 32 and 64"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,],) + def test_conversion_invalid_batching_dim(self): + opnd0 = self.create_tensor_value((16, 32)) + opnd1 = self.create_tensor_value((16, 32)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((16, 32))], + operands=[opnd0, opnd1,], + attributes=dict(call_target_name=ir.StringAttr.get("foo")),) + rule = str_to_sdy_sharding_rule("... i j k, ... i j k -> ... i j k") + with self.assertRaisesRegex( + ValueError, + "Sharding rule 0th operand has rank 3, but the operation 0th operand has rank 2"): + sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type, result.operands[1].type], + [result.result.type,]) + def test_conversion_compound_dimension_size_mismatch(self): opnd = self.create_tensor_value((2, 4)) result = ir.Operation.create( @@ -287,12 +344,12 @@ def test_conversion_compound_dimension_size_mismatch(self): results=[self.get_tensor_type((9,))], operands=[opnd,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("i j -> (i j)") + rule = str_to_sdy_sharding_rule("i j -> (i j)") with self.assertRaisesRegex( ValueError, "0th result actual size 9 doesn't match the size 8 derived from the" " compound factors"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type], [result.result.type,]) @@ -304,14 +361,29 @@ def test_conversion_elementwise_rule_mismatching_ellipsis_rank(self): results=[self.get_tensor_type((16, 32))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("..., ... -> ...") + rule = str_to_sdy_sharding_rule("..., ... -> ...") with self.assertRaisesRegex( ValueError, "Ellipsis represents different number of leading dimensions 2 and 1"): - rule.build( + sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) + def test_conversion_compound_then_individual(self): + opnd = self.create_tensor_value((8,)) + result = ir.Operation.create( + "stablehlo.custom_call", + results=[self.get_tensor_type((2,4))], + operands=[opnd,], + attributes=dict(call_target_name=ir.StringAttr.get("foo"))) + rule = str_to_sdy_sharding_rule("(i j) -> i j") + mlir_rule = sdy_sharding_rule_to_mlir(rule, + [result.operands[0].type], + [result.result.type,]) + self.assertEqual( + str(mlir_rule), + "#sdy.op_sharding_rule<([ij])->([i, j]) {i=2, j=4}>") + def test_conversion_elementwise_rule_scalar_instance(self): opnd0 = self.create_tensor_value(()) opnd1 = self.create_tensor_value(()) @@ -320,8 +392,8 @@ def test_conversion_elementwise_rule_scalar_instance(self): results=[self.get_tensor_type(())], operands=[opnd0, opnd1], attributes=dict(call_target_name=ir.StringAttr.get("foo")),) - rule = SdyShardingRule("..., ... -> ...") - mlir_rule = rule.build( + rule = str_to_sdy_sharding_rule("..., ... -> ...") + mlir_rule = sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) self.assertEqual( @@ -336,8 +408,8 @@ def test_conversion_elementwise_rule_2D_instance(self): results=[self.get_tensor_type((16, 32))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo")),) - rule = SdyShardingRule("..., ... -> ...") - mlir_rule = rule.build( + rule = str_to_sdy_sharding_rule("..., ... -> ...") + mlir_rule = sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) self.assertEqual( @@ -352,8 +424,8 @@ def test_conversion_vector_scalar_add_2D_instance(self): results=[self.get_tensor_type((16, 32))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo")),) - rule = SdyShardingRule("..., -> ...") - mlir_rule = rule.build( + rule = str_to_sdy_sharding_rule("..., -> ...") + mlir_rule = sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) self.assertEqual( @@ -367,8 +439,8 @@ def test_conversion_reshape_rule(self): results=[self.get_tensor_type((8,))], operands=[opnd0,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("i j -> (i j)") - mlir_rule = rule.build( + rule = str_to_sdy_sharding_rule("i j -> (i j)") + mlir_rule = sdy_sharding_rule_to_mlir(rule, [result.operands[0].type], [result.result.type,]) self.assertEqual( @@ -383,8 +455,8 @@ def test_conversion_contracting_dim_matmul(self): results=[self.get_tensor_type((16, 8))], operands=[opnd0, opnd1,], attributes=dict(call_target_name=ir.StringAttr.get("foo"))) - rule = SdyShardingRule("... contracting_dim, contracting_dim k -> ... k") - mlir_rule = rule.build( + rule = str_to_sdy_sharding_rule("... contracting_dim, contracting_dim k -> ... k") + mlir_rule = sdy_sharding_rule_to_mlir(rule, [result.operands[0].type, result.operands[1].type], [result.result.type,]) self.assertEqual(