Skip to content

Commit

Permalink
[FIX] Disable the scaled dot product attention fusion in ATSM platfor…
Browse files Browse the repository at this point in the history
…m. (#2458)
  • Loading branch information
yisonzhu authored Oct 25, 2023
1 parent a28e0c6 commit 2ba07bb
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
11 changes: 11 additions & 0 deletions itex/core/graph/remapper/mha_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "itex/core/graph/config_util.h"
#include "itex/core/graph/remapper/constant_names.h"
#include "itex/core/graph/remapper/fusion.h"
#include "itex/core/graph/remapper/remapper.h"
Expand Down Expand Up @@ -85,6 +86,11 @@ class MHAFusionWithReshapeMatmul : public Fusion {
MatchedProperties ret =
FillProperties(&graph_view, graph_view.GetNode(node_index), pattern_);

#ifndef INTEL_CPU_ONLY
if (!isxehpc_value) {
return ret.ToEmpty();
}
#endif
bool is_ok = !ret.Empty() && CheckShapes(ctx, ret);

if (!is_ok) return ret.ToEmpty();
Expand Down Expand Up @@ -267,6 +273,11 @@ class MHAPatternWithMulAndAdd : public Fusion {
MatchedProperties ret = FillProperties(
&graph_view, graph_view.GetNode(node_index), pattern_, false);

#ifndef INTEL_CPU_ONLY
if (!isxehpc_value) {
return ret.ToEmpty();
}
#endif
bool is_ok = !ret.Empty() && CheckShapes(ctx, ret);

if (!is_ok) return ret.ToEmpty();
Expand Down
16 changes: 8 additions & 8 deletions test/pattern/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,21 +14,17 @@
# ==============================================================================


import os

import numpy as np
import tensorflow as tf
from keras import backend as K

from tensorflow import keras
from tensorflow.python.framework import test_util
from tensorflow.python.framework import config
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn_ops
import time
import os
import subprocess
import sys
from intel_extension_for_tensorflow.python.device import is_xehpc

tf.compat.v1.disable_eager_execution()

Expand Down Expand Up @@ -81,6 +77,8 @@ def td_dot(self, a, b):
class MHAFusionWithReshapeMatmulTest(test_util.TensorFlowTestCase):

def testMHAFusionWithReshapeMatmul(self):
if config.list_logical_devices('XPU') and not is_xehpc():
self.skipTest("Only xehpc support xetla.")
tf.random.set_seed(0)
datatypes = [tf.float32, tf.bfloat16]
if config.list_logical_devices('XPU'):
Expand Down Expand Up @@ -116,6 +114,8 @@ def testMHAFusionWithReshapeMatmul(self):
class MHAPatternWithMulAndAddTest(test_util.TensorFlowTestCase):

def testMHAPatternWithMulAndAdd(self):
if config.list_logical_devices('XPU') and not is_xehpc():
self.skipTest("Only xehpc support xetla.")
tf.random.set_seed(0)
datatypes = [tf.float32, tf.bfloat16]
if config.list_logical_devices('XPU'):
Expand Down

0 comments on commit 2ba07bb

Please sign in to comment.