Skip to content

Commit

Permalink
unify relation queries
Browse files Browse the repository at this point in the history
  • Loading branch information
mr-martian committed Oct 28, 2024
1 parent 904f1a1 commit 5ce0e78
Showing 1 changed file with 34 additions and 31 deletions.
65 changes: 34 additions & 31 deletions rebabel_format/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@ class Unit:
type: str
index: int
name: Any
parent: Optional['Unit'] = None
children: Sequence['Unit'] = field(default_factory=list)

def __getitem__(self, key):
return Condition(self.index, key, 'feature')

def parent(self, other):
if not isinstance(other, Unit):
raise ValueError(f'Parent must be a Unit, not {other.__type__.__name__}')
return Condition(self.index, other.index, 'parent')

def child(self, other):
if not isinstance(other, Unit):
raise ValueError(f'Child must be a Unit, not {other.__type__.__name__}')
return Condition(self.index, other.index, 'child')

@dataclass
class Condition:
left: Any
Expand All @@ -42,6 +50,10 @@ def toSQL(self, feature_index: dict):
elif self.is_compare_ref_feature():
ql, al, sl = self.left.toSQL(feature_index)
return f'{ql} = U{self.right.index}', ids, False
elif self.operator == 'parent':
return f'EXISTS (SELECT NULL FROM relations WHERE parent = U{self.right} AND child = U{self.left} AND isprimary = ? AND active = ?)', [True, True], False
elif self.operator == 'child':
return f'EXISTS (SELECT NULL FROM relations WHERE child = U{self.right} AND parent = U{self.left} AND isprimary = ? AND active = ?)', [False, True], False
def _toSQL(obj):
nonlocal self, feature_index
if isinstance(obj, Condition):
Expand Down Expand Up @@ -69,6 +81,8 @@ def features(self):
return {(self.left, self.right, True)}
elif self.operator == 'exists':
return {(self.left, self.right, False)}
elif self.operator in ['parent', 'child']:
return {(self.left, None, False), (self.right, None, False)}
ret = set()
if isinstance(self.left, Condition):
ret |= self.left.features()
Expand Down Expand Up @@ -192,7 +206,7 @@ def add(self, condition):
is_str = any(x[1] == 'str' for x in ids)
self.features[fkey] = (n, [x[0] for x in ids], is_str)
feat_index.add(fkey)
ckey = tuple(sorted(x[0] for x in feats if x[2]))
ckey = tuple(sorted(x[0] for x in feats))
if ckey not in self.conditionals:
self.conditionals[ckey] = (condition, feat_index)
else:
Expand Down Expand Up @@ -225,35 +239,13 @@ def search(self):

intersect = IntersectionTracker(dict(enumerate(unit_ids)))

def restrict_edge(parent, child, primary):
nonlocal self, intersect, unit_ids
self.db.execute_clauses(
'SELECT parent, child FROM relations',
WhereClause('parent', list(unit_ids[parent])),
WhereClause('child', list(unit_ids[child])),
WhereClause('active', True),
WhereClause('isprimary', True))
pairs = self.db.cur.fetchall()
if not pairs:
return False
intersect.restrict(parent, child, pairs)
return True

for u in self.units:
if u.parent is not None:
if not restrict_edge(u.parent.index, u.index, True):
return
for c in u.children:
if not restrict_edge(u.index, c.index, False):
return

for ckey in self.conditionals:
if len(ckey) < 2:
continue
cond, feats = self.conditionals[ckey]
where, params, _ = cond.toSQL(self.features)
where = f'({where})'
select = [None] * len(ckey)
select = {k: None for k in set(ckey)}
tables = []
for fkey in feats:
if not fkey[2]:
Expand All @@ -264,18 +256,29 @@ def restrict_edge(parent, child, primary):
qs = ','.join(['?']*len(ids))
where += f' AND F{n}.feature IN ({qs})'
i = fkey[0]
if not select[i]:
if fkey[2] and not select[i]:
select[i] = f'F{n}.unit AS U{i}'
else:
where += f' AND F{n}.unit = U{i}'
query = f'SELECT {", ".join(select)} FROM {", ".join(tables)} WHERE {where}'
sl = []
for k in sorted(select):
if select[k]:
sl.append(select[k])
else:
sl.append(f'TU{k}.id AS U{k}')
tables.append(f'units TU{k}')
qs = ', '.join(['?']*len(unit_ids[k]))
where += f' AND U{k} IN ({qs})'
params += unit_ids[k]
query = f'SELECT {", ".join(sl)} FROM {", ".join(tables)} WHERE {where}'
self.db.cur.execute(query, params)
sets = self.db.cur.fetchall()
if not sets:
return
for i in range(len(ckey)):
for j in range(i+1, len(ckey)):
intersect.restrict(ckey[i], ckey[j], [(s[i], s[j]) for s in sets])
ids = sorted(select)
for i in range(len(ids)):
for j in range(i+1, len(ids)):
intersect.restrict(ids[i], ids[j], [(s[i], s[j]) for s in sets])

intersect.make_dict()

Expand Down

0 comments on commit 5ce0e78

Please sign in to comment.