From 19f863840ff4489b4f3f27b1fd72421cfb647249 Mon Sep 17 00:00:00 2001 From: LTLA Date: Thu, 18 Jan 2024 18:49:51 -0800 Subject: [PATCH] Propagate names correctly. --- lib/src/load_list.cpp | 34 +++++++-------------------- src/dolomite_base/save_simple_list.py | 14 ++++++++--- tests/test_simple_list.py | 28 ++++++++++++++++++++++ 3 files changed, 48 insertions(+), 28 deletions(-) diff --git a/lib/src/load_list.cpp b/lib/src/load_list.cpp index 18bb8f6..c603537 100644 --- a/lib/src/load_list.cpp +++ b/lib/src/load_list.cpp @@ -120,19 +120,14 @@ struct PythonStringVector : public uzuki2::StringVector, public PythonBase { return bu.attr("StringList")(storage); } } else { - // Numpy arrays don't have direct support for names, so we - // just convert it into a dict. - pybind11::dict output; - for (size_t i = 0, end = storage.size(); i < end; ++i) { - output[names[i].c_str()] = storage[i]; - } - return output; + pybind11::module bu = pybind11::module::import("biocutils"); + using namespace pybind11::literals; + return bu.attr("StringList")(storage, "names"_a = names); } } pybind11::list storage; - std::vector names; - std::vector missing; + pybind11::list names; bool is_scalar; }; @@ -160,28 +155,17 @@ struct PythonFactor : public uzuki2::Factor, public PythonBase { } pybind11::object extract() const { - if (names.empty()) { - pybind11::module bu = pybind11::module::import("biocutils"); - using namespace pybind11::literals; + pybind11::module bu = pybind11::module::import("biocutils"); + using namespace pybind11::literals; + if (names.size() == 0) { return bu.attr("Factor")(storage, levels, "ordered"_a = ordered); - } else { - // Factor doesn't have direct support for names, so we - // just convert it into a dict. - pybind11::dict output; - for (size_t i = 0, end = storage.size(); i < end; ++i) { - if (storage.at(i) >= 0) { - output[names[i].c_str()] = levels[storage.at(i)]; - } else { - output[names[i].c_str()] = pybind11::none(); - } - } - return output; + return bu.attr("Factor")(storage, levels, "ordered"_a = ordered, "names"_a = names); } } pybind11::array_t storage; - std::vector names; + pybind11::list names; bool is_scalar; pybind11::list levels; bool ordered; diff --git a/src/dolomite_base/save_simple_list.py b/src/dolomite_base/save_simple_list.py index 1448ea3..0ff2b53 100644 --- a/src/dolomite_base/save_simple_list.py +++ b/src/dolomite_base/save_simple_list.py @@ -111,7 +111,7 @@ def _save_simple_list_recursive_stringlist(x: StringList, externals: list, handl if handle is None: output = { "type": "string", "values": x.as_list() } if nms is not None: - output["names"] = nms + output["names"] = nms.as_list() return output has_none = any(y is None for y in x) @@ -125,7 +125,7 @@ def _save_simple_list_recursive_stringlist(x: StringList, externals: list, handl if has_none: dset.attrs["missing-value-placeholder"] = placeholder if nms is not None: - ut._save_fixed_length_strings(handle, "names", nms) + ut._save_fixed_length_strings(handle, "names", nms.as_list()) return @@ -335,13 +335,18 @@ def _save_simple_list_recursive_numpy_generic(x: np.generic, externals: list, ha @_save_simple_list_recursive.register def _save_simple_list_recursive_factor(x: Factor, externals: list, handle): + nms = x.get_names() + if handle is None: - return { + output = { "type": "factor", "values": [(None if y == -1 else int(y)) for y in x.get_codes()], "levels": x.get_levels().as_list(), "ordered": x.get_ordered(), } + if not nms is None: + output["names"] = nms.as_list() + return output else: handle.attrs["uzuki_object"] = "vector" @@ -354,6 +359,9 @@ def _save_simple_list_recursive_factor(x: Factor, externals: list, handle): ut._save_fixed_length_strings(handle, "levels", x.get_levels().as_list()) if x.get_ordered(): handle.create_dataset("ordered", data=x.get_ordered(), dtype="i1") + + if not nms is None: + ut._save_fixed_length_strings(handle, "names", nms.as_list()) return diff --git a/tests/test_simple_list.py b/tests/test_simple_list.py index a570c96..1e2ce6f 100644 --- a/tests/test_simple_list.py +++ b/tests/test_simple_list.py @@ -318,3 +318,31 @@ def test_simple_list_factor(): assert list(roundtrip["missing"]) == list(everything["missing"]) assert list(roundtrip["ordered"]) == list(everything["ordered"]) assert roundtrip["ordered"].get_ordered() + + +def test_simple_list_named(): + everything = { + "factor": Factor.from_sequence([ "sydney", "brisbane", "sydney", "melbourne"]), + "string": StringList(["Aria", "Akari", "Akira", "Aika"], names=["1", "2", "3", "4"]) + } + everything["factor"].set_names(["A", "B", "C", "D"], in_place=True) # TODO: enable this in the constructor. + + # Stage as JSON. + dir = os.path.join(mkdtemp(), "json") + meta = dl.save_object(everything, dir, simple_list_mode="json") + + roundtrip = dl.read_object(dir) + assert list(roundtrip["factor"]) == list(everything["factor"]) + assert roundtrip["factor"].get_names() == everything["factor"].get_names() + assert list(roundtrip["string"]) == list(everything["string"]) + assert roundtrip["string"].get_names() == everything["string"].get_names() + + # Stage as HDF5. + dir = os.path.join(mkdtemp(), "hdf5") + meta = dl.save_object(everything, dir, simple_list_mode="hdf5") + + roundtrip = dl.read_object(dir) + assert list(roundtrip["factor"]) == list(everything["factor"]) + assert roundtrip["factor"].get_names() == everything["factor"].get_names() + assert list(roundtrip["string"]) == list(everything["string"]) + assert roundtrip["string"].get_names() == everything["string"].get_names()