Skip to content

Commit

Permalink
remove the support of pyarrow (#1151)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomorrowIsAnOtherDay authored May 17, 2024
1 parent 39d210b commit a779f51
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 72 deletions.
15 changes: 0 additions & 15 deletions parl/remote/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,21 +241,6 @@ def check_env_consistency(self):
to_str(message[1]), to_str(message[2]), to_str(message[3]),
client_parl_version, client_python_version_major, client_python_version_minor
)
client_pyarrow_version = str(get_version('pyarrow'))
master_pyarrow_version = to_str(message[4])
if client_pyarrow_version != master_pyarrow_version:
if master_pyarrow_version == 'None':
error_message = """"pyarrow" is provided in your current environment, however, it is not \
found in "master"'s environment. To use "pyarrow" for serialization, please install \
"pyarrow={}" in "master"'s environment!""".format(client_pyarrow_version)
elif client_pyarrow_version == 'None':
error_message = """"pyarrow" is provided in "master"'s environment, however, it is not \
found in your current environment. To use "pyarrow" for serialization, please install \
"pyarrow={}" in your current environment!""".format(master_pyarrow_version)
else:
error_message = '''Version mismatch: the 'master' is of version 'pyarrow={}'. However, \
'pyarrow={}'is provided in your current environment.'''.format(master_pyarrow_version, client_pyarrow_version)
raise Exception(error_message)
else:
raise NotImplementedError

Expand Down
37 changes: 2 additions & 35 deletions parl/remote/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,8 @@

__all__ = ['dumps_argument', 'loads_argument', 'dumps_return', 'loads_return']

try:
import pyarrow
pyarrow_installed = True
except ImportError:
pyarrow_installed = False

if pyarrow_installed:
# Reference: https://github.com/apache/arrow/blob/f88474c84e7f02e226eb4cc32afef5e2bbc6e5b4/python/pyarrow/tests/test_serialization.py#L658-L682
def _serialize_serializable(obj):
return {"type": type(obj), "data": obj.__dict__}

def _deserialize_serializable(obj):
val = obj["type"].__new__(obj["type"])
val.__dict__.update(obj["data"])
return val

context = pyarrow.default_serialization_context()

# support deserialize in another environment
context.set_pickle(cloudpickle.dumps, cloudpickle.loads)

# support serialize and deserialize custom class
context.register_type(
object,
"object",
custom_serializer=_serialize_serializable,
custom_deserializer=_deserialize_serializable)

# if pyarrow is installed, parl will use pyarrow to serialize/deserialize objects.
serialize = lambda data: pyarrow.serialize(data, context=context).to_buffer()
deserialize = lambda data: pyarrow.deserialize(data, context=context)
else:
# if pyarrow is not installed, parl will use cloudpickle to serialize/deserialize objects.
serialize = lambda data: cloudpickle.dumps(data)
deserialize = lambda data: cloudpickle.loads(data)
serialize = lambda data: cloudpickle.dumps(data)
deserialize = lambda data: cloudpickle.loads(data)


def dumps_argument(*args, **kwargs):
Expand Down
3 changes: 1 addition & 2 deletions parl/remote/master.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ def heartbeat_exit_callback_func(client_heartbeat_address):
remote_constants.NORMAL_TAG,
to_byte(parl.__version__),
to_byte(str(sys.version_info.major)),
to_byte(str(sys.version_info.minor)),
to_byte(str(get_version('pyarrow')))
to_byte(str(sys.version_info.minor))
])

# a client submits a job to the master
Expand Down
15 changes: 0 additions & 15 deletions parl/remote/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,21 +150,6 @@ def check_env_consistency(self):
to_str(message[1]), to_str(message[2]), to_str(message[3]),
worker_parl_version, worker_python_version_major, worker_python_version_minor
)
worker_pyarrow_version = str(get_version('pyarrow'))
master_pyarrow_version = to_str(message[4])
if worker_pyarrow_version != master_pyarrow_version:
if master_pyarrow_version == 'None':
error_message = """"pyarrow" is provided in your current environment, however, it is not \
found in "master"'s environment. To use "pyarrow" for serialization, please install \
"pyarrow={}" in "master"'s environment!""".format(worker_pyarrow_version)
elif worker_pyarrow_version == 'None':
error_message = """"pyarrow" is provided in "master"'s environment, however, it is not \
found in your current environment. To use "pyarrow" for serialization, please install \
"pyarrow={}" in your current environment!""".format(master_pyarrow_version)
else:
error_message = '''Version mismatch: the 'master' is of version 'pyarrow={}'. However, \
'pyarrow={}'is provided in your current environment.'''.format(master_pyarrow_version, worker_pyarrow_version)
raise Exception(error_message)
else:
raise NotImplementedError

Expand Down
9 changes: 4 additions & 5 deletions parl/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


class UtilsError(Exception):
"""
Super class of exceptions in utils module.
Expand All @@ -24,13 +23,13 @@ def __init__(self, error_info):

class SerializeError(UtilsError):
"""
Serialize error raised by pyarrow.
Serialize error raised by the serialization library.
"""

def __init__(self, error_info):
error_info = (
'Serialize error, you may have provided an object that cannot be '
+ 'serialized by pyarrow. Detailed error:\n{}'.format(error_info))
+ 'serialized by the serialization library. Detailed error:\n{}'.format(error_info))
super(SerializeError, self).__init__(error_info)

def __str__(self):
Expand All @@ -39,14 +38,14 @@ def __str__(self):

class DeserializeError(UtilsError):
"""
Deserialize error raised by pyarrow.
Deserialize error raised by the serialization library.
"""

def __init__(self, error_info):
error_info = (
'Deserialize error, you may have provided an object that cannot be '
+
'deserialized by pyarrow. Detailed error:\n{}'.format(error_info))
'deserialized by the serialization library. Detailed error:\n{}'.format(error_info))
super(DeserializeError, self).__init__(error_info)

def __str__(self):
Expand Down

0 comments on commit a779f51

Please sign in to comment.