近一个月来,我都在阅读Keras之父著作的《Python深度学习》一书。书中最常使用的便是Keras的Model类。
通过查阅官方文档和Model类源码,本文将Keras Model类中的常用方法进行了梳理和总结。
compile
用于配置训练模型。
fit
以给定数量的轮次(数据集上的迭代)训练模型。
返回
一个 History
对象。其 History.history
属性是连续 epoch 训练损失和评估值,以及验证集损失和评估值的记录(如果适用)。
evaluate
1 | evaluate(self, |
在测试模式下返回模型的误差值和评估标准值。
计算是分批进行的。
summary
1 | # 继承自Network类 |
打印网络的总结信息。
1 | line_length: Total length of printed lines |
predict
为输入样本生成输出预测。
计算是分批进行的
predict_classes
1 | # 子类Sequential中的方法 |
为输入样本生成类别预测。
计算是分批进行的
参数
x: 输入数据,Numpy 数组 (或者 Numpy 数组的列表,如果模型有多个输出)。
batch_size: 批量大小。如未指定,默认为 32。
- verbose: 日志显示模式,0 或 1。
save
函数原型:
1 | # 继承自Network类 |
将模型保存到一个HDF5文件中。
示例:
1 | from keras.models import load_model |
save_weights
函数原型:
1 | # 继承自Network类 |
将各层的权重存储到HDF5文件中。
1 | filepath: String, path to the file to save the weights to. |
load_weights
函数原型:
1 | # 继承自Network类 |
get_layer
1 | # 继承自Network类 |
根据名称(唯一)或索引值查找网络层。
如果同时提供了 name
和 index
,则 index
将优先。索引值来自于水平图遍历的顺序(自下而上)。
Keras Model类中文文档 https://keras.io/zh/models/model/#model
Model类源码 https://github.com/keras-team/keras/blob/master/keras/engine/training.py