Skip to content

Commit

Permalink
subqueries
Browse files Browse the repository at this point in the history
  • Loading branch information
mr-martian committed Oct 29, 2024
1 parent 5ce0e78 commit 47bce2d
Showing 1 changed file with 59 additions and 13 deletions.
72 changes: 59 additions & 13 deletions rebabel_format/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from rebabel_format.db import RBBLFile, WhereClause
from rebabel_format import utils

from collections import defaultdict
from collections import Counter, defaultdict
from dataclasses import dataclass, field
from typing import Any, Optional, Sequence

Expand All @@ -27,6 +27,12 @@ def child(self, other):
raise ValueError(f'Child must be a Unit, not {other.__type__.__name__}')
return Condition(self.index, other.index, 'child')

def subquery(self, min=1, max=None):
ret = Query(self.query.db, self.query.type_map, self.query.feat_map)
u = ret.unit(self.type, None)
self.query.subqueries.append((ret, self.index, min, max))
return ret, u

@dataclass
class Condition:
left: Any
Expand Down Expand Up @@ -176,6 +182,9 @@ def __init__(self, db, type_map=None, feat_map=None):
self.units = []
self.conditionals = {}
self.features = {} # (uidx, name) => (fidx, ids, is_str)
self.subqueries = [] # [(query, uidx, min, max), ...]

self.intersect = None

def unit(self, utype, name):
ret = Unit(self, utype, len(self.units), name)
Expand Down Expand Up @@ -213,7 +222,7 @@ def add(self, condition):
oldcond, oldfeats = self.conditionals[ckey]
self.conditionals[ckey] = (oldcond & condition, oldfeats | feat_index)

def search(self):
def prepare_search(self, parent_ids=None):
unit_ids = []
for i in range(len(self.units)):
cond, feats = self.conditionals[(i,)]
Expand All @@ -234,10 +243,12 @@ def search(self):
query = f'SELECT {select} FROM {", ".join(tables)} WHERE {where}'
self.db.cur.execute(query, params)
unit_ids.append(set(x[0] for x in self.db.cur.fetchall()))
if i == 0 and parent_ids is not None:
unit_ids[0] = unit_ids[0] & set(parent_ids)
if not unit_ids[-1]:
return

intersect = IntersectionTracker(dict(enumerate(unit_ids)))
self.intersect = IntersectionTracker(dict(enumerate(unit_ids)))

for ckey in self.conditionals:
if len(ckey) < 2:
Expand Down Expand Up @@ -278,19 +289,54 @@ def search(self):
ids = sorted(select)
for i in range(len(ids)):
for j in range(i+1, len(ids)):
intersect.restrict(ids[i], ids[j], [(s[i], s[j]) for s in sets])
self.intersect.restrict(ids[i], ids[j],
[(s[i], s[j]) for s in sets])

self.intersect.make_dict()

intersect.make_dict()
def combine(self, cur, names):
if len(cur) == len(self.units):
yield dict(zip(names, cur))
else:
for i in sorted(self.intersect.possible(dict(enumerate(cur)), len(cur))):
yield from self.combine(cur + [i], names)

def get_results(self, parent=None):
if self.intersect is None:
return
names = [u.name for u in self.units]
def combine(cur):
nonlocal names, intersect
if len(cur) == len(names):
yield dict(zip(names, cur))
else:
for i in sorted(intersect.possible(dict(enumerate(cur)), len(cur))):
yield from combine(cur + [i])
yield from combine([])
initial = [parent] if parent is not None else []
if not self.subqueries:
yield from self.combine(initial, names)
else:
indexes = list(range(len(self.units)))
for result in self.combine(initial, indexes):
ret = {names[i]: result[i] for i in indexes}
count = Counter()
ok = True
for sub, idx, mn, mx in self.subqueries:
n = count[idx]
count[idx] += 1
results = list(sub.get_results(result[idx]))
if mn is not None and len(results) < mn:
ok = False
break
if mx is not None and len(results) > mx:
ok = False
break
ret[(names[idx], n)] = results
if ok:
yield ret

def search(self):
self.prepare_search()
if self.intersect is None:
return
for sub, idx, mn, mx in self.subqueries:
sub.prepare_search(self.intersect.units[idx])
if mn is not None and mn > 0 and sub.intersect is None:
return
yield from self.get_results()

class FeatureQuery:
def __init__(self, featid, value=None, operator=None):
Expand Down

0 comments on commit 47bce2d

Please sign in to comment.