Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch 2.2 breakage on bfloat16 and float16 #8

Open
proger opened this issue Mar 19, 2024 · 0 comments
Open

Torch 2.2 breakage on bfloat16 and float16 #8

proger opened this issue Mar 19, 2024 · 0 comments

Comments

@proger
Copy link
Owner

proger commented Mar 19, 2024

Running the triton implementation with torch 2.2 on inputs of type float16 and bfloat16 result in the following error:

  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/proger/accelerated-scan/accelerated_scan/triton.py", line 144, in <module>
    out = scan(gates, tokens)
  File "/home/proger/accelerated-scan/accelerated_scan/triton.py", line 129, in scan
    return Scan.apply(gates, tokens)
  File "/home/proger/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/proger/accelerated-scan/accelerated_scan/triton.py", line 87, in forward
    forward_scan[(B,C)](gates, tokens, states, SEQUENCE_LENGTH=T, enable_fp_fusion=False)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/runtime/jit.py", line 532, in run
    self.cache[device][key] = compile(
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 543, in compile
    next_module = compile_kernel(module)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/compiler.py", line 435, in <lambda>
    ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target))
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1228, in ast_to_ttir
    generator.visit(fn.parse())
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 303, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 376, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 298, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 428, in visit_Assign
    values = self.visit(node.value)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1021, in visit_Call
    return self.call_JitFunction(fn, args, kws)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 989, in call_JitFunction
    generator.visit(fn.parse())
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 303, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/usr/lib/python3.10/ast.py", line 426, in generic_visit
    self.visit(item)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 376, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 298, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1069, in visit_Expr
    ast.NodeVisitor.generic_visit(self, node)
  File "/usr/lib/python3.10/ast.py", line 428, in generic_visit
    self.visit(value)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1105, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.10/ast.py", line 418, in visit
    return visitor(node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1012, in visit_Call
    return static_implementation(self, node)
  File "/home/proger/.local/lib/python3.10/site-packages/triton/compiler/code_generator.py", line 1140, in execute_static_assert
    raise CompileTimeAssertionFailure(None, node, _unwrap_if_constexpr(message))
triton.compiler.errors.CompileTimeAssertionFailure: at 2:4:def forward_scan(
    gates,
    ^
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant