From a3876ba6442b463031dbaa6d71229d2fe123c8e5 Mon Sep 17 00:00:00 2001 From: Martin Date: Sat, 20 Jul 2024 21:11:05 +0200 Subject: [PATCH] Fix errors in deeply nested models implementation --- .gitignore | 4 +- CHANGELOG.md | 11 + README.md | 55 ++-- pydantic_redis/_shared/lua_scripts.py | 288 ++++++++----------- pydantic_redis/_shared/model/base.py | 42 ++- pydantic_redis/_shared/model/prop_utils.py | 7 +- pydantic_redis/_shared/model/select_utils.py | 72 +++-- pydantic_redis/_shared/utils.py | 27 +- test/conftest.py | 4 +- test/test_pydantic_redis.py | 35 ++- 10 files changed, 314 insertions(+), 231 deletions(-) diff --git a/.gitignore b/.gitignore index fb3d7f30..18f6d6c5 100644 --- a/.gitignore +++ b/.gitignore @@ -145,4 +145,6 @@ cython_debug/ env3.7/ /lua_scripts/ -.DS_Store \ No newline at end of file +.DS_Store + +*.lua \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 36c0476c..3440ecee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,17 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] +### Added + +### Changed + +- Added support for deeply nested models beyond level-1 deep including: + - dictionaries of lists of ... of nested models + - lists of tuples of lists .... of nested models + +### Fixed + + ## [0.6.0] - 2024-07-01 ### Added diff --git a/README.md b/README.md index 15094d58..6a459927 100644 --- a/README.md +++ b/README.md @@ -88,36 +88,35 @@ benchmark_bulk_insert[redis_store] 721.2247 (6.19) 6 --------------------------------------------------------------------------------------------------------------------- ``` -# >=v0.7 (with fully-fledged nested models) +# >=v0.7 (with deeply nested models) ``` ---------------------------------------------------- benchmark: 22 tests ---------------------------------------------------- -Name (time in us) Mean Min Max ----------------------------------------------------------------------------------------------------------------------------- -test_benchmark_delete[redis_store-Wuthering Heights] 124.5440 (1.01) 109.3710 (1.0) 579.7810 (1.39) -test_benchmark_bulk_delete[redis_store] 122.9285 (1.0) 113.7120 (1.04) 492.2730 (1.18) -test_benchmark_select_columns_for_one_id[redis_store-book1] 182.3891 (1.48) 154.9150 (1.42) 441.2820 (1.06) -test_benchmark_select_columns_for_one_id[redis_store-book2] 183.2679 (1.49) 156.6830 (1.43) 462.6000 (1.11) -test_benchmark_select_columns_for_one_id[redis_store-book0] 181.6972 (1.48) 157.2330 (1.44) 459.2930 (1.10) -test_benchmark_select_columns_for_one_id[redis_store-book3] 183.0834 (1.49) 160.1250 (1.46) 416.8570 (1.0) -test_benchmark_select_all_for_one_id[redis_store-book1] 203.9491 (1.66) 183.3080 (1.68) 469.4700 (1.13) -test_benchmark_select_all_for_one_id[redis_store-book2] 206.7124 (1.68) 184.1920 (1.68) 490.6700 (1.18) -test_benchmark_select_all_for_one_id[redis_store-book0] 207.3341 (1.69) 184.2210 (1.68) 443.9260 (1.06) -test_benchmark_select_all_for_one_id[redis_store-book3] 210.6874 (1.71) 185.0600 (1.69) 696.9330 (1.67) -test_benchmark_select_columns_for_some_items[redis_store] 236.5783 (1.92) 215.7490 (1.97) 496.0540 (1.19) -test_benchmark_select_columns_paginated[redis_store] 248.5335 (2.02) 218.3450 (2.00) 522.1270 (1.25) -test_benchmark_update[redis_store-Wuthering Heights-data0] 282.1803 (2.30) 239.5410 (2.19) 541.5220 (1.30) -test_benchmark_select_some_items[redis_store] 298.2036 (2.43) 264.0860 (2.41) 599.3010 (1.44) -test_benchmark_single_insert[redis_store-book0] 316.0245 (2.57) 269.8110 (2.47) 596.0940 (1.43) -test_benchmark_single_insert[redis_store-book2] 314.1899 (2.56) 270.9780 (2.48) 560.5280 (1.34) -test_benchmark_select_default_paginated[redis_store] 305.2798 (2.48) 277.8170 (2.54) 550.5110 (1.32) -test_benchmark_single_insert[redis_store-book1] 312.5839 (2.54) 279.5660 (2.56) 578.7070 (1.39) -test_benchmark_single_insert[redis_store-book3] 316.9207 (2.58) 284.8630 (2.60) 567.0120 (1.36) -test_benchmark_select_columns[redis_store] 369.1538 (3.00) 331.5770 (3.03) 666.0470 (1.60) -test_benchmark_select_default[redis_store] 553.9420 (4.51) 485.3700 (4.44) 1,235.8540 (2.96) -test_benchmark_bulk_insert[redis_store] 777.4058 (6.32) 730.4280 (6.68) 1,012.7780 (2.43) ----------------------------------------------------------------------------------------------------------------------------- - +------------------------------------------------- benchmark: 22 tests ------------------------------------------------- +Name (time in us) Mean Min Max +----------------------------------------------------------------------------------------------------------------------- +benchmark_delete[redis_store-Wuthering Heights] 123.2946 (1.02) 107.9690 (1.0) 502.6140 (1.33) +benchmark_bulk_delete[redis_store] 120.5815 (1.0) 111.9320 (1.04) 378.8660 (1.0) +benchmark_select_columns_for_one_id[redis_store-book2] 208.2612 (1.73) 180.4660 (1.67) 470.9860 (1.24) +benchmark_select_columns_for_one_id[redis_store-book1] 207.9143 (1.72) 180.6440 (1.67) 489.6890 (1.29) +benchmark_select_columns_for_one_id[redis_store-book0] 204.2471 (1.69) 183.4360 (1.70) 485.2500 (1.28) +benchmark_select_columns_for_one_id[redis_store-book3] 209.5764 (1.74) 189.5780 (1.76) 462.5650 (1.22) +benchmark_select_all_for_one_id[redis_store-book0] 226.4569 (1.88) 207.4920 (1.92) 499.9470 (1.32) +benchmark_select_all_for_one_id[redis_store-book3] 241.5488 (2.00) 210.5230 (1.95) 504.5150 (1.33) +benchmark_select_all_for_one_id[redis_store-book1] 234.4014 (1.94) 210.6420 (1.95) 501.2470 (1.32) +benchmark_select_all_for_one_id[redis_store-book2] 228.9277 (1.90) 212.0090 (1.96) 509.5740 (1.34) +benchmark_update[redis_store-Wuthering Heights-data0] 276.3908 (2.29) 238.3390 (2.21) 704.9450 (1.86) +benchmark_single_insert[redis_store-book3] 311.0476 (2.58) 262.2940 (2.43) 589.3940 (1.56) +benchmark_select_columns_for_some_items[redis_store] 291.2779 (2.42) 266.0960 (2.46) 564.3510 (1.49) +benchmark_select_columns_paginated[redis_store] 300.4108 (2.49) 269.4740 (2.50) 552.8510 (1.46) +benchmark_single_insert[redis_store-book1] 304.5771 (2.53) 274.1740 (2.54) 547.5210 (1.45) +benchmark_single_insert[redis_store-book2] 317.2681 (2.63) 275.6170 (2.55) 641.5440 (1.69) +benchmark_single_insert[redis_store-book0] 313.0004 (2.60) 277.3190 (2.57) 558.2160 (1.47) +benchmark_select_some_items[redis_store] 343.2569 (2.85) 311.9140 (2.89) 624.6600 (1.65) +benchmark_select_default_paginated[redis_store] 359.8463 (2.98) 325.8310 (3.02) 623.2360 (1.65) +benchmark_select_columns[redis_store] 486.6047 (4.04) 429.3250 (3.98) 867.8780 (2.29) +benchmark_select_default[redis_store] 631.3835 (5.24) 584.7630 (5.42) 1,033.5990 (2.73) +benchmark_bulk_insert[redis_store] 761.0832 (6.31) 724.1240 (6.71) 1,034.2950 (2.73) +----------------------------------------------------------------------------------------------------------------------- ``` ## Contributions diff --git a/pydantic_redis/_shared/lua_scripts.py b/pydantic_redis/_shared/lua_scripts.py index 5b052fee..8eee732b 100644 --- a/pydantic_redis/_shared/lua_scripts.py +++ b/pydantic_redis/_shared/lua_scripts.py @@ -1,5 +1,8 @@ """Exposes the redis lua scripts to be used in select queries. +These scripts always return a list of tuples of [record, index] where the index is a flat list of nested models +for that record + Attributes: SELECT_ALL_FIELDS_FOR_ALL_IDS_SCRIPT: the script for selecting all records from redis PAGINATED_SELECT_ALL_FIELDS_FOR_ALL_IDS_SCRIPT: the script for selecting a slice of all records from redis, @@ -14,6 +17,8 @@ but returning only a subset of the fields in each record. """ +# What if instead of constructing tables, we return obj as a JSON string + SELECT_ALL_FIELDS_FOR_ALL_IDS_SCRIPT = """ local s_find = string.find local s_gmatch = string.gmatch @@ -29,48 +34,40 @@ return s_find(s, prefix, 1, true) == 1 end -local function trim_dunder(s) - return s:match '^_*(.-)$' -end - -local function get_obj(id) +local function get_obj_and_index(id) local value = redis_call('HGETALL', id) + local idx = {} for i, k in ipairs(value) do if not (i % 2 == 0) then - if startswith(k, '___') or startswith(k, '____') then - if value[i + 1] == 'null' then - value[i + 1] = 'null' - else - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) - end - - value[i + 1] = nested - end - - value[i] = trim_dunder(k) + if value[i + 1] == 'null' then + elseif startswith(k, '___') then + for v in s_gmatch(value[i + 1], '\"([%w_]+_%%&_[^\"%[%]]+)\"') do + table_insert(idx, v) + table_insert(idx, {get_obj_and_index(v)}) + end elseif startswith(k, '__') then - value[i + 1] = get_obj(value[i + 1]) - value[i] = trim_dunder(k) + table_insert(idx, value[i + 1]) + table_insert(idx, {get_obj_and_index(value[i + 1])}) end - end + + end end if next(value) == nil then - return id + return id, nil end - - return value + return value, idx end repeat local result = redis_call('SCAN', cursor, 'MATCH', ARGV[1]) for _, key in ipairs(result[2]) do if redis_call('TYPE', key).ok == 'hash' then - table_insert(filtered, get_obj(key)) + local value, idx = get_obj_and_index(key) + if type(value) == 'table' then + table_insert(filtered, {value, idx}) + end end end cursor = result[1] @@ -90,41 +87,29 @@ return s_find(s, prefix, 1, true) == 1 end -local function trim_dunder(s) - return s:match '^_*(.-)$' -end - -local function get_obj(id) +local function get_obj_and_index(id) local value = redis_call('HGETALL', id) + local idx = {} for i, k in ipairs(value) do if not (i % 2 == 0) then - if startswith(k, '___') or startswith(k, '____') then - if value[i + 1] == 'null' then - value[i + 1] = 'null' - else - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) - end - - value[i + 1] = nested - end - - value[i] = trim_dunder(k) + if value[i + 1] == 'null' then + elseif startswith(k, '___') then + for v in s_gmatch(value[i + 1], '\"([%w_]+_%%&_[^\"%[%]]+)\"') do + table_insert(idx, v) + table_insert(idx, {get_obj_and_index(v)}) + end elseif startswith(k, '__') then - value[i + 1] = get_obj(value[i + 1]) - value[i] = trim_dunder(k) + table_insert(idx, value[i + 1]) + table_insert(idx, {get_obj_and_index(value[i + 1])}) end - end + end end if next(value) == nil then - return id + return id, nil end - - return value + return value, idx end local table_index_key = ARGV[1] @@ -135,9 +120,9 @@ local ids = redis_call('ZRANGE', table_index_key, start, stop) for _, key in ipairs(ids) do - local value = get_obj(key) + local value, idx = get_obj_and_index(key) if type(value) == 'table' then - table_insert(result, value) + table_insert(result, {value, idx}) end end @@ -158,47 +143,35 @@ return s_find(s, prefix, 1, true) == 1 end -local function trim_dunder(s) - return s:match '^_*(.-)$' -end - -local function get_obj(id) +local function get_obj_and_index(id) local value = redis_call('HGETALL', id) + local idx = {} for i, k in ipairs(value) do if not (i % 2 == 0) then - if startswith(k, '___') or startswith(k, '____') then - if value[i + 1] == 'null' then - value[i + 1] = 'null' - else - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) - end - - value[i + 1] = nested - end - - value[i] = trim_dunder(k) + if value[i + 1] == 'null' then + elseif startswith(k, '___') then + for v in s_gmatch(value[i + 1], '\"([%w_]+_%%&_[^\"%[%]]+)\"') do + table_insert(idx, v) + table_insert(idx, {get_obj_and_index(v)}) + end elseif startswith(k, '__') then - value[i + 1] = get_obj(value[i + 1]) - value[i] = trim_dunder(k) + table_insert(idx, value[i + 1]) + table_insert(idx, {get_obj_and_index(value[i + 1])}) end - end + end end if next(value) == nil then - return id + return id, nil end - - return value + return value, idx end for _, key in ipairs(KEYS) do - local value = get_obj(key) + local value, idx = get_obj_and_index(key) if type(value) == 'table' then - table_insert(result, value) + table_insert(result, {value, idx}) end end @@ -222,40 +195,29 @@ return s_find(s, prefix, 1, true) == 1 end -local function trim_dunder(s) - return s:match '^_*(.-)$' -end - -local function get_obj(id) +local function get_obj_and_index(id) local value = redis_call('HGETALL', id) + local idx = {} for i, k in ipairs(value) do if not (i % 2 == 0) then - if startswith(k, '___') or startswith(k, '____') then - if value[i + 1] == 'null' then - value[i + 1] = 'null' - else - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) - end - - value[i + 1] = nested - end - value[i] = trim_dunder(k) + if value[i + 1] == 'null' then + elseif startswith(k, '___') then + for v in s_gmatch(value[i + 1], '\"([%w_]+_%%&_[^\"%[%]]+)\"') do + table_insert(idx, v) + table_insert(idx, {get_obj_and_index(v)}) + end elseif startswith(k, '__') then - value[i + 1] = get_obj(value[i + 1]) - value[i] = trim_dunder(k) + table_insert(idx, value[i + 1]) + table_insert(idx, {get_obj_and_index(value[i + 1])}) end - end + end end if next(value) == nil then - return id + return id, nil end - - return value + return value, idx end for i, k in ipairs(ARGV) do @@ -270,13 +232,20 @@ if redis_call('TYPE', key).ok == 'hash' then local data = redis_call('HMGET', key, table_unpack(columns)) local parsed_data = {} + local index = {} for i, v in ipairs(data) do - table_insert(parsed_data, trim_dunder(columns[i])) - table_insert(parsed_data, get_obj(v)) + table_insert(parsed_data, columns[i]) + table_insert(parsed_data, v) + + local value, idx = get_obj_and_index(v) + if type(idx) == 'table' then + table_insert(index, v) + table_insert(index, {value, idx}) + end end - - table_insert(filtered, parsed_data) + + table_insert(filtered, {parsed_data, index}) end end cursor = result[1] @@ -297,41 +266,29 @@ return s_find(s, prefix, 1, true) == 1 end -local function trim_dunder(s) - return s:match '^_*(.-)$' -end - -local function get_obj(id) +local function get_obj_and_index(id) local value = redis_call('HGETALL', id) + local idx = {} for i, k in ipairs(value) do if not (i % 2 == 0) then - if startswith(k, '___') or startswith(k, '____') then - if value[i + 1] == 'null' then - value[i + 1] = 'null' - else - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) - end - - value[i + 1] = nested - end - - value[i] = trim_dunder(k) + if value[i + 1] == 'null' then + elseif startswith(k, '___') then + for v in s_gmatch(value[i + 1], '\"([%w_]+_%%&_[^\"%[%]]+)\"') do + table_insert(idx, v) + table_insert(idx, {get_obj_and_index(v)}) + end elseif startswith(k, '__') then - value[i + 1] = get_obj(value[i + 1]) - value[i] = trim_dunder(k) + table_insert(idx, value[i + 1]) + table_insert(idx, {get_obj_and_index(value[i + 1])}) end - end + end end if next(value) == nil then - return id + return id, nil end - - return value + return value, idx end local result = {} @@ -351,15 +308,23 @@ for _, key in ipairs(ids) do local data = redis_call('HMGET', key, table_unpack(columns)) local parsed_data = {} + local index = {} for i, v in ipairs(data) do if v then - table_insert(parsed_data, trim_dunder(columns[i])) - table_insert(parsed_data, get_obj(v)) + table_insert(parsed_data, columns[i]) + table_insert(parsed_data, v) + + local value, idx = get_obj_and_index(v) + if type(idx) == 'table' then + table_insert(index, v) + table_insert(index, {value, idx}) + end end end - table_insert(result, parsed_data) + table_insert(result, {parsed_data, index}) + end return result @@ -381,41 +346,29 @@ return s_find(s, prefix, 1, true) == 1 end -local function trim_dunder(s) - return s:match '^_*(.-)$' -end - -local function get_obj(id) +local function get_obj_and_index(id) local value = redis_call('HGETALL', id) + local idx = {} for i, k in ipairs(value) do if not (i % 2 == 0) then - if startswith(k, '___') or startswith(k, '____') then - if value[i + 1] == 'null' then - value[i + 1] = 'null' - else - local nested = {} - - for v in s_gmatch(value[i + 1], '([^%[^,^%]^\"]+)') do - table_insert(nested, get_obj(v)) - end - - value[i + 1] = nested - end - - value[i] = trim_dunder(k) + if value[i + 1] == 'null' then + elseif startswith(k, '___') then + for v in s_gmatch(value[i + 1], '\"([%w_]+_%%&_[^\"%[%]]+)\"') do + table_insert(idx, v) + table_insert(idx, {get_obj_and_index(v)}) + end elseif startswith(k, '__') then - value[i + 1] = get_obj(value[i + 1]) - value[i] = trim_dunder(k) + table_insert(idx, value[i + 1]) + table_insert(idx, {get_obj_and_index(value[i + 1])}) end - end + end end if next(value) == nil then - return id + return id, nil end - - return value + return value, idx end for _, k in ipairs(ARGV) do @@ -425,15 +378,22 @@ for _, key in ipairs(KEYS) do local data = redis_call('HMGET', key, table_unpack(columns)) local parsed_data = {} + local index = {} for i, v in ipairs(data) do if v then - table_insert(parsed_data, trim_dunder(columns[i])) - table_insert(parsed_data, get_obj(v)) + table_insert(parsed_data, columns[i]) + table_insert(parsed_data, v) + + local value, idx = get_obj_and_index(v) + if type(idx) == 'table' then + table_insert(index, v) + table_insert(index, {value, idx}) + end end end - table_insert(result, parsed_data) + table_insert(result, {parsed_data, index}) end return result """ diff --git a/pydantic_redis/_shared/model/base.py b/pydantic_redis/_shared/model/base.py index b7b7c580..00c1f545 100644 --- a/pydantic_redis/_shared/model/base.py +++ b/pydantic_redis/_shared/model/base.py @@ -20,6 +20,7 @@ from_dict_to_key_value_list, from_bytes_to_str, from_str_or_bytes_to_any, + groups_of_n, ) @@ -161,19 +162,23 @@ def serialize_partially(cls, data: Optional[Dict[str, Any]]) -> Dict[str, Any]: @classmethod def deserialize_partially( - cls, data: Union[List[Any], Dict[Any, Any]] = () + cls, data: Union[List[Any], Dict[Any, Any]] = (), index: Dict[Any, Any] = None ) -> Dict[str, Any]: - """Casts str or bytes in a dict or flattened key-value list to expected data types. + """Casts str or bytes in a dict to expected data types. Converts str or bytes to their expected data types Args: data: flattened list of key-values or dictionary of data to cast. Keeping it as potentially a dictionary ensures backward compatibility. + index: dictionary of the index of nested models potentially present Returns: the dictionary of properly parsed key-values. """ + if index is None: + index = {} + if isinstance(data, dict): # for backward compatibility data = from_dict_to_key_value_list(data) @@ -182,13 +187,16 @@ def deserialize_partially( field_type_trees = cls.get_field_type_trees() - for i in range(0, len(data), 2): - key = from_bytes_to_str(data[i]) + for k, v in groups_of_n(data, 2): + # remove the dunders for nested model fields + key = from_bytes_to_str(k).lstrip("_") field_type = cls._field_types.get(key) - value = from_str_or_bytes_to_any(value=data[i + 1], field_type=field_type) + value = from_str_or_bytes_to_any(value=v, field_type=field_type) type_tree = field_type_trees.get(key) - parsed_dict[key] = _cast_by_type_tree(value=value, type_tree=type_tree) + parsed_dict[key] = _cast_by_type_tree( + value=value, type_tree=type_tree, index=index + ) return parsed_dict @@ -240,13 +248,16 @@ def _generate_field_type_tree(field_type: Any, strict: bool = False) -> AggTypeT return None, (field_type,) -def _cast_by_type_tree(value: Any, type_tree: Optional[AggTypeTree]) -> Any: +def _cast_by_type_tree( + value: Any, type_tree: Optional[AggTypeTree], index: Dict[Any, Any] = None +) -> Any: """Casts a given value into a value basing on the tree of its aggregate type Args: value: the value to be cast basing on the type tree type_tree: the tree representing the nested hierarchy of types for the aggregate type that the value is to be cast into + index: dictionary of the index of nested models potentially present Returns: the parsed value @@ -259,26 +270,33 @@ def _cast_by_type_tree(value: Any, type_tree: Optional[AggTypeTree]) -> Any: if nesting_type is NestingType.ON_ROOT: _type = type_args[0] - return _type(**_type.deserialize_partially(value)) + nested_model_data = value + if isinstance(value, str): + # load the nested model if it is not yet loaded + nested_model_data = index.get(value, value) + return _type(**_type.deserialize_partially(nested_model_data)) if nesting_type is NestingType.IN_LIST: _type = type_args[0] - return [_cast_by_type_tree(item, _type) for item in value] + return [_cast_by_type_tree(item, _type, index) for item in value] if nesting_type is NestingType.IN_TUPLE: return tuple( - [_cast_by_type_tree(item, _type) for _type, item in zip(type_args, value)] + [ + _cast_by_type_tree(item, _type, index) + for _type, item in zip(type_args, value) + ] ) if nesting_type is NestingType.IN_DICT: _, value_type = type_args - return {k: _cast_by_type_tree(v, value_type) for k, v in value.items()} + return {k: _cast_by_type_tree(v, value_type, index) for k, v in value.items()} if nesting_type is NestingType.IN_UNION: # the value can be any of the types in type_args for _type in type_args: try: - parsed_value = _cast_by_type_tree(value, _type) + parsed_value = _cast_by_type_tree(value, _type, index) # return the first successfully parsed value # that is not equal to the original value if parsed_value != value: diff --git a/pydantic_redis/_shared/model/prop_utils.py b/pydantic_redis/_shared/model/prop_utils.py index b1adcf00..8eb01544 100644 --- a/pydantic_redis/_shared/model/prop_utils.py +++ b/pydantic_redis/_shared/model/prop_utils.py @@ -2,11 +2,16 @@ """ +import re from typing import Type, Any from .base import AbstractModel +NESTED_MODEL_SEPARATOR = "_%&_" +NESTED_MODEL_VALUE_REGEX = re.compile(f"^([\\w_]+{NESTED_MODEL_SEPARATOR}[\\w_]+)$") + + def get_redis_key(model: Type[AbstractModel], primary_key_value: Any): """Gets the key used internally in redis for the `primary_key_value` of `model`. @@ -30,7 +35,7 @@ def get_redis_key_prefix(model: Type[AbstractModel]): the prefix of the all the redis keys that are associated with this model """ model_name = model.__name__.lower() - return f"{model_name}_%&_" + return f"{model_name}{NESTED_MODEL_SEPARATOR}" def get_redis_keys_regex(model: Type[AbstractModel]): diff --git a/pydantic_redis/_shared/model/select_utils.py b/pydantic_redis/_shared/model/select_utils.py index f509e74a..114cf94c 100644 --- a/pydantic_redis/_shared/model/select_utils.py +++ b/pydantic_redis/_shared/model/select_utils.py @@ -2,7 +2,7 @@ """ -from typing import List, Any, Type, Union, Awaitable, Optional +from typing import List, Any, Type, Union, Awaitable, Optional, Dict, Tuple from pydantic_redis._shared.model.prop_utils import ( get_redis_keys_regex, @@ -10,11 +10,14 @@ get_model_index_key, ) - from .base import AbstractModel +from ..utils import groups_of_n + + +RawRedisSelectData = List[Tuple[List[Any], List[Any]]] -def get_select_fields(model: Type[AbstractModel], columns: List[str] = []) -> List[str]: +def get_select_fields(model: Type[AbstractModel], columns: List[str] = ()) -> List[str]: """Gets the fields to be used for selecting HMAP fields in Redis. It replaces any fields in `columns` that correspond to nested records with their @@ -35,7 +38,7 @@ def select_all_fields_all_ids( model: Type[AbstractModel], skip: int = 0, limit: Optional[int] = None, -) -> Union[List[List[Any]], Awaitable[List[List[Any]]]]: +) -> Union[RawRedisSelectData, Awaitable[RawRedisSelectData]]: """Retrieves all records of the given model in the redis database. Args: @@ -44,7 +47,7 @@ def select_all_fields_all_ids( limit: the maximum number of records to return. If None, limit is infinity. Returns: - the list of records from redis, each record being a flattened list of key-values. + list of tuple of [record, index-of-nested-models] with each record being a flattened list of key-values. In case we are using async, an Awaitable of that list is returned instead. """ if isinstance(limit, int): @@ -58,7 +61,7 @@ def select_all_fields_all_ids( def select_all_fields_some_ids( model: Type[AbstractModel], ids: List[str] -) -> Union[List[List[Any]], Awaitable[List[List[Any]]]]: +) -> Union[RawRedisSelectData, Awaitable[RawRedisSelectData]]: """Retrieves some records from redis. Args: @@ -66,7 +69,7 @@ def select_all_fields_some_ids( ids: the list of primary keys of the records to be retrieved. Returns: - the list of records where each record is a flattened key-value list. + list of tuple of [record, index-of-nested-models] with each record is a flattened key-value list. In case we are using async, an Awaitable of that list is returned instead. """ table_prefix = get_redis_key_prefix(model=model) @@ -80,7 +83,7 @@ def select_some_fields_all_ids( fields: List[str], skip: int = 0, limit: Optional[int] = None, -) -> Union[List[List[Any]], Awaitable[List[List[Any]]]]: +) -> Union[RawRedisSelectData, Awaitable[RawRedisSelectData]]: """Retrieves records of model from redis, each as with a subset of the fields. Args: @@ -90,7 +93,7 @@ def select_some_fields_all_ids( limit: the maximum number of records to return. If None, limit is infinity. Returns: - the list of records from redis, each record being a flattened list of key-values. + list of tuple of [record, index-of-nested-models] with each record being a flattened list of key-values. In case we are using async, an Awaitable of that list is returned instead. """ columns = get_select_fields(model=model, columns=fields) @@ -108,7 +111,7 @@ def select_some_fields_all_ids( def select_some_fields_some_ids( model: Type[AbstractModel], fields: List[str], ids: List[str] -) -> Union[List[List[Any]], Awaitable[List[List[Any]]]]: +) -> Union[RawRedisSelectData, Awaitable[RawRedisSelectData]]: """Retrieves some records of current model from redis, each as with a subset of the fields. Args: @@ -117,7 +120,7 @@ def select_some_fields_some_ids( ids: the list of primary keys of the records to be retrieved. Returns: - the list of records from redis, each record being a flattened list of key-values. + list of tuple of [record, index-of-nested-models] with each record being a flattened list of key-values. In case we are using async, an Awaitable of that list is returned instead. """ table_prefix = get_redis_key_prefix(model=model) @@ -128,7 +131,7 @@ def select_some_fields_some_ids( def parse_select_response( - model: Type[AbstractModel], response: List[List], as_models: bool + model: Type[AbstractModel], response: RawRedisSelectData, as_models: bool ): """Casts a list of flattened key-value lists into a list of models or dicts. @@ -150,17 +153,23 @@ def parse_select_response( if as_models: return [ - model(**model.deserialize_partially(record)) - for record in response - if record != [] + model( + **model.deserialize_partially(record, index=_construct_index(raw_index)) + ) + for record, raw_index in response + if len(response) != 0 ] - return [model.deserialize_partially(record) for record in response if record != []] + return [ + model.deserialize_partially(record, index=_construct_index(raw_index)) + for record, raw_index in response + if len(response) != 0 + ] def _select_all_ids_all_fields_paginated( model: Type[AbstractModel], limit: int, skip: Optional[int] -): +) -> Union[RawRedisSelectData, Awaitable[RawRedisSelectData]]: """Retrieves a slice of all records of the given model in the redis database. Args: @@ -169,7 +178,7 @@ def _select_all_ids_all_fields_paginated( limit: the maximum number of records to return. If None, limit is infinity. Returns: - the list of records from redis, each record being a flattened list of key-values. + list of tuple of [record, index-of-nested-models] with each record being a flattened list of key-values. In case we are using async, an Awaitable of that list is returned instead. """ if skip is None: @@ -182,7 +191,7 @@ def _select_all_ids_all_fields_paginated( def _select_some_fields_all_ids_paginated( model: Type[AbstractModel], columns: List[str], limit: int, skip: int -): +) -> Union[RawRedisSelectData, Awaitable[RawRedisSelectData]]: """Retrieves a slice of all records of model from redis, each as with a subset of the fields. Args: @@ -192,7 +201,7 @@ def _select_some_fields_all_ids_paginated( limit: the maximum number of records to return. If None, limit is infinity. Returns: - the list of records from redis, each record being a flattened list of key-values. + list of tuple of [record, index-of-nested-models] with each record being a flattened list of key-values. In case we are using async, an Awaitable of that list is returned instead. """ if skip is None: @@ -201,3 +210,26 @@ def _select_some_fields_all_ids_paginated( args = [table_index_key, skip, limit, *columns] store = model.get_store() return store.paginated_select_some_fields_for_all_ids_script(args=args) + + +def _construct_index(index_list: List[Any]) -> Dict[str, Any]: + """Constructs the index dict from the index list of nested models returned from redis + + Args: + index_list: the flat list of the index of nested models, with key followed by [model, index] tuple + [key1, [model1_flat_list, index1_flat_list], key2, [model2_flat_list, index2_flat_list]...] + + Returns: + the index as a dict + """ + index = {} + for k, model_and_index in groups_of_n(index_list, 2): + model_as_list, index_as_list = model_and_index + model_index = _construct_index(index_as_list) + index[k] = { + # remove the dunders for nested model fields + key.lstrip("_"): model_index.get(value, value) + for key, value in groups_of_n(model_as_list, 2) + } + + return index diff --git a/pydantic_redis/_shared/utils.py b/pydantic_redis/_shared/utils.py index 8fa95db7..7c1e87dd 100644 --- a/pydantic_redis/_shared/utils.py +++ b/pydantic_redis/_shared/utils.py @@ -3,10 +3,12 @@ """ import typing -from typing import Any, Tuple, Optional, Union, Dict, Type, List +from typing import Any, Tuple, Optional, Union, Dict, Type, List, Iterable, TypeVar import orjson +T = TypeVar("T") + def strip_leading(word: str, substring: str) -> str: """Strips the leading substring if it exists. @@ -96,8 +98,12 @@ def from_str_or_bytes_to_any(value: Any, field_type: Type) -> Any: elif not isinstance(value, str): return value - # JSON parse all other values that are str - return orjson.loads(value) + try: + # JSON parse all other values that are str + return orjson.loads(value) + except orjson.JSONDecodeError: + # try to be as fault-tolerant as sanely possible + return value def from_any_to_valid_redis_type(value: Any) -> Union[str, bytes, List[Any]]: @@ -155,3 +161,18 @@ def from_dict_to_key_value_list(data: Dict[str, Any]) -> List[Any]: parsed_list.append(v) return parsed_list + + +def groups_of_n(items: Iterable[T], n: int) -> Iterable[Tuple[T, ...]]: + """Returns an iterable of tuples of size n from the given list of items + + Note that it might ignore the last items if n does not fit nicely into the items list + + Args: + items: the list of items from which to extract the tuples + n: the size of the tuples + + Returns: + the iterable of tuples of n size from the list of items + """ + return zip(*[iter(items)] * n) diff --git a/test/conftest.py b/test/conftest.py index 98356cdc..3263a7f2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,6 +1,6 @@ import socket from datetime import date -from typing import Tuple, List, Optional +from typing import Tuple, List, Optional, Dict import pytest import pytest_asyncio @@ -56,6 +56,8 @@ class Library(syn.Model): lost: Optional[List[Book]] = None popular: Optional[Tuple[Book, Book]] = None new: Optional[Tuple[Book, Author, Book, int]] = None + list_of_tuples: Optional[List[Tuple[str, Book]]] = None + dict_of_models: Optional[Dict[str, Book]] = None class AsyncLibrary(asy.Model): diff --git a/test/test_pydantic_redis.py b/test/test_pydantic_redis.py index a174936e..c1b05cb9 100644 --- a/test/test_pydantic_redis.py +++ b/test/test_pydantic_redis.py @@ -6,6 +6,7 @@ import pytest from pydantic_redis import Store +from pydantic_redis._shared.model.prop_utils import NESTED_MODEL_SEPARATOR from pydantic_redis.config import RedisConfig # noqa from pydantic_redis._shared.model.base import NESTED_MODEL_PREFIX # noqa from pydantic_redis._shared.utils import strip_leading # noqa @@ -199,6 +200,38 @@ def test_update_optional_nested_tuple_of_models(store: Store): assert got == expected +@pytest.mark.parametrize("store", redis_store_fixture) +def test_update_list_of_tuples_of_nested_models(store: Store): + list_of_tuples = [("some book", books[0]), ("book2", books[2])] + data = [Library(name="Babel Library", address="In a book", list_of_tuples=list_of_tuples)] + Library.insert(data) + # the tuple of nested models is automatically inserted + got = sorted(Book.select(), key=lambda x: x.title) + expected_books = [book for _, book in list_of_tuples] + expected = sorted(expected_books, key=lambda x: x.title) + assert expected == got + + got = sorted(Library.select(), key=lambda x: x.name) + expected = sorted(data, key=lambda x: x.name) + assert got == expected + + +@pytest.mark.parametrize("store", redis_store_fixture) +def test_update_dict_of_models(store: Store): + dict_of_models = {"some book": books[0], "book2": books[2]} + data = [Library(name="Babel Library", address="In a book", dict_of_models=dict_of_models)] + Library.insert(data) + # the tuple of nested models is automatically inserted + got = sorted(Book.select(), key=lambda x: x.title) + expected_books = [book for _, book in dict_of_models.items()] + expected = sorted(expected_books, key=lambda x: x.title) + assert expected == got + + got = sorted(Library.select(), key=lambda x: x.name) + expected = sorted(data, key=lambda x: x.name) + assert got == expected + + @pytest.mark.parametrize("store", redis_store_fixture) def test_select_default(store: Store): """Selecting without arguments returns all the book models""" @@ -432,7 +465,7 @@ def test_delete_multiple(store: Store): def __deserialize_book_data(raw_book_data: Dict[str, Any]) -> Book: """Deserializes the raw book data returning a book instance""" author_id = raw_book_data.pop(f"{NESTED_MODEL_PREFIX}author") - author_id = strip_leading(author_id, "author_%&_") + author_id = strip_leading(author_id, f"author{NESTED_MODEL_SEPARATOR}") data = Book.deserialize_partially(raw_book_data)