Skip to content

Commit

Permalink
#2105 Fix temporary variables with wrong type
Browse files Browse the repository at this point in the history
  • Loading branch information
sergisiso committed Dec 21, 2023
1 parent 6b816d0 commit 9a0f120
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from psyclone.psyir.nodes import (
Assignment, Reference, ArrayReference, IfBlock, Loop,
IntrinsicCall, Node, UnaryOperation, BinaryOperation)
from psyclone.psyir.symbols import ArrayType, DataSymbol
from psyclone.psyir.symbols import ArrayType, DataSymbol, ScalarType
from psyclone.psyGen import Transformation
from psyclone.psyir.transformations.reference2arrayrange_trans import \
Reference2ArrayRangeTrans
Expand Down Expand Up @@ -170,6 +170,19 @@ def validate(self, node, options=None):
f"Unexpected shape for array. Expecting one of "
f"Deferred, Attribute or Bounds but found '{shape}'.")

# If the lhs symbol is used anywhere on the assignment rhs, we need
# to create a temporary, and for this we need to resolve its datatype
for rhs_reference in assignment.rhs.walk(Reference):
if rhs_reference.symbol is assignment.lhs.symbol:
if not (isinstance(assignment.lhs.symbol, DataSymbol) and
isinstance(assignment.lhs.datatype, ScalarType)):
line = assignment.debug_string().strip('\n')
raise TransformationError(
f"To loopify '{line}'"
f" we need a temporary variable, but the type of "
f"'{assignment.lhs.debug_string()}' can not be "
f"resolved or is unsupported.")

# pylint: disable=too-many-locals
def apply(self, node, options=None):
'''Apply the array-reduction intrinsic conversion transformation to
Expand Down Expand Up @@ -199,7 +212,7 @@ def apply(self, node, options=None):
if increment:
new_lhs_symbol = node.scope.symbol_table.new_symbol(
root_name="tmp_var", symbol_type=DataSymbol,
datatype=lhs_symbol.datatype)
datatype=orig_lhs.datatype)
new_lhs = Reference(new_lhs_symbol)
else:
new_lhs = orig_lhs.copy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,27 @@ def test_not_assignment(fortran_reader):
"of an Assignment" in str(info.value))


def test_validate_increment_with_unsupported_type(fortran_reader):
'''Check that the expected error is produced when the resulting code
needs a temporary variable but the lhs type can not be resolved.
'''
code = (
"subroutine test()\n"
"use othermod\n"
"real :: a(10)\n"
"x(1) = x(1) + maxval(a)\n"
"end subroutine\n")
psyir = fortran_reader.psyir_from_source(code)
trans = Maxval2LoopTrans()
node = psyir.walk(IntrinsicCall)[0]
with pytest.raises(TransformationError) as info:
trans.apply(node)
assert ("To loopify 'x(1) = x(1) + MAXVAL(a)' we need a temporary "
"variable, but the type of 'x(1)' can not be resolved or is "
"unsupported." in str(info.value))


# apply

@pytest.mark.parametrize("idim1,idim2,rdim11,rdim12,rdim21,rdim22",
Expand Down Expand Up @@ -648,10 +669,38 @@ def test_increment(fortran_reader, fortran_writer, tmpdir):
assert Compile(tmpdir).string_compiles(result)


def test_reduce_to_struct_and_array_accessors(fortran_reader, fortran_writer,
tmpdir):
def test_increment_with_accessor(fortran_reader, fortran_writer, tmpdir):
'''Check that the expected code is produced when the variable being
assigned to is an increment e.g. x = x + ...
assigned needs a temporary variable that is not the same type as the
lhs symbol because some accessor expression is used.
'''
code = (
"subroutine test()\n"
"real :: a(10)\n"
"real, dimension(1) :: x\n"
"x(1) = x(1) + maxval(a)\n"
"end subroutine\n")
expected_decl = "real :: tmp_var"
expected = (
" tmp_var = -HUGE(tmp_var)\n"
" do idx = 1, 10, 1\n"
" tmp_var = MAX(tmp_var, a(idx))\n"
" enddo\n"
" x(1) = x(1) + tmp_var\n")
psyir = fortran_reader.psyir_from_source(code)
trans = Maxval2LoopTrans()
node = psyir.walk(IntrinsicCall)[0]
trans.apply(node)
result = fortran_writer(psyir)
assert expected_decl in result
assert expected in result
assert Compile(tmpdir).string_compiles(result)


def test_reduce_to_struct_and_array_accessors(fortran_reader, fortran_writer):
'''Check that the expected code is produced when the variable being
assigned to is has array and structure accessors.
'''
code = (
Expand All @@ -671,7 +720,6 @@ def test_reduce_to_struct_and_array_accessors(fortran_reader, fortran_writer,
trans.apply(node)
result = fortran_writer(psyir)
assert expected in result
assert Compile(tmpdir).string_compiles(result)


def test_range2loop_fails(fortran_reader, fortran_writer):
Expand Down

0 comments on commit 9a0f120

Please sign in to comment.