Skip to content

Commit

Permalink
Propagate names correctly.
Browse files Browse the repository at this point in the history
  • Loading branch information
LTLA committed Jan 19, 2024
1 parent 7fef2d0 commit 19f8638
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 28 deletions.
34 changes: 9 additions & 25 deletions lib/src/load_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,19 +120,14 @@ struct PythonStringVector : public uzuki2::StringVector, public PythonBase {
return bu.attr("StringList")(storage);
}
} else {
// Numpy arrays don't have direct support for names, so we
// just convert it into a dict.
pybind11::dict output;
for (size_t i = 0, end = storage.size(); i < end; ++i) {
output[names[i].c_str()] = storage[i];
}
return output;
pybind11::module bu = pybind11::module::import("biocutils");
using namespace pybind11::literals;
return bu.attr("StringList")(storage, "names"_a = names);
}
}

pybind11::list storage;
std::vector<std::string> names;
std::vector<size_t> missing;
pybind11::list names;
bool is_scalar;
};

Expand Down Expand Up @@ -160,28 +155,17 @@ struct PythonFactor : public uzuki2::Factor, public PythonBase {
}

pybind11::object extract() const {
if (names.empty()) {
pybind11::module bu = pybind11::module::import("biocutils");
using namespace pybind11::literals;
pybind11::module bu = pybind11::module::import("biocutils");
using namespace pybind11::literals;
if (names.size() == 0) {
return bu.attr("Factor")(storage, levels, "ordered"_a = ordered);

} else {
// Factor doesn't have direct support for names, so we
// just convert it into a dict.
pybind11::dict output;
for (size_t i = 0, end = storage.size(); i < end; ++i) {
if (storage.at(i) >= 0) {
output[names[i].c_str()] = levels[storage.at(i)];
} else {
output[names[i].c_str()] = pybind11::none();
}
}
return output;
return bu.attr("Factor")(storage, levels, "ordered"_a = ordered, "names"_a = names);
}
}

pybind11::array_t<int32_t> storage;
std::vector<std::string> names;
pybind11::list names;
bool is_scalar;
pybind11::list levels;
bool ordered;
Expand Down
14 changes: 11 additions & 3 deletions src/dolomite_base/save_simple_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _save_simple_list_recursive_stringlist(x: StringList, externals: list, handl
if handle is None:
output = { "type": "string", "values": x.as_list() }
if nms is not None:
output["names"] = nms
output["names"] = nms.as_list()
return output

has_none = any(y is None for y in x)
Expand All @@ -125,7 +125,7 @@ def _save_simple_list_recursive_stringlist(x: StringList, externals: list, handl
if has_none:
dset.attrs["missing-value-placeholder"] = placeholder
if nms is not None:
ut._save_fixed_length_strings(handle, "names", nms)
ut._save_fixed_length_strings(handle, "names", nms.as_list())

return

Expand Down Expand Up @@ -335,13 +335,18 @@ def _save_simple_list_recursive_numpy_generic(x: np.generic, externals: list, ha

@_save_simple_list_recursive.register
def _save_simple_list_recursive_factor(x: Factor, externals: list, handle):
nms = x.get_names()

if handle is None:
return {
output = {
"type": "factor",
"values": [(None if y == -1 else int(y)) for y in x.get_codes()],
"levels": x.get_levels().as_list(),
"ordered": x.get_ordered(),
}
if not nms is None:
output["names"] = nms.as_list()
return output

else:
handle.attrs["uzuki_object"] = "vector"
Expand All @@ -354,6 +359,9 @@ def _save_simple_list_recursive_factor(x: Factor, externals: list, handle):
ut._save_fixed_length_strings(handle, "levels", x.get_levels().as_list())
if x.get_ordered():
handle.create_dataset("ordered", data=x.get_ordered(), dtype="i1")

if not nms is None:
ut._save_fixed_length_strings(handle, "names", nms.as_list())
return


Expand Down
28 changes: 28 additions & 0 deletions tests/test_simple_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,3 +318,31 @@ def test_simple_list_factor():
assert list(roundtrip["missing"]) == list(everything["missing"])
assert list(roundtrip["ordered"]) == list(everything["ordered"])
assert roundtrip["ordered"].get_ordered()


def test_simple_list_named():
everything = {
"factor": Factor.from_sequence([ "sydney", "brisbane", "sydney", "melbourne"]),
"string": StringList(["Aria", "Akari", "Akira", "Aika"], names=["1", "2", "3", "4"])
}
everything["factor"].set_names(["A", "B", "C", "D"], in_place=True) # TODO: enable this in the constructor.

# Stage as JSON.
dir = os.path.join(mkdtemp(), "json")
meta = dl.save_object(everything, dir, simple_list_mode="json")

roundtrip = dl.read_object(dir)
assert list(roundtrip["factor"]) == list(everything["factor"])
assert roundtrip["factor"].get_names() == everything["factor"].get_names()
assert list(roundtrip["string"]) == list(everything["string"])
assert roundtrip["string"].get_names() == everything["string"].get_names()

# Stage as HDF5.
dir = os.path.join(mkdtemp(), "hdf5")
meta = dl.save_object(everything, dir, simple_list_mode="hdf5")

roundtrip = dl.read_object(dir)
assert list(roundtrip["factor"]) == list(everything["factor"])
assert roundtrip["factor"].get_names() == everything["factor"].get_names()
assert list(roundtrip["string"]) == list(everything["string"])
assert roundtrip["string"].get_names() == everything["string"].get_names()

0 comments on commit 19f8638

Please sign in to comment.