From 47bce2d6751167d7ca8e147242d750ed07c5a493 Mon Sep 17 00:00:00 2001 From: Daniel Swanson Date: Tue, 29 Oct 2024 10:48:27 -0400 Subject: [PATCH] subqueries --- rebabel_format/query.py | 72 +++++++++++++++++++++++++++++++++-------- 1 file changed, 59 insertions(+), 13 deletions(-) diff --git a/rebabel_format/query.py b/rebabel_format/query.py index 4d4e33a..e61528f 100644 --- a/rebabel_format/query.py +++ b/rebabel_format/query.py @@ -3,7 +3,7 @@ from rebabel_format.db import RBBLFile, WhereClause from rebabel_format import utils -from collections import defaultdict +from collections import Counter, defaultdict from dataclasses import dataclass, field from typing import Any, Optional, Sequence @@ -27,6 +27,12 @@ def child(self, other): raise ValueError(f'Child must be a Unit, not {other.__type__.__name__}') return Condition(self.index, other.index, 'child') + def subquery(self, min=1, max=None): + ret = Query(self.query.db, self.query.type_map, self.query.feat_map) + u = ret.unit(self.type, None) + self.query.subqueries.append((ret, self.index, min, max)) + return ret, u + @dataclass class Condition: left: Any @@ -176,6 +182,9 @@ def __init__(self, db, type_map=None, feat_map=None): self.units = [] self.conditionals = {} self.features = {} # (uidx, name) => (fidx, ids, is_str) + self.subqueries = [] # [(query, uidx, min, max), ...] + + self.intersect = None def unit(self, utype, name): ret = Unit(self, utype, len(self.units), name) @@ -213,7 +222,7 @@ def add(self, condition): oldcond, oldfeats = self.conditionals[ckey] self.conditionals[ckey] = (oldcond & condition, oldfeats | feat_index) - def search(self): + def prepare_search(self, parent_ids=None): unit_ids = [] for i in range(len(self.units)): cond, feats = self.conditionals[(i,)] @@ -234,10 +243,12 @@ def search(self): query = f'SELECT {select} FROM {", ".join(tables)} WHERE {where}' self.db.cur.execute(query, params) unit_ids.append(set(x[0] for x in self.db.cur.fetchall())) + if i == 0 and parent_ids is not None: + unit_ids[0] = unit_ids[0] & set(parent_ids) if not unit_ids[-1]: return - intersect = IntersectionTracker(dict(enumerate(unit_ids))) + self.intersect = IntersectionTracker(dict(enumerate(unit_ids))) for ckey in self.conditionals: if len(ckey) < 2: @@ -278,19 +289,54 @@ def search(self): ids = sorted(select) for i in range(len(ids)): for j in range(i+1, len(ids)): - intersect.restrict(ids[i], ids[j], [(s[i], s[j]) for s in sets]) + self.intersect.restrict(ids[i], ids[j], + [(s[i], s[j]) for s in sets]) + + self.intersect.make_dict() - intersect.make_dict() + def combine(self, cur, names): + if len(cur) == len(self.units): + yield dict(zip(names, cur)) + else: + for i in sorted(self.intersect.possible(dict(enumerate(cur)), len(cur))): + yield from self.combine(cur + [i], names) + def get_results(self, parent=None): + if self.intersect is None: + return names = [u.name for u in self.units] - def combine(cur): - nonlocal names, intersect - if len(cur) == len(names): - yield dict(zip(names, cur)) - else: - for i in sorted(intersect.possible(dict(enumerate(cur)), len(cur))): - yield from combine(cur + [i]) - yield from combine([]) + initial = [parent] if parent is not None else [] + if not self.subqueries: + yield from self.combine(initial, names) + else: + indexes = list(range(len(self.units))) + for result in self.combine(initial, indexes): + ret = {names[i]: result[i] for i in indexes} + count = Counter() + ok = True + for sub, idx, mn, mx in self.subqueries: + n = count[idx] + count[idx] += 1 + results = list(sub.get_results(result[idx])) + if mn is not None and len(results) < mn: + ok = False + break + if mx is not None and len(results) > mx: + ok = False + break + ret[(names[idx], n)] = results + if ok: + yield ret + + def search(self): + self.prepare_search() + if self.intersect is None: + return + for sub, idx, mn, mx in self.subqueries: + sub.prepare_search(self.intersect.units[idx]) + if mn is not None and mn > 0 and sub.intersect is None: + return + yield from self.get_results() class FeatureQuery: def __init__(self, featid, value=None, operator=None):