TensorFlow2 手把手教你实现自定义层

TensorFlow2 手把手教你实现自定义层

  • 概述
  • Sequential
  • Model & Layer
  • 案例
    • 数据集介绍
    • 完整代码

概述

通过自定义网络, 我们可以自己创建网络并和现有的网络串联起来, 从而实现各种各样的网络结构.

Sequential

Sequential 是 Keras 的一个网络容器. 可以帮助我们将多层网络封装在一起.

在这里插入图片描述

通过 Sequential 我们可以把现有的层已经我们自己的层实现结合, 一次前向传播就可以实现数据从第一层到最后一层的计算.

格式:

tf.keras.Sequential(
    layers=None, name=None
)

例子:

# 5层网络模型
model = tf.keras.Sequential([
    tf.keras.layers.Dense(256, activation=tf.nn.relu),
    tf.keras.layers.Dense(128, activation=tf.nn.relu),
    tf.keras.layers.Dense(64, activation=tf.nn.relu),
    tf.keras.layers.Dense(32, activation=tf.nn.relu),
    tf.keras.layers.Dense(10)
])

Model & Layer

通过 Model 和 Layer 的__init__call()我们可以自定义层和模型.

Model:

class My_Model(tf.keras.Model):  # 继承Model

    def __init__(self):
        """
        初始化
        """
        
        super(My_Model, self).__init__()
        self.fc1 = My_Dense(784, 256)  # 第一层
        self.fc2 = My_Dense(256, 128)  # 第二层
        self.fc3 = My_Dense(128, 64)  # 第三层
        self.fc4 = My_Dense(64, 32)  # 第四层
        self.fc5 = My_Dense(32, 10)  # 第五层

    def call(self, inputs, training=None):
        """
        在Model被调用的时候执行
        :param inputs: 输入
        :param training: 默认为None
        :return: 返回输出
        """
        
        x = self.fc1(inputs)
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        x = self.fc5(x)

        return x

Layer:

class My_Dense(tf.keras.layers.Layer):  # 继承Layer

    def __init__(self, input_dim, output_dim):
        """
        初始化
        :param input_dim:
        :param output_dim:
        """

        super(My_Dense, self).__init__()

        # 添加变量
        self.kernel = self.add_variable("w", [input_dim, output_dim])  # 权重
        self.bias = self.add_variable("b", [output_dim])  # 偏置

    def call(self, inputs, training=None):
        """
        在Layer被调用的时候执行, 计算结果
        :param inputs: 输入
        :param training: 默认为None
        :return: 返回计算结果
        """

        # y = w * x + b
        out = inputs @ self.kernel + self.bias

        return out

案例

数据集介绍

CIFAR-10 是由 10 类不同的物品组成的 6 万张彩色图片的数据集. 其中 5 万张为训练集, 1 万张为测试集.
在这里插入图片描述

完整代码

import tensorflow as tf

def pre_process(x, y):

    # 转换x
    x = 2 * tf.cast(x, dtype=tf.float32) / 255 - 1  # 转换为-1~1的形式
    x = tf.reshape(x, [-1, 32 * 32 * 3])  # 把x铺平

    # 转换y
    y = tf.convert_to_tensor(y)  # 转换为0~1的形式
    y = tf.one_hot(y, depth=10)  # 转成one_hot编码

    # 返回x, y
    return x, y

def get_data():
    """
    获取数据
    :return:
    """

    # 获取数据
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()

    # 调试输出维度
    print(X_train.shape)  # (50000, 32, 32, 3)
    print(y_train.shape)  # (50000, 1)

    # squeeze
    y_train = tf.squeeze(y_train)  # (50000, 1) => (50000,)
    y_test = tf.squeeze(y_test)  # (10000, 1) => (10000,)

    # 分割训练集
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(10000, seed=0)
    train_db = train_db.batch(batch_size).map(pre_process).repeat(iteration_num)  # 迭代20次

    # 分割测试集
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test)).shuffle(10000, seed=0)
    test_db = test_db.batch(batch_size).map(pre_process)

    return train_db, test_db

class My_Dense(tf.keras.layers.Layer):  # 继承Layer

    def __init__(self, input_dim, output_dim):
        """
        初始化
        :param input_dim:
        :param output_dim:
        """

        super(My_Dense, self).__init__()

        # 添加变量
        self.kernel = self.add_weight("w", [input_dim, output_dim])  # 权重
        self.bias = self.add_weight("b", [output_dim])  # 偏置

    def call(self, inputs, training=None):
        """
        在Layer被调用的时候执行, 计算结果
        :param inputs: 输入
        :param training: 默认为None
        :return: 返回计算结果
        """

        # y = w * x + b
        out = inputs @ self.kernel + self.bias

        return out


class My_Model(tf.keras.Model):  # 继承Model

    def __init__(self):
        """
        初始化
        """

        super(My_Model, self).__init__()
        self.fc1 = My_Dense(32 * 32 * 3, 256)  # 第一层
        self.fc2 = My_Dense(256, 128)  # 第二层
        self.fc3 = My_Dense(128, 64)  # 第三层
        self.fc4 = My_Dense(64, 32)  # 第四层
        self.fc5 = My_Dense(32, 10)  # 第五层

    def call(self, inputs, training=None):
        """
        在Model被调用的时候执行
        :param inputs: 输入
        :param training: 默认为None
        :return: 返回输出
        """

        x = self.fc1(inputs)
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        x = self.fc5(x)

        return x

# 定义超参数
batch_size = 256  # 一次训练的样本数目
learning_rate = 0.001  # 学习率
iteration_num = 20  # 迭代次数
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)  # 优化器
loss = tf.losses.CategoricalCrossentropy(from_logits=True)  # 损失
network = My_Model()  # 实例化网络

# 调试输出summary
network.build(input_shape=[None, 32 * 32 * 3])
print(network.summary())

# 组合
network.compile(optimizer=optimizer,
                loss=loss,
                metrics=["accuracy"])

if __name__ == "__main__":
    # 获取分割的数据集
    train_db, test_db = get_data()

    # 拟合
    network.fit(train_db, epochs=5, validation_data=test_db, validation_freq=1)

输出结果:

Model: "my__model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
my__dense (My_Dense)         multiple                  786688    
_________________________________________________________________
my__dense_1 (My_Dense)       multiple                  32896     
_________________________________________________________________
my__dense_2 (My_Dense)       multiple                  8256      
_________________________________________________________________
my__dense_3 (My_Dense)       multiple                  2080      
_________________________________________________________________
my__dense_4 (My_Dense)       multiple                  330       
=================================================================
Total params: 830,250
Trainable params: 830,250
Non-trainable params: 0
_________________________________________________________________
None
(50000, 32, 32, 3)
(50000, 1)
2021-06-15 14:35:26.600766: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/5
3920/3920 [==============================] - 39s 10ms/step - loss: 0.9676 - accuracy: 0.6595 - val_loss: 1.8961 - val_accuracy: 0.5220
Epoch 2/5
3920/3920 [==============================] - 41s 10ms/step - loss: 0.3338 - accuracy: 0.8831 - val_loss: 3.3207 - val_accuracy: 0.5141
Epoch 3/5
3920/3920 [==============================] - 41s 10ms/step - loss: 0.1713 - accuracy: 0.9410 - val_loss: 4.2247 - val_accuracy: 0.5122
Epoch 4/5
3920/3920 [==============================] - 41s 10ms/step - loss: 0.1237 - accuracy: 0.9581 - val_loss: 4.9458 - val_accuracy: 0.5050
Epoch 5/5
3920/3920 [==============================] - 42s 11ms/step - loss: 0.1003 - accuracy: 0.9666 - val_loss: 5.2425 - val_accuracy: 0.5097

热门文章

暂无图片
编程学习 ·

Java输出数组的内容

Java输出数组的内容_一万个小时-CSDN博客_java打印数组内容1. 输出内容最常见的方式// List<String>类型的列表List<String> list new ArrayList<String>();list.add("First");list.add("Second");list.add("Third");list.ad…
暂无图片
编程学习 ·

母螳螂的“魅惑之术”

在它们对大蝗虫发起进攻的时候&#xff0c;我认认真真地观察了一次&#xff0c;因为它们突然像触电一样浑身痉挛起来&#xff0c;警觉地面对限前这个大家伙&#xff0c;然后放下自己优雅的身段和祈祷的双手&#xff0c;摆出了一个可怕的姿势。我被眼前的一幕吓到了&#xff0c;…
暂无图片
编程学习 ·

疯狂填词 mad_libs 第9章9.9.2

#win7 python3.7.0 import os,reos.chdir(d:\documents\program_language) file1open(.\疯狂填词_d9z9d2_r.txt) file2open(.\疯狂填词_d9z9d2_w.txt,w) words[ADJECTIVE,NOUN,VERB,NOUN] str1file1.read()#方法1 for word in words :word_replaceinput(fEnter a {word} :)str1…
暂无图片
编程学习 ·

HBASE 高可用

为了保证HBASE是高可用的,所依赖的HDFS和zookeeper也要是高可用的. 通过参数hbase.rootdir指定了连接到Hadoop的地址,mycluster表示为Hadoop的集群. HBASE本身的高可用很简单,只要在一个健康的集群其他节点通过命令 hbase-daemon.sh start master启动一个Hmaster进程,这个Hmast…
暂无图片
编程学习 ·

js事件操作语法

一、事件的绑定语法 语法形式1 事件监听 标签对象.addEventListener(click,function(){}); 语法形式2 on语法绑定 标签对象.onclick function(){} on语法是通过 等于赋值绑定的事件处理函数 , 等于赋值本质上执行的是覆盖赋值,后赋值的数据会覆盖之前存储的数据,也就是on…
暂无图片
编程学习 ·

Photoshop插件--晕影动态--选区--脚本开发--PS插件

文章目录1.插件界面2.关键代码2.1 选区2.2 动态晕影3.作者寄语PS是一款栅格图像编辑软件&#xff0c;具有许多强大的功能&#xff0c;本文演示如何通过脚本实现晕影动态和选区相关功能&#xff0c;展示从互联网收集而来的一个小插件&#xff0c;供大家学习交流&#xff0c;请勿…
暂无图片
编程学习 ·

vs LNK1104 无法打开文件“xxx.obj”

写在前面&#xff1a; 向大家推荐两本新书&#xff0c;《深度学习计算机视觉实战》和《学习OpenCV4&#xff1a;基于Python的算法实战》。 《深度学习计算机视觉实战》讲了计算机视觉理论基础&#xff0c;讲了案例项目&#xff0c;讲了模型部署&#xff0c;这些项目学会之后可以…
暂无图片
编程学习 ·

工业元宇宙的定义与实施路线图

工业元宇宙的定义与实施路线图 李正海 1 工业元宇宙 给大家做一个关于工业元宇宙的定义。对于工业&#xff0c;从设计的角度来讲&#xff0c;现在的设计人员已经做到了普遍的三维设计&#xff0c;但是进入元宇宙时代&#xff0c;就不仅仅只是三维设计了&#xff0c;我们的目…
暂无图片
编程学习 ·

【leectode 2022.1.15】完成一半题目

有 N 位扣友参加了微软与力扣举办了「以扣会友」线下活动。主办方提供了 2*N 道题目&#xff0c;整型数组 questions 中每个数字对应了每道题目所涉及的知识点类型。 若每位扣友选择不同的一题&#xff0c;请返回被选的 N 道题目至少包含多少种知识点类型。 示例 1&#xff1a…
暂无图片
编程学习 ·

js 面试题总结

一、js原型与原型链 1. prototype 每个函数都有一个prototype属性&#xff0c;被称为显示原型 2._ _proto_ _ 每个实例对象都会有_ _proto_ _属性,其被称为隐式原型 每一个实例对象的隐式原型_ _proto_ _属性指向自身构造函数的显式原型prototype 3. constructor 每个prot…
暂无图片
编程学习 ·

java练习代码

打印自定义行数的空心菱形练习代码如下 import java.util.Scanner; public class daYinLengXing{public static void main(String[] args) {System.out.println("请输入行数");Scanner myScanner new Scanner(System.in);int g myScanner.nextInt();int num g%2;//…
暂无图片
编程学习 ·

RocketMQ-什么是死信队列?怎么解决

目录 什么是死信队列 死信队列的特征 死信消息的处理 什么是死信队列 当一条消息初次消费失败&#xff0c;消息队列会自动进行消费重试&#xff1b;达到最大重试次数后&#xff0c;若消费依然失败&#xff0c;则表明消费者在正常情况下无法正确地消费该消息&#xff0c;此时…
暂无图片
编程学习 ·

项目 cg day04

第4章 lua、Canal实现广告缓存 学习目标 Lua介绍 Lua语法 输出、变量定义、数据类型、流程控制(if..)、循环操作、函数、表(数组)、模块OpenResty介绍(理解配置) 封装了Nginx&#xff0c;并且提供了Lua扩展&#xff0c;大大提升了Nginx对并发处理的能&#xff0c;10K-1000K Lu…
暂无图片
编程学习 ·

输出三角形

#include <stdio.h> int main() { int i,j; for(i0;i<5;i) { for(j0;j<i;j) { printf("*"); } printf("\n"); } }
暂无图片
编程学习 ·

stm32的BOOTLOADER学习1

序言 最近计划学习stm32的BOOTLOADER学习,把学习过程记录下来 因为现在网上STM32C8T6还是比较贵的,根据我的需求flash空间小一些也可以,所以我决定使用stm32c6t6.这个芯片的空间是32kb的。 #熟悉芯片内部的空间地址 1、flash ROM&#xff1a; 大小32KB&#xff0c;范围&#xf…
暂无图片
编程学习 ·

通过awk和shell来限制IP多次访问之学不会你打死我

学不会你打死我 今天我们用shell脚本&#xff0c;awk工具来分析日志来判断是否存在扫描器来进行破解网站密码——限制访问次数过多的IP地址&#xff0c;通过Iptables来进行限制。代码在末尾 首先我们要先查看日志的格式&#xff0c;分析出我们需要筛选的内容&#xff0c;日志…
暂无图片
编程学习 ·

Python - 如何像程序员一样思考

在为计算机编写程序之前&#xff0c;您必须学会如何像程序员一样思考。学习像程序员一样思考对任何学生都很有价值。以下步骤可帮助任何人学习编码并了解计算机科学的价值——即使他们不打算成为计算机科学家。 顾名思义&#xff0c;Python经常被想要学习编程的人用作第一语言…
暂无图片
编程学习 ·

蓝桥杯python-数字三角形

问题描述 虽然我前后用了三种做法&#xff0c;但是我发现只有“优化思路_1”可以通过蓝桥杯官网中的测评&#xff0c;但是如果用c/c的话&#xff0c;每个都通得过&#xff0c;足以可见python的效率之低&#xff08;但耐不住人家好用啊&#xff08;哭笑&#xff09;&#xff09…