Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prevent the creation of embedded models #227

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion django_mongodb_backend/fields/embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,18 @@ def __init__(self, embedded_model, *args, **kwargs):
super().__init__(*args, **kwargs)

def check(self, **kwargs):
from ..models import EmbeddedModel

errors = super().check(**kwargs)
if not issubclass(self.embedded_model, EmbeddedModel):
return [
checks.Error(
"Embedded model must be a subclass of "
"django_mongodb_backend.models.EmbeddedModel.",
obj=self,
id="django_mongodb_backend.embedded_model.E002",
)
]
for field in self.embedded_model._meta.fields:
if field.remote_field:
errors.append(
Expand All @@ -27,7 +38,7 @@ def check(self, **kwargs):
f"({self.embedded_model().__class__.__name__}.{field.name} "
f"is a {field.__class__.__name__}).",
obj=self,
id="django_mongodb.embedded_model.E001",
id="django_mongodb_backend.embedded_model.E001",
)
)
return errors
Expand Down
41 changes: 41 additions & 0 deletions django_mongodb_backend/managers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,48 @@
from django.db import NotSupportedError
from django.db.models.manager import BaseManager

from .queryset import MongoQuerySet


class MongoManager(BaseManager.from_queryset(MongoQuerySet)):
pass


class EmbeddedModelManager(BaseManager):
"""
Prevent all queryset operations on embedded models since they don't have
their own collection.
"""

def get_queryset(self):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def all(self):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def get(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def get_or_create(self, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def filter(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be queried.")

def create(self, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be created.")

def bulk_create(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be created.")

def update(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be updated.")

def bulk_update(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be updated.")

def update_or_create(self, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be updated or created.")

def delete(self):
raise NotSupportedError("EmbeddedModels cannot be deleted.")
16 changes: 16 additions & 0 deletions django_mongodb_backend/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from django.db import NotSupportedError, models

from .managers import EmbeddedModelManager


class EmbeddedModel(models.Model):
objects = EmbeddedModelManager()

class Meta:
abstract = True

def delete(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be deleted.")

def save(self, *args, **kwargs):
raise NotSupportedError("EmbeddedModels cannot be saved.")
27 changes: 27 additions & 0 deletions django_mongodb_backend/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@
from .utils import OperationCollector


def ignore_embedded_models(func):
"""Make a SchemaEditor a no-op if model is an EmbeddedModel."""

def wrapper(self, model, *args, **kwargs):
# If parent_model isn't None, this is a valid recursive operation.
parent_model = kwargs.get("parent_model")
from .models import EmbeddedModel

if parent_model is None and issubclass(model, EmbeddedModel):
return
func(self, model, *args, **kwargs)

return wrapper


class DatabaseSchemaEditor(BaseDatabaseSchemaEditor):
def get_collection(self, name):
if self.collect_sql:
Expand All @@ -22,6 +37,7 @@ def get_database(self):
return self.connection.get_database()

@wrap_database_errors
@ignore_embedded_models
def create_model(self, model):
self.get_database().create_collection(model._meta.db_table)
self._create_model_indexes(model)
Expand Down Expand Up @@ -75,13 +91,15 @@ def _create_model_indexes(self, model, column_prefix="", parent_model=None):
for index in model._meta.indexes:
self.add_index(model, index, column_prefix=column_prefix, parent_model=parent_model)

@ignore_embedded_models
def delete_model(self, model):
# Delete implicit M2m tables.
for field in model._meta.local_many_to_many:
if field.remote_field.through._meta.auto_created:
self.delete_model(field.remote_field.through)
self.get_collection(model._meta.db_table).drop()

@ignore_embedded_models
def add_field(self, model, field):
# Create implicit M2M tables.
if field.many_to_many and field.remote_field.through._meta.auto_created:
Expand All @@ -103,6 +121,7 @@ def add_field(self, model, field):
elif self._field_should_have_unique(field):
self._add_field_unique(model, field)

@ignore_embedded_models
def _alter_field(
self,
model,
Expand Down Expand Up @@ -149,6 +168,7 @@ def _alter_field(
if not old_field_unique and new_field_unique:
self._add_field_unique(model, new_field)

@ignore_embedded_models
def remove_field(self, model, field):
# Remove implicit M2M tables.
if field.many_to_many and field.remote_field.through._meta.auto_created:
Expand Down Expand Up @@ -210,6 +230,7 @@ def _remove_model_indexes(self, model, column_prefix="", parent_model=None):
for index in model._meta.indexes:
self.remove_index(parent_model or model, index)

@ignore_embedded_models
def alter_index_together(self, model, old_index_together, new_index_together, column_prefix=""):
olds = {tuple(fields) for fields in old_index_together}
news = {tuple(fields) for fields in new_index_together}
Expand All @@ -222,6 +243,7 @@ def alter_index_together(self, model, old_index_together, new_index_together, co
for field_names in news.difference(olds):
self._add_composed_index(model, field_names, column_prefix=column_prefix)

@ignore_embedded_models
def alter_unique_together(
self, model, old_unique_together, new_unique_together, column_prefix="", parent_model=None
):
Expand Down Expand Up @@ -249,6 +271,7 @@ def alter_unique_together(
model, constraint, parent_model=parent_model, column_prefix=column_prefix
)

@ignore_embedded_models
def add_index(
self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None
):
Expand Down Expand Up @@ -302,6 +325,7 @@ def _add_field_index(self, model, field, *, column_prefix=""):
index.name = self._create_index_name(model._meta.db_table, [column_prefix + field.column])
self.add_index(model, index, field=field, column_prefix=column_prefix)

@ignore_embedded_models
def remove_index(self, model, index):
if index.contains_expressions:
return
Expand Down Expand Up @@ -355,6 +379,7 @@ def _remove_field_index(self, model, field, column_prefix=""):
)
collection.drop_index(index_names[0])

@ignore_embedded_models
def add_constraint(self, model, constraint, field=None, column_prefix="", parent_model=None):
if isinstance(constraint, UniqueConstraint) and self._unique_supported(
condition=constraint.condition,
Expand Down Expand Up @@ -384,6 +409,7 @@ def _add_field_unique(self, model, field, column_prefix=""):
constraint = UniqueConstraint(fields=[field.name], name=name)
self.add_constraint(model, constraint, field=field, column_prefix=column_prefix)

@ignore_embedded_models
def remove_constraint(self, model, constraint):
if isinstance(constraint, UniqueConstraint) and self._unique_supported(
condition=constraint.condition,
Expand Down Expand Up @@ -417,6 +443,7 @@ def _remove_field_unique(self, model, field, column_prefix=""):
)
self.get_collection(model._meta.db_table).drop_index(constraint_names[0])

@ignore_embedded_models
def alter_db_table(self, model, old_db_table, new_db_table):
if old_db_table == new_db_table:
return
Expand Down
3 changes: 2 additions & 1 deletion docs/source/embedded-models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ The basics
Let's consider this example::

from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel

class Customer(models.Model):
name = models.CharField(...)
address = EmbeddedModelField("Address")
...

class Address(models.Model):
class Address(EmbeddedModel):
...
city = models.CharField(...)

Expand Down
7 changes: 5 additions & 2 deletions docs/source/fields.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ Stores a model of type ``embedded_model``.

Specifies the model class to embed. It can be either a concrete model
class or a :ref:`lazy reference <lazy-relationships>` to a model class.
The target model must be a subclass of
``django_mongodb_backend.models.EmbeddedModel``.

The embedded model cannot have relational fields
(:class:`~django.db.models.ForeignKey`,
Expand All @@ -234,11 +236,12 @@ Stores a model of type ``embedded_model``.

from django.db import models
from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel

class Address(models.Model):
class Address(EmbeddedModel):
...

class Author(models.Model):
class Author(EmbeddedModel):
address = EmbeddedModelField(Address)

class Book(models.Model):
Expand Down
Empty file added docs/source/models.rst
Empty file.
7 changes: 4 additions & 3 deletions tests/model_fields_/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.db import models

from django_mongodb_backend.fields import ArrayField, EmbeddedModelField, ObjectIdField
from django_mongodb_backend.models import EmbeddedModel


# ObjectIdField
Expand Down Expand Up @@ -98,19 +99,19 @@ class Holder(models.Model):
data = EmbeddedModelField("Data", null=True, blank=True)


class Data(models.Model):
class Data(EmbeddedModel):
integer = models.IntegerField(db_column="custom_column")
auto_now = models.DateTimeField(auto_now=True)
auto_now_add = models.DateTimeField(auto_now_add=True)


class Address(models.Model):
class Address(EmbeddedModel):
city = models.CharField(max_length=20)
state = models.CharField(max_length=2)
zip_code = models.IntegerField(db_index=True)


class Author(models.Model):
class Author(EmbeddedModel):
name = models.CharField(max_length=10)
age = models.IntegerField()
address = EmbeddedModelField(Address)
Expand Down
24 changes: 19 additions & 5 deletions tests/model_fields_/test_embedded_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from django.test.utils import isolate_apps

from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel

from .models import (
Address,
Expand Down Expand Up @@ -108,18 +109,31 @@ def test_nested(self):
@isolate_apps("model_fields_")
class CheckTests(SimpleTestCase):
def test_no_relational_fields(self):
class Target(models.Model):
class Target(EmbeddedModel):
key = models.ForeignKey("MyModel", models.CASCADE)

class MyModel(models.Model):
field = EmbeddedModelField(Target)

model = MyModel()
errors = model.check()
errors = MyModel().check()
self.assertEqual(len(errors), 1)
# The inner CharField has a non-positive max_length.
self.assertEqual(errors[0].id, "django_mongodb.embedded_model.E001")
self.assertEqual(errors[0].id, "django_mongodb_backend.embedded_model.E001")
msg = errors[0].msg
self.assertEqual(
msg, "Embedded models cannot have relational fields (Target.key is a ForeignKey)."
)

def test_embedded_model_subclass(self):
class Target(models.Model):
pass

class MyModel(models.Model):
field = EmbeddedModelField(Target)

errors = MyModel().check()
self.assertEqual(len(errors), 1)
self.assertEqual(errors[0].id, "django_mongodb_backend.embedded_model.E002")
msg = errors[0].msg
self.assertEqual(
msg, "Embedded model must be a subclass of django_mongodb_backend.models.EmbeddedModel."
)
8 changes: 2 additions & 6 deletions tests/model_forms_/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from django.db import models

from django_mongodb_backend.fields import EmbeddedModelField
from django_mongodb_backend.models import EmbeddedModel


class Address(models.Model):
class Address(EmbeddedModel):
po_box = models.CharField(max_length=50, blank=True, verbose_name="PO Box")
city = models.CharField(max_length=20)
state = models.CharField(max_length=2)
Expand All @@ -15,8 +16,3 @@ class Author(models.Model):
age = models.IntegerField()
address = EmbeddedModelField(Address)
billing_address = EmbeddedModelField(Address, blank=True, null=True)


class Book(models.Model):
name = models.CharField(max_length=100)
author = EmbeddedModelField(Author)
Empty file added tests/models_/__init__.py
Empty file.
5 changes: 5 additions & 0 deletions tests/models_/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from django_mongodb_backend.models import EmbeddedModel


class Embed(EmbeddedModel):
pass
Loading
Loading