Node2Vec spark版本采样生成序列
前言
最近对node2vec比较感兴趣,再有源码的加持,想在生产环境复现一把,在复现过程中,发现几处bug(有向图的生成,边的起点和终点的拼接符号),本文予以修正,涉及Alias采样方法也给出了参考,每个函数和重要过程加了注释,还愣着干啥,赶紧在copy到你那里吧,记得点赞,收藏、关注哦
1. 定义图顶点和边的属性
case class EdgeAttr(var dstNeighbors: Array[Long] = Array.empty[Long],var J: Array[Int] = Array.empty[Int],var q: Array[Double] = Array.empty[Double])case class NodeAttr(var neighbors: Array[(Long, Double)] = Array.empty[(Long, Double)],var path: Array[Long] = Array.empty[Long])
2. 实现采样方法和定义有向图和无向图
2.1 原理
文章设计了一个灵活的采样策略用于平衡BFS和DFS,即利用带偏置的随机游走策略来,该方式可以BFS和DFS的方式探索邻近区域
在学习Node2Vec过程中,概率转移矩阵的计算用到了非均匀随机抽样方法,根据当前node的权重,决定下一次访问哪个邻接点
2.2 代码
import scala.collection.mutable.ArrayBufferobject GraphOps {def setupAlias(nodeWeights: Array[(Long, Double)]): (Array[Int], Array[Double]) = {val K = nodeWeights.lengthval J = Array.fill(K)(0)val q = Array.fill(K)(0.0)val smaller = new ArrayBuffer[Int]()val larger = new ArrayBuffer[Int]()val sum = nodeWeights.map(_._2).sumnodeWeights.zipWithIndex.foreach { case ((nodeId, weight), i) =>q(i) = K * weight / sumif (q(i) < 1.0) {smaller.append(i)} else {larger.append(i)}}while (smaller.nonEmpty && larger.nonEmpty) {val small = smaller.remove(smaller.length - 1)val large = larger.remove(larger.length - 1)J(small) = largeq(large) = q(large) + q(small) - 1.0if (q(large) < 1.0) smaller.append(large)else larger.append(large)}(J, q)}def setupEdgeAlias(p: Double = 1.0, q: Double = 1.0)(srcId: Long, srcNeighbors: Array[(Long, Double)], dstNeighbors: Array[(Long, Double)]): (Array[Int], Array[Double]) = {val neighbors_ = dstNeighbors.map { case (dstNeighborId, weight) =>var unnormProb = weight / qif (srcId == dstNeighborId) unnormProb = weight / pelse if (srcNeighbors.exists(_._1 == dstNeighborId)) unnormProb = weight(dstNeighborId, unnormProb)}setupAlias(neighbors_)}def drawAlias(J: Array[Int], q: Array[Double]): Int = {val K = J.lengthval kk = math.floor(math.random * K).toIntif (math.random < q(kk)) kkelse J(kk)}lazy val createUndirectedEdge = (srcId: Long, dstId: Long, weight: Double) => {Array((srcId, Array((dstId, weight))),(dstId, Array((srcId, weight))))}lazy val createDirectedEdge = (srcId: Long, dstId: Long, weight: Double) => {Array((srcId, Array((dstId, weight))))}
}
2.3 参考
Alias Method:时间复杂度O(1)的离散采样方法
Alias Method: 非均匀随机抽样算法
【数学】时间复杂度O(1)的离散采样算法—— Alias method/别名采样方法
【Graph Embedding】node2vec:算法原理,实现和应用
浅梦的学习笔记
3. 生成过程和最终结果
参考源码和issue,解决了一些bug,并且idea本地验证通过,大数据量集群验证通过
3.1 代码逻辑
-
加载原始数据
-
将原始序列,转换为原始边三元组,格式为(srcId,dstId,weight),其中srcId表示边的起点,dstId表示表的终点,weight表示边的起点和终点出现次数,计算过程使用了聚合函数reduceByKey
-
将原始顶点index化
-
将index->原始顶点转为map,并广播
-
生成index化的三元组边
-
根据index之后的三元组,格式RDD[(Long, Long, Double)],生成图的顶点和边
-
初始化图的顶点属性和图的边属性
-
随机游走,采样生成序列,bug修改,参考 https://github.com/aditya-grover/node2vec/issues/29
-
映射回原始的采样序列
-
显示采样结果
3.2 代码
/*** 配置类** @param numPartition 分区数量* @param walkLength 每个顶点采样序列长度* @param numWalks 每个顶点采样次数* @param p 返回参数* @param q in-out参数* @param directed 是否有向图,有向图有bug,此处已经修复,参考 https://github.com/aditya-grover/node2vec/issues/29* @param degree 顶点的度* @param input 数据txt路径,没上传样例数据,不过代码中有给出例子,可以参考,序列按照逗号分隔,如v1,v2,v3,v4*/
case class Config(var numPartition: Int = 10,var walkLength: Int = 8,var numWalks: Int = 5,var p: Double = 1.0,var q: Double = 1.0,var directed: Boolean = true,var degree: Int = 30,var input: String = "./data")
package com.test.graphimport org.apache.spark.SparkContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.graphx.{Edge, EdgeTriplet, Graph, VertexId}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}import scala.collection.mutable.ArrayBufferobject Node2vec {val config: Config = Config()def main(args: Array[String]): Unit = {val spark: SparkSession = SparkSession.builder().master("local[*]").appName("Node2vec").getOrCreate()val sc = spark.sparkContextimport spark.implicits._// 1. 加载原始数据/*** 样例数据:* v1,v2,v3,v4* v3,v2,v1,v6* v1,v2,v6,v7* v1,v10,v8,v4* v1,v3,v8,v4* v1,v10,v9,v4* v1* v1,v10,v9,v11*/val sequenceRDD: RDD[String] = sc.textFile(config.input)// 2. 将原始序列,转换为原始边三元组,格式为(srcId,dstId,weight),其中srcId表示边的起点,dstId表示表的终点,weight表示边的起点和终点出现次数,计算过程使用了聚合函数reduceByKeyval rawEdges: RDD[(String, String, Double)] = sequenceProcess(sequenceRDD)// 3. 将原始顶点index化val node2Id: RDD[(String, VertexId)] = createNode2Id(rawEdges)// 4. 将index->原始顶点转为map,并广播val id2NodeMap: collection.Map[VertexId, String] = node2Id.map {case (node_id, node_index) => (node_index, node_id)}.collectAsMap()val id2NodeMapBC: Broadcast[collection.Map[VertexId, String]] = sc.broadcast(id2NodeMap)// 5. 生成index化的三元组边val inputTriplets: RDD[(VertexId, VertexId, Double)] = indexingGraph(rawEdges, node2Id)// 显示中间结果rawEdges.toDF("src_id", "dst_id", "weight").show(false)/*** +------+------+------+* |src_id|dst_id|weight|* +------+------+------+* |v1 |v3 |1.0 |* |v10 |v9 |2.0 |* |v1 |v10 |3.0 |* |v8 |v4 |2.0 |* |v2 |v6 |1.0 |* |v3 |v2 |1.0 |* |v2 |v3 |1.0 |* |v1 |v6 |1.0 |* |v3 |v4 |1.0 |* |v1 |v2 |2.0 |* |v9 |v11 |1.0 |* |v9 |v4 |1.0 |* |v10 |v8 |1.0 |* |v6 |v7 |1.0 |* |v2 |v1 |1.0 |* |v3 |v8 |1.0 |* +------+------+------+*/inputTriplets.toDF("src_index", "dst_index", "weight").show(false)/*** +---------+---------+------+* |src_index|dst_index|weight|* +---------+---------+------+* |3 |0 |2.0 |* |6 |0 |1.0 |* |7 |0 |1.0 |* |7 |1 |1.0 |* |4 |2 |1.0 |* |8 |2 |1.0 |* |5 |3 |1.0 |* |6 |3 |1.0 |* |6 |4 |1.0 |* |8 |4 |2.0 |* |8 |5 |3.0 |* |4 |6 |1.0 |* |8 |6 |1.0 |* |5 |7 |2.0 |* |4 |8 |1.0 |* |2 |9 |1.0 |* +---------+---------+------+*/// 6. 根据index之后的三元组,格式RDD[(Long, Long, Double)],生成图的顶点和边val (indexedNodes, indexedEdges) = buildGraph(inputTriplets)// 7. 初始化图的顶点属性和图的边属性val graph: Graph[NodeAttr, EdgeAttr] = initTransitionProb(indexedNodes = indexedNodes, indexedEdges = indexedEdges)// 8. 随机游走,采样生成序列,bug修改,参考 https://github.com/aditya-grover/node2vec/issues/29val indexedSequenceRDD: RDD[(VertexId, ArrayBuffer[VertexId])] = randomWalk(graph)// 9. 映射回原始的采样序列val sampledSequenceDF: DataFrame = indexedSequenceRDD.map {case (vertexId, path) => {path.map(elem => id2NodeMapBC.value.getOrElse(elem, "")).mkString(",")}}.toDF("sampled_sequence")// 10. 显示采样结果sampledSequenceDF.show(1000, false)/*** +------------------------+* |sampled_sequence |* +------------------------+* |v9,v4 |* |v10,v9,v4 |* |v1,v6,v7 |* |v6,v7 |* |v2,v3,v8,v4 |* |v8,v4 |* |v3,v2,v1,v2,v3,v4 |* |v9,v4 |* |v10,v9,v11 |* |v3,v2,v1,v6,v7 |* |v1,v6,v7 |* |v6,v7 |* |v8,v4 |* |v2,v3,v4 |* |v9,v4 |* |v6,v7 |* |v3,v2,v1,v10,v8,v4 |* |v1,v10,v8,v4 |* |v10,v8,v4 |* |v8,v4 |* |v2,v3,v4 |* |v9,v4 |* |v1,v2,v3,v2,v1,v10,v9,v4|* |v6,v7 |* |v2,v6,v7 |* |v10,v8,v4 |* |v3,v8,v4 |* |v8,v4 |* |v10,v9,v4 |* |v1,v10,v9,v4 |* |v9,v11 |* |v2,v1,v10,v9,v11 |* |v6,v7 |* |v3,v8,v4 |* |v8,v4 |* +------------------------+*/}/*** 将原始序列,转换为原始边三元组,格式为(srcId,dstId,weight),其中srcId表示边的起点,dstId表示表的终点,weight表示边的起点和终点出现次数,计算过程使用了聚合函数reduceByKey** @param sequenceRDD 用户序列,用逗号分隔* @return 返回(srcId,dstId,weight),类型RDD[(String, String, Double)]*/def sequenceProcess(sequenceRDD: RDD[String]): RDD[(String, String, Double)] = {sequenceRDD.flatMap(line => {val sequenceArray: Array[String] = line.split(",")val pairSeq = ArrayBuffer[(String, Int)]()var previousItem: String = nullsequenceArray.foreach((element: String) => {if (previousItem != null) {pairSeq.append((previousItem + ":" + element, 1))}previousItem = element})pairSeq}).reduceByKey(_ + _).map { case (pair: String, weight: Int) =>val arr: Array[String] = pair.split(":")(arr(0), arr(1), weight.toDouble)}}/*** 生成index化的三元组边** @param rawEdges 原始边三元组id,格式RDD[(String, String, Double)]* @param node2Id 每个顶点对应的索引,格式RDD[(String, VertexId)]* @return 返回index之后的三元组,格式RDD[(Long, Long, Double)]*/def indexingGraph(rawEdges: RDD[(String, String, Double)], node2Id: RDD[(String, VertexId)]): RDD[(Long, Long, Double)] = {rawEdges.map { case (src, dst, weight) =>(src, (dst, weight))}.join(node2Id).map { case (src, (edge: (String, Double), srcIndex: Long)) =>try {val (dst: String, weight: Double) = edge(dst, (srcIndex, weight))} catch {case e: Exception => null}}.filter(_ != null).join(node2Id).map { case (dst, (edge: (Long, Double), dstIndex: Long)) =>try {val (srcIndex, weight) = edge(srcIndex, dstIndex, weight)} catch {case e: Exception => null}}.filter(_ != null)}/*** 将原始顶点index化** @param rawEdges 原始边三元组id,格式RDD[(String, String, Double)]* @tparam T 泛型* @return 返回每个顶点对应的索引,格式RDD[(String, VertexId)]*/def createNode2Id[T <: Any](rawEdges: RDD[(String, String, T)]): RDD[(String, VertexId)] = rawEdges.flatMap { case (src, dst, weight) =>val strings: Array[String] = Array(src, dst)strings}.distinct().zipWithIndex()/*** 根据index之后的三元组,格式RDD[(Long, Long, Double)],生成图的顶点和边** @param inputTriplets index之后的三元组,格式RDD[(Long, Long, Double)]* @param config 图的配置信息* @return 返回图的顶点和边*/def buildGraph(inputTriplets: RDD[(VertexId, VertexId, Double)]): (RDD[(VertexId, NodeAttr)], RDD[Edge[EdgeAttr]]) = {val sc: SparkContext = inputTriplets.sparkContextval bcMaxDegree = sc.broadcast(config.degree)val bcEdgeCreator = config.directed match {case true => sc.broadcast(GraphOps.createDirectedEdge)case false => sc.broadcast(GraphOps.createUndirectedEdge)}val indexedNodes = inputTriplets.flatMap { case (srcId, dstId, weight) =>bcEdgeCreator.value.apply(srcId, dstId, weight)}.reduceByKey(_ ++ _).map { case (nodeId, neighbors: Array[(VertexId, Double)]) =>var neighbors_ = neighborsif (neighbors_.length > bcMaxDegree.value) {neighbors_ = neighbors.sortWith { case (left, right) => left._2 > right._2 }.slice(0, bcMaxDegree.value)}(nodeId, NodeAttr(neighbors = neighbors_.distinct))}.repartition(config.numPartition).cacheval indexedEdges = indexedNodes.flatMap { case (srcId, clickNode) =>clickNode.neighbors.map { case (dstId, weight) =>Edge(srcId, dstId, EdgeAttr())}}.repartition(config.numPartition).cache(indexedNodes, indexedEdges)}/*** 初始化图的顶点属性和图的边属性** @param indexedNodes 图的顶点* @param indexedEdges 图的边* @return 返回构建好的图*/def initTransitionProb(indexedNodes: RDD[(VertexId, NodeAttr)], indexedEdges: RDD[Edge[EdgeAttr]]): Graph[NodeAttr, EdgeAttr] = {val sc = indexedEdges.sparkContextval bcP = sc.broadcast(config.p)val bcQ = sc.broadcast(config.q)Graph(indexedNodes, indexedEdges).mapVertices[NodeAttr] { case (vertexId, nodeAttr) =>if (nodeAttr != null) {val (j, q) = GraphOps.setupAlias(nodeAttr.neighbors)val nextNodeIndex = GraphOps.drawAlias(j, q)nodeAttr.path = Array(vertexId, nodeAttr.neighbors(nextNodeIndex)._1)nodeAttr} else {NodeAttr()}}.mapTriplets { edgeTriplet: EdgeTriplet[NodeAttr, EdgeAttr] =>val (j, q) = GraphOps.setupEdgeAlias(bcP.value, bcQ.value)(edgeTriplet.srcId, edgeTriplet.srcAttr.neighbors, edgeTriplet.dstAttr.neighbors)edgeTriplet.attr.J = jedgeTriplet.attr.q = qedgeTriplet.attr.dstNeighbors = edgeTriplet.dstAttr.neighbors.map(_._1)edgeTriplet.attr}.cache}/*** 随机游走,采样生成序列,bug修改,参考 https://github.com/aditya-grover/node2vec/issues/29** @param graph 图* @return 返回采样生成的序列*/def randomWalk(graph: Graph[NodeAttr, EdgeAttr]): RDD[(VertexId, ArrayBuffer[VertexId])] = {var randomWalkPaths: RDD[(Long, ArrayBuffer[Long])] = nullval edge2attr = graph.triplets.map { edgeTriplet =>// 起点和终点之间加入拼接符号,解决11,13 和111,3拼接出问题(s"${edgeTriplet.srcId}->${edgeTriplet.dstId}", edgeTriplet.attr)}.repartition(config.numPartition).cachefor (iter <- 0 until config.numWalks) {var prevWalk: RDD[(Long, ArrayBuffer[Long])] = null// 保证path非空,否则后面程序出现空指针异常var randomWalk = graph.vertices.filter(_._2.path.nonEmpty).map { case (nodeId, clickNode) =>val pathBuffer = new ArrayBuffer[Long]()pathBuffer.append(clickNode.path: _*)(nodeId, pathBuffer)}.cache// 每次迭代,保存旧的RDD,当生成新的RDD后,在内存中释放掉旧的RDD,由于initTransitionProb函数将graph保存到内容中,此处将graph从内存中释放,保证每次迭代从头开始采样graph.unpersist(blocking = false)graph.edges.unpersist(blocking = false)for (walkCount <- 0 until config.walkLength) {// 每次迭代,保存旧的RDD,当生成新的RDD后,在内存中释放掉旧的RDDprevWalk = randomWalkrandomWalk = randomWalk.map { case (srcNodeId, pathBuffer) =>val prevNodeId = pathBuffer(pathBuffer.length - 2)val currentNodeId = pathBuffer.last(s"$prevNodeId->$currentNodeId", (srcNodeId, pathBuffer))}.join(edge2attr).map { case (edge, ((srcNodeId, pathBuffer), attr)) =>try {if (pathBuffer != null && pathBuffer.nonEmpty && attr.dstNeighbors != null && attr.dstNeighbors.nonEmpty) {val nextNodeIndex = GraphOps.drawAlias(attr.J, attr.q)val nextNodeId = attr.dstNeighbors(nextNodeIndex)pathBuffer.append(nextNodeId)}(srcNodeId, pathBuffer)} catch {case e: Exception => throw new RuntimeException(e.getMessage)}}.cache// 在内存中释放掉旧的RDDprevWalk.unpersist(blocking = false)}if (randomWalkPaths != null) {// 每次迭代,保存旧的RDD,当生成新的RDD后,在内存中释放掉旧的RDDval prevRandomWalkPaths = randomWalkPathsrandomWalkPaths = randomWalkPaths.union(randomWalk).cache()// 在内存中释放掉旧的RDDprevRandomWalkPaths.unpersist(blocking = false)} else {randomWalkPaths = randomWalk}}randomWalkPaths}}