1、caffecpp代码解析main()函数-GetBrewFunction函数-train函数-Solve()#ifdef WITH_PYTHON_LAYER#include boost/python.hppnamespace bp = boost:python;#endif#include #include #include #include #include #include #include boost/algorithm/string.hpp#include caffe/caffe.hpp#include caffe/util/signal_handler.husing caffe:Bl
2、ob;using caffe:Caffe;using caffe:Net;using caffe:Layer;using caffe:Solver;using caffe:shared_ptr;using caffe:string;using caffe:Timer;using caffe:vector;using std:ostringstream;/*gflags是google的一个开源的处理命令行参数的库。 在使用命令行参数的文件文件中(源文件或头文件),首先使用一下定义语句进行变量的定义。 DEFINE_int32,DEFINE_int64,DEFINE_bool,DEFINE_dou
3、ble,DEFINE_string等, 语法为:DEFINE_int32(name, default_value, description)。 接着你就可以使用FLAGS_name变量了,这些变量的值则是由命令行参数传递,无则为默认值, 在其他代码文件中若想用该命令参数,可以用DECLARE_int32(name)声明(name为int32类型,也可以使用其他支持的类型)。 在caffe.cpp中有很多FLAGS_name定义,如DEFINE_string(gpu,some description),则命令行后-gpu 0,表示FLAGS_gpu=0,默认值为空。*/DEFINE_string
4、(gpu, , Optional; run in GPU mode on given device IDs separated by ,. Use -gpu all to run on all available GPUs. The effective training batch size is multiplied by the number of devices.);DEFINE_string(solver, , The solver definition protocol buffer text file.);DEFINE_string(model, , The model defin
5、ition protocol buffer text file.);DEFINE_string(phase, , Optional; network phase (TRAIN or TEST). Only used for time.);DEFINE_int32(level, 0, Optional; network level.);DEFINE_string(stage, , Optional; network stages (not to be confused with phase), separated by ,.);DEFINE_string(snapshot, , Optional
6、; the snapshot solver state to resume training.);DEFINE_string(weights, , Optional; the pretrained weights to initialize finetuning, separated by ,. Cannot be set simultaneously with snapshot.);DEFINE_int32(iterations, 50, The number of iterations to run.);DEFINE_string(sigint_effect, stop, Optional
7、; action to take when a SIGINT signal is received: snapshot, stop or none.);DEFINE_string(sighup_effect, snapshot, Optional; action to take when a SIGHUP signal is received: snapshot, stop or none.);/ A simple registry for caffe commands.typedef int (*BrewFunction)();/*声明了一个BrewFunction函数指针类型,可以用它来定
8、义一个函数指针*/typedef std:map BrewMap;/*声明了一个BrewFunction函数指针类型,可以用它来定义一个函数指针,创建了一个名为BrewMap,并且包含caffe:string类型数据的map空对象,该对象使用BrewFunction函数来对集合中的元素进行排序*/BrewMap g_brew_map; /*定义key为string的map容器实例*/*这里巧妙的用宏定义的方式声明了分别包含train(),test(), device_query(),time()四个函数的四个不同类*/ /*理解这个关键理解宏在预编译阶段是如何被展开*/#define Regi
9、sterBrewFunction(func) namespace class _Registerer_#func public: /* NOLINT */ _Registerer_#func() g_brew_map#func = &func; ; _Registerer_#func g_registerer_#func; /在C/C+的宏中,#的功能是将其后面的宏参数进行字符串化操作(Stringfication),简单说就是在对它所引用的宏变量通过替换后在其左右各加上一个双引号。 /”#”被称为连接符(concatenator),用来将两个子串Token连接为一个Token。注意这里连接的
10、对象是Token就行,而不一定是宏的变量。 所谓的子串(token)就是指编译器能够识别的最小语法单元举例说明#用法:#define PRINT( n ) printf( token #n = %d, token#n )同时又定义了二个整形变量:int token9 = 9;现在在主程序中以下面的方式调用这个宏:PRINT( 9 );那么在编译时,上面的这句话被扩展为:printf( token 9 = %d, token9 );注意到在这个例子中,PRINT(9);中的这个”9”被原封不动的当成了一个字符串,与”token”连接在了一起,从而成为了token9。而#n也被”9”所替代。/*在
11、caffe.cpp 中 BrewFunction 作为GetBrewFunction()函数的返回类型, 可以是 train(),test(),device_query(),time() 这四个函数指针的其中一个。 在train(),test(),中可以调用solver类的函数,从而进入到net,进入到每一层,运行整个caffe程序。*/*1:加了static后表示该函数失去了全局可见性,只在该函数所在的文件作用域内可见 2:当函数声明为static以后,编译器在该目标编译单元内只含有该函数的入口地址,没有函数名,其它编译单元便不能通过该函数名来调用该函数,这也是对1的解析与说明*/stati
12、c BrewFunction GetBrewFunction(const caffe:string& name) if (g_brew_map.count(name) /判断输入的是不是g_brew_map中train,test,device_query,time中一个 return g_brew_mapname; /如果是的话,就调用相应的train(),test(),device_query(),time() else LOG(ERROR) Available caffe actions:; for (BrewMap:iterator it = g_brew_map.begin(); it
13、 != g_brew_map.end(); +it) LOG(ERROR) t first; /LOG来源于google的glog库,控制程序的日志输出消息和测试消息(根据不同的level输出消息) LOG(FATAL) Unknown action: name; return NULL; / not reachable, just to suppress old compiler warnings. / Parse GPU ids or use all available devices/解析可用GPU,使用所有可用硬件static void get_gpus(vector* gpus) i
14、f (FLAGS_gpu = all) int count = 0;#ifndef CPU_ONLY CUDA_CHECK(cudaGetDeviceCount(&count);#else NO_GPU;#endif for (int i = 0; i push_back(i); else if (FLAGS_gpu.size() vector strings; boost:split(strings, FLAGS_gpu, boost:is_any_of(,); for (int i = 0; i push_back(boost:lexical_cast(stringsi); else CH
15、ECK_EQ(gpus-size(), 0); / Parse phase from flagscaffe:Phase get_phase_from_flags(caffe:Phase default_value) if (FLAGS_phase = ) return default_value; if (FLAGS_phase = TRAIN) return caffe:TRAIN; if (FLAGS_phase = TEST) return caffe:TEST; LOG(FATAL) phase must be TRAIN or TEST; return caffe:TRAIN; /
16、Avoid warning/ Parse stages from flagsvector get_stages_from_flags() vector stages; boost:split(stages, FLAGS_stage, boost:is_any_of(,); return stages;/ caffe commands to call by/ caffe / To add a command, define a function int command() and register it with/ RegisterBrewFunction(action);/ Device Qu
17、ery: show diagnostic information for a GPU device.int device_query() /*这里定义device_query函数*/ LOG(INFO) Querying GPUs FLAGS_gpu; vector gpus; get_gpus(&gpus); /*获得有几个GPU*/ for (int i = 0; i gpus.size(); +i) /*依次查询每个GPU信息*/ caffe:Caffe:SetDevice(gpusi); caffe:Caffe:DeviceQuery(); return 0; RegisterBrew
18、Function(device_query); /*这里通过预编译阶段的宏替换,将定义的device_query函数指针赋值到map容器中*/ /*加载训练的或者传入的模型*/ Load the weights from the specified caffemodel(s) into the train and/ test nets.void CopyLayers(caffe:Solver* solver, const std:string& model_list) std:vector model_names; boost:split(model_names, model_list, bo
19、ost:is_any_of(,) ); for (int i = 0; i model_names.size(); +i) LOG(INFO) Finetuning from net()-CopyTrainedLayersFrom(model_namesi); for (int j = 0; j test_nets().size(); +j) solver-test_nets()j-CopyTrainedLayersFrom(model_namesi); /将交互端传来的string类型的标志转成枚举类型的变量/ Translate the signal effect the user spe
20、cified on the command-line to the/ corresponding enumeration.caffe:SolverAction:Enum GetRequestedAction( const std:string& flag_value) if (flag_value = stop) return caffe:SolverAction:STOP; if (flag_value = snapshot) return caffe:SolverAction:SNAPSHOT; if (flag_value = none) return caffe:SolverActio
21、n:NONE; LOG(FATAL) Invalid signal effect flag_value was specified; return caffe:SolverAction:NONE;/ Train / Finetune a model./*训练或者微调网络都是走这个分支*/int train() /*定义train函数*/ google的glog库,检查-solver、-snapshot和-weight并输出消息;必须有指定solver,并且snapshot和weight两者只需指定其一; CHECK_GT(FLAGS_solver.size(), 0) Need a solve
22、r definition to train.; CHECK(!FLAGS_snapshot.size() | !FLAGS_weights.size() Give a snapshot to resume training or weights to finetune but not both.; vector stages = get_stages_from_flags();/*实例化SolverParameter类,该类保存solver参数和相应的方法,SolverParameter是通过Google Protocol Buffer自动生成的一个类*/ caffe:SolverParame
23、ter solver_param; /*定义SolverParameter的对象,该类保存solver参数和相应的方法*/ caffe:ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param); solver_param.mutable_train_state()-set_level(FLAGS_level); for (int i = 0; i add_stage(stagesi); / If the gpus flag is not provided, allow the mode and device to be set
24、 / in the solver prototxt. if (FLAGS_gpu.size() = 0 /根据命令参数-gpu或者solver.prototxt提供的信息设置GPU & solver_param.solver_mode() = caffe:SolverParameter_SolverMode_GPU) if (solver_param.has_device_id() FLAGS_gpu = + boost:lexical_cast(solver_param.device_id(); else / Set default GPU if unspecified FLAGS_gpu
25、= + boost:lexical_cast(0); /boost:lexical_cast(0)是将数值0转换为字符串“0”; /*上述代码: 首先是判断用户在Command Line中是否输入了gpu相关的参数, 如果没有(FLAGS_gpu.size()=0)但是用户在solver的prototxt定义中提供了相关的参数, 那就把相关的参数放到FLAGS_gpu中,如果用户仅仅是选择了在solver的prototxt定义中选择了GPU模式, 但是没有指明具体的gpu_id,那么就默认设置为0。*/ /多GPU下,将GPU编号存入vector容器中(get_gpus()函数通过FLAGS_
26、gpu获取) vector gpus; get_gpus(&gpus); if (gpus.size() = 0) LOG(INFO) Use CPU.; Caffe:set_mode(Caffe:CPU); else ostringstream s; for (int i = 0; i gpus.size(); +i) s (i ? , : ) gpusi; LOG(INFO) Using GPUs s.str();#ifndef CPU_ONLY cudaDeviceProp device_prop; for (int i = 0; i gpus.size(); +i) cudaGetDe
27、viceProperties(&device_prop, gpusi); LOG(INFO) GPU gpusi : device_prop.name; #endif solver_param.set_device_id(gpus0); Caffe:SetDevice(gpus0); Caffe:set_mode(Caffe:GPU); Caffe:set_solver_count(gpus.size(); /处理snapshot, stop or none信号,其声明在include/caffe/util/signal_Handler.h中; /GetRequestedAction在caff
28、e.cpp中,将stop,snapshot,none转换为标准信号,即解析; caffe:SignalHandler signal_handler( GetRequestedAction(FLAGS_sigint_effect), GetRequestedAction(FLAGS_sighup_effect);/声明boost库中智能指针solver,指向caffe:Solver对象,该对象由CreateSolver创建,后续细讲; shared_ptrcaffe:Solver solver(caffe:SolverRegistry:CreateSolver(solver_param);/*通过GetActionFunction来处理获得的系统信号*/ /*在SetActionFunction中将GetActionFunction函数地址传给参数action_request_function_*/ /*在网络训练的过程中,在GetRequestedAction中来处理action_request_function_得到的函数指针*/ s
copyright@ 2008-2022 冰豆网网站版权所有
经营许可证编号:鄂ICP备2022015515号-1