Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Computation Hash] Introduce deterministic hash for user computations #8539

Merged
merged 2 commits into from
Jan 10, 2025

Conversation

rpsilva-aws
Copy link
Contributor

Fixes #8537

@rpsilva-aws rpsilva-aws marked this pull request as ready for review January 7, 2025 01:33
@rpsilva-aws
Copy link
Contributor Author

@tengyifei

@rpsilva-aws
Copy link
Contributor Author

rpsilva-aws commented Jan 7, 2025

Test results without the deterministic serialization (instead, relying on the former SerializeAsString()):

[ RUN      ] XlaUtilTest.TestDeterministicComputationSerialization
torch_xla/csrc/runtime/xla_util_test.cc:281: Failure
Expected equality of these values:
  hash1
    Which is: 43931100196028486903611743554166252076
  hash2
    Which is: 6626235922799979895316908395105923211
Hashes should match regardless of the frontend attribute ordering
[  FAILED  ] XlaUtilTest.TestDeterministicComputationSerialization (0 ms)
[----------] 5 tests from XlaUtilTest (1 ms total)

So hash1 and hash2 differ, through it is serializing the exact same HLO module proto.

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_computation_hash branch 2 times, most recently from 5e406d7 to 4372f2e Compare January 7, 2025 06:52
@tengyifei tengyifei self-requested a review January 8, 2025 22:02
@tengyifei
Copy link
Collaborator

LGTM however the tests fail.

From what I can tell, both failed tests involve an XlaComputation. The test_conditional feeds XlaComputation into an HLO Cond op. The scan feeds an XlaComputation into an HLO While op. So probably the hash is broken for ComputationPtr in a way that causes different graphs to hash into the same value.

Copy link
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM however the tests fail.

From what I can tell, both failed tests involve an XlaComputation. The test_conditional feeds XlaComputation into an HLO Cond op. The scan feeds an XlaComputation into an HLO While op. So probably the hash is broken for ComputationPtr in a way that causes different graphs to hash into the same value.

@rpsilva-aws
Copy link
Contributor Author

Indeed, looking. A first look shows that the resulting hash for both user computations are the same, though the computation IRs differ in the constant parameter:

  • 0.9 case:
HloModule test_conditional.18, entry_computation_layout={(f32[], f32[2,2]{1,0}, f32[2,2]{1,0})->f32[2,2]{1,0}}

%CondTrue.7 (p0.8: (f32[2,2], f32[2,2])) -> f32[2,2] {
  %p0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
  %get-tuple-element.9 = f32[2,2]{1,0} get-tuple-element((f32[2,2]{1,0}, f32[2,2]{1,0}) %p0.8), index=0
  %get-tuple-element.10 = f32[2,2]{1,0} get-tuple-element((f32[2,2]{1,0}, f32[2,2]{1,0}) %p0.8), index=1
  ROOT %add.11 = f32[2,2]{1,0} add(f32[2,2]{1,0} %get-tuple-element.9, f32[2,2]{1,0} %get-tuple-element.10)
}

%CondFalse.12 (p0.13: (f32[2,2], f32[2,2])) -> f32[2,2] {
  %p0.13 = (f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
  %get-tuple-element.14 = f32[2,2]{1,0} get-tuple-element((f32[2,2]{1,0}, f32[2,2]{1,0}) %p0.13), index=0
  %get-tuple-element.15 = f32[2,2]{1,0} get-tuple-element((f32[2,2]{1,0}, f32[2,2]{1,0}) %p0.13), index=1
  ROOT %subtract.16 = f32[2,2]{1,0} subtract(f32[2,2]{1,0} %get-tuple-element.14, f32[2,2]{1,0} %get-tuple-element.15)
}

ENTRY %test_conditional.18 (p0.1: f32[], p1.2: f32[2,2], p2.3: f32[2,2]) -> f32[2,2] {
  %p0.1 = f32[] parameter(0)
  %constant.4 = f32[] constant(0.9)
  %compare.5 = pred[] compare(f32[] %p0.1, f32[] %constant.4), direction=GT
  %p1.2 = f32[2,2]{1,0} parameter(1)
  %p2.3 = f32[2,2]{1,0} parameter(2)
  %tuple.6 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %p1.2, f32[2,2]{1,0} %p2.3)
  ROOT %conditional.17 = f32[2,2]{1,0} conditional(pred[] %compare.5, (f32[2,2]{1,0}, f32[2,2]{1,0}) %tuple.6, (f32[2,2]{1,0}, f32[2,2]{1,0}) %tuple.6), true_computation=%CondTrue.7, false_computation=%CondFalse.12
}

  • 0.1 case:
HloModule test_conditional.18, entry_computation_layout={(f32[], f32[2,2]{1,0}, f32[2,2]{1,0})->f32[2,2]{1,0}}

%CondTrue.7 (p0.8: (f32[2,2], f32[2,2])) -> f32[2,2] {
  %p0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
  %get-tuple-element.9 = f32[2,2]{1,0} get-tuple-element((f32[2,2]{1,0}, f32[2,2]{1,0}) %p0.8), index=0
  %get-tuple-element.10 = f32[2,2]{1,0} get-tuple-element((f32[2,2]{1,0}, f32[2,2]{1,0}) %p0.8), index=1
  ROOT %add.11 = f32[2,2]{1,0} add(f32[2,2]{1,0} %get-tuple-element.9, f32[2,2]{1,0} %get-tuple-element.10)
}

%CondFalse.12 (p0.13: (f32[2,2], f32[2,2])) -> f32[2,2] {
  %p0.13 = (f32[2,2]{1,0}, f32[2,2]{1,0}) parameter(0)
  %get-tuple-element.14 = f32[2,2]{1,0} get-tuple-element((f32[2,2]{1,0}, f32[2,2]{1,0}) %p0.13), index=0
  %get-tuple-element.15 = f32[2,2]{1,0} get-tuple-element((f32[2,2]{1,0}, f32[2,2]{1,0}) %p0.13), index=1
  ROOT %subtract.16 = f32[2,2]{1,0} subtract(f32[2,2]{1,0} %get-tuple-element.14, f32[2,2]{1,0} %get-tuple-element.15)
}

ENTRY %test_conditional.18 (p0.1: f32[], p1.2: f32[2,2], p2.3: f32[2,2]) -> f32[2,2] {
  %p0.1 = f32[] parameter(0)
  %constant.4 = f32[] constant(0.1)
  %compare.5 = pred[] compare(f32[] %p0.1, f32[] %constant.4), direction=GT
  %p1.2 = f32[2,2]{1,0} parameter(1)
  %p2.3 = f32[2,2]{1,0} parameter(2)
  %tuple.6 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %p1.2, f32[2,2]{1,0} %p2.3)
  ROOT %conditional.17 = f32[2,2]{1,0} conditional(pred[] %compare.5, (f32[2,2]{1,0}, f32[2,2]{1,0}) %tuple.6, (f32[2,2]{1,0}, f32[2,2]{1,0}) %tuple.6), true_computation=%CondTrue.7, false_computation=%CondFalse.12
}

I'll see what is happening with the hash here.

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_computation_hash branch from 50ed7c5 to a8b5b31 Compare January 9, 2025 01:55
@rpsilva-aws
Copy link
Contributor Author

Thanks for capturing it @tengyifei. I was using the computation object arg after move semantics.

@rpsilva-aws
Copy link
Contributor Author

@tengyifei tengyifei self-requested a review January 9, 2025 22:06
@bhavya01
Copy link
Collaborator

bhavya01 commented Jan 9, 2025

@bhavya01, it seems that the TPU CI has an issue? Seeing the same failure on other runs:

Yifei reverted the breaking change #8547. I expect it to pass now

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_computation_hash branch from a8b5b31 to 72f8c27 Compare January 9, 2025 23:43
@tengyifei tengyifei merged commit 196cab3 into pytorch:master Jan 10, 2025
12 checks passed
@rpsilva-aws rpsilva-aws deleted the rpsilva_computation_hash branch January 10, 2025 05:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Computation Hash] User Computation hash disregards protobuf requirements
4 participants