model.compile 是 TensorFlow Keras 中用于配置训练模型的方法。在开始训练之前,需要通过这个方法来指定模型的优化器、损失函数和评估指标等。
注意事项: 在开始训练(调用 model.fit)之前,必须先调用 model.compile()。
1 基本用法
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
1) optimizer: 优化器
可以是预定义优化器的字符串(如 'adam', 'sgd' 等),也可以是 tf.keras.optimizers 下的优化器实例。优化器负责调整模型的权重以最小化损失函数。
以下是可以使用的字符串参数:
'sgd': 随机梯度下降优化器'adam': Adam 优化器'rmsprop': RMSprop 优化器'adagrad': Adagrad 优化器'adadelta': Adadelta 优化器'adamax': Adamax 优化器'nadam': Nadam 优化器'ftrl': Ftrl 优化器
需要注意的是:
-
这些字符串参数是不区分大小写的。例如,‘Adam’ 和 ‘adam’ 都是有效的。
-
使用字符串参数时,优化器会使用其默认参数值。如果你需要自定义优化器的参数(如学习率),最好直接使用优化器类:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001) model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy']) -
‘adam’ 通常是一个很好的默认选择,因为它在各种问题上都表现良好。但对于特定问题,其他优化器可能会表现得更好。
-
在实践中,选择合适的优化器和调整其参数(如学习率)往往比选择特定的优化器算法更重要。
2) loss: 损失函数
用于计算模型的预测值和真实值之间的差异。可以是字符串(预定义损失函数的名称),也可以是 tf.keras.losses 下的损失函数对象。对于不同类型的问题(如分类、回归等),需要选择合适的损失函数。
以下是一些常用的字符串参数对应的损失函数:
'binary_crossentropy': 用于二分类问题的交叉熵损失。'categorical_crossentropy': 用于多分类问题的交叉熵损失,要求标签为 one-hot 编码。'sparse_categorical_crossentropy': 用于多分类问题的交叉熵损失,标签为整数。'mean_squared_error'或'mse': 均方误差损失,用于回归问题。'mean_absolute_error'或'mae': 平均绝对误差损失,用于回归问题。'mean_absolute_percentage_error'或'mape': 平均绝对百分比误差,用于回归问题。'mean_squared_logarithmic_error'或'msle': 均方对数误差,用于回归问题,对小差异不敏感。'poisson': 泊松损失,适用于计数问题或其他泊松分布问题。'kullback_leibler_divergence'或'kld': Kullback-Leibler 散度,用于衡量两个概率分布之间的差异。'hinge': 用于“最大间隔”分类问题的铰链损失。'squared_hinge': 铰链损失的平方版本。'logcosh': 对数双曲余弦损失,用于回归问题,对异常值不敏感。
3) metrics: 评估指标列表,用于评估模型的性能
这些指标在训练过程中不会用于梯度计算,仅用于观察。常见的指标包括 'accuracy'、'precision'、'recall' 等。
在 model.compile() 方法中,metrics 参数用于指定在训练和评估期间模型将评估哪些指标。这些指标不会用于训练过程中的反向传播和权重更新,仅用于观察模型的性能。以下是一些可以通过字符串参数传入的常用指标:
'accuracy'或'acc': 准确率,用于分类问题。'binary_accuracy': 二分类准确率。'categorical_accuracy': 多分类准确率,要求标签为 one-hot 编码。'sparse_categorical_accuracy': 多分类准确率,标签为整数。'top_k_categorical_accuracy': Top-k 准确率,即目标类别在模型预测的前 k 个最可能的类别中的准确率,用于多分类问题。'sparse_top_k_categorical_accuracy': 与'top_k_categorical_accuracy'类似,但适用于标签为整数的情况。'mean_squared_error'或'mse': 均方误差,用于回归问题。'mean_absolute_error'或'mae': 平均绝对误差,用于回归问题。'mean_absolute_percentage_error'或'mape': 平均绝对百分比误差,用于回归问题。'mean_squared_logarithmic_error'或'msle': 均方对数误差,用于回归问题。'cosine_similarity': 余弦相似度,用于回归问题或多标签分类问题。'precision': 精确率,用于二分类或多标签分类问题。'recall': 召回率,用于二分类或多标签分类问题。'auc': 曲线下面积(Area Under the Curve),用于二分类问题。
使用示例:
# 二分类问题
model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy', 'precision', 'recall'])# 多分类问题
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy', 'top_k_categorical_accuracy'])# 回归问题
model.compile(optimizer='adam',loss='mean_squared_error',metrics=['mae', 'mse'])
对于一些特定的指标(如 'precision', 'recall', 'auc' 等),可能需要使用 tf.keras.metrics 下的类实例来获得更多的配置选项,例如设置阈值或为多标签分类问题指定平均方法。
from tensorflow.keras.metrics import Precision, Recallmodel.compile(optimizer='adam',loss='binary_crossentropy',metrics=[Precision(thresholds=0.5), Recall(thresholds=0.5)])
2 高级用法
- 使用自定义优化器:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
- 使用自定义损失函数:
def custom_loss(y_true, y_pred):# 自定义损失计算逻辑return tf.reduce_mean(tf.square(y_true - y_pred))model.compile(optimizer='adam', loss=custom_loss, metrics=['accuracy'])
- 使用多个损失函数和评估指标:
如果模型有多个输出,你可以为每个输出指定不同的损失函数和评估指标。
model.compile(optimizer='adam',loss={'output_a': 'sparse_categorical_crossentropy', 'output_b': 'mse'},metrics={'output_a': ['accuracy'], 'output_b': ['mae', 'mse']})
- 使用学习率衰减:
from tensorflow.keras.optimizers.schedules import ExponentialDecaylr_schedule = ExponentialDecay(initial_learning_rate=1e-2, decay_steps=10000, decay_rate=0.9)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])