Caffe的数据变换器(DataTransformer)主要提供了对原始输入图像的预处理方法,包括随机切块、随机镜像、幅度缩放、去均值、灰度/色度变换等。相信熟悉图像处理、OpenCV的读者对上述操作并不陌生。
数据结构描述
message TransformationParameter {
// For data pre-processing, we can do simple scaling and subtracting the
// data mean, if provided. Note that the mean subtraction is always carried
// out before scaling.
//像素幅度缩放参数,默认为1,即不缩放
optional float scale = 1 [default = 1];
// Specify if we want to randomly mirror data.
//图像随机镜像开关,默认为false,即不进行镜像操作
optional bool mirror = 2 [default = false];
// Specify if we would like to randomly crop an image.
//图像随机切块的大小,默认为0,即不进行切块操作
optional uint32 crop_size = 3 [default = 0];
// mean_file and mean_value cannot be specified at the same time(存储图像均值的文件)
optional string mean_file = 4;
// if specified can be repeated once (would subtract it from all the channels)
// or can be repeated the same number of times as channels
// (would subtract them from the corresponding channel)
//均值数值,无须读取文件。若数目与图像通道数相等,则每个图像通道分别减去对应的均值;如果只给出一个值.则毎个图像通道都减去同一个均值
repeated float mean_value = 5;
// Force the decoded image to have 3 color channels.
//强制为三通道彩色图像输入
optional bool force_color = 6 [default = false];
// Force the decoded image to have 1 color channels.
//强制为单通道灰度图像输入
optional bool force_gray = 7 [default = false];
}
数据变换器的实现
数据变换器声明头文件位于include/cafTe/data_transformer.hpp
中,如果需要单独使用该模块,应包含这个头文件。文件内容如下:
#ifndef CAFFE_DATA_TRANSFORMER_HPP
#define CAFFE_DATA_TRANSFORMER_HPP
#include <vector>
#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
namespace caffe {
/**
* @brief Applies common transformations to the input data, such as
* scaling, mirroring, substracting the image mean...
*/
//DataTransformer类声明
template <typename Dtype>
class DataTransformer {
public:
//显式构造函数
explicit DataTransformer(const TransformationParameter& param, Phase phase);
//析构函数
virtual ~DataTransformer() {}
/**
* @brief Initialize the Random number generations if needed by the
* transformation.
*/
//初始化随机数种子函数
void InitRand();
/**
* @brief Applies the transformation defined in the data layer's
* transform_param block to the data.
*
* @param datum
* Datum containing the data to be transformed.
* @param transformed_blob
* This is destination blob. It can be part of top blob's data if
* set_cpu_data() is used. See data_layer.cpp for an example.
*/
//下面几种函数重载,以适应多种输入数据源
void Transform(const Datum& datum, Blob<Dtype>* transformed_blob);
/**
* @brief Applies the transformation defined in the data layer's
* transform_param block to a vector of Datum.
*
* @param datum_vector
* A vector of Datum containing the data to be transformed.
* @param transformed_blob
* This is destination blob. It can be part of top blob's data if
* set_cpu_data() is used. See memory_layer.cpp for an example.
*/
void Transform(const vector<Datum> & datum_vector,
Blob<Dtype>* transformed_blob);
#ifdef USE_OPENCV
/**
* @brief Applies the transformation defined in the data layer's
* transform_param block to a vector of Mat.
*
* @param mat_vector
* A vector of Mat containing the data to be transformed.
* @param transformed_blob
* This is destination blob. It can be part of top blob's data if
* set_cpu_data() is used. See memory_layer.cpp for an example.
*/
void Transform(const vector<cv::Mat> & mat_vector,
Blob<Dtype>* transformed_blob);
/**
* @brief Applies the transformation defined in the data layer's
* transform_param block to a cv::Mat
*
* @param cv_img
* cv::Mat containing the data to be transformed.
* @param transformed_blob
* This is destination blob. It can be part of top blob's data if
* set_cpu_data() is used. See image_data_layer.cpp for an example.
*/
void Transform(const cv::Mat& cv_img, Blob<Dtype>* transformed_blob);
#endif // USE_OPENCV
/**
* @brief Applies the same transformation defined in the data layer's
* transform_param block to all the num images in a input_blob.
*
* @param input_blob
* A Blob containing the data to be transformed. It applies the same
* transformation to all the num images in the blob.
* @param transformed_blob
* This is destination blob, it will contain as many images as the
* input blob. It can be part of top blob's data.
*/
void Transform(Blob<Dtype>* input_blob, Blob<Dtype>* transformed_blob);
//获取执行Transform后的输出Blob形状
/**
* @brief Infers the shape of transformed_blob will have when
* the transformation is applied to the data.
*
* @param datum
* Datum containing the data to be transformed.
*/
vector<int> InferBlobShape(const Datum& datum);
/**
* @brief Infers the shape of transformed_blob will have when
* the transformation is applied to the data.
* It uses the first element to infer the shape of the blob.
*
* @param datum_vector
* A vector of Datum containing the data to be transformed.
*/
vector<int> InferBlobShape(const vector<Datum> & datum_vector);
/**
* @brief Infers the shape of transformed_blob will have when
* the transformation is applied to the data.
* It uses the first element to infer the shape of the blob.
*
* @param mat_vector
* A vector of Mat containing the data to be transformed.
*/
#ifdef USE_OPENCV
vector<int> InferBlobShape(const vector<cv::Mat> & mat_vector);
/**
* @brief Infers the shape of transformed_blob will have when
* the transformation is applied to the data.
*
* @param cv_img
* cv::Mat containing the data to be transformed.
*/
vector<int> InferBlobShape(const cv::Mat& cv_img);
#endif // USE_OPENCV
protected:
/**
* @brief Generates a random integer from Uniform({0, 1, ..., n-1}).
*
* @param n
* The upperbound (exclusive) value of the random number.
* @return
* A uniformly random integer value from ({0, 1, ..., n-1}).
*/
//产生取值{0, 1, n-1}的随机整数,服从均匀分布
virtual int Rand(int n);
void Transform(const Datum& datum, Dtype* transformed_data);
// Tranformation parameters(变换参数,该数据结构由ProtoBuffer工具自动生成)
TransformationParameter param_;
//随机数生成器,声明在include/caffe/common.hpp中
shared_ptr<Caffe::RNG> rng_;
//当前运行阶段,可能为TRAIN或TEST。阶段不同,执行变换会有差异
Phase phase_;
//均值图像,用于从均值文件中读取
Blob<Dtype> data_mean_;
//均值数值,用于从param_中提取
vector<Dtype> mean_values_;
};
} // namespace caffe
#endif // CAFFE_DATA_TRANSFORMER_HPP_
数据变化器的实现文件位于src/caffe/data_transformer.cpp
,我们来深入阅读一下。
#ifdef USE_OPENCV
#include <opencv2/core/core.hpp>
#endif // USE_OPENCV
#include <string>
#include <vector>
#include "caffe/data_transformer.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/math_functions.hpp"
#include "caffe/util/rng.hpp"
namespace caffe {
//构造函数
template<typename Dtype>
DataTransformer<Dtype>::DataTransformer(const TransformationParameter& param,
Phase phase)
: param_(param), phase_(phase) { //初始化param_和phase_
// check if we want to use mean_file(查看是否使用均值文件)
if (param_.has_mean_file()) {
//如果定了均值文件,又指定了均值数值,则报错,只能2选1
CHECK_EQ(param_.mean_value_size(), 0) <<
"Cannot specify mean_file and mean_value at the same time";
const string& mean_file = param.mean_file(); //获取均值文件名
if (Caffe::root_solver()) {
LOG(INFO) << "Loading mean file from: " << mean_file;
}
//从均值文件中读取数据到blob_proto对象中
BlobProto blob_proto;
ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
//从blob_proto将均值反序列化到data_mean_内存中
data_mean_.FromProto(blob_proto);
}
// check if we want to use mean_value(均值数值)
if (param_.mean_value_size() > 0) {
CHECK(param_.has_mean_file() == false) <<
"Cannot specify mean_file and mean_value at the same time";
for (int c = 0; c < param_.mean_value_size(); ++c) {
mean_values_.push_back(param_.mean_value(c));//从param_中读取均值数值,不在读取均值文件
}
}
}
//变换函数,从众多重载函数中,我们选择一个重点讲解,其他的计算流程都类似
//下面函数使用了Datum作为输入,这个结构体我们可以从caffe.proto中一窥究竟
/*
// Datum用来从LMDB/LEVELDB中读取数据,或将数据写入LMDB/LEVELDB,和BlobProto有相似的功能,只是BlobProto用于模型权值序列化/反序列化,而Datum专为数据或特征阁(feature map)提供序列化/反序列化服务.
message Datum {
//数据维度信息,channels * height ★ width
optional int32 channels = 1;
optional int32 height = 2;
optional int32 width = 3;
// the actual image data, in bytes(图像数据,以字节类型存储)
optional bytes data = 4;
//标签数据,统一用int32类型存储
optional int32 label = 5;
// Optionally, the datum could also hold float data.(可选,图像数据也可以用float类型存储 )
repeated float float_data = 6;
// If true data contains an encoded image that need to be decoded(是否为编码数据,默认不是)
optional bool encoded = 7 [default = false];
}
*/
//下面函数输入为Datum,输出为数据指针
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum,
Dtype* transformed_data) {
//获得datum数据字串、维度信息
const string& data = datum.data();
const int datum_channels = datum.channels();
const int datum_height = datum.height();
const int datum_width = datum.width();
//从取处理参数,如切块大小、幅度缩放、随机镜像、图像均值等
const int crop_size = param_.crop_size();
const Dtype scale = param_.scale();
const bool do_mirror = param_.mirror() && Rand(2);
const bool has_mean_file = param_.has_mean_file();
const bool has_uint8 = data.size() > 0;
const bool has_mean_values = mean_values_.size() > 0;
CHECK_GT(datum_channels, 0); //保证输入数据通道数大于0
CHECK_GE(datum_height, crop_size); //保证输入数据宽和高大于切块大小
CHECK_GE(datum_width, crop_size);
//获得图像均值
Dtype* mean = NULL;
if (has_mean_file) { //若指定了图像均值文件
//保证图像均值的维度与输入图像数据的维度完全相同
CHECK_EQ(datum_channels, data_mean_.channels());
CHECK_EQ(datum_height, data_mean_.height());
CHECK_EQ(datum_width, data_mean_.width());
mean = data_mean_.mutable_cpu_data(); //夺取图像均值数据控制权
}
if (has_mean_values) { //若没有指定图像均值文件,而是直接给出数值
//保证均值数值维度为1,或与输人图像数据的channels数目相同
CHECK(mean_values_.size() == 1 || mean_values_.size() == datum_channels) <<
"Specify either 1 mean_value or as many as channels: " << datum_channels;
if (datum_channels > 1 && mean_values_.size() == 1) {
// Replicate the mean_value for simplicity(若均值数值维度为1,而输入数据channels数目大于1,则重复该值channels次 )
for (int c = 1; c < datum_channels; ++c) {
mean_values_.push_back(mean_values_[0]);
}
}
}
//输入图像宽和高
int height = datum_height;
int width = datum_width;
//开始图像切块
int h_off = 0;
int w_off = 0;
if (crop_size) { //crop_size不为0,则进行切块;若为0表示不切块
height = crop_size;
width = crop_size;
// We only do random crop when we do training.
if (phase_ == TRAIN) {
h_off = Rand(datum_height - crop_size + 1); //切块的 height偏移量
w_off = Rand(datum_width - crop_size + 1); //切块的 width 偏移量
} else {
h_off = (datum_height - crop_size) / 2;
w_off = (datum_width - crop_size) / 2;
}
}
Dtype datum_element; //存放输入图像的像素值
int top_index, data_index; //分别存放输出index,输入index
for (int c = 0; c < datum_channels; ++c) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
data_index = (c * datum_height + h_off + h) * datum_width + w_off + w;
if (do_mirror) { //若需要镜像操作,则对输出index设置width反向
top_index = (c * height + h) * width + (width - 1 - w);
} else {
top_index = (c * height + h) * width + w;
}
if (has_uint8) { //若datum中使用uint8存储图像数据,需要转换为float
datum_element =
static_cast<Dtype>(static_cast<uint8_t>(data[data_index]));
} else {
datum_element = datum.float_data(data_index);
}
if (has_mean_file) { //若指定了均值文件
transformed_data[top_index] =
(datum_element - mean[data_index]) * scale; // 执行去均值、幅度缩放
} else {
if (has_mean_values) { //若指定了均值数值
transformed_data[top_index] =
(datum_element - mean_values_[c]) * scale; //去均值,幅度缩放
} else {
transformed_data[top_index] = datum_element * scale; //不去均值,制作幅度缩放
}
}
}
}
}
}
//与上面函数类似.只是输出变为Blob
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum,
Blob<Dtype>* transformed_blob) {
// If datum is encoded, decode and transform the cv::image.(如果datum是经过编码的图像,则先解码 )
if (datum.encoded()) {
#ifdef USE_OPENCV
CHECK(!(param_.force_color() && param_.force_gray()))
<< "cannot set both force_color and force_gray";
cv::Mat cv_img;
if (param_.force_color() || param_.force_gray()) {
// If force_color then decode in color otherwise decode in gray.
cv_img = DecodeDatumToCVMat(datum, param_.force_color());
} else {
cv_img = DecodeDatumToCVMatNative(datum);
}
// Transform the cv::image into blob.(将cv::image变为Blob)
return Transform(cv_img, transformed_blob);
#else
LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";
#endif // USE_OPENCV
} else {
if (param_.force_color() || param_.force_gray()) {
LOG(ERROR) << "force_color and force_gray only for encoded datum";
}
}
const int crop_size = param_.crop_size();
const int datum_channels = datum.channels();
const int datum_height = datum.height();
const int datum_width = datum.width();
// Check dimensions.
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
const int num = transformed_blob->num();
CHECK_EQ(channels, datum_channels);
CHECK_LE(height, datum_height);
CHECK_LE(width, datum_width);
CHECK_GE(num, 1);
if (crop_size) {
CHECK_EQ(crop_size, height);
CHECK_EQ(crop_size, width);
} else {
CHECK_EQ(datum_height, height);
CHECK_EQ(datum_width, width);
}
Dtype* transformed_data = transformed_blob->mutable_cpu_data();
Transform(datum, transformed_data); //参数变换完毕,调用现有函数
}
//对一组datum进行变换
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,
Blob<Dtype>* transformed_blob) {
const int datum_num = datum_vector.size();
const int num = transformed_blob->num();
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
CHECK_GT(datum_num, 0) << "There is no datum to add";
CHECK_LE(datum_num, num) <<
"The size of datum_vector must be no greater than transformed_blob->num()";
Blob<Dtype> uni_blob(1, channels, height, width); //临时Blob
//依次对每个datum进行变换.放入对应的Blob中
for (int item_id = 0; item_id < datum_num; ++item_id) {
int offset = transformed_blob->offset(item_id);
uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset);
Transform(datum_vector[item_id], &uni_blob);
}
}
//对一组输入cv::Mat对象进行变换.放入Blob中
#ifdef USE_OPENCV
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const vector<cv::Mat> & mat_vector,
Blob<Dtype>* transformed_blob) {
const int mat_num = mat_vector.size();
const int num = transformed_blob->num();
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
CHECK_GT(mat_num, 0) << "There is no MAT to add";
CHECK_EQ(mat_num, num) <<
"The size of mat_vector must be equals to transformed_blob->num()";
Blob<Dtype> uni_blob(1, channels, height, width);
for (int item_id = 0; item_id < mat_num; ++item_id) {
int offset = transformed_blob->offset(item_id);
uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset);
Transform(mat_vector[item_id], &uni_blob);
}
}
//对一个cv:Mat对象进行变换
template<typename Dtype>
void DataTransformer<Dtype>::Transform(const cv::Mat& cv_img,
Blob<Dtype>* transformed_blob) {
const int crop_size = param_.crop_size();
const int img_channels = cv_img.channels();
const int img_height = cv_img.rows;
const int img_width = cv_img.cols;
// Check dimensions.
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
const int num = transformed_blob->num();
CHECK_EQ(channels, img_channels);
CHECK_LE(height, img_height);
CHECK_LE(width, img_width);
CHECK_GE(num, 1);
CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte";
const Dtype scale = param_.scale();
const bool do_mirror = param_.mirror() && Rand(2);
const bool has_mean_file = param_.has_mean_file();
const bool has_mean_values = mean_values_.size() > 0;
CHECK_GT(img_channels, 0);
CHECK_GE(img_height, crop_size);
CHECK_GE(img_width, crop_size);
Dtype* mean = NULL;
if (has_mean_file) {
CHECK_EQ(img_channels, data_mean_.channels());
CHECK_EQ(img_height, data_mean_.height());
CHECK_EQ(img_width, data_mean_.width());
mean = data_mean_.mutable_cpu_data();
}
if (has_mean_values) {
CHECK(mean_values_.size() == 1 || mean_values_.size() == img_channels) <<
"Specify either 1 mean_value or as many as channels: " << img_channels;
if (img_channels > 1 && mean_values_.size() == 1) {
// Replicate the mean_value for simplicity(复制均值数值,便于操作)
for (int c = 1; c < img_channels; ++c) {
mean_values_.push_back(mean_values_[0]);
}
}
}
int h_off = 0;
int w_off = 0;
cv::Mat cv_cropped_img = cv_img;
if (crop_size) {
CHECK_EQ(crop_size, height);
CHECK_EQ(crop_size, width);
// We only do random crop when we do training.
if (phase_ == TRAIN) {
h_off = Rand(img_height - crop_size + 1);
w_off = Rand(img_width - crop_size + 1);
} else {
h_off = (img_height - crop_size) / 2;
w_off = (img_width - crop_size) / 2;
}
cv::Rect roi(w_off, h_off, crop_size, crop_size);
cv_cropped_img = cv_img(roi);
} else {
CHECK_EQ(img_height, height);
CHECK_EQ(img_width, width);
}
CHECK(cv_cropped_img.data);
Dtype* transformed_data = transformed_blob->mutable_cpu_data();
int top_index;
for (int h = 0; h < height; ++h) {
const uchar* ptr = cv_cropped_img.ptr<uchar>(h);
int img_index = 0;
for (int w = 0; w < width; ++w) {
for (int c = 0; c < img_channels; ++c) {
if (do_mirror) {
top_index = (c * height + h) * width + (width - 1 - w);
} else {
top_index = (c * height + h) * width + w;
}
// int top_index = (c * height + h) * width + w;
Dtype pixel = static_cast<Dtype>(ptr[img_index++]);
if (has_mean_file) {
int mean_index = (c * img_height + h_off + h) * img_width + w_off + w;
transformed_data[top_index] =
(pixel - mean[mean_index]) * scale;
} else {
if (has_mean_values) {
transformed_data[top_index] =
(pixel - mean_values_[c]) * scale;
} else {
transformed_data[top_index] = pixel * scale;
}
}
}
}
}
}
#endif // USE_OPENCV
//输入是Blob,输出也是Blob
template<typename Dtype>
void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
Blob<Dtype>* transformed_blob) {
const int crop_size = param_.crop_size();
const int input_num = input_blob->num();
const int input_channels = input_blob->channels();
const int input_height = input_blob->height();
const int input_width = input_blob->width();
if (transformed_blob->count() == 0) {
// Initialize transformed_blob with the right shape.(初始化变换后的Blob的形状)
if (crop_size) {
transformed_blob->Reshape(input_num, input_channels,
crop_size, crop_size);
} else {
transformed_blob->Reshape(input_num, input_channels,
input_height, input_width);
}
}
const int num = transformed_blob->num();
const int channels = transformed_blob->channels();
const int height = transformed_blob->height();
const int width = transformed_blob->width();
const int size = transformed_blob->count();
CHECK_LE(input_num, num);
CHECK_EQ(input_channels, channels);
CHECK_GE(input_height, height);
CHECK_GE(input_width, width);
const Dtype scale = param_.scale();
const bool do_mirror = param_.mirror() && Rand(2);
const bool has_mean_file = param_.has_mean_file();
const bool has_mean_values = mean_values_.size() > 0;
int h_off = 0;
int w_off = 0;
if (crop_size) {
CHECK_EQ(crop_size, height);
CHECK_EQ(crop_size, width);
// We only do random crop when we do training.
if (phase_ == TRAIN) {
h_off = Rand(input_height - crop_size + 1);
w_off = Rand(input_width - crop_size + 1);
} else {
h_off = (input_height - crop_size) / 2;
w_off = (input_width - crop_size) / 2;
}
} else {
CHECK_EQ(input_height, height);
CHECK_EQ(input_width, width);
}
Dtype* input_data = input_blob->mutable_cpu_data();
if (has_mean_file) {
CHECK_EQ(input_channels, data_mean_.channels());
CHECK_EQ(input_height, data_mean_.height());
CHECK_EQ(input_width, data_mean_.width());
for (int n = 0; n < input_num; ++n) {
int offset = input_blob->offset(n);
caffe_sub(data_mean_.count(), input_data + offset,
data_mean_.cpu_data(), input_data + offset);
}
}
if (has_mean_values) {
CHECK(mean_values_.size() == 1 || mean_values_.size() == input_channels) <<
"Specify either 1 mean_value or as many as channels: " << input_channels;
if (mean_values_.size() == 1) {
caffe_add_scalar(input_blob->count(), -(mean_values_[0]), input_data);
} else {
for (int n = 0; n < input_num; ++n) {
for (int c = 0; c < input_channels; ++c) {
int offset = input_blob->offset(n, c);
caffe_add_scalar(input_height * input_width, -(mean_values_[c]),
input_data + offset);
}
}
}
}
Dtype* transformed_data = transformed_blob->mutable_cpu_data();
for (int n = 0; n < input_num; ++n) {
int top_index_n = n * channels;
int data_index_n = n * channels;
for (int c = 0; c < channels; ++c) {
int top_index_c = (top_index_n + c) * height;
int data_index_c = (data_index_n + c) * input_height + h_off;
for (int h = 0; h < height; ++h) {
int top_index_h = (top_index_c + h) * width;
int data_index_h = (data_index_c + h) * input_width + w_off;
if (do_mirror) {
int top_index_w = top_index_h + width - 1;
for (int w = 0; w < width; ++w) {
transformed_data[top_index_w-w] = input_data[data_index_h + w];
}
} else {
for (int w = 0; w < width; ++w) {
transformed_data[top_index_h + w] = input_data[data_index_h + w];
}
}
}
}
}
if (scale != Dtype(1)) {
DLOG(INFO) << "Scale: " << scale;
caffe_scal(size, scale, transformed_data);
}
}
//获得数据变换输出Blob尺寸
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const Datum& datum) {
if (datum.encoded()) {
#ifdef USE_OPENCV
CHECK(!(param_.force_color() && param_.force_gray()))
<< "cannot set both force_color and force_gray";
cv::Mat cv_img;
if (param_.force_color() || param_.force_gray()) {
// If force_color then decode in color otherwise decode in gray.
cv_img = DecodeDatumToCVMat(datum, param_.force_color());
} else {
cv_img = DecodeDatumToCVMatNative(datum);
}
// InferBlobShape using the cv::image.
return InferBlobShape(cv_img);
#else
LOG(FATAL) << "Encoded datum requires OpenCV; compile with USE_OPENCV.";
#endif // USE_OPENCV
}
const int crop_size = param_.crop_size();
const int datum_channels = datum.channels();
const int datum_height = datum.height();
const int datum_width = datum.width();
// Check dimensions.
CHECK_GT(datum_channels, 0);
CHECK_GE(datum_height, crop_size);
CHECK_GE(datum_width, crop_size);
// Build BlobShape.
vector<int> shape(4);
shape[0] = 1;
shape[1] = datum_channels;
shape[2] = (crop_size)? crop_size: datum_height;
shape[3] = (crop_size)? crop_size: datum_width;
return shape;
}
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(
const vector<Datum> & datum_vector) {
const int num = datum_vector.size();
CHECK_GT(num, 0) << "There is no datum to in the vector";
// Use first datum in the vector to InferBlobShape.
vector<int> shape = InferBlobShape(datum_vector[0]);
// Adjust num to the size of the vector.
shape[0] = num;
return shape;
}
#ifdef USE_OPENCV
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(const cv::Mat& cv_img) {
const int crop_size = param_.crop_size();
const int img_channels = cv_img.channels();
const int img_height = cv_img.rows;
const int img_width = cv_img.cols;
// Check dimensions.
CHECK_GT(img_channels, 0);
CHECK_GE(img_height, crop_size);
CHECK_GE(img_width, crop_size);
// Build BlobShape.
vector<int> shape(4);
shape[0] = 1;
shape[1] = img_channels;
shape[2] = (crop_size)? crop_size: img_height;
shape[3] = (crop_size)? crop_size: img_width;
return shape;
}
template<typename Dtype>
vector<int> DataTransformer<Dtype>::InferBlobShape(
const vector<cv::Mat> & mat_vector) {
const int num = mat_vector.size();
CHECK_GT(num, 0) << "There is no cv_img to in the vector";
// Use first cv_img in the vector to InferBlobShape.
vector<int> shape = InferBlobShape(mat_vector[0]);
// Adjust num to the size of the vector.
shape[0] = num;
return shape;
}
#endif // USE_OPENCV
//初始化随机数种子
template <typename Dtype>
void DataTransformer<Dtype>::InitRand() {
//如果在初始化参数中要求对输入进行随机镜像操作,或者在训练阶段需要随机切块,那么需要初始化随机数种子
const bool needs_rand = param_.mirror() ||
(phase_ == TRAIN && param_.crop_size());
if (needs_rand) {
const unsigned int rng_seed = caffe_rng_rand();
rng_.reset(new Caffe::RNG(rng_seed));
} else {
rng_.reset();
}
}
//生成0~n-1之间的随机数
template <typename Dtype>
int DataTransformer<Dtype>::Rand(int n) {
CHECK(rng_);
CHECK_GT(n, 0);
caffe::rng_t* rng =
static_cast<caffe::rng_t*>(rng_->generator());
return ((*rng)() % n);
}
INSTANTIATE_CLASS(DataTransformer);
} // namespace caffe