[CAFFE] What files do you need to train your own network

The following list of files serves as an example to do your own training in Caffe.

  • train.sh

If you are using bash, you will be running this script to train your network. This tells where to look for the solver prototxt, and whether to restart training from an existing ‘.solverstate’ file. Note that ‘–snapshot’ is optional; it’s used when you want to restart training from an existing state.

TOOLS=./build/tools        // assume you are in the caffe directory
$TOOLS/caffe train \       // the verb "train" starts the training process
--solver=<PATH TO solver prototxt>
--snapshot=<PATH TO solver .solverstate file>
  • train.prototxt

This is the network architecture you define for your training.  You define all the layers you need, including data layer, convolutional layer, pooling, ReLU, etc. More examples of the prototxt file could be found in the caffe model zoo. (where trained network architecture is hosted)

The tricker component of train.prototxt is the data layer. I will have another post particularly talking about data layer later.

  • deploy.prototxt

The deploy prototxt is basically a duplicate of the train prototxt. This makes sense since you want your test data to be forwarded to the same network architecture. The only difference is that you have to replace the data layer in train.prototxt with a specification of the input data dimension.

Let’s say you had this data layer in your train.prototxt

layer {
  name: "..."
  type: "Data"
  top: "data"
  top: "label"
  include {
    ...
  }
  transform_param {
    ...
  }
  data_param {
    ...
  }
}

You would want to replace the above layer with the following in your deploy.prototxt:

input: "data"
input_shape {
  dim: 1
  dim: 3 (If it's RGB color image)
  dim: Height
  dim: Width
}
  • data

Caffe supports different data types to be used for training. The simplest but slowest is to use a txt file with actual image path and label written in each line. But this has a latency for data fetching from the memory, and could significantly slow down your training process.

I’m more used to using lmdb files as data source. Caffe will allocate memory onto GPU and fetch data from there, thus a big speedup for training. But it’s less straightforward creating lmdb format data file, I may have a separate post on creating lmdb format data.

After you have your data ready, you specify its path in your train.prototxt inside the data layer.

  • solver.prototxt

This contains all the hyper-parameters you have for your training. An example is shown below

# The train/test net protocol buffer definition
net: "
# test_iter specifies how many forward passes the test should carry out.
# total_test_number = test_iter * batch_size
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
base_lr: 0.001
momentum: 0.9
weight_decay: 0.004
# The learning rate policy
lr_policy: "fixed"
# Display every 100 iterations
display: 100
# The maximum number of iterations
max_iter: 4000
# snapshot intermediate results
snapshot: 4000
snapshot_format: HDF5
snapshot_prefix: "PATH TO PREFIX LOCATION"
# solver mode: CPU or GPU
solver_mode: GPU
  • [optinal] solver2.prototxt

If you want to decay your learning rate after a certain amount of training iteration, you would specify another solver prototxt here with reduced learning rate. This file is optional and is only needed when you want to restart your training with a different set of hyper-parameters.

 

Basically, this is all you need to train your network with caffe. Happy brewing!

Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Google+ photo

You are commenting using your Google+ account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s