Skip to main content

kaldi editing nnet3 chain model - using the auxiliary xent output as the main output

I had a task to edit a trained kaldi nnet3 chain model so that the output node is the output-xent instead the original output.

First, let's look at the nnet structure:

nnet3-am-info final.mdl
input-dim: 20
ivector-dim: -1
num-pdfs: 6105
prior-dimension: 0
# Nnet info follows.
left-context: 15
right-context: 15
num-parameters: 15499085
modulus: 1
input-node name=input dim=20
component-node name=L0_fixaffine component=L0_fixaffine input=Append(Offset(input, -1), input, Offset(input, 1)) input-dim=60 output-dim=60
component-node name=Tdnn_0_affine component=Tdnn_0_affine input=L0_fixaffine input-dim=60 output-dim=625
component-node name=Tdnn_0_relu component=Tdnn_0_relu input=Tdnn_0_affine input-dim=625 output-dim=625
component-node name=Tdnn_0_renorm component=Tdnn_0_renorm input=Tdnn_0_relu input-dim=625 output-dim=625
component-node name=Tdnn_1_affine component=Tdnn_1_affine input=Append(Offset(Tdnn_0_renorm, -1), Tdnn_0_renorm, Offset(Tdnn_0_renorm, 1)) input-dim=1875 output-dim=625
component-node name=Tdnn_1_relu component=Tdnn_1_relu input=Tdnn_1_affine input-dim=625 output-dim=625
component-node name=Tdnn_1_renorm component=Tdnn_1_renorm input=Tdnn_1_relu input-dim=625 output-dim=625
component-node name=Tdnn_2_affine component=Tdnn_2_affine input=Append(Offset(Tdnn_1_renorm, -1), Tdnn_1_renorm, Offset(Tdnn_1_renorm, 1)) input-dim=1875 output-dim=625
component-node name=Tdnn_2_relu component=Tdnn_2_relu input=Tdnn_2_affine input-dim=625 output-dim=625
component-node name=Tdnn_2_renorm component=Tdnn_2_renorm input=Tdnn_2_relu input-dim=625 output-dim=625
component-node name=Tdnn_3_affine component=Tdnn_3_affine input=Append(Offset(Tdnn_2_renorm, -3), Tdnn_2_renorm, Offset(Tdnn_2_renorm, 3)) input-dim=1875 output-dim=625
component-node name=Tdnn_3_relu component=Tdnn_3_relu input=Tdnn_3_affine input-dim=625 output-dim=625
component-node name=Tdnn_3_renorm component=Tdnn_3_renorm input=Tdnn_3_relu input-dim=625 output-dim=625
component-node name=Tdnn_4_affine component=Tdnn_4_affine input=Append(Offset(Tdnn_3_renorm, -3), Tdnn_3_renorm, Offset(Tdnn_3_renorm, 3)) input-dim=1875 output-dim=625
component-node name=Tdnn_4_relu component=Tdnn_4_relu input=Tdnn_4_affine input-dim=625 output-dim=625
component-node name=Tdnn_4_renorm component=Tdnn_4_renorm input=Tdnn_4_relu input-dim=625 output-dim=625
component-node name=Tdnn_5_affine component=Tdnn_5_affine input=Append(Offset(Tdnn_4_renorm, -3), Tdnn_4_renorm, Offset(Tdnn_4_renorm, 3)) input-dim=1875 output-dim=625
component-node name=Tdnn_5_relu component=Tdnn_5_relu input=Tdnn_5_affine input-dim=625 output-dim=625
component-node name=Tdnn_5_renorm component=Tdnn_5_renorm input=Tdnn_5_relu input-dim=625 output-dim=625
component-node name=Tdnn_6_affine component=Tdnn_6_affine input=Append(Offset(Tdnn_5_renorm, -3), Tdnn_5_renorm, Offset(Tdnn_5_renorm, 3)) input-dim=1875 output-dim=625
component-node name=Tdnn_6_relu component=Tdnn_6_relu input=Tdnn_6_affine input-dim=625 output-dim=625
component-node name=Tdnn_6_renorm component=Tdnn_6_renorm input=Tdnn_6_relu input-dim=625 output-dim=625
component-node name=Tdnn_pre_final_chain_affine component=Tdnn_pre_final_chain_affine input=Tdnn_6_renorm input-dim=625 output-dim=625
component-node name=Tdnn_pre_final_chain_relu component=Tdnn_pre_final_chain_relu input=Tdnn_pre_final_chain_affine input-dim=625 output-dim=625
component-node name=Tdnn_pre_final_chain_renorm component=Tdnn_pre_final_chain_renorm input=Tdnn_pre_final_chain_relu input-dim=625 output-dim=625
component-node name=Tdnn_pre_final_xent_affine component=Tdnn_pre_final_xent_affine input=Tdnn_6_renorm input-dim=625 output-dim=625
component-node name=Tdnn_pre_final_xent_relu component=Tdnn_pre_final_xent_relu input=Tdnn_pre_final_xent_affine input-dim=625 output-dim=625
component-node name=Tdnn_pre_final_xent_renorm component=Tdnn_pre_final_xent_renorm input=Tdnn_pre_final_xent_relu input-dim=625 output-dim=625
component-node name=Final_affine component=Final_affine input=Tdnn_pre_final_chain_renorm input-dim=625 output-dim=6105
output-node name=output input=Final_affine dim=6105 objective=linear
component-node name=Final-xent_affine component=Final-xent_affine input=Tdnn_pre_final_xent_renorm input-dim=625 output-dim=6105
component-node name=Final-xent_log_softmax component=Final-xent_log_softmax input=Final-xent_affine input-dim=6105 output-dim=6105
output-node name=output-xent input=Final-xent_log_softmax dim=6105 objective=linear
...

The solution removes chain output node, removes orphans as some of the nodes are not necessary anymore, and finally renames the xent output. Now we can do the edits:

nnet3-am-copy --edits='remove-output-nodes name=output;remove-orphans;rename-node old-name=output-xent new-name=output' final.mdl final.xent.mdl

LOG (nnet3-am-copy:ReadEditConfig():nnet-utils.cc:687) Removing 1 output nodes.
LOG (nnet3-am-copy:RemoveSomeNodes():nnet-nnet.cc:885) Removed 1 orphan nodes.
LOG (nnet3-am-copy:RemoveSomeNodes():nnet-nnet.cc:885) Removed 8 orphan nodes.
LOG (nnet3-am-copy:RemoveOrphanComponents():nnet-nnet.cc:810) Removing 4 orphan components.
LOG (nnet3-am-copy:main():nnet3-am-copy.cc:156) Copied neural net from final.mdl to final.xent.mdl


Finally, let's look at the modified structure:

nnet3-am-info final.xent.mdl
input-dim: 20
ivector-dim: -1
num-pdfs: 6105
prior-dimension: 0
# Nnet info follows.
left-context: 15
right-context: 15
num-parameters: 11286105
modulus: 1
input-node name=input dim=20
component-node name=L0_fixaffine component=L0_fixaffine input=Append(Offset(input, -1), input, Offset(input, 1)) input-dim=60 output-dim=60
component-node name=Tdnn_0_affine component=Tdnn_0_affine input=L0_fixaffine input-dim=60 output-dim=625
component-node name=Tdnn_0_relu component=Tdnn_0_relu input=Tdnn_0_affine input-dim=625 output-dim=625
component-node name=Tdnn_0_renorm component=Tdnn_0_renorm input=Tdnn_0_relu input-dim=625 output-dim=625
component-node name=Tdnn_1_affine component=Tdnn_1_affine input=Append(Offset(Tdnn_0_renorm, -1), Tdnn_0_renorm, Offset(Tdnn_0_renorm, 1)) input-dim=1875 output-dim=625
component-node name=Tdnn_1_relu component=Tdnn_1_relu input=Tdnn_1_affine input-dim=625 output-dim=625
component-node name=Tdnn_1_renorm component=Tdnn_1_renorm input=Tdnn_1_relu input-dim=625 output-dim=625
component-node name=Tdnn_2_affine component=Tdnn_2_affine input=Append(Offset(Tdnn_1_renorm, -1), Tdnn_1_renorm, Offset(Tdnn_1_renorm, 1)) input-dim=1875 output-dim=625
component-node name=Tdnn_2_relu component=Tdnn_2_relu input=Tdnn_2_affine input-dim=625 output-dim=625
component-node name=Tdnn_2_renorm component=Tdnn_2_renorm input=Tdnn_2_relu input-dim=625 output-dim=625
component-node name=Tdnn_3_affine component=Tdnn_3_affine input=Append(Offset(Tdnn_2_renorm, -3), Tdnn_2_renorm, Offset(Tdnn_2_renorm, 3)) input-dim=1875 output-dim=625
component-node name=Tdnn_3_relu component=Tdnn_3_relu input=Tdnn_3_affine input-dim=625 output-dim=625
component-node name=Tdnn_3_renorm component=Tdnn_3_renorm input=Tdnn_3_relu input-dim=625 output-dim=625
component-node name=Tdnn_4_affine component=Tdnn_4_affine input=Append(Offset(Tdnn_3_renorm, -3), Tdnn_3_renorm, Offset(Tdnn_3_renorm, 3)) input-dim=1875 output-dim=625
component-node name=Tdnn_4_relu component=Tdnn_4_relu input=Tdnn_4_affine input-dim=625 output-dim=625
component-node name=Tdnn_4_renorm component=Tdnn_4_renorm input=Tdnn_4_relu input-dim=625 output-dim=625
component-node name=Tdnn_5_affine component=Tdnn_5_affine input=Append(Offset(Tdnn_4_renorm, -3), Tdnn_4_renorm, Offset(Tdnn_4_renorm, 3)) input-dim=1875 output-dim=625
component-node name=Tdnn_5_relu component=Tdnn_5_relu input=Tdnn_5_affine input-dim=625 output-dim=625
component-node name=Tdnn_5_renorm component=Tdnn_5_renorm input=Tdnn_5_relu input-dim=625 output-dim=625
component-node name=Tdnn_6_affine component=Tdnn_6_affine input=Append(Offset(Tdnn_5_renorm, -3), Tdnn_5_renorm, Offset(Tdnn_5_renorm, 3)) input-dim=1875 output-dim=625
component-node name=Tdnn_6_relu component=Tdnn_6_relu input=Tdnn_6_affine input-dim=625 output-dim=625
component-node name=Tdnn_6_renorm component=Tdnn_6_renorm input=Tdnn_6_relu input-dim=625 output-dim=625
component-node name=Tdnn_pre_final_xent_affine component=Tdnn_pre_final_xent_affine input=Tdnn_6_renorm input-dim=625 output-dim=625
component-node name=Tdnn_pre_final_xent_relu component=Tdnn_pre_final_xent_relu input=Tdnn_pre_final_xent_affine input-dim=625 output-dim=625
component-node name=Tdnn_pre_final_xent_renorm component=Tdnn_pre_final_xent_renorm input=Tdnn_pre_final_xent_relu input-dim=625 output-dim=625
component-node name=Final-xent_affine component=Final-xent_affine input=Tdnn_pre_final_xent_renorm input-dim=625 output-dim=6105
component-node name=Final-xent_log_softmax component=Final-xent_log_softmax input=Final-xent_affine input-dim=6105 output-dim=6105
output-node name=output input=Final-xent_log_softmax dim=6105 objective=linear
...


Cool, it worked!

Comments