掌握聚合最新动态了解行业最新趋势
API接口,开发服务,免费咨询服务

快速开启你的第一个项目:TensorFlow项目架构模板

本文经机器之心(微信公众号:almosthuman2014)授权转载,禁止二次转载。

项目链接:https://github.com/Mrgemy95/Tensorflow-Project-Template

TensorFlow 项目模板

简洁而精密的结构对于深度学习项目来说是必不可少的,在经过多次练习和 TensorFlow 项目开发之后,本文作者提出了一个结合简便性、优化文件结构和良好 OOP 设计的 TensorFlow 项目模板。该模板可以帮助你快速启动自己的 TensorFlow 项目,直接从实现自己的核心思想开始。

这个简单的模板可以帮助你直接从构建模型、训练等任务开始工作。

目录

  • 概述

  • 详述

  • 项目架构

  • 文件夹结构

  • 主要组件

  • 模型

  • 训练器

  • 数据加载器

  • 记录器

  • 配置

  • Main

  • 未来工作

概述

简言之,本文介绍的是这一模板的使用方法,例如,如果你希望实现 VGG 模型,那么你应该:

在模型文件夹中创建一个名为 VGG 的类,由它继承「base_model」类

  1.   class VGGModel(BaseModel):

  2.        def __init__(self, config):

  3.            super(VGGModel, self).__init__(config)

  4.            #call the build_model and init_saver functions.

  5.            self.build_model()

  6.            self.init_saver()

覆写这两个函数 "build_model",在其中执行你的 VGG 模型;以及定义 TensorFlow 保存的「init_saver」,随后在 initalizer 中调用它们。

  1.    def build_model(self):

  2.        # here you build the tensorflow graph of any model you want and also define the loss.

  3.        pass

  4.     def init_saver(self):

  5.        #here you initalize the tensorflow saver that will be used in saving the checkpoints.

  6.        self.saver = tf.train.Saver(max_to_keep=self.config.max_to_keep)

在 trainers 文件夹中创建 VGG 训练器,继承「base_train」类。

  1.        class VGGTrainer(BaseTrain):

  2.        def __init__(self, sess, model, data, config, logger):

  3.            super(VGGTrainer, self).__init__(sess, model, data, config, logger)

覆写这两个函数「train_step」、「train_epoch」,在其中写入训练过程的逻辑。

  1.       def train_epoch(self):

  2.        """

  3.       implement the logic of epoch:

  4.       -loop ever the number of iteration in the config and call teh train step

  5.       -add any summaries you want using the sammary

  6.        """

  7.        pass

  8.    def train_step(self):

  9.        """

  10.       implement the logic of the train step

  11.       - run the tensorflow session

  12.       - return any metrics you need to summarize

  13.       """

  14.        pass

在主文件中创建会话,创建以下对象:「Model」、「Logger」、「Data_Generator」、「Trainer」与配置:

  1.      sess = tf.Session()

  2.    # create instance of the model you want

  3.    model = VGGModel(config)

  4.    # create your data generator

  5.    data = DataGenerator(config)

  6.    # create tensorboard logger

  7.    logger = Logger(sess, config)

向所有这些对象传递训练器对象,通过调用「trainer.train()」开始训练。

  1.       trainer = VGGTrainer(sess, model, data, config, logger)

  2.    # here you train your model

  3.    trainer.train()

你会看到模板文件、一个示例模型和训练文件夹,向你展示如何快速开始你的第一个模型。

详述

模型架构

文件夹结构

  1.       ├──  base

  2. │   ├── base_model.py   - this file contains the abstract class of the model.

  3. │   └── ease_train.py - this file contains the abstract class of the trainer.

  4. ├── model               -This folder contains any model of your project.

  5. │   └── example_model.py

  6. ├── trainer             -this folder contains trainers of your project.

  7. │   └── example_trainer.py

  8. │  

  9. ├──  mains              - here's the main/s of your project (you may need more than one main.

  10. │                        

  11. │  

  12. ├──  data _loader  

  13. │    └── data_generator.py  - here's the data_generator that responsible for all data handling.

  14. └── utils

  15.     ├── logger.py

  16.     └── any_other_utils_you_need

主要组件

模型

  • 基础模型

基础模型是一个必须由你所创建的模型继承的抽象类,其背后的思路是:绝大多数模型之间都有很多东西是可以共享的。基础模型包含:

  • Save-此函数可保存 checkpoint 至桌面。

  • Load-此函数可加载桌面上的 checkpoint。

  • Cur-epoch、Global_step counters-这些变量会跟踪训练 epoch 和全局步。

  • Init_Saver-一个抽象函数,用于初始化保存和加载 checkpoint 的操作,注意:请在要实现的模型中覆盖此函数。

  • Build_model-是一个定义模型的抽象函数,注意:请在要实现的模型中覆盖此函数。

  • 你的模型

以下是你在模型中执行的地方。因此,你应该:

  • 创建你的模型类并继承 base_model 类。

  • 覆写 "build_model",在其中写入你想要的 tensorflow 模型。

  • 覆写"init_save",在其中你创建 tensorflow 保存器,以用它保存和加载检查点。

  • 在 initalizer 中调用"build_model" 和 "init_saver"

训练器

  • 基础训练器

基础训练器(Base trainer)是一个只包装训练过程的抽象的类。

  • 你的训练器

以下是你应该在训练器中执行的。

  • 创建你的训练器类,并继承 base_trainer 类。

  • 覆写这两个函数,在其中你执行每一步和每一 epoch 的训练过程。

数据加载器

这些类负责所有的数据操作和处理,并提供一个可被训练器使用的易用接口。

记录器(Logger)

这个类负责 tensorboard 总结。在你的训练器中创建一个有关所有你想要的 tensorflow 变量的词典,并将其传递给 logger.summarize()。

配置

我使用 Json 作为配置方法,接着解析它,因此写入所有你想要的配置,然后用"utils/config/process_config"解析它,并把这个配置对象传递给所有其他对象。

Main

以下是你整合的所有之前的部分。

1. 解析配置文件。

2. 创建一个 TensorFlow 会话。

3. 创建 "Model"、"Data_Generator" 和 "Logger"实例,并解析所有它们的配置。

4. 创建一个"Trainer"实例,并把之前所有的对象传递给它。

5. 现在你可通过调用"Trainer.train()"训练你的模型。

未来工作

未来,该项目计划通过新的 TensorFlow 数据集 API 替代数据加载器。

原文来自:机器之心

声明:所有来源为“聚合数据”的内容信息,未经本网许可,不得转载!如对内容有异议或投诉,请与我们联系。邮箱:marketing@think-land.com

  • 购物小票识别

    支持识别各类商场、超市及药店的购物小票,包括店名、单号、总金额、消费时间、明细商品名称、单价、数量、金额等信息,可用于商品售卖信息统计、购物中心用户积分兑换及企业内部报销等场景

    支持识别各类商场、超市及药店的购物小票,包括店名、单号、总金额、消费时间、明细商品名称、单价、数量、金额等信息,可用于商品售卖信息统计、购物中心用户积分兑换及企业内部报销等场景

  • 涉农贷款地址识别

    涉农贷款地址识别,支持对私和对公两种方式。输入地址的行政区划越完整,识别准确度越高。

    涉农贷款地址识别,支持对私和对公两种方式。输入地址的行政区划越完整,识别准确度越高。

  • 人脸四要素

    根据给定的手机号、姓名、身份证、人像图片核验是否一致

    根据给定的手机号、姓名、身份证、人像图片核验是否一致

  • 个人/企业涉诉查询

    通过企业关键词查询企业涉讼详情,如裁判文书、开庭公告、执行公告、失信公告、案件流程等等。

    通过企业关键词查询企业涉讼详情,如裁判文书、开庭公告、执行公告、失信公告、案件流程等等。

  • IP反查域名

    IP反查域名是通过IP查询相关联的域名信息的功能,它提供IP地址历史上绑定过的域名信息。

    IP反查域名是通过IP查询相关联的域名信息的功能,它提供IP地址历史上绑定过的域名信息。

0512-88869195
数 据 驱 动 未 来
Data Drives The Future