-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmodel.lua
More file actions
156 lines (123 loc) · 4.02 KB
/
model.lua
File metadata and controls
156 lines (123 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
require 'nn'
local nn_extras = require('DeepUtils/nn_extras')
require('pl.stringx').import()
local file = require('pl.file')
local List = require('pl.List')
-- Parse cmd line
local opt = require('cmdlineargs')
local function optget(tbl, option, message)
local ret = tbl[option]
assert(ret ~= nil, message)
return ret
end
-- Parse nonlinearity
local nltbl = {
['tanh'] = nn.Tanh,
['relu'] = nn.ReLU,
['prelu'] = nn.PReLU,
}
local nonlinearity = optget(nltbl, opt.nonlinearity, "Invalid nonlinearity of "..opt.nonlinearity)
-- Let's assume a length M sequence
-- N is AA embed sizes + psi sizes
local function create_common_model(seq)
local model = seq
local prl = nn.ParallelTable()
opt.processed_dim = 0
local aadict = file.read(path.join(opt.hashDir, "aa1.lst"))
local aaxsize = #(aadict:splitlines())
local wordlookup = nn.LookupTable(aaxsize, opt.AAEmbedSize)
prl:add(wordlookup)
opt.processed_dim = opt.processed_dim + opt.AAEmbedSize
for i=1,opt.PSINum do
prl:add(nn.UpDim(nn.UpDim.END))
opt.processed_dim = opt.processed_dim + 1
end
-- Table of N M-lengthed sequences
model:add(prl)
model:add(nn.JoinTable(2))
-- MxN matrix
if opt.indropout > 0 then model:add(nn.Dropout(opt.indropout)) end
return model
end
-- Convolution Models
local function create_conv_base_model(seq)
local params = opt.modelparams
local pools = opt.pools
local conv_kernels = tablex.map(tonumber, params:split(','))
local pools = tablex.map(tonumber, pools:split(','))
assert(#pools == #conv_kernels, "num pools must equal num conv_kernels")
local model = seq
model:add(nn.UpDim())
opt.nhus[0] = opt.processed_dim
for i,_ in ipairs(opt.nhus) do
local kern = conv_kernels[i] or 1
local pool = pools[i] or 1
local padding = math.floor(kern/2)
if padding > 0 then
local concat = nn.ConcatTable()
for j = 0,pool-1 do
concat:add(nn.SpatialZeroPadding(0,0,padding-j,kern-padding-1+j+pool))
end
model:add(concat)
model:add(nn.JoinTable(1))
end
model:add(nn.TemporalConvolution(opt.nhus[i-1], opt.nhus[i], kern))
model:add(nonlinearity())
if pool > 1 then
model:add(nn.TemporalMaxPooling(pool))
end
if opt.dropout > 0 then model:add(nn.Dropout(opt.dropout)) end
end
model:add(nn.TemporalZip())
return model
end
local function create_conv_sub_model(nClass)
return nn.Sequential():add(nn.TemporalConvolution(opt.nhus[#opt.nhus], nClass, 1))
end
local function create_conv_model()
local model = nn.Sequential()
create_common_model(model)
create_conv_base_model(model)
local sub_mlp_concat = nn.ConcatTable()
for task in opt.subtasks:iter() do
sub_mlp_concat:add(create_conv_sub_model(opt.nClass[task]))
end
model:add(sub_mlp_concat)
return model
end
-- Criterions
local function nll_criterion(model)
local concat_table = model:get(model:size())
local criterion = nn.MultiCriterionTable()
for i,task in ipairs(opt.subtasks) do
criterion:add( nn.ClassNLLCriterion() )
concat_table:get(i):add(nn.LogSoftMax())
end
return criterion
end
local criterion_choice = {
['nll'] = nll_criterion,
}
local function getModel()
local model = create_conv_model()
local criterion = criterion_choice[opt.loss](model)
return model, criterion
end
local function finetune_model(model)
local sub_model_choice = {
['conv'] = create_conv_sub_model,
}
local sub_mlp_concat = nn.ConcatTable()
for task in opt.subtasks:iter() do
sub_mlp_concat:add(sub_model_choice[opt.model](opt.nClass[task]))
end
model.modules[1].modules[1]._input = torch.LongTensor()
model.modules[model:size()] = sub_mlp_concat
model.output = sub_mlp_concat.output
criterion = criterion_choice[opt.loss](model)
return model, criterion
end
return {
getModel=getModel,
finetune_model=finetune_model
}