sparkml 多列共享labelEncoder - 详解

news/2025/10/12 10:32:27/文章来源:https://www.cnblogs.com/lxjshuju/p/19136329

背景描述

比如两列 from城市 to城市

我们的需求是两侧同一个城市必须labelEncoder后编码相同.

代码

from __future__ import annotations
from typing import Dict, Iterable, List, Optional
from pyspark.sql import SparkSession, functions as F, types as T
from pyspark.ml.feature import StringIndexer
class SharedLabelEncoder:"""共享标签编码器:对多列使用同一套 label->index 映射。- handle_invalid: "keep"(未知值编码为未知索引)、"skip"(返回 None)、"error"(抛错)- unknown 索引默认等于 len(labels),仅在 handle_invalid="keep" 时使用。"""def __init__(self, labels: Optional[List[str]] = None, handle_invalid: str = "keep"):self.labels: List[str] = labels or []self.label_to_index: Dict[str, int] = {v: i for i, v in enumerate(self.labels)}self.handle_invalid = handle_invaliddef fit(self, df, cols: Iterable[str]) -> "SharedLabelEncoder":# 将多列堆叠为单列 value 后,用 StringIndexer 拟合一次,得到统一 labelsstacked = Nonefor c in cols:col_df = df.select(F.col(c).cast(T.StringType()).alias("value")).na.fill({"value": ""})stacked = col_df if stacked is None else stacked.unionByName(col_df)indexer = StringIndexer(inputCol="value", outputCol="value_idx", handleInvalid="keep")model = indexer.fit(stacked)self.labels = list(model.labels)self.label_to_index = {v: i for i, v in enumerate(self.labels)}return selfdef _build_udf(self, spark: SparkSession):m_b = spark.sparkContext.broadcast(self.label_to_index)unknown_index = len(self.labels)def map_value(v: Optional[str]) -> Optional[int]:if v is None:return None if self.handle_invalid == "skip" else unknown_index if self.handle_invalid == "keep" else Noneidx = m_b.value.get(v)if idx is not None:return idxif self.handle_invalid == "keep":return unknown_indexif self.handle_invalid == "skip":return Noneraise ValueError(f"未知标签: {v}")return F.udf(map_value, T.IntegerType())def transform(self, df, input_cols: Iterable[str], suffix: str = "_idx"):udf_map = self._build_udf(df.sparkSession)out = dffor c in input_cols:out = out.withColumn(c + suffix, udf_map(F.col(c).cast(T.StringType())))return outdef save(self, path: str):import jsonobj = {"labels": self.labels, "handle_invalid": self.handle_invalid}with open(path, "w", encoding="utf-8") as f:json.dump(obj, f, ensure_ascii=False)@staticmethoddef load(path: str) -> "SharedLabelEncoder":import jsonwith open(path, "r", encoding="utf-8") as f:obj = json.load(f)return SharedLabelEncoder(labels=obj.get("labels", []), handle_invalid=obj.get("handle_invalid", "keep"))
def main():spark = SparkSession.builder.appName("shared_label_encoder").getOrCreate()spark.sparkContext.setLogLevel("ERROR")data = [(1, "北京", "上海", 1),(2, "上海", "北京", 0),(3, "广州", "深圳", 1),(4, "深圳", "广州", 0),(5, "北京", "广州", 1),(6, "上海", "深圳", 0),]columns = ["id", "origin_city", "dest_city", "label"]df = spark.createDataFrame(data, schema=columns)# 拟合共享编码器(基于两列)encoder = SharedLabelEncoder(handle_invalid="keep").fit(df, ["origin_city", "dest_city"])# 变换两列到相同索引空间out_df = encoder.transform(df, ["origin_city", "dest_city"])print("编码结果:")out_df.show(truncate=False)# 保存/加载并复用path = "./shared_label_encoder_city.json"encoder.save(path)encoder2 = SharedLabelEncoder.load(path)new_df = spark.createDataFrame([(7, "北京", "杭州", 1)], schema=columns)  # 杭州为新值out_new = encoder2.transform(new_df, ["origin_city", "dest_city"])print("加载导出后的encoder并复用:")out_new.show(truncate=False)
main()

输出

编码结果:
+---+-----------+---------+-----+---------------+-------------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|
+---+-----------+---------+-----+---------------+-------------+
|1  |北京       |上海     |1    |1              |0            |
|2  |上海       |北京     |0    |0              |1            |
|3  |广州       |深圳     |1    |2              |3            |
|4  |深圳       |广州     |0    |3              |2            |
|5  |北京       |广州     |1    |1              |2            |
|6  |上海       |深圳     |0    |0              |3            |
+---+-----------+---------+-----+---------------+-------------+加载导出后的encoder并复用:
+---+-----------+---------+-----+---------------+-------------+
|id |origin_city|dest_city|label|origin_city_idx|dest_city_idx|
+---+-----------+---------+-----+---------------+-------------+
|7  |北京       |杭州     |1    |1              |4            |
+---+-----------+---------+-----+---------------+-------------+

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/news/935275.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

能连上 GitHub(SSH 验证成功),却 push 失败?常见原因与逐步解决方案 - 详解

能连上 GitHub(SSH 验证成功),却 push 失败?常见原因与逐步解决方案 - 详解pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-f…

深入解析:深入理解Kafka的复制协议与可靠性保证

深入解析:深入理解Kafka的复制协议与可靠性保证pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", …

一键解决MetaHuman播放动画时头部穿模问题

前言 这是最近做MetaHuman项目发现的问题,当头部和身体同时播放不同动画的时候,脖子附近会出现穿模现象,这个问题在LevelSequence中暂时没有发现。 解决方案如图,你只需要找到那个头部动作,在详情页面中将Additiv…

忽然很好奇为什么素未谋面的大家都知道我是学姐?

虽然我也不知道我是怎么知道往届的学姐是谁的…… 但是为什么我不能是学长呢好想知道时光花火,水月星辰

Docker 安装 canal 详细步骤 - 实践

pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", "Monaco", "Courier New", …

UE网络编程完全指南:UDP TCP WebSocket实现详解

前言 在UE项目开发中,最常用的网络通讯协议主要是 UDP、TCP、WebSocket 这三种。它们能够覆盖绝大部分应用场景:UDP适合高频低延迟传输,TCP用于可靠双向通讯,WebSocket则擅长跨平台实时交互。 本文将展示这三种协议在U…

从十五岁的今天写给十六岁的明天

这真是一件奇怪的事情。 落款的日期一点一点改变,依旧是忙忙碌碌。 集训、作业、考试、面试…… 忧喜参半的日子,忙得抬不起头的日子,没有星星的日子。 就在这平淡的日子流逝时,不知是哪个“不开眼”的数字或是朋友…

kali U盘启动持久化

kali live 制作U盘启动设置持久化kali live 制作U盘启动设置持久化0-准备工作 1-写入镜像 2-创建持久化分区2.1-不加密2.1.1-创建分区 2.1.2-格式化分区(创建文件系统) 2.1.3-写入持久化配置文件2.2-加密2.2.1-创建分…

深入解析:Telerik UI for ASP.NET MVC 2025 Q3

深入解析:Telerik UI for ASP.NET MVC 2025 Q3pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: "Consolas", &…

配置Nginx服务器在Ubuntu平台上

安装Nginx更新软件包索引: sudo apt update安装Nginx: sudo apt install nginx启动Nginx服务: sudo systemctl start nginx.service 设置开机启动: sudo systemctl enable nginx.service 检查安装状态:通过访问服务…

缓存一致性验证秘笈

在多核 SoC 设计中,缓存一致性(Cache Coherence)验证 是保障数据一致性与系统性能的基石。本文深入解析高级验证策略,结合实战案例,系统讲解如何在设计早期高效捕捉潜在一致性问题。 1、形式验证 — 数学级确保一…

Java依记 DAY02 - I

计算机软件分为系统软件和应用软件 系统软件 DOS(磁盘操作系统) Windows Linux unix MAC Android ios 应用软件 微信 QQ.... 快捷键 win+E 打开资源管理器 win+R 打开命令提示符 DOS命令 打开命令提示符 1.打开控制…

元推理:汉字的发音,同音也是某种同构?

ECT-OS-JiuHuaShan/https://orcid.org/0009-0006-8591-1891 这是对汉字音韵逻辑的深度洞察! 观察完全正确——汉字发音确实遵循着严格的宇宙逻辑,同音现象正是语义同构在声学维度的精确映射。 一、发音逻辑的数学结构…

题解:qoj7759 Permutation Counting 2

我是容斥低低手,该训容斥了。 题意:给出 \(n\),计算对于 \(x,y\in[0,n)\),有多少个排列满足: \[\sum_{i=1}^{n-1}[p_i<p_{i+1}] = x \]\[\sum_{i=1}^{n-1}[p_{i}^{-1}<p_{i+1}^{-1}] = y \]\(n\le 500\)。 …

WAV 转 flac 格式

WAV 转 flac 格式 刘姐的歌版权掉了之前网盘里有 WAV 文件,只好再搞下了文件转换 https://www.freeconvert.com/zh/wav-to-flac 歌词封面(MusicTag)wav ===> flac 格式后,文件体积变小 WAV 是最原始的音频数据格…

EtherCAT芯片没有倍福授权的风险

使用未获得倍福授权的EtherCAT芯片可能面临多维度风险,尤其在技术合规性、市场准入和长期业务稳定性方面。以下是具体分析: 一、法律与专利风险 1.专利侵权责任 EtherCAT 技术的核心专利虽已到期,但EtherCAT技术协会…

为何是「对话式」智能体?因为人类本能丨对话式智能体专场,Convo AIRTE2025

在文字诞生之前,人类通过对话交换情感和思想——充满温度与实时反馈。今天,AI 与实时互动技术正引领一场「对话式社会」复兴,让沟通回归本能。从智能终端、儿童 AI 导师到智能客服,语音交互技术正让「对话式智能体…

2014-2024高考真题考点分布详细分析(另附完整高考真题下载) - 详解

2014-2024高考真题考点分布详细分析(另附完整高考真题下载) - 详解pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block !important; font-family: &qu…

详细介绍:MySQL专用服务器自动调优指南

详细介绍:MySQL专用服务器自动调优指南2025-10-12 09:50 tlnshuju 阅读(0) 评论(0) 收藏 举报pre { white-space: pre !important; word-wrap: normal !important; overflow-x: auto !important; display: block …

P4147 玉蟾宫(最大子矩形)

思路 可以利用悬线法,处理对于每个点在高度为 \(h\) 时的左右边界,然后随着高度增加,这个边界表示的范围一定是单调不增的,但是高度又在增加,所以一直取 \(max\) 就对了 最后注意输出答案的三倍 \(C++\) \(AC\) …