diff --git a/src/main/scala/scalismo/numerics/GramDiagonalize.scala b/src/main/scala/scalismo/numerics/GramDiagonalize.scala new file mode 100644 index 00000000..cb739c1c --- /dev/null +++ b/src/main/scala/scalismo/numerics/GramDiagonalize.scala @@ -0,0 +1,31 @@ +package scalismo.numerics + +import breeze.linalg.{*, given} + +object GramDiagonalize { + + /** + * Given a non orthogonal basis nxr and the variance (squared [eigen]scalars) of that basis, returns an orthonormal + * basis with the adjusted variance. sets small eigenvalues to zero. + */ + def rediagonalizeGram(basis: DenseMatrix[Double], + s: DenseVector[Double] + ): (DenseMatrix[Double], DenseVector[Double]) = { + // val l: DenseMatrix[Double] = basis(*, ::) * breeze.numerics.sqrt(s) + val l: DenseMatrix[Double] = DenseMatrix.zeros[Double](basis.rows, basis.cols) + val sqs: DenseVector[Double] = breeze.numerics.sqrt(s) + for i <- 0 until basis.cols do l(::, i) := sqs(i) * basis(::, i) + + val gram = l.t * l + val svd = breeze.linalg.svd(gram) + val newS: DenseVector[Double] = breeze.numerics.sqrt(svd.S).map(d => if (d > 1e-10) 1.0 / d else 0.0) + + // val newbasis: DenseMatrix[Double] = l * (svd.U(*, ::) * newS) + val inner: DenseMatrix[Double] = DenseMatrix.zeros[Double](gram.rows, gram.cols) + for i <- 0 until basis.cols do inner(::, i) := newS(i) * svd.U(::, i) + val newbasis = l * inner + + (newbasis, svd.S) + } + +} diff --git a/src/main/scala/scalismo/statisticalmodel/DiscreteLowRankGaussianProcess.scala b/src/main/scala/scalismo/statisticalmodel/DiscreteLowRankGaussianProcess.scala index bc3cb2ca..6375b183 100644 --- a/src/main/scala/scalismo/statisticalmodel/DiscreteLowRankGaussianProcess.scala +++ b/src/main/scala/scalismo/statisticalmodel/DiscreteLowRankGaussianProcess.scala @@ -19,18 +19,19 @@ import breeze.linalg.svd.SVD import breeze.linalg.{diag, DenseMatrix, DenseVector} import breeze.stats.distributions.Gaussian import scalismo.common.DiscreteField.vectorize -import scalismo.common._ +import scalismo.common.* import scalismo.common.interpolation.{FieldInterpolator, NearestNeighborInterpolator} -import scalismo.geometry._ +import scalismo.geometry.* import scalismo.image.StructuredPoints import scalismo.kernels.{DiscreteMatrixValuedPDKernel, MatrixValuedPDKernel} -import scalismo.numerics.{PivotedCholesky, Sampler} -import scalismo.statisticalmodel.DiscreteLowRankGaussianProcess.{Eigenpair => DiscreteEigenpair, _} +import scalismo.numerics.{GramDiagonalize, PivotedCholesky, Sampler} +import scalismo.statisticalmodel.DiscreteLowRankGaussianProcess.{Eigenpair as DiscreteEigenpair, *} import scalismo.statisticalmodel.LowRankGaussianProcess.Eigenpair import scalismo.statisticalmodel.NaNStrategy.NanIsNumericValue import scalismo.statisticalmodel.dataset.DataCollection import scalismo.utils.{Memoize, Random} +import scala.annotation.threadUnsafe import scala.language.higherKinds import scala.collection.parallel.immutable.ParVector @@ -358,6 +359,37 @@ class DiscreteLowRankGaussianProcess[D: NDSpace, DDomain[DD] <: DiscreteDomain[D ) } + /** + * realigns the model on the provided part of the domain. Aligns over the translation and, when using + * withExtendedBasis = true, over the extended basis (the default implicit [[RealignExtendedBasis]] adds rotation. + * This rotation will always be calculated around the center of the provided ids. Rotations are around the cardinal + * directions.). + * + * @param ids + * these define the parts of the domain that are aligned to. Depending on the withExtendedBasis parameter has a + * minimum length requirements (default basis extension in 3D should be used with >=4 provided ids for example) + * @param withExtendedBasis + * True if the extended basis should be included. By default this uses a rotation extension. False makes the + * realignment only over translation. Translational alignment can be done exactly. For more information see + * [[RealignExtendedBasis]]. + * @param diagonalize + * True if a diagonal basis should be returned. In general, it is strongly recommended to use a orthonormal basis - + * here referred to as diagonal. This does not increase complexity and is a more intuitive formulation of the model. + * If internal fields are accessed diagonalize should be set to true. This option can be set to false to make the + * same coefficient lead to very similar shapes in the pre- and after realignment model (or exactly the same shapes + * if withExtendedBasis = false). + * @return + * The resulting [[DiscreteLowRankGaussianProcess]] aligned on the provided instances of [[PointId]]. If + * withExtendedBasis = false then the original and the returned model can produce the same mesh with different + * translations. That means the shape spaces are the same but the fieldsa are translated. + */ + def realign(ids: IndexedSeq[PointId], withExtendedBasis: Boolean = true, diagonalize: Boolean = true)(using + vectorizer: Vectorizer[Value], + realigning: RealignExtendedBasis[D, Value] + ): DiscreteLowRankGaussianProcess[D, DDomain, Value] = { + DiscreteLowRankGaussianProcess.realignment(this, ids, withExtendedBasis, diagonalize) + } + protected[statisticalmodel] def instanceVector(alpha: DenseVector[Double]): DenseVector[Double] = { require(rank == alpha.size) @@ -638,6 +670,75 @@ object DiscreteLowRankGaussianProcess { DiscreteMatrixValuedPDKernel(domain, cov, outputDim) } + def realignment[D: NDSpace, DDomain[DD] <: DiscreteDomain[DD], Value]( + model: DiscreteLowRankGaussianProcess[D, DDomain, Value], + ids: IndexedSeq[PointId], + withExtendedBasis: Boolean, + diagonalize: Boolean + )(using + vectorizer: Vectorizer[Value], + realigning: RealignExtendedBasis[D, Value] + ): DiscreteLowRankGaussianProcess[D, DDomain, Value] = { + val d = NDSpace.apply[D].dimensionality + // build the projection matrix for the desired pose + val p = { + @threadUnsafe + lazy val pt = breeze.linalg.tile(DenseMatrix.eye[Double](d), model.domain.pointSet.numberOfPoints, 1) + if withExtendedBasis then + val center = ids.map(id => model.domain.pointSet.point(id).toVector).reduce(_ + _).map(_ / ids.length).toPoint + val pr = realigning.getBasis[DDomain](model, center) + if realigning.useTranslation then DenseMatrix.horzcat(pt, pr) + else pr + else pt + } + // call the realignment implementation + val (nmean, nbasis, nvar) = realignmentComputation(model.meanVector, + model.basisMatrix, + model.variance, + p, + ids.map(_.id), + dim = d, + diagonalize = diagonalize, + projectMean = false + ) + + new DiscreteLowRankGaussianProcess[D, DDomain, Value](model.domain, nmean, nvar, nbasis) + } + + private def realignmentComputation(mean: DenseVector[Double], + basis: DenseMatrix[Double], + s: DenseVector[Double], + p: DenseMatrix[Double], + ids: IndexedSeq[Int], + dim: Int, + diagonalize: Boolean, + projectMean: Boolean + ): (DenseVector[Double], DenseMatrix[Double], DenseVector[Double]) = { + val x = for // prepare indices + id <- ids + d <- 0 until dim + yield id * dim + d + // prepare the majority of the projection matrix + val px = p(x, ::).toDenseMatrix + val ptpipt = breeze.linalg.pinv(px.t * px) * px.t + + // performs the actual projection. batches all basis vectors + // p -> projection rank, n number of indexes*dim, r cols of basis, N rows of basis + val alignedC = ptpipt * basis(x, ::).toDenseMatrix // pxn * nxr + val alignedEigf = basis - p * alignedC // Nxr - Nxp * pxr + val alignedMean = if projectMean then // if desired projects the mean vector as well + val alignedMc = ptpipt * mean // same projection with r==1 + mean - p * alignedMc + else mean + + // rediagonalize. You can skip this if you ONLY sample from the resulting model + val (newbasis, news) = + if diagonalize then GramDiagonalize.rediagonalizeGram(alignedEigf, s) + else (alignedEigf, s) + + (alignedMean, newbasis, news) + } + } // Explicit variants for 1D, 2D and 3D diff --git a/src/main/scala/scalismo/statisticalmodel/PointDistributionModel.scala b/src/main/scala/scalismo/statisticalmodel/PointDistributionModel.scala index 7375c59b..18a9bdc2 100644 --- a/src/main/scala/scalismo/statisticalmodel/PointDistributionModel.scala +++ b/src/main/scala/scalismo/statisticalmodel/PointDistributionModel.scala @@ -257,6 +257,17 @@ case class PointDistributionModel[D: NDSpace, DDomain[D] <: DiscreteDomain[D]]( PointDistributionModel(newGP) } + /** + * realigns the internal [[DiscreteLowRankGaussianProcess]] and returns the resulting [[PointDistributionModel]]. this + * calls [[DiscreteLowRankGaussianProcess.realign]]. + */ + def realign(ids: IndexedSeq[PointId], withExtendedBasis: Boolean = true, diagonalize: Boolean = true)(using + vectorizer: Vectorizer[EuclideanVector[D]], + realign: RealignExtendedBasis[D, EuclideanVector[D]] + ): PointDistributionModel[D, DDomain] = { + new PointDistributionModel[D, DDomain](this.gp.realign(ids, withExtendedBasis, diagonalize)) + } + } object PointDistributionModel { diff --git a/src/main/scala/scalismo/statisticalmodel/RealignExtendedBasis.scala b/src/main/scala/scalismo/statisticalmodel/RealignExtendedBasis.scala new file mode 100644 index 00000000..e45b3674 --- /dev/null +++ b/src/main/scala/scalismo/statisticalmodel/RealignExtendedBasis.scala @@ -0,0 +1,88 @@ +package scalismo.statisticalmodel + +import breeze.linalg.DenseMatrix +import scalismo.common.DiscreteDomain +import scalismo.geometry.* + +/** + * ideally used to represent linear effects that should be normalized between the training data. The realignment process + * then builds a linear projection matrix that is applied to an existing model. the space of the effects that should be + * normalized needs to be spanned by the returned matrix + */ +trait RealignExtendedBasis[D, Value]: + + /** + * whether or not the default translation basis should also be used. that means false does not perform a translation + * realignment. This in combination with getBasis allows for complete control of the projection matrix. + */ + def useTranslation: Boolean + + /** + * basis to span the kernel of the projection. for example, a translation alignment could be performed by spanning + * that space with constant vectors for each cardinal direction. + */ + def getBasis[DDomain[DD] <: DiscreteDomain[DD]](model: DiscreteLowRankGaussianProcess[D, DDomain, Value], + center: Point[D] + ): DenseMatrix[Double] + +/** + * includes the additional default rotation centerpoint implementation which is useful to calculate the rotation basis. + */ +trait RealignExtendedBasisRotation[D, Value] extends RealignExtendedBasis[D, Value]: + def centeredP[D: NDSpace, DDomain[DD] <: DiscreteDomain[DD]](domain: DDomain[D], + center: Point[D] + ): DenseMatrix[Double] = { + // build centered data matrix + val x = DenseMatrix.zeros[Double](center.dimensionality, domain.pointSet.numberOfPoints) + val c = center.toBreezeVector + for (p, i) <- domain.pointSet.points.zipWithIndex do x(::, i) := p.toBreezeVector - c + x + } + +object RealignExtendedBasis: + /** + * returns a projection basis for rotation - the tangential speed for the rotations around the three cardinal + * directions. + */ + given realignBasis3D: RealignExtendedBasisRotation[_3D, EuclideanVector[_3D]] with + def useTranslation: Boolean = true + def getBasis[DDomain[DD] <: DiscreteDomain[DD]]( + model: DiscreteLowRankGaussianProcess[_3D, DDomain, EuclideanVector[_3D]], + center: Point[_3D] + ): DenseMatrix[Double] = { + val np = model.domain.pointSet.numberOfPoints + val x = centeredP(model.domain, center) + + val pr = DenseMatrix.zeros[Double](np * 3, 3) + // the derivative of the rotation matrices + val dr = new DenseMatrix[Double](9, + 3, + Array(1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 1.0, + 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 1.0) + ) + // get tangential speed + val dx = dr * x + for i <- 0 until 3 do + val v = dx(3 * i until 3 * i + 3, ::).toDenseVector + pr(::, i) := v / breeze.linalg.norm(v) + pr + } + + /** + * returns a projection basis for rotation - the tangential speed for the single 2d rotation. + */ + given realignBasis2D: RealignExtendedBasisRotation[_2D, EuclideanVector[_2D]] with + def useTranslation: Boolean = true + def getBasis[DDomain[DD] <: DiscreteDomain[DD]]( + model: DiscreteLowRankGaussianProcess[_2D, DDomain, EuclideanVector[_2D]], + center: Point[_2D] + ): DenseMatrix[Double] = { + val np = model.domain.pointSet.numberOfPoints + val x = centeredP(model.domain, center) + + // derivative of the rotation matrix + val dr = new DenseMatrix[Double](2, 2, Array(0.0, -1.0, 1.0, 0.0)) + val dx = (dr * x).reshape(2 * np, 1) + val n = breeze.linalg.norm(dx, breeze.linalg.Axis._0) + dx / n(0) + } diff --git a/src/test/scala/scalismo/statisticalmodel/GaussianProcessTests.scala b/src/test/scala/scalismo/statisticalmodel/GaussianProcessTests.scala index c6b5003d..253d8fe0 100644 --- a/src/test/scala/scalismo/statisticalmodel/GaussianProcessTests.scala +++ b/src/test/scala/scalismo/statisticalmodel/GaussianProcessTests.scala @@ -15,7 +15,7 @@ */ package scalismo.statisticalmodel -import breeze.linalg.{DenseMatrix, DenseVector} +import breeze.linalg.{rand, DenseMatrix, DenseVector} import breeze.stats.distributions.Gaussian import scalismo.ScalismoTestSuite import scalismo.common.* @@ -28,6 +28,7 @@ import scalismo.io.statisticalmodel.StatismoIO import scalismo.kernels.{DiagonalKernel, GaussianKernel, MatrixValuedPDKernel} import scalismo.numerics.PivotedCholesky.RelativeTolerance import scalismo.numerics.{GridSampler, UniformSampler} +import scalismo.transformations.TranslationAfterRotation import scalismo.utils.Random import java.net.URLDecoder @@ -507,11 +508,11 @@ class GaussianProcessTests extends ScalismoTestSuite { object Fixture { val domain = BoxDomain((-5.0, -5.0, -5.0), (5.0, 5.0, 5.0)) - val sampler = UniformSampler(domain, 6 * 6 * 6) + val sampler = UniformSampler(domain, 5 * 5 * 5) val mean = Field[_3D, EuclideanVector[_3D]](EuclideanSpace[_3D], _ => EuclideanVector(0.0, 0.0, 0.0)) val gp = GaussianProcess(mean, DiagonalKernel(GaussianKernel[_3D](5), 3)) - val lowRankGp = LowRankGaussianProcess.approximateGPNystrom(gp, sampler, 200) + val lowRankGp = LowRankGaussianProcess.approximateGPNystrom(gp, sampler, 120) val discretizationPoints = sampler.sample().map(_._1) val discreteLowRankGp = DiscreteLowRankGaussianProcess(UnstructuredPointsDomain(discretizationPoints), lowRankGp) @@ -723,6 +724,99 @@ class GaussianProcessTests extends ScalismoTestSuite { } } + describe("when realigning a model") { + it("the translation aligned model should be exactly orthogonal to translations on the aligned ids.") { + val f = Fixture + val dgp = f.discreteLowRankGp + + val alignedDgp = dgp.realign(dgp.mean.pointsWithIds.map(t => t._2).toIndexedSeq) + + val shifts: IndexedSeq[Double] = alignedDgp.klBasis + .map(klp => { + val ef = klp.eigenfunction + ef.data.reduce(_ + _).norm + }) + .toIndexedSeq + val res = shifts.sum + + res shouldBe 0.0 +- 1e-7 + } + + it("the rotation aligned model should exhibit only small rotations on the aligned ids.") { + val f = Fixture + val dgp = f.discreteLowRankGp + + val ids = { + val sorted = dgp.mean.pointsWithIds.toIndexedSeq.sortBy(t => t._1.toArray.sum) + sorted.take(10).map(_._2) + } + val coef = (0 until 5).map(_ => this.random.scalaRandom.nextInt(100)) + val alignedDgp = dgp.realign(ids) + val res = IndexedSeq(dgp, alignedDgp).map(model => { + val samples = coef + .map(i => model.instance(DenseVector.tabulate[Double](model.rank) { j => if i == j then 0.1 else 0.0 })) + .map(_ => model.sample()) + val rp = ids.map(id => model.mean.domain.pointSet.point(id).toVector).reduce(_ + _).map(d => d / ids.length) + val rotations = samples.map(sample => { + val ldms = ids.map(id => + (model.domain.pointSet.point(id) + model.mean.data(id.id), + sample.domain.pointSet.point(id.id) + sample.data(id.id) + ) + ) + val rigidTransform: TranslationAfterRotation[_3D] = + scalismo.registration.LandmarkRegistration.rigid3DLandmarkRegistration(ldms, rp.toPoint) + rigidTransform.rotation.parameters + }) + rotations.map(m => m.data.map(math.abs).sum).sum + }) + res(1) shouldBe <(res(0) * 0.6) + } + + it("the non diagonalized model should be equivalent to the diagonalized one") { + val f = Fixture + val dgp = f.discreteLowRankGp + + val ids = { + val sorted = dgp.mean.pointsWithIds.toIndexedSeq.sortBy(t => t._1.toArray.sum) + sorted.take(10).map(_._2) + } + val coef = (0 until 5).map(_ => this.random.scalaRandom.nextInt(100)) + val adDgp = dgp.realign(ids, withExtendedBasis = false) + val aDgp = dgp.realign(ids, withExtendedBasis = false, diagonalize = false) + + // true variance is not the same as naive reading of variance field + val varad = adDgp.variance.data.sum + val vara = aDgp.variance.data + .zip(breeze.linalg.norm(aDgp.basisMatrix, breeze.linalg.Axis._0).t.data) + .map(t => t._1 * (t._2 * t._2)) + .sum + varad shouldBe vara +- 1e-5 + + // projection works as it relies on underlying regression calculations + val sample = adDgp.sample() + val proj = aDgp.project(sample) + val dif = sample.data.zip(proj.data).map(t => (t._1 - t._2).norm).sum + dif shouldBe 0.0 +- 1e-1 + + // marginal likelihood works + val marglad = adDgp.marginalLikelihood(f.trainingDataDiscreteGP) + val margla = aDgp.marginalLikelihood(f.trainingDataDiscreteGP) + marglad shouldBe margla +- 1e-5 + + // regression mean and variance are also the same + val td = f.trainingDataDiscreteGP.map(t => (t._1, EuclideanVector3D.ones, t._3)) + val regmeanad = adDgp.posterior(td) + val regmeana = aDgp.posterior(td) + val difmeans = regmeanad.mean.data.zip(regmeana.mean.data).map(t => (t._1 - t._2).norm).sum + difmeans shouldBe 0.0 +- 1e-1 + val corvarad = regmeanad.variance.data.sum + val corvara = regmeana.variance.data + .zip(breeze.linalg.norm(regmeana.basisMatrix, breeze.linalg.Axis._0).t.data) + .map(t => t._1 * (t._2 * t._2)) + .sum + corvarad shouldBe corvara +- 1e-5 + } + } } describe("when comparing marginalLikelihoods") {