-
Notifications
You must be signed in to change notification settings - Fork 14
Expand file tree
/
Copy pathtrainThreads.lua
More file actions
184 lines (130 loc) · 4.5 KB
/
Copy pathtrainThreads.lua
File metadata and controls
184 lines (130 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
----------------------------------------------------------------------
-- SGD training of the network
----------------------------------------------------------------------
require 'optim'
require 'xlua'
require 'cutorch'
require 'cunn'
require 'VideoOptFlow'
local threads = require 'threads'
----------------------------------------------------------------------
-- parse command line arguments
if not opt then
print '==> processing options'
cmd = torch.CmdLine()
cmd:text()
cmd:text('Options:')
cmd:option('-save', '/usr/local/data/jtaylor/Deep/save', 'subdirectory to save/log experiments in')
cmd:option('-LR', 1e-3, 'learning rate at t=0')
cmd:option('-LRDecay', 1e-5, 'learning rate decay')
cmd:option('-momentum', 0.9, 'momentum')
cmd:option('-weightDecay',1e-7,'weight decay')
cmd:option('-workers',2,'threads for asynchronous loading/training')
cmd:text()
opt = cmd:parse(arg or {})
end
-- training logs
trainLogger = optim.Logger(paths.concat(opt.save, 'train.log'))
-- get model parameters
parameters,gradParameters = model:getParameters()
-- configure SGD
optimState = {
learningRate = opt.LR,
weightDecay = opt.weightDecay,
momentum = opt.momentum,
learningRateDecay = opt.LRDecay
}
optimMethod = optim.sgd
printFreq = 10 -- freq to print confusion matrix during epoch
paths.dofile('dataset.lua')
loader = dataLoader(opt.fps,opt.datapath)
do
local gloader = dataLoader(opt.fps,opt.datapath)
pool = threads.Threads(opt.workers,
function ()
require 'torch'
paths.dofile('dataset.lua')
end,
function ()
threadLoader = gloader
end
)
end
function train()
-- time stuff
local time = sys.clock()
epoch = epoch or 1
--confusion = optim.ConfusionMatrix({1,2})
confusion = optim.ConfusionMatrix(loader.classes)
-- set model to training mode (for modules that differ in training and testing, like Dropout)
model:training()
-- shuffle at each epoch
shuffle = torch.randperm(#loader.trainIndeces)
print("- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -")
print("==> online epoch # " .. epoch)
for t = 1,#loader.trainIndeces do
-- progress bar
xlua.progress(t, #loader.trainIndeces)
pool:addjob(
function(idx)
local inputCPU, labelsCPU = threadLoader:get(idx)
return inputCPU, labelsCPU
end,
function(inputCPU,labelsCPU)
if #inputCPU > 1 then
local labels = labelsCPU:cuda()
local input = {}
for i = 1,#inputCPU do
input[i] = inputCPU[i]:cuda()
end
local feval = function(x)
-- append optical flow to input data
-- note: can't run optical flow calcs in thread callback since it uses the gpu
-- the workers may write to overlapping memory on the gpu
if opt.optflow then
input = VideoOptFlow(input)
end
--print(input)
gradParameters:zero()
model:forget()
local output = model:forward(input)
local err = criterion:forward(output,labels)
for i = 1,#output do
confusion:add(output[i],labels[i])
end
local gradOutputs = criterion:backward(output,labels)
model:backward(input,gradOutputs)
gradParameters:div(#input)
err = err/#input
return err,gradParameters
end
optimMethod(feval,parameters,optimState)
model:forget()
end
end,
loader.trainIndeces[shuffle[t]]
)
--[[
-- print updates periodically throughout epoch
if t%printFreq==0 then
--print(confusion)
confusion:updateValids()
print('mean class accuracy = ' .. confusion.totalValid*100 .. '%')
end
--]]
end
-- finish all threads training before continuing
pool:synchronize()
-- time taken
time = sys.clock()-time
print("\n==> training time = " .. (time*1000) .. 'ms')
--print(confusion)
confusion:updateValids()
print('Average row correct: ' .. (confusion.averageValid*100) .. '%')
print('Average rowUcol correct (VOC measure): ' .. (confusion.averageUnionValid*100) .. '%')
print('Global correct: ' .. (confusion.totalValid*100) .. '%')
-- update training log
trainLogger:add{['% mean class accuracy (train set)'] = confusion.totalValid*100}
-- return global accuracy
return 100*confusion.totalValid
end