diff --git a/src/main/scala/BIDMach/updaters/GradCollide.scala b/src/main/scala/BIDMach/updaters/GradCollide.scala index 8172a772..7641e2af 100755 --- a/src/main/scala/BIDMach/updaters/GradCollide.scala +++ b/src/main/scala/BIDMach/updaters/GradCollide.scala @@ -204,7 +204,7 @@ class GradCollide(override val opts:GradCollide.Opts = new GradCollide.Options) tmp ~ p *@ p; val meansqp = dotprod(tmp, tmp) / p.length; tmp ~ q *@ q; - val meansqq = dotprod(tmp, tmp) / p.length; + val meansqq = dotprod(tmp, tmp) / q.length; val meanp = lp * lp; val meanq = lq * lq; val cosp = dp / (p.length * lp * lq + epsilon); @@ -236,7 +236,7 @@ class GradCollide(override val opts:GradCollide.Opts = new GradCollide.Options) // Now find the vector e which captures the "excess energy" in p and q, i.e. // the component of p or -q which is orthogonal to x=p+q, and the squared magnitude of e (energy). val pq2 = pp + 2*pq + qq; - val pcoeff = (pq + qq) / (pq2 + epsilon); // e = pcoeff * p + qcoeff * q; + val pcoeff = (pq + qq) / (pq2 + epsilon); // e = pcoeff * p - qcoeff * q; val qcoeff = (pq + pp) / (pq2 + epsilon); val energy = pcoeff * pcoeff * pp + qcoeff * qcoeff * qq - 2 * pcoeff * qcoeff * pq; @@ -245,7 +245,7 @@ class GradCollide(override val opts:GradCollide.Opts = new GradCollide.Options) tmp ~ q * aelem.set(qcoeff * h); c ~ c - tmp; - // Scale the random vector to match the energy diffference + // Scale the random vector to match the energy difference between p and p-h*e (or equivalently, between q and q+h*e) if (energy > 0) { x ~ x * aelem.set(math.sqrt(energy * (2 * h - h * h) / dotprod(x, x)).toFloat); } else { @@ -264,7 +264,7 @@ class GradCollide(override val opts:GradCollide.Opts = new GradCollide.Options) tmp ~ p *@ p; val meansqp = dotprod(tmp, tmp) / p.length; tmp ~ q *@ q; - val meansqq = dotprod(tmp, tmp) / p.length; + val meansqq = dotprod(tmp, tmp) / q.length; val meanp = lp * lp; val meanq = lq * lq; val cosp = dp / (p.length * lp * lq + epsilon); @@ -292,7 +292,7 @@ class GradCollide(override val opts:GradCollide.Opts = new GradCollide.Options) tmp ~ p *@ p; val meansqp = dotprod(tmp, tmp) / p.length; tmp ~ q *@ q; - val meansqq = dotprod(tmp, tmp) / p.length; + val meansqq = dotprod(tmp, tmp) / q.length; val meanp = lp * lp; val meanq = lq * lq; val cosp = dp / (p.length * lp * lq + epsilon); @@ -338,7 +338,7 @@ class GradCollide(override val opts:GradCollide.Opts = new GradCollide.Options) tmp ~ p *@ p; val meansqp = dotprod(tmp, tmp) / p.length; tmp ~ q *@ q; - val meansqq = dotprod(tmp, tmp) / p.length; + val meansqq = dotprod(tmp, tmp) / q.length; val meanp = lp * lp; val meanq = lq * lq; val cosp = dp / (p.length * lp * lq + epsilon); @@ -363,7 +363,7 @@ class GradCollide(override val opts:GradCollide.Opts = new GradCollide.Options) tmp ~ p *@ p; val meansqp = dotprod(tmp, tmp) / p.length; tmp ~ q *@ q; - val meansqq = dotprod(tmp, tmp) / p.length; + val meansqq = dotprod(tmp, tmp) / q.length; val cosp = dotprod(p, q) / (p.length * math.sqrt(meanp * meanq).toFloat + epsilon); Mat.logger.info("before: i=%d, cos(p,q)=%g, tote=%g, meanp=%g, meanq=%g, varp=%g, varq=%g" format (i, cosp, meanp+meanq, meanp, meanq, meansqp - meanp * meanp, meansqq - meanq * meanq)); } @@ -385,7 +385,7 @@ class GradCollide(override val opts:GradCollide.Opts = new GradCollide.Options) tmp ~ p *@ p; val meansqp = dotprod(tmp, tmp) / p.length; tmp ~ q *@ q; - val meansqq = dotprod(tmp, tmp) / p.length; + val meansqq = dotprod(tmp, tmp) / q.length; val cosp = dotprod(p, q) / (p.length * math.sqrt(meanp * meanq).toFloat + epsilon); Mat.logger.info("after : i=%d, cos(p,q)=%g, tote=%g, meanp=%g, meanq=%g, varp=%g, varq=%g" format (i, cosp, meanp+meanq, meanp, meanq, meansqp - meanp * meanp, meansqq - meanq * meanq)); } @@ -409,7 +409,7 @@ class GradCollide(override val opts:GradCollide.Opts = new GradCollide.Options) tmp ~ p *@ p; val meansqp = dotprod(tmp, tmp) / p.length; tmp ~ q *@ q; - val meansqq = dotprod(tmp, tmp) / p.length; + val meansqq = dotprod(tmp, tmp) / q.length; val cosp = dotprod(p, q) / (p.length * math.sqrt(meanp * meanq).toFloat + epsilon); val varp = meansqp - meanp * meanp; val varq = meansqq - meanq * meanq; @@ -445,7 +445,7 @@ class GradCollide(override val opts:GradCollide.Opts = new GradCollide.Options) tmp ~ p *@ p; val meansqp = dotprod(tmp, tmp) / p.length; tmp ~ q *@ q; - val meansqq = dotprod(tmp, tmp) / p.length; + val meansqq = dotprod(tmp, tmp) / q.length; val cosp = dotprod(p, q) / (p.length * math.sqrt(meanp * meanq).toFloat + epsilon); val varp = meansqp - meanp * meanp; val varq = meansqq - meanq * meanq;