注册 登录  
 加关注
   显示下一条  |  关闭
温馨提示!由于新浪微博认证机制调整,您的新浪微博帐号绑定已过期,请重新绑定!立即重新绑定新浪微博》  |  关闭

阿弥陀佛

街树飘影未见尘 潭月潜水了无声 般若观照心空静...

 
 
 

日志

 
 
关于我

一直从事气象预报、服务建模实践应用。 注重气象物理场、实况场、地理信息、本体知识库、分布式气象内容管理系统建立。 对Barnes客观分析, 小波,计算神经网络、信任传播、贝叶斯推理、专家系统、网络本体语言有一定体会。 一直使用Java、Delphi、Prolog、SQL编程。

网易考拉推荐

scala kd-tree  

2015-02-10 08:39:44|  分类: Scala |  标签: |举报 |字号 订阅

  下载LOFTER 我的照片书  |
package common
import scala.collection.SortedMap
/**
* Created by 何险峰 on 15-2-5.
* 为了尽可能的列出邻近点,对原始代码KDNode.nearest进行了些修改。
* 尽管如此,本KDTree算法,仅能够取一个,相同“欧氏距离”的点
* 参见: http://rosettacode.org/wiki/K-d_tree*
*/
object KDTree {
// Task 1A. Build tree of KDNodes. Translated from Wikipedia.
def apply[T](points: Seq[Seq[T]], depth: Int = 0)(implicit num: Numeric[T]): Option[KDNode[T]] = {
val dim = points.headOption.map(_.size) getOrElse 0
if (points.isEmpty || dim < 1) None
else {
val axis = depth % dim
val sorted = points.sortBy(f => f(axis))
val median = sorted(sorted.size / 2)(axis)
val (left, right) = sorted.partition(v => num.lt(v(axis), median))
Some(KDNode(right.head, apply(left, depth + 1), apply(right.tail, depth + 1), axis))
}
}
def allnearest[T](haystack: Seq[Seq[T]], needles: Seq[T])(implicit num: Numeric[T]):SortedMap[T,Seq[T]]={
val kd = KDTree(haystack).get
val ne = kd.nearest(needles)
var ss = SortedMap[T,Seq[T]]()
for (node<-ne.visited){
val dd = KDTree.distsq(needles,node.value)
ss += dd -> node.value
}
ss
}
// Task 1B. Find the nearest node in this subtree. Translated from Wikipedia.
case class KDNode[T](value: Seq[T], left: Option[KDNode[T]], right: Option[KDNode[T]], axis: Int)(implicit num: Numeric[T]) {
def nearest(to: Seq[T]): Nearest[T] = {
val default = Nearest[T](value, to, Set(this))
compare(to, value) match {
case 0 => default // exact match
case t =>
lazy val bestL = left.map(_ nearest to).getOrElse(default)
lazy val bestR = right.map(_ nearest to).getOrElse(default)
val visited = bestL.visited ++ bestR.visited + this
val branch1 = if (t <= 0) bestL else bestR
val best = if (num.lteq(branch1.distsq, default.distsq)) branch1 else default
val splitDist = num.minus(to(axis), value(axis))
val d2 = num.times(splitDist, splitDist)
val lteq = num.lteq(d2, best.distsq)
if (lteq) {
val branch2 = if (t <= 0) bestR else bestL
val visi = branch2.visited ++ visited
if (num.lteq(branch2.distsq, best.distsq)) branch2.copy(visited = visi) else best.copy(visited = visi)
} else best.copy(visited = visited )
}
}
}
// Keep track of nodes visited, as per task. Pretty-printable.
case class Nearest[T](value: Seq[T], to: Seq[T], visited: Set[KDNode[T]] = Set[KDNode[T]]())(implicit num: Numeric[T]) {
import scala.math._
lazy val distsq = KDTree.distsq(value, to)
lazy val distance=sqrt(num.toDouble(distsq))
override def toString = f"Searched for=${to} found=${value} distance=${distance}%.4f visited=${visited.size}"
}
// Numeric utilities
def distsq[T](a: Seq[T], b: Seq[T])(implicit num: Numeric[T]) =
a.zip(b).map(c => num.times(num.minus(c._1, c._2), num.minus(c._1, c._2))).sum
def compare[T](a: Seq[T], b: Seq[T])(implicit num: Numeric[T]): Int =
a.zip(b).find(c => num.compare(c._1, c._2) != 0).map(c => num.compare(c._1, c._2)).getOrElse(0)
}

=======================================================================
package common

object KDTreeTest extends App {
def test[T](haystack: Seq[Seq[T]], needles: Seq[T]*)(implicit num: Numeric[T]) = {
println
val tree = KDTree(haystack)
if (haystack.size < 20) tree.foreach(println)
for (kd <- tree; needle <- needles; nearest = kd.nearest(needle)) {
println(nearest)
// Brute force proof
val better = haystack
.map(KDTree.Nearest(_, needle))
.filter(n => num.lt(n.distsq, nearest.distsq))
.sortBy(_.distsq)
assert(better.isEmpty, s"Found ${better.size} closer than ${nearest.value} e.g. ${better.head}").toString
}
}
def testNear[T](haystack: Seq[Seq[T]], needle: Seq[T])(implicit num: Numeric[T]) = {
val ss = KDTree.allnearest[T](haystack,needle)
ss.foreach(f => println(f._1 + "," + f._2))
println
}

def testBrute[T](haystack: Seq[Seq[T]], to: Seq[T])(implicit num: Numeric[T]) = {
val better = haystack
.map(KDTree.Nearest(_, to))
.sortBy(_.distsq)
better.foreach(f => println(f.distsq+","+f.value))
}
// Results 1
val wikitest = List(List(2,3), List(5,4), List(9,6), List(4,7), List(8,1), List(7,2))
//test(wikitest, List(9,6))
testNear(wikitest,List(9,2))

println("线 ,List(5,5)")
val linetest = List(List(5,3), List(5,4), List(5,6), List(5,7), List(5,1), List(5,9))
testNear(linetest,List(5,5))
println("Brute线 ,List(5,5)")
testBrute(linetest,List(5,5))

println("圆 ,List(8,8)")
val circletest = List(List(1,5), List(1,6), List(2,4), List(2,7), List(2,8), List(3,3),List(3,9),List(4,2),List(4,9),List(5,2),List(5,9),List(5,10),List(6,2),List(6,9), List(7,2),List(7,9),List(8,3),List(8,4),List(9,5),List(9,6),List(9,7))
testNear(circletest,List(8,8))
println("Brute圆 ,List(8,8)")
testBrute(circletest,List(8,8))
//螺旋
println("螺旋 ,List(9,8)")
val spiretest = List(List(1.0,9.0), List(1.0,10.0), List(1.0,11.0), List(1.0,12.0),
List(2.0,6.0),List(2.0,7.0),List(2.0,13.0),List(2.0,14.0),
List(3.0,5.0),List(3.0,15.0),
List(4.0,4.0),List(4.0,9.0),List(4.0,10.0),List(4.0,11.0),List(4.0,12.0),List(4.0,16.0),
List(5.0,3.0),List(5.0,7.0),List(5.0,8.0),List(5.0,13.0),List(5.0,17.0),
List(6.0,3.0),List(6.0,6.0),List(6.0,14.0),List(6.0,17.0),
List(7.0,2.0),List(7.0,5.0),List(7.0,9.0),List(7.0,10.0),List(7.0,11.0),List(7.0,15.0),List(7.0,18.0),
List(8.0,2.0),List(8.0,5.0),List(8.0,8.0),List(8.0,12.0),List(8.0,15.0),List(8.0,18.0),
List(9.0,2.0),List(9.0,5.0),List(9.0,12.0),List(9.0,15.0),List(9.0,18.0),
List(10.0,2.0),List(10.0,5.0),List(10.0,8.0),List(10.0,11.0),List(10.0,14.0),List(10.0,17.0),
List(11.0,2.0),List(11.0,5.0),List(11.0,9.0),List(11.0,14.0),List(11.0,17.0),
List(12.0,2.0),List(12.0,5.0),List(12.0,9.0),List(12.0,10.0),List(12.0,11.0),List(12.0,12.0),List(12.0,13.0),List(12.0,17.0),
List(13.0,3.0),List(13.0,6.0),List(13.0,16.0),
List(14.0,3.0),List(14.0,7.0),List(14.0,8.0),List(14.0,14.0),List(14.0,15.0),
List(15.0,4.0),List(15.0,9.0),List(15.0,10.0),List(15.0,11.0),List(15.0,12.0),List(15.0,13.0),
List(16.0,5.0),List(17.0,6.0),List(17.0,7.0),List(17.0,8.0),List(17.0,9.0),
List(18.0,10.0)
)
testNear(spiretest,List(9.0,8.0))

println("Brute螺旋 ,List(8,8)")
testBrute(spiretest,List(9.0,8.0))

// Results 2 (1000 points uniformly distributed in 3-d cube coordinates, sides 2 to 20)
val uniform = for(x <- 1 to 10; y <- 1 to 10; z <- 1 to 10) yield List(x*2, y*2, z*2)
assume(uniform.size == 1000)
test(uniform, List(0, 0, 0), List(2, 2, 20), List(9, 10, 11))

// Results 3 (1000 points randomly distributed in 3-d cube coordinates, sides -1.0 to 1.0)
scala.util.Random.setSeed(0)
def random(n: Int) = (1 to n).map(_ => (scala.util.Random.nextDouble - 0.5)* 2)
test((1 to 1000).map(_ => random(3)), random(3))

// Results 4 (27 points uniformly distributed in 3-d cube coordinates, sides 3...9)
val small = for(x <- 1 to 3; y <- 1 to 3; z <- 1 to 3) yield List(x*3, y*3, z*3)
assume(small.size == 27)
test(small, List(0, 0, 0), List(4, 5, 6))
}
=============================================================
输出结果
2,List(8, 1)
4,List(7, 2)
16,List(9, 6)
20,List(5, 4)
50,List(2, 3)

线 ,List(5,5)
1,List(5, 4)
4,List(5, 7)
16,List(5, 9)

Brute线 ,List(5,5)
1,List(5, 4)
1,List(5, 6)
4,List(5, 3)
4,List(5, 7)
16,List(5, 1)
16,List(5, 9)
圆 ,List(8,8)
2,List(7, 9)
5,List(6, 9)
10,List(9, 5)
13,List(5, 10)
16,List(8, 4)
17,List(4, 9)
25,List(8, 3)
26,List(3, 9)
36,List(2, 8)
37,List(7, 2)
40,List(6, 2)
45,List(5, 2)
50,List(3, 3)
52,List(2, 4)
53,List(1, 6)
58,List(1, 5)

Brute圆 ,List(8,8)
2,List(7, 9)
2,List(9, 7)
5,List(6, 9)
5,List(9, 6)
10,List(5, 9)
10,List(9, 5)
13,List(5, 10)
16,List(8, 4)
17,List(4, 9)
25,List(8, 3)
26,List(3, 9)
36,List(2, 8)
37,List(2, 7)
37,List(7, 2)
40,List(6, 2)
45,List(5, 2)
50,List(3, 3)
52,List(2, 4)
52,List(4, 2)
53,List(1, 6)
58,List(1, 5)
螺旋 ,List(9,8)
1.0,List(10.0, 8.0)
5.0,List(11.0, 9.0)
8.0,List(7.0, 10.0)
9.0,List(9.0, 5.0)
10.0,List(8.0, 5.0)
13.0,List(6.0, 6.0)
16.0,List(9.0, 12.0)
17.0,List(8.0, 12.0)
18.0,List(12.0, 11.0)
20.0,List(13.0, 6.0)
25.0,List(14.0, 8.0)
26.0,List(14.0, 7.0)
29.0,List(4.0, 10.0)
34.0,List(4.0, 11.0)
36.0,List(9.0, 2.0)
37.0,List(8.0, 2.0)
40.0,List(7.0, 2.0)
41.0,List(4.0, 12.0)
45.0,List(12.0, 2.0)
49.0,List(9.0, 15.0)
50.0,List(14.0, 3.0)
52.0,List(15.0, 12.0)
53.0,List(2.0, 6.0)
58.0,List(16.0, 5.0)
61.0,List(15.0, 13.0)
64.0,List(17.0, 8.0)
65.0,List(17.0, 9.0)
68.0,List(1.0, 10.0)
73.0,List(1.0, 11.0)
74.0,List(14.0, 15.0)
80.0,List(1.0, 12.0)
82.0,List(10.0, 17.0)
85.0,List(11.0, 17.0)
89.0,List(4.0, 16.0)
90.0,List(6.0, 17.0)
97.0,List(5.0, 17.0)
100.0,List(9.0, 18.0)
101.0,List(8.0, 18.0)
104.0,List(7.0, 18.0)

Brute螺旋 ,List(8,8)
1.0,List(8.0, 8.0)
1.0,List(10.0, 8.0)
5.0,List(7.0, 9.0)
5.0,List(11.0, 9.0)
8.0,List(7.0, 10.0)
9.0,List(9.0, 5.0)
10.0,List(8.0, 5.0)
10.0,List(10.0, 5.0)
10.0,List(10.0, 11.0)
10.0,List(12.0, 9.0)
13.0,List(6.0, 6.0)
13.0,List(7.0, 5.0)
13.0,List(7.0, 11.0)
13.0,List(11.0, 5.0)
13.0,List(12.0, 10.0)
16.0,List(5.0, 8.0)
16.0,List(9.0, 12.0)
17.0,List(5.0, 7.0)
17.0,List(8.0, 12.0)
18.0,List(12.0, 5.0)
18.0,List(12.0, 11.0)
20.0,List(13.0, 6.0)
25.0,List(12.0, 12.0)
25.0,List(14.0, 8.0)
26.0,List(4.0, 9.0)
26.0,List(14.0, 7.0)
29.0,List(4.0, 10.0)
34.0,List(4.0, 11.0)
34.0,List(6.0, 3.0)
34.0,List(12.0, 13.0)
36.0,List(9.0, 2.0)
37.0,List(8.0, 2.0)
37.0,List(10.0, 2.0)
37.0,List(10.0, 14.0)
37.0,List(15.0, 9.0)
40.0,List(7.0, 2.0)
40.0,List(11.0, 2.0)
40.0,List(11.0, 14.0)
40.0,List(15.0, 10.0)
41.0,List(4.0, 4.0)
41.0,List(4.0, 12.0)
41.0,List(5.0, 3.0)
41.0,List(5.0, 13.0)
41.0,List(13.0, 3.0)
45.0,List(3.0, 5.0)
45.0,List(6.0, 14.0)
45.0,List(12.0, 2.0)
45.0,List(15.0, 11.0)
49.0,List(9.0, 15.0)
50.0,List(2.0, 7.0)
50.0,List(8.0, 15.0)
50.0,List(14.0, 3.0)
52.0,List(15.0, 4.0)
52.0,List(15.0, 12.0)
53.0,List(2.0, 6.0)
53.0,List(7.0, 15.0)
58.0,List(16.0, 5.0)
61.0,List(14.0, 14.0)
61.0,List(15.0, 13.0)
64.0,List(17.0, 8.0)
65.0,List(1.0, 9.0)
65.0,List(17.0, 7.0)
65.0,List(17.0, 9.0)
68.0,List(1.0, 10.0)
68.0,List(17.0, 6.0)
73.0,List(1.0, 11.0)
74.0,List(2.0, 13.0)
74.0,List(14.0, 15.0)
80.0,List(1.0, 12.0)
80.0,List(13.0, 16.0)
82.0,List(10.0, 17.0)
85.0,List(2.0, 14.0)
85.0,List(3.0, 15.0)
85.0,List(11.0, 17.0)
85.0,List(18.0, 10.0)
89.0,List(4.0, 16.0)
90.0,List(6.0, 17.0)
90.0,List(12.0, 17.0)
97.0,List(5.0, 17.0)
100.0,List(9.0, 18.0)
101.0,List(8.0, 18.0)
104.0,List(7.0, 18.0)

Searched for=List(0, 0, 0) found=List(2, 2, 2) distance=3.4641 visited=1000
Searched for=List(2, 2, 20) found=List(2, 2, 20) distance=0.0000 visited=999
Searched for=List(9, 10, 11) found=List(10, 10, 10) distance=1.4142 visited=1000

Searched for=Vector(0.19269603520919643, -0.25958512078298535, -0.2572864045762784) found=Vector(0.07811099409527977, -0.2477618820196814, -0.20252227622550611) distance=0.1275 visited=1000

Searched for=List(0, 0, 0) found=List(3, 3, 3) distance=5.1962 visited=27
Searched for=List(4, 5, 6) found=List(3, 6, 6) distance=1.4142 visited=27

Process finished with exit code 0
scala kd-tree - 险峰 - 阿弥陀佛
 

  评论这张
 
阅读(477)| 评论(0)
推荐 转载

历史上的今天

在LOFTER的更多文章

评论

<#--最新日志,群博日志--> <#--推荐日志--> <#--引用记录--> <#--博主推荐--> <#--随机阅读--> <#--首页推荐--> <#--历史上的今天--> <#--被推荐日志--> <#--上一篇,下一篇--> <#-- 热度 --> <#-- 网易新闻广告 --> <#--右边模块结构--> <#--评论模块结构--> <#--引用模块结构--> <#--博主发起的投票-->
 
 
 
 
 
 
 
 
 
 
 
 
 
 

页脚

网易公司版权所有 ©1997-2017