Skip to content

Commit

Permalink
refactor query generation
Browse files Browse the repository at this point in the history
fixes #13, #22

Co-authored-by: Tim Graham <[email protected]>
  • Loading branch information
WaVEV and timgraham committed May 24, 2024
1 parent e3f02fc commit 3789b9f
Show file tree
Hide file tree
Showing 10 changed files with 204 additions and 260 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ jobs:
lookup
model_fields
or_lookups
queries.tests.Ticket12807Tests.test_ticket_12807
sessions_tests
timezones
update
Expand Down
10 changes: 10 additions & 0 deletions django_mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,13 @@
from .utils import check_django_compatability

check_django_compatability()

from .expressions import register_expressions # noqa: E402
from .functions import register_functions # noqa: E402
from .lookups import register_lookups # noqa: E402
from .query import register_nodes # noqa: E402

register_expressions()
register_functions()
register_lookups()
register_nodes()
25 changes: 20 additions & 5 deletions django_mongodb/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

from django.core.exceptions import ImproperlyConfigured
from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.backends.signals import connection_created
Expand All @@ -10,6 +12,7 @@
from .features import DatabaseFeatures
from .introspection import DatabaseIntrospection
from .operations import DatabaseOperations
from .query_utils import safe_regex
from .schema import DatabaseSchemaEditor
from .utils import CollectionDebugWrapper

Expand Down Expand Up @@ -52,11 +55,23 @@ class DatabaseWrapper(BaseDatabaseWrapper):
"UUIDField": "string",
}
operators = {
"exact": "= %s",
"gt": "> %s",
"gte": ">= %s",
"lt": "< %s",
"lte": "<= %s",
"exact": lambda val: val,
"gt": lambda val: {"$gt": val},
"gte": lambda val: {"$gte": val},
"lt": lambda val: {"$lt": val},
"lte": lambda val: {"$lte": val},
"in": lambda val: {"$in": val},
"range": lambda val: {"$gte": val[0], "$lte": val[1]},
"isnull": lambda val: None if val else {"$ne": None},
"iexact": safe_regex("^%s$", re.IGNORECASE),
"startswith": safe_regex("^%s"),
"istartswith": safe_regex("^%s", re.IGNORECASE),
"endswith": safe_regex("%s$"),
"iendswith": safe_regex("%s$", re.IGNORECASE),
"contains": safe_regex("%s"),
"icontains": safe_regex("%s", re.IGNORECASE),
"regex": lambda val: re.compile(val),
"iregex": lambda val: re.compile(val, re.IGNORECASE),
}

display_name = "MongoDB"
Expand Down
7 changes: 5 additions & 2 deletions django_mongodb/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from django.core.exceptions import EmptyResultSet
from django.core.exceptions import EmptyResultSet, FullResultSet
from django.db import DatabaseError, IntegrityError, NotSupportedError
from django.db.models import NOT_PROVIDED, Count, Expression, Value
from django.db.models.aggregates import Aggregate
Expand Down Expand Up @@ -136,7 +136,10 @@ def build_query(self, columns=None):
self.check_query()
self.setup_query()
query = self.query_class(self, columns)
query.add_filters(self.query.where)
try:
query.mongo_query = self.query.where.as_mql(self, self.connection)
except FullResultSet:
query.mongo_query = {}
query.order_by(self._get_ordering())
return query

Expand Down
9 changes: 9 additions & 0 deletions django_mongodb/expressions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from django.db.models.expressions import Col


def col(self, compiler, connection): # noqa: ARG001
return self.target.column


def register_expressions():
Col.as_mql = col
17 changes: 6 additions & 11 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# cannot encode object: <django.db.models.expressions.DatabaseDefault
"basic.tests.ModelInstanceCreationTests.test_save_primary_with_db_default",
# Date lookups aren't implemented: https://github.com/mongodb-labs/django-mongodb/issues/9
# (e.g. 'ExtractMonth' object has no attribute 'alias')
# (e.g. ExtractWeekDay is not supported.)
"basic.tests.ModelLookupTest.test_does_not_exist",
"basic.tests.ModelLookupTest.test_equal_lookup",
"basic.tests.ModelLookupTest.test_rich_lookup",
"basic.tests.ModelLookupTest.test_too_many",
"basic.tests.ModelTest.test_year_lookup_edge_case",
"lookup.tests.LookupTests.test_chain_date_time_lookups",
"lookup.test_timefield.TimeFieldLookupTests.test_hour_lookups",
Expand All @@ -30,29 +29,25 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"timezones.tests.NewDatabaseTests.test_query_convert_timezones",
"timezones.tests.NewDatabaseTests.test_query_datetime_lookups",
"timezones.tests.NewDatabaseTests.test_query_datetime_lookups_in_other_timezone",
# 'NulledTransform' object has no attribute 'alias'
# 'NulledTransform' object has no attribute 'as_mql'.
"lookup.tests.LookupTests.test_exact_none_transform",
# "Save with update_fields did not affect any rows."
"basic.tests.SelectOnSaveTests.test_select_on_save_lying_update",
# filtering on large decimalfield, see https://code.djangoproject.com/ticket/34590
# for some background.
"model_fields.test_decimalfield.DecimalFieldTests.test_lookup_decimal_larger_than_max_digits",
"model_fields.test_decimalfield.DecimalFieldTests.test_lookup_really_big_value",
# 'TruncDate' object has no attribute 'alias'
# 'TruncDate' object has no attribute 'as_mql'.
"model_fields.test_datetimefield.DateTimeFieldTests.test_lookup_date_with_use_tz",
"model_fields.test_datetimefield.DateTimeFieldTests.test_lookup_date_without_use_tz",
# Incorrect empty QuerySet handling: https://github.com/mongodb-labs/django-mongodb/issues/22
"lookup.tests.LookupTests.test_in",
"or_lookups.tests.OrLookupsTests.test_empty_in",
# Slicing with QuerySet.count() doesn't work.
"lookup.tests.LookupTests.test_count",
# Custom lookups not supported.
"lookup.tests.LookupTests.test_custom_lookup_none_rhs",
# Lookup in order_by() not supported: argument of type 'LessThan' is not iterable
# Lookup in order_by() not supported:
# unsupported operand type(s) for %: 'function' and 'str'
"lookup.tests.LookupQueryingTests.test_lookup_in_order_by",
# annotate() after values() doesn't raise NotSupportedError.
"lookup.tests.LookupTests.test_exact_query_rhs_with_selected_columns",
# tuple index out of range in _normalize_lookup_value()
# tuple index out of range in process_rhs()
"lookup.tests.LookupTests.test_exact_sliced_queryset_limit_one",
"lookup.tests.LookupTests.test_exact_sliced_queryset_limit_one_offset",
# Regex lookup doesn't work on non-string fields.
Expand Down
24 changes: 24 additions & 0 deletions django_mongodb/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from django.db import NotSupportedError
from django.db.models.expressions import Col
from django.db.models.functions.datetime import Extract

from .query_utils import process_lhs


def extract(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
if self.lookup_name == "week":
operator = "$week"
elif self.lookup_name == "month":
operator = "$month"
elif self.lookup_name == "year":
operator = "$year"
else:
raise NotSupportedError("%s is not supported." % self.__class__.__name__)
if isinstance(self.lhs, Col):
lhs_mql = f"${lhs_mql}"
return {operator: lhs_mql}


def register_functions():
Extract.as_mql = extract
47 changes: 47 additions & 0 deletions django_mongodb/lookups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from django.db import NotSupportedError
from django.db.models.expressions import Col
from django.db.models.fields.related_lookups import In, MultiColSource, RelatedIn
from django.db.models.lookups import BuiltinLookup, Exact, IsNull, UUIDTextMixin

from .query_utils import process_lhs, process_rhs


def builtin_lookup(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
value = process_rhs(self, compiler, connection)
rhs_mql = connection.operators[self.lookup_name](value)
return {lhs_mql: rhs_mql}


def exact(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
value = process_rhs(self, compiler, connection)
if isinstance(self.lhs, Col):
lhs_mql = f"${lhs_mql}"
return {"$expr": {"$eq": [lhs_mql, value]}}


def in_(self, compiler, connection):
if isinstance(self.lhs, MultiColSource):
raise NotImplementedError("MultiColSource is not supported.")
return builtin_lookup(self, compiler, connection)


def is_null(self, compiler, connection):
if not isinstance(self.rhs, bool):
raise ValueError("The QuerySet value for an isnull lookup must be True or False.")
lhs_mql = process_lhs(self, compiler, connection)
rhs_mql = connection.operators["isnull"](self.rhs)
return {lhs_mql: rhs_mql}


def uuid_text_mixin(self, compiler, connection): # noqa: ARG001
raise NotSupportedError("Pattern lookups on UUIDField are not supported.")


def register_lookups():
BuiltinLookup.as_mql = builtin_lookup
Exact.as_mql = exact
In.as_mql = RelatedIn.as_mql = in_
IsNull.as_mql = is_null
UUIDTextMixin.as_mql = uuid_text_mixin
Loading

0 comments on commit 3789b9f

Please sign in to comment.