From 7dd117f147ec39c8b882c3f0a45117775f34cacb Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Sat, 18 Jan 2025 21:06:52 -0500 Subject: [PATCH 1/2] fixed id of EmbeddedModelField check error --- django_mongodb_backend/fields/embedded_model.py | 2 +- tests/model_fields_/test_embedded_model.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index a5899326..fa65a8bb 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -27,7 +27,7 @@ def check(self, **kwargs): f"({self.embedded_model().__class__.__name__}.{field.name} " f"is a {field.__class__.__name__}).", obj=self, - id="django_mongodb.embedded_model.E001", + id="django_mongodb_backend.embedded_model.E001", ) ) return errors diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index 8b0b53f6..48e69e8e 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -114,11 +114,9 @@ class Target(models.Model): class MyModel(models.Model): field = EmbeddedModelField(Target) - model = MyModel() - errors = model.check() + errors = MyModel().check() self.assertEqual(len(errors), 1) - # The inner CharField has a non-positive max_length. - self.assertEqual(errors[0].id, "django_mongodb.embedded_model.E001") + self.assertEqual(errors[0].id, "django_mongodb_backend.embedded_model.E001") msg = errors[0].msg self.assertEqual( msg, "Embedded models cannot have relational fields (Target.key is a ForeignKey)." From 87be36a2b0d12a2cd692a42cb8f3f5eb1669e44f Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Fri, 17 Jan 2025 20:03:57 -0500 Subject: [PATCH 2/2] prevent the creation of embedded models --- .../fields/embedded_model.py | 11 ++++ django_mongodb_backend/managers.py | 41 +++++++++++++ django_mongodb_backend/models.py | 16 +++++ django_mongodb_backend/schema.py | 27 +++++++++ docs/source/embedded-models.rst | 3 +- docs/source/fields.rst | 7 ++- docs/source/models.rst | 0 tests/model_fields_/models.py | 7 ++- tests/model_fields_/test_embedded_model.py | 18 +++++- tests/model_forms_/models.py | 8 +-- tests/models_/__init__.py | 0 tests/models_/models.py | 5 ++ tests/models_/test_embedded_model.py | 59 +++++++++++++++++++ tests/schema_/models.py | 5 +- tests/schema_/test_embedded_model.py | 50 ++++++++++++++-- 15 files changed, 236 insertions(+), 21 deletions(-) create mode 100644 django_mongodb_backend/models.py create mode 100644 docs/source/models.rst create mode 100644 tests/models_/__init__.py create mode 100644 tests/models_/models.py create mode 100644 tests/models_/test_embedded_model.py diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index fa65a8bb..0153b372 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -18,7 +18,18 @@ def __init__(self, embedded_model, *args, **kwargs): super().__init__(*args, **kwargs) def check(self, **kwargs): + from ..models import EmbeddedModel + errors = super().check(**kwargs) + if not issubclass(self.embedded_model, EmbeddedModel): + return [ + checks.Error( + "Embedded model must be a subclass of " + "django_mongodb_backend.models.EmbeddedModel.", + obj=self, + id="django_mongodb_backend.embedded_model.E002", + ) + ] for field in self.embedded_model._meta.fields: if field.remote_field: errors.append( diff --git a/django_mongodb_backend/managers.py b/django_mongodb_backend/managers.py index 055c8440..5a405688 100644 --- a/django_mongodb_backend/managers.py +++ b/django_mongodb_backend/managers.py @@ -1,3 +1,4 @@ +from django.db import NotSupportedError from django.db.models.manager import BaseManager from .queryset import MongoQuerySet @@ -5,3 +6,43 @@ class MongoManager(BaseManager.from_queryset(MongoQuerySet)): pass + + +class EmbeddedModelManager(BaseManager): + """ + Prevent all queryset operations on embedded models since they don't have + their own collection. + """ + + def get_queryset(self): + raise NotSupportedError("EmbeddedModels cannot be queried.") + + def all(self): + raise NotSupportedError("EmbeddedModels cannot be queried.") + + def get(self, *args, **kwargs): + raise NotSupportedError("EmbeddedModels cannot be queried.") + + def get_or_create(self, **kwargs): + raise NotSupportedError("EmbeddedModels cannot be queried.") + + def filter(self, *args, **kwargs): + raise NotSupportedError("EmbeddedModels cannot be queried.") + + def create(self, **kwargs): + raise NotSupportedError("EmbeddedModels cannot be created.") + + def bulk_create(self, *args, **kwargs): + raise NotSupportedError("EmbeddedModels cannot be created.") + + def update(self, *args, **kwargs): + raise NotSupportedError("EmbeddedModels cannot be updated.") + + def bulk_update(self, *args, **kwargs): + raise NotSupportedError("EmbeddedModels cannot be updated.") + + def update_or_create(self, **kwargs): + raise NotSupportedError("EmbeddedModels cannot be updated or created.") + + def delete(self): + raise NotSupportedError("EmbeddedModels cannot be deleted.") diff --git a/django_mongodb_backend/models.py b/django_mongodb_backend/models.py new file mode 100644 index 00000000..adeba21e --- /dev/null +++ b/django_mongodb_backend/models.py @@ -0,0 +1,16 @@ +from django.db import NotSupportedError, models + +from .managers import EmbeddedModelManager + + +class EmbeddedModel(models.Model): + objects = EmbeddedModelManager() + + class Meta: + abstract = True + + def delete(self, *args, **kwargs): + raise NotSupportedError("EmbeddedModels cannot be deleted.") + + def save(self, *args, **kwargs): + raise NotSupportedError("EmbeddedModels cannot be saved.") diff --git a/django_mongodb_backend/schema.py b/django_mongodb_backend/schema.py index 9769df8b..db246304 100644 --- a/django_mongodb_backend/schema.py +++ b/django_mongodb_backend/schema.py @@ -10,6 +10,21 @@ from .utils import OperationCollector +def ignore_embedded_models(func): + """Make a SchemaEditor a no-op if model is an EmbeddedModel.""" + + def wrapper(self, model, *args, **kwargs): + # If parent_model isn't None, this is a valid recursive operation. + parent_model = kwargs.get("parent_model") + from .models import EmbeddedModel + + if parent_model is None and issubclass(model, EmbeddedModel): + return + func(self, model, *args, **kwargs) + + return wrapper + + class DatabaseSchemaEditor(BaseDatabaseSchemaEditor): def get_collection(self, name): if self.collect_sql: @@ -22,6 +37,7 @@ def get_database(self): return self.connection.get_database() @wrap_database_errors + @ignore_embedded_models def create_model(self, model): self.get_database().create_collection(model._meta.db_table) self._create_model_indexes(model) @@ -75,6 +91,7 @@ def _create_model_indexes(self, model, column_prefix="", parent_model=None): for index in model._meta.indexes: self.add_index(model, index, column_prefix=column_prefix, parent_model=parent_model) + @ignore_embedded_models def delete_model(self, model): # Delete implicit M2m tables. for field in model._meta.local_many_to_many: @@ -82,6 +99,7 @@ def delete_model(self, model): self.delete_model(field.remote_field.through) self.get_collection(model._meta.db_table).drop() + @ignore_embedded_models def add_field(self, model, field): # Create implicit M2M tables. if field.many_to_many and field.remote_field.through._meta.auto_created: @@ -103,6 +121,7 @@ def add_field(self, model, field): elif self._field_should_have_unique(field): self._add_field_unique(model, field) + @ignore_embedded_models def _alter_field( self, model, @@ -149,6 +168,7 @@ def _alter_field( if not old_field_unique and new_field_unique: self._add_field_unique(model, new_field) + @ignore_embedded_models def remove_field(self, model, field): # Remove implicit M2M tables. if field.many_to_many and field.remote_field.through._meta.auto_created: @@ -210,6 +230,7 @@ def _remove_model_indexes(self, model, column_prefix="", parent_model=None): for index in model._meta.indexes: self.remove_index(parent_model or model, index) + @ignore_embedded_models def alter_index_together(self, model, old_index_together, new_index_together, column_prefix=""): olds = {tuple(fields) for fields in old_index_together} news = {tuple(fields) for fields in new_index_together} @@ -222,6 +243,7 @@ def alter_index_together(self, model, old_index_together, new_index_together, co for field_names in news.difference(olds): self._add_composed_index(model, field_names, column_prefix=column_prefix) + @ignore_embedded_models def alter_unique_together( self, model, old_unique_together, new_unique_together, column_prefix="", parent_model=None ): @@ -249,6 +271,7 @@ def alter_unique_together( model, constraint, parent_model=parent_model, column_prefix=column_prefix ) + @ignore_embedded_models def add_index( self, model, index, *, field=None, unique=False, column_prefix="", parent_model=None ): @@ -302,6 +325,7 @@ def _add_field_index(self, model, field, *, column_prefix=""): index.name = self._create_index_name(model._meta.db_table, [column_prefix + field.column]) self.add_index(model, index, field=field, column_prefix=column_prefix) + @ignore_embedded_models def remove_index(self, model, index): if index.contains_expressions: return @@ -355,6 +379,7 @@ def _remove_field_index(self, model, field, column_prefix=""): ) collection.drop_index(index_names[0]) + @ignore_embedded_models def add_constraint(self, model, constraint, field=None, column_prefix="", parent_model=None): if isinstance(constraint, UniqueConstraint) and self._unique_supported( condition=constraint.condition, @@ -384,6 +409,7 @@ def _add_field_unique(self, model, field, column_prefix=""): constraint = UniqueConstraint(fields=[field.name], name=name) self.add_constraint(model, constraint, field=field, column_prefix=column_prefix) + @ignore_embedded_models def remove_constraint(self, model, constraint): if isinstance(constraint, UniqueConstraint) and self._unique_supported( condition=constraint.condition, @@ -417,6 +443,7 @@ def _remove_field_unique(self, model, field, column_prefix=""): ) self.get_collection(model._meta.db_table).drop_index(constraint_names[0]) + @ignore_embedded_models def alter_db_table(self, model, old_db_table, new_db_table): if old_db_table == new_db_table: return diff --git a/docs/source/embedded-models.rst b/docs/source/embedded-models.rst index 08e6891b..21b7d139 100644 --- a/docs/source/embedded-models.rst +++ b/docs/source/embedded-models.rst @@ -11,13 +11,14 @@ The basics Let's consider this example:: from django_mongodb_backend.fields import EmbeddedModelField + from django_mongodb_backend.models import EmbeddedModel class Customer(models.Model): name = models.CharField(...) address = EmbeddedModelField("Address") ... - class Address(models.Model): + class Address(EmbeddedModel): ... city = models.CharField(...) diff --git a/docs/source/fields.rst b/docs/source/fields.rst index 270ea7d7..9fc2371b 100644 --- a/docs/source/fields.rst +++ b/docs/source/fields.rst @@ -224,6 +224,8 @@ Stores a model of type ``embedded_model``. Specifies the model class to embed. It can be either a concrete model class or a :ref:`lazy reference ` to a model class. + The target model must be a subclass of + ``django_mongodb_backend.models.EmbeddedModel``. The embedded model cannot have relational fields (:class:`~django.db.models.ForeignKey`, @@ -234,11 +236,12 @@ Stores a model of type ``embedded_model``. from django.db import models from django_mongodb_backend.fields import EmbeddedModelField + from django_mongodb_backend.models import EmbeddedModel - class Address(models.Model): + class Address(EmbeddedModel): ... - class Author(models.Model): + class Author(EmbeddedModel): address = EmbeddedModelField(Address) class Book(models.Model): diff --git a/docs/source/models.rst b/docs/source/models.rst new file mode 100644 index 00000000..e69de29b diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py index 9b00665b..c0cd6402 100644 --- a/tests/model_fields_/models.py +++ b/tests/model_fields_/models.py @@ -3,6 +3,7 @@ from django.db import models from django_mongodb_backend.fields import ArrayField, EmbeddedModelField, ObjectIdField +from django_mongodb_backend.models import EmbeddedModel # ObjectIdField @@ -98,19 +99,19 @@ class Holder(models.Model): data = EmbeddedModelField("Data", null=True, blank=True) -class Data(models.Model): +class Data(EmbeddedModel): integer = models.IntegerField(db_column="custom_column") auto_now = models.DateTimeField(auto_now=True) auto_now_add = models.DateTimeField(auto_now_add=True) -class Address(models.Model): +class Address(EmbeddedModel): city = models.CharField(max_length=20) state = models.CharField(max_length=2) zip_code = models.IntegerField(db_index=True) -class Author(models.Model): +class Author(EmbeddedModel): name = models.CharField(max_length=10) age = models.IntegerField() address = EmbeddedModelField(Address) diff --git a/tests/model_fields_/test_embedded_model.py b/tests/model_fields_/test_embedded_model.py index 48e69e8e..f9bd6b28 100644 --- a/tests/model_fields_/test_embedded_model.py +++ b/tests/model_fields_/test_embedded_model.py @@ -4,6 +4,7 @@ from django.test.utils import isolate_apps from django_mongodb_backend.fields import EmbeddedModelField +from django_mongodb_backend.models import EmbeddedModel from .models import ( Address, @@ -108,7 +109,7 @@ def test_nested(self): @isolate_apps("model_fields_") class CheckTests(SimpleTestCase): def test_no_relational_fields(self): - class Target(models.Model): + class Target(EmbeddedModel): key = models.ForeignKey("MyModel", models.CASCADE) class MyModel(models.Model): @@ -121,3 +122,18 @@ class MyModel(models.Model): self.assertEqual( msg, "Embedded models cannot have relational fields (Target.key is a ForeignKey)." ) + + def test_embedded_model_subclass(self): + class Target(models.Model): + pass + + class MyModel(models.Model): + field = EmbeddedModelField(Target) + + errors = MyModel().check() + self.assertEqual(len(errors), 1) + self.assertEqual(errors[0].id, "django_mongodb_backend.embedded_model.E002") + msg = errors[0].msg + self.assertEqual( + msg, "Embedded model must be a subclass of django_mongodb_backend.models.EmbeddedModel." + ) diff --git a/tests/model_forms_/models.py b/tests/model_forms_/models.py index d61196ab..df3bd580 100644 --- a/tests/model_forms_/models.py +++ b/tests/model_forms_/models.py @@ -1,9 +1,10 @@ from django.db import models from django_mongodb_backend.fields import EmbeddedModelField +from django_mongodb_backend.models import EmbeddedModel -class Address(models.Model): +class Address(EmbeddedModel): po_box = models.CharField(max_length=50, blank=True, verbose_name="PO Box") city = models.CharField(max_length=20) state = models.CharField(max_length=2) @@ -15,8 +16,3 @@ class Author(models.Model): age = models.IntegerField() address = EmbeddedModelField(Address) billing_address = EmbeddedModelField(Address, blank=True, null=True) - - -class Book(models.Model): - name = models.CharField(max_length=100) - author = EmbeddedModelField(Author) diff --git a/tests/models_/__init__.py b/tests/models_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models_/models.py b/tests/models_/models.py new file mode 100644 index 00000000..e02edda0 --- /dev/null +++ b/tests/models_/models.py @@ -0,0 +1,5 @@ +from django_mongodb_backend.models import EmbeddedModel + + +class Embed(EmbeddedModel): + pass diff --git a/tests/models_/test_embedded_model.py b/tests/models_/test_embedded_model.py new file mode 100644 index 00000000..a9f04f14 --- /dev/null +++ b/tests/models_/test_embedded_model.py @@ -0,0 +1,59 @@ +from django.db import NotSupportedError +from django.test import SimpleTestCase + +from .models import Embed + + +class TestMethods(SimpleTestCase): + def test_save(self): + e = Embed() + with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be saved."): + e.save() + + def test_delete(self): + e = Embed() + with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be deleted."): + e.delete() + + +class TestManagerMethods(SimpleTestCase): + def test_all(self): + with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be queried."): + Embed.objects.all() + + def test_get(self): + with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be queried."): + Embed.objects.get() + + def test_get_or_create(self): + with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be queried."): + Embed.objects.get_or_create() + + def test_filter(self): + with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be queried."): + Embed.objects.filter(foo="bar") + + def test_create(self): + with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be created."): + Embed.objects.create() + + def test_bulk_create(self): + with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be created."): + Embed.objects.bulk_create() + + def test_update(self): + with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be updated."): + Embed.objects.update(foo="bar") + + def test_bulk_update(self): + with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be updated."): + Embed.objects.bulk_update() + + def test_update_or_create(self): + msg = "EmbeddedModels cannot be updated or created." + with self.assertRaisesMessage(NotSupportedError, msg): + Embed.objects.update_or_create() + + def test_delete(self): + with self.assertRaisesMessage(NotSupportedError, "EmbeddedModels cannot be deleted."): + Embed.objects.delete() diff --git a/tests/schema_/models.py b/tests/schema_/models.py index 7c0f4533..f8f0aaa0 100644 --- a/tests/schema_/models.py +++ b/tests/schema_/models.py @@ -2,6 +2,7 @@ from django.db import models from django_mongodb_backend.fields import EmbeddedModelField +from django_mongodb_backend.models import EmbeddedModel # These models are inserted into a separate Apps so the test runner doesn't # migrate them. @@ -9,7 +10,7 @@ new_apps = Apps() -class Address(models.Model): +class Address(EmbeddedModel): city = models.CharField(max_length=20) state = models.CharField(max_length=2) zip_code = models.IntegerField(db_index=True) @@ -19,7 +20,7 @@ class Meta: apps = new_apps -class Author(models.Model): +class Author(EmbeddedModel): name = models.CharField(max_length=10) age = models.IntegerField(db_index=True) address = EmbeddedModelField(Address) diff --git a/tests/schema_/test_embedded_model.py b/tests/schema_/test_embedded_model.py index faf58ce2..7bb4e211 100644 --- a/tests/schema_/test_embedded_model.py +++ b/tests/schema_/test_embedded_model.py @@ -6,11 +6,12 @@ from django.utils.deprecation import RemovedInDjango51Warning from django_mongodb_backend.fields import EmbeddedModelField +from django_mongodb_backend.models import EmbeddedModel from .models import Address, Author, Book, new_apps -class SchemaTests(TransactionTestCase): +class TestMixin: available_apps = [] models = [Address, Author, Book] @@ -88,6 +89,8 @@ def assertTableExists(self, model): def assertTableNotExists(self, model): self.assertNotIn(model._meta.db_table, connection.introspection.table_names()) + +class SchemaTests(TestMixin, TransactionTestCase): # SchemaEditor.create_model() tests def test_db_index(self): """Field(db_index=True) on an embedded model.""" @@ -133,7 +136,7 @@ def test_unique(self): def test_index_together(self): """Meta.index_together on an embedded model.""" - class Address(models.Model): + class Address(EmbeddedModel): index_together_one = models.CharField(max_length=10) index_together_two = models.CharField(max_length=10) @@ -180,7 +183,7 @@ class Meta: def test_unique_together(self): """Meta.unique_together on an embedded model.""" - class Address(models.Model): + class Address(EmbeddedModel): unique_together_one = models.CharField(max_length=10) unique_together_two = models.CharField(max_length=10) @@ -188,7 +191,7 @@ class Meta: app_label = "schema_" unique_together = [("unique_together_one", "unique_together_two")] - class Author(models.Model): + class Author(EmbeddedModel): address = EmbeddedModelField(Address) unique_together_three = models.CharField(max_length=10) unique_together_four = models.CharField(max_length=10) @@ -231,14 +234,14 @@ class Meta: def test_indexes(self): """Meta.indexes on an embedded model.""" - class Address(models.Model): + class Address(EmbeddedModel): indexed_one = models.CharField(max_length=10) class Meta: app_label = "schema_" indexes = [models.Index(fields=["indexed_one"])] - class Author(models.Model): + class Author(EmbeddedModel): address = EmbeddedModelField(Address) indexed_two = models.CharField(max_length=10) @@ -627,3 +630,38 @@ class Meta: ) editor.delete_model(Author) self.assertTableNotExists(Author) + + +class EmbeddedModelsIgnoredTests(TestMixin, TransactionTestCase): + def test_embedded_not_created(self): + """create_model() and delete_model() ignore EmbeddedModel.""" + with connection.schema_editor() as editor: + editor.create_model(Book) + editor.create_model(Address) + editor.create_model(Author) + self.assertTableExists(Book) + self.assertTableNotExists(Address) + self.assertTableNotExists(Author) + editor.delete_model(Book) + with self.assertNumQueries(0): + editor.delete_model(Address) + editor.delete_model(Author) + self.assertTableNotExists(Book) + + def test_embedded_add_field_ignored(self): + """add_field() and remove_field() ignore EmbeddedModel.""" + new_field = models.CharField(max_length=1, default="a") + new_field.set_attributes_from_name("char") + with connection.schema_editor() as editor, self.assertNumQueries(0): + editor.add_field(Author, new_field) + with connection.schema_editor() as editor, self.assertNumQueries(0): + editor.remove_field(Author, new_field) + + def test_embedded_alter_field_ignored(self): + """alter_field() ignores EmbeddedModel.""" + old_field = models.CharField(max_length=1) + old_field.set_attributes_from_name("old") + new_field = models.CharField(max_length=1) + new_field.set_attributes_from_name("new") + with connection.schema_editor() as editor, self.assertNumQueries(0): + editor.alter_field(Author, old_field, new_field)