【推荐算法】Embedding+MLP:TensorFlow实现经典深度学习推荐模型详解

Embedding+MLP:TensorFlow实现经典深度学习模型详解

    • 1. 算法逻辑
      • 模型结构和工作流程
      • 关键组件
    • 2. 算法原理与数学推导
      • Embedding层原理
      • MLP前向传播
      • 反向传播与优化
    • 3. 模型评估
      • 常用评估指标
      • 评估方法
    • 4. 应用案例:推荐系统CTR预测
      • 问题描述
      • 模型架构
      • 性能优化
    • 5. 常见面试题
    • 6. 优缺点分析
      • 优点
      • 缺点
      • 改进方向
    • 7. TensorFlow实现详解
      • 完整实现代码
      • 关键实现细节
      • 训练优化技巧
    • 总结

1. 算法逻辑

模型结构和工作流程

Embedding+MLP模型是处理高维稀疏特征的经典架构,广泛应用于推荐系统、广告点击率预测(CTR)等场景。其核心思想是将类别型特征通过Embedding层转换为低维稠密向量,再与连续特征拼接后输入多层感知机(MLP)。

类别特征
连续特征
输入特征
特征类型
Embedding层
归一化处理
特征向量
特征拼接
多层感知机 MLP
输出层
预测结果

关键组件

  1. Embedding层:将高维稀疏的类别特征映射为低维稠密向量

    • 输入:类别ID(整数索引)
    • 输出:固定维度的浮点数向量
      在这里插入图片描述
  2. 特征拼接层:将多个Embedding向量和连续特征拼接为单一特征向量

  3. 多层感知机(MLP):由多个全连接层组成,学习特征间的非线性关系
    在这里插入图片描述

2. 算法原理与数学推导

Embedding层原理

给定类别特征 i i i,其one-hot编码为 x ∈ { 0 , 1 } n \mathbf{x} \in \{0,1\}^n x{0,1}n,Embedding层本质是一个查找操作:

e i = W e T x \mathbf{e}_i = \mathbf{W}_e^T \mathbf{x} ei=WeTx

其中 W e ∈ R n × d \mathbf{W}_e \in \mathbb{R}^{n \times d} WeRn×d是嵌入矩阵, d d d是嵌入维度。实际实现中,等价于:

e i = W e [ i ] \mathbf{e}_i = \mathbf{W}_e[i] ei=We[i]

MLP前向传播

设拼接后的特征向量为 h ( 0 ) \mathbf{h}^{(0)} h(0),MLP的第 l l l层计算为:

z ( l ) = W ( l ) h ( l − 1 ) + b ( l ) \mathbf{z}^{(l)} = \mathbf{W}^{(l)} \mathbf{h}^{(l-1)} + \mathbf{b}^{(l)} z(l)=W(l)h(l1)+b(l)
h ( l ) = σ ( z ( l ) ) \mathbf{h}^{(l)} = \sigma(\mathbf{z}^{(l)}) h(l)=σ(z(l))

其中 σ \sigma σ是激活函数(如ReLU),最后一层使用sigmoid或softmax激活。

反向传播与优化

使用交叉熵损失函数:

L = − 1 N ∑ i = 1 N [ y i log ⁡ ( y ^ i ) + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] \mathcal{L} = -\frac{1}{N} \sum_{i=1}^{N} [y_i \log(\hat{y}_i) + (1-y_i)\log(1-\hat{y}_i)] L=N1i=1N[yilog(y^i)+(1yi)log(1y^i)]

通过链式法则计算梯度:

∂ L ∂ W ( l ) = ∂ L ∂ z ( l ) ∂ z ( l ) ∂ W ( l ) \frac{\partial \mathcal{L}}{\partial \mathbf{W}^{(l)}} = \frac{\partial \mathcal{L}}{\partial \mathbf{z}^{(l)}} \frac{\partial \mathbf{z}^{(l)}}{\partial \mathbf{W}^{(l)}} W(l)L=z(l)LW(l)z(l)

使用Adam等优化器更新参数:

W ( l ) ← W ( l ) − η ∂ L ∂ W ( l ) \mathbf{W}^{(l)} \leftarrow \mathbf{W}^{(l)} - \eta \frac{\partial \mathcal{L}}{\partial \mathbf{W}^{(l)}} W(l)W(l)ηW(l)L

3. 模型评估

常用评估指标

指标公式适用场景
AUC$$ \frac{\sum_{i \in P} \sum_{j \in N} I(\hat{y}_i > \hat{y}_j)}{P
LogLoss − 1 N ∑ i = 1 N [ y i log ⁡ y ^ i + ( 1 − y i ) log ⁡ ( 1 − y ^ i ) ] -\frac{1}{N} \sum_{i=1}^N [y_i \log \hat{y}_i + (1-y_i)\log(1-\hat{y}_i)] N1i=1N[yilogy^i+(1yi)log(1y^i)]概率预测质量
Precision T P T P + F P \frac{TP}{TP + FP} TP+FPTP关注假正例成本
Recall T P T P + F N \frac{TP}{TP + FN} TP+FNTP关注假负例成本

评估方法

  1. 离线评估:时间窗口划分训练/测试集
  2. 在线评估:A/B测试
  3. 特征重要性分析:通过消融实验评估特征贡献

4. 应用案例:推荐系统CTR预测

问题描述

预测用户点击广告的概率,特征包括:

  • 用户特征:ID、年龄、性别
  • 广告特征:ID、类别、价格
  • 上下文特征:时间、位置

模型架构

用户ID
User Embedding
广告ID
Ad Embedding
广告类别
Category Embedding
用户年龄
归一化
广告价格
特征拼接
全连接层 256
全连接层 128
全连接层 64
Sigmoid输出

性能优化

  1. 特征分桶:连续特征离散化
  2. 注意力机制:加权重要特征
  3. 多任务学习:同时优化CTR和CVR

5. 常见面试题

  1. 为什么需要Embedding层?直接使用one-hot编码输入MLP有什么问题?

    • 高维稀疏性导致参数爆炸( n n n类特征需要 n n n维输入)
    • 缺乏语义相似性(所有类别相互独立)
    • 计算效率低下(矩阵运算低效)
  2. 如何确定Embedding维度?

    • 经验公式: d = 6 × category_size 0.25 d = 6 \times \text{category\_size}^{0.25} d=6×category_size0.25
    • 网格搜索:尝试8/16/32/64等常用维度
    • 自动学习:通过矩阵分解确定最优维度
  3. 如何处理未见过的类别(冷启动问题)?

    • 使用哈希技巧: h ( id ) m o d B h(\text{id}) \mod B h(id)modB
    • 分配默认Embedding向量
    • 建立外部KV存储动态更新Embedding
  4. 如何加速Embedding层训练?

    # TensorFlow优化示例
    embedding_layer = tf.keras.layers.Embedding(input_dim=10000, output_dim=64,embeddings_initializer='uniform',embeddings_regularizer=tf.keras.regularizers.l2(1e-6)
    )
    

6. 优缺点分析

优点

  1. 高效处理稀疏特征:Embedding层大幅降低维度
  2. 特征自动学习:端到端学习特征表示
  3. 灵活可扩展:易于添加新特征
  4. 捕获非线性关系:MLP拟合复杂模式

缺点

  1. 特征交互有限:难以学习显式特征交叉
  2. 顺序不敏感:拼接操作丢失特征顺序信息
  3. 冷启动问题:新类别预测效果差
  4. 解释性差:Embedding向量难以解释

改进方向

问题解决方案代表模型
特征交互浅显式特征交叉DeepFM, xDeepFM
顺序不敏感序列建模DIN, DIEN
冷启动元学习, 图网络MAML, GNN

7. TensorFlow实现详解

完整实现代码

import tensorflow as tf
from tensorflow.keras.layers import Input, Embedding, Dense, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2def build_embedding_mlp(categorical_feature_info, continuous_feature_names, embedding_dim=16, hidden_units=[256, 128, 64]):# 输入层定义inputs = {}# 类别特征输入for name, vocab_size in categorical_feature_info.items():inputs[name] = Input(shape=(1,), name=name, dtype=tf.int32)# 连续特征输入for name in continuous_feature_names:inputs[name] = Input(shape=(1,), name=name, dtype=tf.float32)# Embedding层embeddings = []for name, vocab_size in categorical_feature_info.items():embedding = Embedding(input_dim=vocab_size + 1,  # +1 for unknownoutput_dim=embedding_dim,embeddings_regularizer=l2(1e-6),name=f'embed_{name}')(inputs[name])embeddings.append(tf.squeeze(embedding, axis=1))# 连续特征处理normalized_conts = [tf.keras.layers.BatchNormalization()(inputs[name]) for name in continuous_feature_names]# 特征拼接concat = Concatenate(axis=1)(embeddings + normalized_conts)# MLP部分x = concatfor i, units in enumerate(hidden_units):x = Dense(units, activation='relu', kernel_regularizer=l2(1e-5),name=f'fc_{i}')(x)x = tf.keras.layers.BatchNormalization()(x)x = tf.keras.layers.Dropout(0.3)(x)# 输出层output = Dense(1, activation='sigmoid', name='output')(x)# 构建模型model = Model(inputs=list(inputs.values()), outputs=output)# 编译模型model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),loss='binary_crossentropy',metrics=['accuracy', tf.keras.metrics.AUC(name='auc')])return model# 使用示例
categorical_features = {'user_id': 10000, 'item_id': 5000, 'category': 100}
continuous_features = ['age', 'price']
model = build_embedding_mlp(categorical_features, continuous_features)
model.summary()

关键实现细节

  1. 动态输入处理

    # 灵活处理不同特征
    inputs = {name: Input(shape=(1,), name=name, dtype=dtype) for name in feature_names}
    
  2. Embedding优化

    # 使用正则化防止过拟合
    Embedding(..., embeddings_regularizer=l2(1e-6))
    
  3. 特征归一化

    # BatchNorm加速收敛
    tf.keras.layers.BatchNormalization()(continuous_input)
    
  4. 防止过拟合

    # Dropout层
    x = tf.keras.layers.Dropout(0.3)(x)# L2正则化
    Dense(..., kernel_regularizer=l2(1e-5))
    

训练优化技巧

  1. 样本不平衡处理

    model.fit(..., class_weight={0: 1, 1: 10})  # 增加正样本权重
    
  2. 动态学习率

    lr_schedule = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
    
  3. 早停机制

    early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_auc', patience=5, mode='max')
    
  4. 大数据集训练

    # 使用TF Dataset优化
    dataset = tf.data.Dataset.from_tensor_slices((features, labels))
    dataset = dataset.shuffle(100000).batch(512).prefetch(2)
    

总结

Embedding+MLP模型通过特征嵌入多层感知机的组合,有效解决了高维稀疏特征的处理问题,成为推荐系统、广告点击率预测等领域的基准模型。TensorFlow提供了灵活高效的实现方式:

  1. 特征处理:Embedding层高效处理类别特征
  2. 模型架构:灵活拼接+深度MLP捕获复杂模式
  3. 优化技术:正则化、归一化、动态学习率提升性能
  4. 评估体系:AUC、LogLoss等指标全面评估

尽管后续发展出更复杂的模型,Embedding+MLP因其简单高效易于实现的特点,仍是工业界广泛使用的解决方案。掌握其核心原理和TensorFlow实现,是深入理解现代深度学习推荐系统的重要基础。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.pswp.cn/web/82839.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

黑马点评【基于redis实现共享session登录】

目录 一、基于Session实现登录流程 1.发送验证码: 2.短信验证码登录、注册: 3.校验登录状态: 4.session共享问题 4.1为什么会出现 Session 集群共享问题? 4.2常见解决方案 1. 基于 Cookie 的 Session(客户端存储&#xff0…

Python读取阿里法拍网的html+解决登录cookie

效果图 import time from selenium import webdriver from selenium.webdriver.chrome.options import Options from selenium.webdriver.chrome.service import Service from webdriver_manager.chrome import ChromeDriverManager from lxml import etreedef get_taobao_auct…

【win | docker开启远程配置】使用 SSH 隧道访问 Docker的前操作

在主机A pycharm如何连接远程主机B win docker? 需要win docker配置什么? 快捷配置-主机B win OpenSSH SSH Server https://blog.csdn.net/z164470/article/details/121683333 winR,打开命令行,输入net start sshd,启动SSH。 或者右击我的电脑&#…

Cursor生成Java的架构设计图

文章目录 整体说明一、背景二、前置条件三、生成 Promt四、结果查看五、结果编辑 摘要: Cursor生成Java的架构设计图 关键词: Cursor、人工智能 、开发工具、Java 架构设计图 整体说明 Cursor 作为现在非常好用的开发工具,非常的火爆&#…

1Panel运行的.net程序无法读取系统字体(因为使用了docker)

问题来源 我之前都是使用的宝塔面板,之前我也部署过我的程序,就没有什么问题,但是上次我部署我的程序的时候,就提示无法找到字体Arial。 我的程序中使用该字体生成验证码。 我多次安装了微软的字体包,但是依旧没有效…

面试总结。

一、回流(重排)与重绘(Repaint) 优化回答: 概念区分: 回流(Reflow/Relayout):当元素的几何属性(如宽高、位置、隐藏 / 显示)发生改变时&#xff…

TensorFlow深度学习实战(20)——自组织映射详解

TensorFlow深度学习实战(20)——自组织映射详解 0. 前言1. 自组织映射原理2. 自组织映射的优缺点3. 使用自组织映射实现颜色映射小结系列链接 0. 前言 自组织映射 (Self-Organizing Map, SOM) 是一种无监督学习算法,主要用于高维数据的降维、…

Go内存泄漏排查与修复最佳实践

一、引言 即使Go语言拥有强大的垃圾回收机制,内存泄漏仍然是我们在生产环境中经常面临的挑战。与传统印象不同,垃圾回收并不是万能的"记忆清道夫",它只能处理那些不再被引用的内存,而无法识别那些仍被引用但实际上不再…

LeetCode刷题 -- 542. 01矩阵 基于 DFS 更新优化的多源最短路径实现

LeetCode刷题 – 542. 01矩阵 基于 DFS 更新优化的多源最短路径实现 题目描述简述 给定一个 m x n 的二进制矩阵 mat,其中: 每个元素为 0 或 1返回一个同样大小的矩阵 ans,其中 ans[i][j] 表示 mat[i][j] 到最近 0 的最短曼哈顿距离 算法思…

MySQL用户远程访问权限设置

mysql相关指令 一. MySQL给用户添加远程访问权限1. 创建或者修改用户权限方法一:创建用户并授予远程访问权限方法二:修改现有用户的访问限制方法三:授予特定数据库的特定权限 2. 修改 MySQL 配置文件3. 安全最佳实践4. 测试远程连接5. 撤销权…

如何使用 BPF 分析 Linux 内存泄漏,Linux 性能调优之 BPF 分析内核态、用户态内存泄漏

写在前面 博文内容为 通过 BCC 工具集 memleak 进行内存泄漏分析的简单认知包括 memleak 脚本简单认知,内核态(内核模块)、用户态(Java,Python,C)内存跟踪泄漏分析 Demo理解不足小伙伴帮忙指正 😃,生活加油知其不可奈何而安之若命,德之至也。----《庄子内篇人间世》 …

谷歌Sign Gemma: AI手语翻译,沟通从此无界!

嘿,朋友们!想象一下,语言不再是交流的障碍,每个人都能顺畅表达与理解。这听起来是不是很酷?谷歌最新发布的Sign Gemma AI模型,正朝着这个激动人心的未来迈出了一大步!它就像一位随身的、不知疲倦…

全生命周期的智慧城市管理

前言 全生命周期的智慧城市管理。未来,城市将在 实现从基础设施建设、日常运营到数据管理的 全生命周期统筹。这将避免过去智慧城市建设 中出现的“碎片化”问题,实现资源的高效配 置和项目的协调发展。城市管理者将运用先进 的信息技术,如物…

最新Spring Security实战教程(十七)企业级安全方案设计 - 多因素认证(MFA)实现

🌷 古之立大事者,不惟有超世之才,亦必有坚忍不拔之志 🎐 个人CSND主页——Micro麦可乐的博客 🐥《Docker实操教程》专栏以最新的Centos版本为基础进行Docker实操教程,入门到实战 🌺《RabbitMQ》…

logstash拉取redisStream的流数据,并存储ES

先说结论, window验证logstash截至2025-06-06 是没有原生支持的。 为啥考虑用redisStream呢?因为不想引入三方的kafka等组件, 让服务部署轻量化, 所以使用现有的redis来实现, 为啥不用list呢? 已经用strea…

IEC 61347-1:2015 灯控制装置安全通用要求详解

IEC 61347-1:2015 灯控制装置安全通用要求详解 IEC 61347-1:2015《灯控制装置 第1部分:一般要求和安全要求》是国际电工委员会(IEC)制定的关于灯控制装置安全性能的核心基础标准。它为各类用于启动和稳定工作电流的灯控制装置(如…

26、跳表

在C标准库中,std::map 和 std::set 是使用红黑树作为底层数据结构的容器。 红黑树是一种自平衡二叉搜索树,能够保证插入、删除和查找操作的时间复杂度为O(log n)。 以下是一些使用红黑树的C标准库容器: std::map:一种关联容器&a…

LabVIEW音频测试分析

LabVIEW通过读取指定WAV 文件,实现对音频信号的播放、多维度测量分析功能,为音频设备研发、声学研究及质量检测提供专业工具支持。 主要功能 文件读取与播放:支持持续读取示例数据文件夹内的 WAV 文件,可实时播放音频以监听被测信…

JUC并发编程(二)Monitor/自旋/轻量级/锁膨胀/wait/notify/锁消除

目录 一 基础 1 概念 2 卖票问题 3 转账问题 二 锁机制与优化策略 0 Monitor 1 轻量级锁 2 锁膨胀 3 自旋 4 偏向锁 5 锁消除 6 wait /notify 7 sleep与wait的对比 8 join原理 一 基础 1 概念 临界区 一段代码块内如果存在对共享资源的多线程读写操作&#xf…

Doris 与 Elasticsearch:谁更适合你的数据分析需求?

一、Doris 和 Elasticsearch 的基本概念 (一)Doris 是什么? Doris 是一个用于数据分析的分布式 MPP(大规模并行处理)数据库。它主要用于存储和分析大量的结构化数据(比如表格数据)&#xff0c…