From 87773c81fbdea0da93e74565243711a2c94db7e2 Mon Sep 17 00:00:00 2001 From: alex Date: Mon, 13 Jan 2025 13:28:23 +0100 Subject: [PATCH] Changes: - fix foreign keys with tenancy --- docs/release-notes.md | 3 +- edgy/core/connection/schemas.py | 23 ++++---- edgy/core/db/context_vars.py | 3 ++ edgy/core/db/fields/many_to_many.py | 16 +++++- edgy/core/db/fields/types.py | 5 +- edgy/core/db/models/mixins/db.py | 53 +++++++++++++++++-- edgy/core/db/models/mixins/reflection.py | 4 +- edgy/core/db/models/types.py | 4 +- edgy/core/db/relationships/relation.py | 10 ++-- .../multi_tenancy/test_tenant_models_using.py | 27 +++++++++- 10 files changed, 125 insertions(+), 23 deletions(-) diff --git a/docs/release-notes.md b/docs/release-notes.md index 2e476b84..1170ac0e 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -15,7 +15,7 @@ hide: For models this can be passed : `class Foo(edgy.Model, on_conflict="keep"): ...`. - Passing a tuple or list of types to `replace_related_field` is now allowed. - Add `through_registry` to ManyToMany. -- Add `no_copy` to MetaInfo. +- Add `no_copy` to models MetaInfo. - Add `ModelCollisionError` exception. ### Changed @@ -32,6 +32,7 @@ hide: - Fix deleting (clearing cache) of BaseForeignKey target. - Creating two models with the same name did lead to silent replacements. - Invalidating causes schema errors. +- ManyToMany and ForeignKey fields in connection with tenancy. ## 0.24.2 diff --git a/edgy/core/connection/schemas.py b/edgy/core/connection/schemas.py index d333dc35..5975680a 100644 --- a/edgy/core/connection/schemas.py +++ b/edgy/core/connection/schemas.py @@ -7,6 +7,7 @@ from sqlalchemy.exc import DBAPIError, ProgrammingError from edgy.core.connection.database import Database +from edgy.core.db.context_vars import NO_GLOBAL_FIELD_CONSTRAINTS from edgy.exceptions import SchemaError if TYPE_CHECKING: @@ -71,14 +72,16 @@ async def create_schema( if init_models: for model_class in self.registry.models.values(): model_class.table_schema(schema=schema, update_cache=update_cache) - if init_tenant_models and init_models: - for model_class in self.registry.tenant_models.values(): - model_class.table_schema(schema=schema, update_cache=update_cache) - elif init_tenant_models: + if init_tenant_models: + token = NO_GLOBAL_FIELD_CONSTRAINTS.set(True) + try: + for model_class in self.registry.tenant_models.values(): + tenant_tables.append(model_class.build(schema=schema)) + finally: + NO_GLOBAL_FIELD_CONSTRAINTS.reset(token) + # we need two passes for model_class in self.registry.tenant_models.values(): - tenant_tables.append( - model_class.table_schema(schema=schema, update_cache=update_cache) - ) + model_class.add_global_field_constraints(schema=schema) def execute_create(connection: sqlalchemy.Connection, name: Optional[str]) -> None: try: @@ -87,8 +90,10 @@ def execute_create(connection: sqlalchemy.Connection, name: Optional[str]) -> No ) except ProgrammingError as e: raise SchemaError(detail=e.orig.args[0]) from e - for table in tenant_tables: - table.create(connection, checkfirst=if_not_exists) + if tenant_tables: + self.registry.metadata_by_name[name].create_all( + connection, checkfirst=if_not_exists, tables=tenant_tables + ) if init_models: self.registry.metadata_by_name[name].create_all( connection, checkfirst=if_not_exists diff --git a/edgy/core/db/context_vars.py b/edgy/core/db/context_vars.py index 954cc381..0a7bf57a 100644 --- a/edgy/core/db/context_vars.py +++ b/edgy/core/db/context_vars.py @@ -14,6 +14,9 @@ "CURRENT_MODEL_INSTANCE", default=None ) CURRENT_PHASE: ContextVar[str] = ContextVar("CURRENT_PHASE", default="") +NO_GLOBAL_FIELD_CONSTRAINTS: ContextVar[bool] = ContextVar( + "NO_GLOBAL_FIELD_CONSTRAINTS", default=False +) EXPLICIT_SPECIFIED_VALUES: ContextVar[Optional[set[str]]] = ContextVar( "EXPLICIT_SPECIFIED_VALUES", default=None ) diff --git a/edgy/core/db/fields/many_to_many.py b/edgy/core/db/fields/many_to_many.py index f6114ff7..fc079361 100644 --- a/edgy/core/db/fields/many_to_many.py +++ b/edgy/core/db/fields/many_to_many.py @@ -168,9 +168,13 @@ def create_through_model( Generates a middle model based on the owner of the field and the field itself and adds it to the main registry to make sure it generates the proper models and migrations. """ + from edgy.contrib.multi_tenancy.base import TenantModel + from edgy.contrib.multi_tenancy.metaclasses import TenantMeta from edgy.core.db.models.metaclasses import MetaInfo - __bases__: tuple[type[BaseModelType], ...] = () + __bases__: tuple[type[BaseModelType], ...] = ( + (TenantModel,) if getattr(self.owner.meta, "is_tenant", False) else () + ) pknames = set() if self.through: through = self.through @@ -243,7 +247,15 @@ def callback(model_class: type["BaseModelType"]) -> None: if has_pknames: meta_args["unique_together"] = [(self.from_foreign_key, self.to_foreign_key)] - new_meta: MetaInfo = MetaInfo(None, registry=False, no_copy=True, **meta_args) + # TenantMeta is compatible to normal meta + new_meta: MetaInfo = TenantMeta( + None, + registry=False, + no_copy=True, + is_tenant=getattr(self.owner.meta, "is_tenant", False), + register_default=getattr(self.owner.meta, "register_default", False), + **meta_args, + ) to_related_name: Union[str, Literal[False]] if self.related_name is False: diff --git a/edgy/core/db/fields/types.py b/edgy/core/db/fields/types.py index d05bca45..bdc9bdcc 100644 --- a/edgy/core/db/fields/types.py +++ b/edgy/core/db/fields/types.py @@ -130,7 +130,10 @@ def to_model( return {field_name: value} def get_global_constraints( - self, name: str, columns: Sequence[sqlalchemy.Column], schemes: Sequence[str] = () + self, + name: str, + columns: Sequence[sqlalchemy.Column], + schemes: Sequence[str] = (), ) -> Sequence[Union[sqlalchemy.Constraint, sqlalchemy.Index]]: """Return global constraints and indexes. Useful for multicolumn fields diff --git a/edgy/core/db/models/mixins/db.py b/edgy/core/db/models/mixins/db.py index 706e963b..4b58acb3 100644 --- a/edgy/core/db/models/mixins/db.py +++ b/edgy/core/db/models/mixins/db.py @@ -17,6 +17,7 @@ CURRENT_INSTANCE, EXPLICIT_SPECIFIED_VALUES, MODEL_GETATTR_BEHAVIOR, + NO_GLOBAL_FIELD_CONSTRAINTS, get_schema, ) from edgy.core.db.datastructures import Index, UniqueConstraint @@ -382,7 +383,7 @@ def copy_edgy_model( @property def table(self) -> sqlalchemy.Table: - if getattr(self, "_table", None) is None: + if self.__dict__.get("_table", None) is None: schema = self.get_active_instance_schema() return cast( "sqlalchemy.Table", @@ -391,13 +392,22 @@ def table(self) -> sqlalchemy.Table: return self._table @table.setter - def table(self, value: sqlalchemy.Table) -> None: + def table(self, value: Optional[sqlalchemy.Table]) -> None: with contextlib.suppress(AttributeError): del self._pknames with contextlib.suppress(AttributeError): del self._pkcolumns self._table = value + @table.deleter + def table(self) -> None: + with contextlib.suppress(AttributeError): + del self._pknames + with contextlib.suppress(AttributeError): + del self._pkcolumns + with contextlib.suppress(AttributeError): + del self._table + @property def pkcolumns(self) -> Sequence[str]: if self.__dict__.get("_pkcolumns", None) is None: @@ -669,7 +679,9 @@ async def save( @classmethod def build( - cls, schema: Optional[str] = None, metadata: Optional[sqlalchemy.MetaData] = None + cls, + schema: Optional[str] = None, + metadata: Optional[sqlalchemy.MetaData] = None, ) -> sqlalchemy.Table: """ Builds the SQLAlchemy table representation from the loaded fields. @@ -697,7 +709,10 @@ def build( for name, field in cls.meta.fields.items(): current_columns = field.get_columns(name) columns.extend(current_columns) - global_constraints.extend(field.get_global_constraints(name, current_columns, schemes)) + if not NO_GLOBAL_FIELD_CONSTRAINTS.get(): + global_constraints.extend( + field.get_global_constraints(name, current_columns, schemes) + ) # Handle the uniqueness together uniques = [] @@ -723,6 +738,36 @@ def build( else cls.get_active_class_schema(check_schema=False, check_tenant=False), ) + @classmethod + def add_global_field_constraints( + cls, + schema: Optional[str] = None, + metadata: Optional[sqlalchemy.MetaData] = None, + ) -> sqlalchemy.Table: + """ + Add global constraints to table. Required for tenants. + """ + tablename: str = cls.meta.tablename + registry = cls.meta.registry + assert registry, "registry is not set" + if metadata is None: + metadata = registry.metadata_by_url[str(cls.database.url)] + schemes: list[str] = [] + if schema: + schemes.append(schema) + if cls.__using_schema__ is not Undefined: + schemes.append(cls.__using_schema__) + db_schema = cls.get_db_schema() or "" + schemes.append(db_schema) + table = metadata.tables[tablename if not schema else f"{schema}.{tablename}"] + for name, field in cls.meta.fields.items(): + current_columns: list[sqlalchemy.Column] = [] + for column_name in cls.meta.field_to_column_names[name]: + current_columns.append(table.columns[column_name]) + for constraint in field.get_global_constraints(name, current_columns, schemes): + table.append_constraint(constraint) + return table + @classmethod def _get_unique_constraints( cls, fields: Union[Sequence, str, sqlalchemy.UniqueConstraint] diff --git a/edgy/core/db/models/mixins/reflection.py b/edgy/core/db/models/mixins/reflection.py index 9c9fdece..522397da 100644 --- a/edgy/core/db/models/mixins/reflection.py +++ b/edgy/core/db/models/mixins/reflection.py @@ -27,7 +27,9 @@ def real_add_to_registry(cls: type["BaseModelType"], **kwargs: Any) -> type["Bas @classmethod def build( - cls, schema: Optional[str] = None, metadata: Optional[sqlalchemy.MetaData] = None + cls, + schema: Optional[str] = None, + metadata: Optional[sqlalchemy.MetaData] = None, ) -> Any: """ The inspect is done in an async manner and reflects the objects from the database. diff --git a/edgy/core/db/models/types.py b/edgy/core/db/models/types.py index a39951fe..d32c0394 100644 --- a/edgy/core/db/models/types.py +++ b/edgy/core/db/models/types.py @@ -144,7 +144,9 @@ def model_dump(self, show_pk: Union[bool, None] = None, **kwargs: Any) -> dict[s @classmethod @abstractmethod def build( - cls, schema: Optional[str] = None, metadata: Optional[sqlalchemy.MetaData] = None + cls, + schema: Optional[str] = None, + metadata: Optional[sqlalchemy.MetaData] = None, ) -> sqlalchemy.Table: """ Builds the SQLAlchemy table representation from the loaded fields. diff --git a/edgy/core/db/relationships/relation.py b/edgy/core/db/relationships/relation.py index e19c0262..886c394a 100644 --- a/edgy/core/db/relationships/relation.py +++ b/edgy/core/db/relationships/relation.py @@ -95,6 +95,8 @@ def expand_relationship(self, value: Any) -> Any: **{self.from_foreign_key: self.instance, self.to_foreign_key: value} ) instance.identifying_db_fields = [self.from_foreign_key, self.to_foreign_key] # type: ignore + if getattr(through.meta, "is_tenant", False): + instance.__using_schema__ = self.instance.get_active_instance_schema() # type: ignore return instance def stage(self, *children: "BaseModelType") -> None: @@ -150,9 +152,9 @@ async def remove(self, child: Optional["BaseModelType"] = None) -> None: try: child = await self.get() except ObjectNotFound: - raise RelationshipNotFound(detail="no child found") from None + raise RelationshipNotFound(detail="No child found.") from None else: - raise RelationshipNotFound(detail="no child specified") + raise RelationshipNotFound(detail="No child specified.") if not isinstance( child, (self.to, self.to.proxy_model, self.through, self.through.proxy_model), # type: ignore @@ -164,7 +166,7 @@ async def remove(self, child: Optional["BaseModelType"] = None) -> None: count = await child.query.filter(*child.identifying_clauses()).count() if count == 0: raise RelationshipNotFound( - detail=f"There is no relationship between '{self.from_foreign_key}' and '{self.to_foreign_key}: {getattr(child,self.to_foreign_key).pk}'." + detail=f"There is no relationship between '{self.from_foreign_key}' and '{self.to_foreign_key}: {getattr(child, self.to_foreign_key).pk}'." ) else: await child.delete() @@ -242,6 +244,8 @@ def expand_relationship(self, value: Any) -> Any: value = {next(iter(related_columns)): value} instance = target.proxy_model(**value) instance.identifying_db_fields = related_columns # type: ignore + if getattr(target.meta, "is_tenant", False): + instance.__using_schema__ = self.instance.get_active_instance_schema() # type: ignore return instance def stage(self, *children: "BaseModelType") -> None: diff --git a/tests/contrib/multi_tenancy/test_tenant_models_using.py b/tests/contrib/multi_tenancy/test_tenant_models_using.py index ce909caf..816ff821 100644 --- a/tests/contrib/multi_tenancy/test_tenant_models_using.py +++ b/tests/contrib/multi_tenancy/test_tenant_models_using.py @@ -52,15 +52,25 @@ class Meta: is_tenant = True +class Cart(TenantModel): + products = fields.ManyToMany(Product) + + class Meta: + registry = models + is_tenant = True + + @pytest.mark.parametrize("use_copy", ["false", "instant", "after"]) async def test_schema_with_using_in_different_place(use_copy): if use_copy == "instant": copied = models.__copy__() NewTenant = copied.get_model("Tenant") NewProduct = copied.get_model("Product") + NewCart = copied.get_model("Cart") else: NewTenant = Tenant NewProduct = Product + NewCart = Cart tenant = await NewTenant.query.create( schema_name="edgy", domain_url="https://edgy.dymmond.com", tenant_name="edgy" ) @@ -68,13 +78,28 @@ async def test_schema_with_using_in_different_place(use_copy): copied = models.__copy__() NewTenant = copied.get_model("Tenant") NewProduct = copied.get_model("Product") + NewCart = copied.get_model("Cart") + cart = await NewCart.query.using(schema=tenant.schema_name).create() + assert cart.__using_schema__ == tenant.schema_name for i in range(5): - await NewProduct.query.using(schema=tenant.schema_name).create(name=f"product-{i}") + product = await NewProduct.query.using(schema=tenant.schema_name).create( + name=f"product-{i}" + ) + if i % 2 == 0: + product_through = cart.products.through(cart=cart, product=product) + product_through.__using_schema__ = tenant.schema_name + assert await cart.products.add(product_through) + else: + assert await cart.products.add(product) total = await NewProduct.query.filter().using(schema=tenant.schema_name).all() assert len(total) == 5 + total = await cart.products.filter().using(schema=tenant.schema_name).all() + + assert len(total) == 5 + total = await NewProduct.query.all() assert len(total) == 0