-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] OpenXLA PJRT plugin #33
Conversation
Present proposal for OpenXLA PJRT plugin along with creation of new repo for its development.
Thanks for the writeup, Jacques! |
nice to have windows support. Some products need Windows implementation. I had mentioned a few times that Windows availability would make OpenXLA more attractive. |
I looked at this, and the plugin infra itself looks like it needs a bit of work for Windows compatibility still. But the plugin implementation is tested and deployed on Windows currently. So this doesn't seem too far to expect from were I see. |
What is the process here? Create an issue? |
Process for this RFC to progress or requesting windows support for the implementation generally? For the latter, I would create a dedicated issue for platform policy overall (vs asking component by component). |
Related: jax-ml/jax#438 I don't see any reason why the plugin infra and the openxla plugin can't support Windows. |
Which repository? JAX or ??? |
|
I would like to see something about how a plugin is tested. Having a reference test suite that checks a PJRT implementation is "correct" would be very useful. |
Agreed. Ime, the Jax test suite is a useful component of this, but it doesn't have sufficient coverage for more integration-heavy tests (ie. Multi device, exotic compilation modes, out of the ordinary data transfer scenarios). Probably makes sense to have an additional layer of testing, and then we call some union of test suites an interim CTS and work on unifying it in a next phase. |
What is exactly a |
It is likely not required for the plugin to compile any |
Would it be possible to comunicate with the plugin to understand if an input program could be compiled (or is it supported) without failing (fast) at "runtime" on compilation? I am asking about this with the scope of having a complete overview, or at least a preliminary check, on the PJRT plugin coverage over a specific program. I don't know if it make sense at this level but it was something that I've expected at the current framework bridge level and we still don't have it. So I don't know if the topic is still valid or not at this lower level with PJRT programs. |
That is a good suggestion. Ability to flag compatibility for a given plugin instance [there are potentially conservative and aggressive options too - e.g., an unsupported op that could be optimized away after a few rounds of optimizations]. We should keep that in mind for the compile API discussion. |
Ultimately is seems that it should be a StableHLO module in the OpenXLA architecture?
This ties pretty well with a discussion this morning in the open meeting: someone asked about dynamic shape support and how some platforms won't be able to support it. So as @jpienaar mentions above, we need to have this in mind when designing the APIs: "program capabilities" or "features" that aren't uniformly supported. There is a range of possibilities in the API design to express this... |
@jpienaar Yes exactly these arguments are all related to what I meant. |
+1, PJRT is not OpenXLA/XLA specific but within the context of OpenXLA this is the supported input format. |
Yes, I think having some of those tests would be good. A problem I've had with using the ML Frameworks tests for the IPU PJRT implementation is it's unclear what the "canonical" set of tests is. There are a lot of exceptions and exclusions for CPU, TPU, and GPU. |
If the |
We are going into the compiler API design and query functionality a bit, which is a separate discussion that needs to happen still. We can have multiple layers and expensiveness of tests. So the most simple one could be purely on op names, more advanced/expensive checks the attributes, more check the types too, then checking usage within a context (e.g., can I elide asserts?), then trying to run some initial optimizations to see. As a straw man one could add yes, no, maybe results for query "is supported model" and the cost of the query or amount of effort to try can be configured additionally - and trying to do a couple of optimizations would be on the expensive path. The more information the backend could provide the earlier and easier. For backends that are complete (which is a much more tractable target for HLOs!) one might not need to dig in too deep to get a result (exception is ops like scatter which are not a simpler ops). But there may be value in being able to give a very quick response even if conservative for some use cases, while in others conservative isn't useful. Larger discussion :) (I've seen multiple different attempts here and this will be design question in the compiler API work). |
It will be super useful also for CI jobs or other development activities if for an early check you will not need to retrieve/run the verification on the device specific resources. Then honestly we need something similar on Frameworks bridges for failures om generating StableHLO programs.. but this is another story. Also I don't know if we could have some complications with custom calls. |
This a braindump of things to consider will keep adding :
|
if (function_prt != nullptr) { | ||
PJRT_Api* api = function_prt(); | ||
plugin_name = parse(library_path); | ||
PluginInfo plugin_info(api, config_values); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to check if plugin_name is already loaded and if so, do an early return.
I wouldn't return an error. Allowing to load the same pluging many times isn't wrong.
This will also allows function_prt() to do all the initialization that it needs. (related to some open questions bellow)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion! Edited.
rfcs/20230123-pjrt-plugin.md
Outdated
PJRT TPU client is created. Shall we add another method InitializePlugin and | ||
only run it once? Alternatively, the plugin can implement it in | ||
`PJRT_Client_Create` and run it once in the first time a client was created. | ||
* Do we want to create PJRT clients for every plugin that is found? Will that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't do that automatically as this will increase some resource utilization even if the end user doesn't want to use it.
I would let frameworks decide which behavior they want. I wouldn't impose that decision at the PJRT level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Updated the text accordingly.
pjrt_client = create_pjrt_client(plugin_name) | ||
``` | ||
|
||
For TensorFlow, discovery will be added to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if 2 frameworks create to clients for the same device. Is this supported? If so, it would be great to specify it.
If this isn't supported, it will be harder to have in the same python script different frameworks. So supporting multiple clients would be great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is up to the specific hardware and plugin implementation. If the hardware only allows exclusive access, then the software will abide by that constraint. Otherwise, some plugins may use the hardware in an exclusive way (ie. Allocate all memory). The openxla plugin that we envision now will default to allowing multiple clients and supporting on demand memory allocation (with a possible option for more greedy access as an optimization for cases that need it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sound good. Thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is more to this than greedy access as an optimisation. A Graphcore IPU can only be owned by a single context on the host. So two processes, or indeed two clients in a single process, can't share an IPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that becomes a restriction to use of a Graphcore IPU then -- the PJRT API layer isn't going to do any kind of virtualization or remoting. If a software mechanism is needed to arbitrate multi-party access to a device, then that would be up to the device implementation.
Side note: this is currently an issue when using Jax by default with the XLA GPU backend as it allocates all memory, effectively making it impossible to share. There are environment variable workarounds to cause it to dynamically allocate.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that becomes a platform specific restriction, and I don't want a virtualisation layer. My concern is this becomes an implicit assumption in all users of this API.
The openxla plugin that we envision now will default to allowing multiple clients
My only point being that greedy or exclusive access isn't necessarily an optimisation, it can be a requirement (there's a reason in silicon for it).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see us doing anything in the software stack that makes multi-tenancy either harder or easier, but we will probably seek to make the default openxla implementation more user friendly on this front by default as it is a frequent pain point.
There really should be one Client per process, and if there can only be one Client per system then that would limit to only launching one process for that category of devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a paragraph to summarize this discussion thread.
|
I'm having trouble tracking the exact use case that is promoting the question (and can imagine different directions). Could you elaborate? |
For those watching from the sides, could someone point to a public
documentation on what is PjRT ? Thanks!
…On Wed, Feb 1, 2023 at 5:49 PM Frédéric Bastien ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In rfcs/20230123-pjrt-plugin.md
<#33 (comment)>:
> +* Call [load_pjrt_plugins](#heading=h.782ksg6rl5bj) when
+ [initializing backends](https://github.com/google/jax/blob/a66b3dcdd378b723e275a19258de826260b4c83e/jax/_src/lib/xla_bridge.py#L381).
+* Call [get_loaded_plugin_names](#heading=h.396bmv8gkskz) to get loaded PJRT
+ `plugin_name`, have some framework specific logics to decide whether to call
+ [create_pjrt_client](#heading=h.396bmv8gkskz) to create the PJRT client.
+
+```python
+def create_pjrt_clients():
+ loaded_plugin_names = get_loaded_plugin_names()
+ for plugin_name in loaded_plugin_names:
+ # Framework specific logics to decide whether to create
+ if should_create_pjrt_client(plugin_name):
+ pjrt_client = create_pjrt_client(plugin_name)
+```
+
+For TensorFlow, discovery will be added to
That sound good. Thanks.
—
Reply to this email directly, view it on GitHub
<#33 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABY5KE36KPNORGZFEFODVL3WVKH2DANCNFSM6AAAAAAUES333A>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
--
Jan Pfeifer
Research Engineer
Google Switzerland GmbH
@ : ***@***.***
@ : ***@***.***
T : +41 79 907 3855
|
Also what is the extended version of this acronym? |
I don't have a pointer to a documentation immediately, but think of PJRT right now as the public API for setting up XLA (other than building the input graph). See this basic example: https://github.com/openxla/xla/blob/main/xla/examples/axpy/stablehlo_compile_test.cc#L60-L79 The
d8c85bc19e29fdff0aa3d03065cdf79cef6c0fb9: |
I think the best on these are those listed at the bottom of the RFC and the headers here, let us know if too low level. As mentioned in intro PJRT is a device API that will provide an easy interface with which frameworks can integrate a packaged compiler and runtime solution.
From the header: 'PjRt stands for "Pretty much Just another RunTime"' (this makes a lot more sense if one considers the originally JAX expansion). |
It is important to support versioning of the program representation so that the plugin can check before processing the program. Do you consider versioning and what is the runtime behavior if the version doesn't exactly match? |
We are driving this towards StableHLO which is developing version constraints as part of its design. New implementations should use that. |
I'm sorry for the vague statement; I'm trying to jot down the key points so I can remember to bring them up when there is a higher bandwidth format. A small write-up for 7 is: For Grace Hopper + based systems, the NVL bandwidth is much higher than PCIe, opening the door for profitable offloading of tensors to the HBM. The next version of the runtime API would need to consider the programming model for such hybrid systems while tackling the traditional Host <-> GPU typical design. One example is if one is doing a "one layer at a time" style of computation for a massive model with a limited set of GPUs (model_size> All GPU memories) Current frameworks (e.g. JAX) don't have an ergonomic path for this hybrid memory model; assuming they do, it will likely need an abstraction at the runtime level too. |
Added a section to capture it. |
@jyingl3 Can you catch me up on comm channels for this work? I just got a basic CI going and found that there was some API drift around Executables/LoadedExecutables that I am adapting to. As you know, it is basically impossible to keep up with the noise in the TF repo, and without seeing the dev process, this isn't a great experience for collaborators. @theadactyl Any objection to using the IREE #pjrt-plugin channel for now to coordinate on API updates and such? Open to other options. |
Can one of the PJRT owners please advise on where to file bugs and how to reach the engineers. I have filed this one: openxla/xla#1237 but the engineers who work on this do not appear to be members of the repo. Thanks. |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This RFC is approved. To be clear, the scope of the RFC is creating the repo, not approving any particular design decision. I created this issue in the new repo so that the design feedback can be appropriately logged: https://github.com/openxla/openxla-pjrt-plugin/issues/3
Thanks for checking! We have finalized the communication channel:
We also just added a README to https://github.com/openxla/xla/tree/main/xla/pjrt/c about the communication channel and some resources. |
Request for comment OpenXLA PJRT plugin along with creation of new repo for its development.