Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added realignment to dlgp and pdm #428

Open
wants to merge 14 commits into
base: release-1.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions src/main/scala/scalismo/numerics/GramDiagonalize.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package scalismo.numerics

import breeze.linalg.{*, given}

object GramDiagonalize {

/**
* Given a non orthogonal basis nxr and the variance (squared [eigen]scalars) of that basis, returns an orthonormal
* basis with the adjusted variance. sets small eigenvalues to zero.
*/
def rediagonalizeGram(basis: DenseMatrix[Double],
s: DenseVector[Double]
): (DenseMatrix[Double], DenseVector[Double]) = {
// val l: DenseMatrix[Double] = basis(*, ::) * breeze.numerics.sqrt(s)
val l: DenseMatrix[Double] = DenseMatrix.zeros[Double](basis.rows, basis.cols)
val sqs: DenseVector[Double] = breeze.numerics.sqrt(s)
for i <- 0 until basis.cols do l(::, i) := sqs(i) * basis(::, i)

val gram = l.t * l
val svd = breeze.linalg.svd(gram)
val newS: DenseVector[Double] = breeze.numerics.sqrt(svd.S).map(d => if (d > 1e-10) 1.0 / d else 0.0)

// val newbasis: DenseMatrix[Double] = l * (svd.U(*, ::) * newS)
val inner: DenseMatrix[Double] = DenseMatrix.zeros[Double](gram.rows, gram.cols)
for i <- 0 until basis.cols do inner(::, i) := newS(i) * svd.U(::, i)
val newbasis = l * inner

(newbasis, svd.S)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ import breeze.linalg.svd.SVD
import breeze.linalg.{diag, DenseMatrix, DenseVector}
import breeze.stats.distributions.Gaussian
import scalismo.common.DiscreteField.vectorize
import scalismo.common._
import scalismo.common.*
import scalismo.common.interpolation.{FieldInterpolator, NearestNeighborInterpolator}
import scalismo.geometry._
import scalismo.geometry.*
import scalismo.image.StructuredPoints
import scalismo.kernels.{DiscreteMatrixValuedPDKernel, MatrixValuedPDKernel}
import scalismo.numerics.{PivotedCholesky, Sampler}
import scalismo.statisticalmodel.DiscreteLowRankGaussianProcess.{Eigenpair => DiscreteEigenpair, _}
import scalismo.numerics.{GramDiagonalize, PivotedCholesky, Sampler}
import scalismo.statisticalmodel.DiscreteLowRankGaussianProcess.{Eigenpair as DiscreteEigenpair, *}
import scalismo.statisticalmodel.LowRankGaussianProcess.Eigenpair
import scalismo.statisticalmodel.NaNStrategy.NanIsNumericValue
import scalismo.statisticalmodel.dataset.DataCollection
Expand Down Expand Up @@ -358,6 +358,26 @@ class DiscreteLowRankGaussianProcess[D: NDSpace, DDomain[DD] <: DiscreteDomain[D
)
}

/**
* realigns the model on the provided part of the domain. By default aligns over the translation and approximately
* over rotation as well. The rotation will always be calculated around the center of the provided ids. Rotations are
* around the cardinal directions.
*
* @param ids
* these define the parts of the domain that are aligned to
* @param withRotation
* True if the rotation should be included. False makes the realignment over translation exact.
* @param diagonalize
* True if a diagonal basis should be returned. False is cheaper for exclusively drawing samples.
* @return
* The resulting [[DiscreteLowRankGaussianProcess]] aligned on the provided instances of [[PointId]]
*/
def realign(ids: IndexedSeq[PointId], withRotation: Boolean = true, diagonalize: Boolean = true)(implicit
vectorizer: Vectorizer[Value]
): DiscreteLowRankGaussianProcess[D, DDomain, Value] = {
DiscreteLowRankGaussianProcess.realignment(this, ids, withRotation, diagonalize)
}

protected[statisticalmodel] def instanceVector(alpha: DenseVector[Double]): DenseVector[Double] = {
require(rank == alpha.size)

Expand Down Expand Up @@ -485,6 +505,12 @@ object DiscreteLowRankGaussianProcess {
new DiscreteLowRankGaussianProcess[D, DDomain, Value](domain, meanVec, varianceVec, basisMat)
}

def unapply[D: NDSpace, DDomain[D] <: DiscreteDomain[D], Value](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add an unapply method, we should also remove the private modifier from the corresponding fields, and maybe add an apply method. Like this it is a bit contradictory.
It was private so far to prevent access to it, but I don't have a strong opinion to keep it like this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i forgot to remove the unapply method. will be removed in the next version

dgp: DiscreteLowRankGaussianProcess[D, DDomain, Value]
): Option[(DiscreteDomain[D], DenseVector[Double], DenseMatrix[Double], DenseVector[Double])] = {
Option(dgp.domain, dgp.meanVector, dgp.basisMatrix, dgp.variance)
}

/**
* Discrete implementation of [[LowRankGaussianProcess.regression]]
*/
Expand Down Expand Up @@ -638,6 +664,118 @@ object DiscreteLowRankGaussianProcess {
DiscreteMatrixValuedPDKernel(domain, cov, outputDim)
}

def realignment[D: NDSpace, DDomain[DD] <: DiscreteDomain[DD], Value](
model: DiscreteLowRankGaussianProcess[D, DDomain, Value],
ids: IndexedSeq[PointId],
withRotation: Boolean,
diagonalize: Boolean
)(implicit vectorizer: Vectorizer[Value]): DiscreteLowRankGaussianProcess[D, DDomain, Value] = {
model match
case m: DiscreteLowRankGaussianProcess[`_3D`, _, EuclideanVector[`_3D`]] =>
realignment3D(m, ids, withRotation, diagonalize)
case _ if !withRotation =>
// TODO general pure translation
throw new NotImplementedError("not yet implemented")
case _ =>
throw new NotImplementedError("not yet implemented")
}

private def realignment3D[DDomain[_3D] <: DiscreteDomain[_3D]](
model: DiscreteLowRankGaussianProcess[_3D, DDomain, EuclideanVector[_3D]],
ids: IndexedSeq[PointId],
withRotation: Boolean,
diagonalize: Boolean
): DiscreteLowRankGaussianProcess[_3D, DDomain, EuclideanVector[_3D]] = {
// build the projection matrix for the desired pose
val p = {
val pt = breeze.linalg.tile(DenseMatrix.eye[Double](3), model.domain.pointSet.numberOfPoints, 1)
if withRotation then
val center = ids.map(id => model.domain.pointSet.point(id).toVector).reduce(_ + _).map(_ / ids.length).toPoint
val pr = getTangentialSpeedMatrix(model.domain, center)
DenseMatrix.horzcat(pt, pr)
else pt
}
// call the realignment implementation
val (nmean, nbasis, nvar) = realignmentComputation(model.meanVector,
model.basisMatrix,
model.variance,
p,
ids.map(_.id),
dim = 3,
diagonalize = diagonalize,
projectMean = false
)

new DiscreteLowRankGaussianProcess[_3D, DDomain, EuclideanVector[_3D]](model.domain, nmean, nvar, nbasis)
}

private def realignmentComputation(mean: DenseVector[Double],
basis: DenseMatrix[Double],
s: DenseVector[Double],
p: DenseMatrix[Double],
ids: IndexedSeq[Int],
dim: Int,
diagonalize: Boolean,
projectMean: Boolean
): (DenseVector[Double], DenseMatrix[Double], DenseVector[Double]) = {
val x = for // prepare indices
id <- ids
d <- 0 until dim
yield id * dim + d
// prepare the majority of the projection matrix
val px = p(x, ::).toDenseMatrix
val ptpipt = breeze.linalg.pinv(px.t * px) * px.t

// performs the actual projection. batches all basis vectors
// p -> projection rank, n number of indexes*dim, r cols of basis, N rows of basis
val alignedC = ptpipt * basis(x, ::).toDenseMatrix // pxn * nxr
val alignedEigf = basis - p * alignedC // Nxr - Nxp * pxr
val alignedMean = if projectMean then // if desired projects the mean vector as well
val alignedMc = ptpipt * mean // same projection with r==1
mean - p * alignedMc
else mean

// rediagonalize. You can skip this if you ONLY sample from the resulting model
val (newbasis, news) =
if diagonalize then GramDiagonalize.rediagonalizeGram(alignedEigf, s)
else (alignedEigf, s)

(alignedMean, newbasis, news)
}

private def getTangentialSpeedMatrix[D: NDSpace](domain: DiscreteDomain[D], center: Point[D]): DenseMatrix[Double] = {
val dim = center.dimensionality
val np = domain.pointSet.numberOfPoints
require(dim >= 2, "requires at least two dimensions to calculate a tangential speed")
// build centered data matrix
val x = DenseMatrix.zeros[Double](dim, np)
val c = center.toBreezeVector
for (p, i) <- domain.pointSet.points.zipWithIndex do x(::, i) := p.toBreezeVector - c

val pr = if dim == 3 then
val pr = DenseMatrix.zeros[Double](np * dim, 3)
// the derivative of the rotation matrix
val dr = new DenseMatrix[Double](9,
3,
Array(1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 1.0,
0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, -1.0, 0.0, 0.0, 0.0, 0.0, 1.0)
)
// get tangential speed
val dx = dr * x
for i <- 0 until 3 do
val v = dx(3 * i until 3 * i + 3, ::).toDenseVector
pr(::, i) := v / breeze.linalg.norm(v)
pr
else if dim == 2 then
val dr = new DenseMatrix[Double](2, 2, Array(0.0, -1.0, 1.0, 0.0))
val dx = (dr * x).reshape(dim * np, 1)
val n = breeze.linalg.norm(dx, breeze.linalg.Axis._0)
dx / n(0)
else throw new NotImplementedError("tangential speed only implemented for 2d and 3d space")

pr
}

}

// Explicit variants for 1D, 2D and 3D
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,16 @@ case class PointDistributionModel[D: NDSpace, DDomain[D] <: DiscreteDomain[D]](
PointDistributionModel(newGP)
}

/**
* realigns the [[DiscreteLowRankGaussianProcess]] and returns the resulting [[PointDistributionModel]]
*/
def realign(ids: IndexedSeq[PointId],
withRotation: Boolean = true,
diagonalize: Boolean = true
): PointDistributionModel[D, DDomain] = {
new PointDistributionModel[D, DDomain](this.gp.realign(ids, withRotation, diagonalize))
}

}

object PointDistributionModel {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
*/
package scalismo.statisticalmodel

import breeze.linalg.{DenseMatrix, DenseVector}
import breeze.linalg.{rand, DenseMatrix, DenseVector}
import breeze.stats.distributions.Gaussian
import scalismo.ScalismoTestSuite
import scalismo.common.*
Expand All @@ -28,6 +28,7 @@ import scalismo.io.statisticalmodel.StatismoIO
import scalismo.kernels.{DiagonalKernel, GaussianKernel, MatrixValuedPDKernel}
import scalismo.numerics.PivotedCholesky.RelativeTolerance
import scalismo.numerics.{GridSampler, UniformSampler}
import scalismo.transformations.TranslationAfterRotation
import scalismo.utils.Random

import java.net.URLDecoder
Expand Down Expand Up @@ -723,6 +724,54 @@ class GaussianProcessTests extends ScalismoTestSuite {
}
}

describe("when realigning a model") {
it("the translation aligned model should be exactly orthogonal to translations on the aligned ids.") {
val f = Fixture
val dgp = f.discreteLowRankGp

val alignedDgp = dgp.realign(dgp.mean.pointsWithIds.map(t => t._2).toIndexedSeq, withRotation = true)

val shifts: IndexedSeq[Double] = alignedDgp.klBasis
.map(klp => {
val ef = klp.eigenfunction
ef.data.reduce(_ + _).norm
})
.toIndexedSeq
val res = shifts.sum

res shouldBe 0.0 +- 1e-7
}

it("the rotation aligned model should exhibit only small rotations on the aligned ids.") {
val f = Fixture
val dgp = f.discreteLowRankGp

val ids = {
val sorted = dgp.mean.pointsWithIds.toIndexedSeq.sortBy(t => t._1.toArray.sum)
sorted.take(10).map(_._2)
}
val coef = (0 until 5).map(_ => this.random.scalaRandom.nextInt(100))
val alignedDgp = dgp.realign(ids)
val res = IndexedSeq(dgp, alignedDgp).map(model => {
val samples = coef
.map(i => model.instance(DenseVector.tabulate[Double](model.rank) { j => if i == j then 0.1 else 0.0 }))
.map(_ => model.sample())
val rp = ids.map(id => model.mean.domain.pointSet.point(id).toVector).reduce(_ + _).map(d => d / ids.length)
val rotations = samples.map(sample => {
val ldms = ids.map(id =>
(model.domain.pointSet.point(id) + model.mean.data(id.id),
sample.domain.pointSet.point(id.id) + sample.data(id.id)
)
)
val rigidTransform: TranslationAfterRotation[_3D] =
scalismo.registration.LandmarkRegistration.rigid3DLandmarkRegistration(ldms, rp.toPoint)
rigidTransform.rotation.parameters
})
rotations.map(m => m.data.map(math.abs).sum).sum
})
res(1) shouldBe <(res(0) * 0.9)
}
}
}

describe("when comparing marginalLikelihoods") {
Expand Down
Loading