@@ -5,14 +5,15 @@ import fs from 'fs-extra';
55import {
66 Settings , DesktopJob , RunPipeline , RunTraining ,
77 DesktopJobUpdater ,
8+ ExportTrainedPipeline ,
89} from 'platform/desktop/constants' ;
910import { cleanString } from 'platform/desktop/sharedUtils' ;
1011import { serialize } from 'platform/desktop/backend/serializers/viame' ;
1112import { observeChild } from 'platform/desktop/backend/native/processManager' ;
1213
1314import { MultiType , stereoPipelineMarker , multiCamPipelineMarkers } from 'dive-common/constants' ;
1415import * as common from './common' ;
15- import { jobFileEchoMiddleware , createWorkingDirectory } from './utils' ;
16+ import { jobFileEchoMiddleware , createWorkingDirectory , createCustomWorkingDirectory } from './utils' ;
1617import {
1718 getMultiCamImageFiles , getMultiCamVideoPath ,
1819 writeMultiCamStereoPipelineArgs ,
@@ -212,6 +213,98 @@ async function runPipeline(
212213 return jobBase ;
213214}
214215
216+ /**
217+ * a node.js implementation of dive_tasks.tasks.export_trained_model
218+ */
219+ async function exportTrainedPipeline ( settings : Settings ,
220+ exportTrainedPipelineArgs : ExportTrainedPipeline ,
221+ updater : DesktopJobUpdater ,
222+ validateViamePath : ( settings : Settings ) => Promise < true | string > ,
223+ viameConstants : ViameConstants ,
224+ ) : Promise < DesktopJob > {
225+ const { path, pipeline } = exportTrainedPipelineArgs ;
226+
227+ const isValid = await validateViamePath ( settings ) ;
228+ if ( isValid !== true ) {
229+ throw new Error ( isValid ) ;
230+ }
231+
232+ const exportPipelinePath = npath . join ( settings . viamePath , PipelineRelativeDir , "convert_to_onnx.pipe" ) ;
233+ if ( ! fs . existsSync ( npath . join ( exportPipelinePath ) ) ) {
234+ throw new Error ( "Your VIAME version doesn't support ONNX export. You have to update it to a newer version to be able to export models." ) ;
235+ }
236+
237+ const modelPipelineDir = npath . parse ( pipeline . pipe ) . dir ;
238+ let weightsPath : string ;
239+ if ( fs . existsSync ( npath . join ( modelPipelineDir , 'yolo.weights' ) ) ) {
240+ weightsPath = npath . join ( modelPipelineDir , 'yolo.weights' ) ;
241+ } else {
242+ throw new Error ( "Your pipeline has no trained weights (yolo.weights is missing)" ) ;
243+ }
244+
245+ const jobWorkDir = await createCustomWorkingDirectory ( settings , 'OnnxExport' , pipeline . name ) ;
246+
247+ const converterOutput = npath . join ( jobWorkDir , 'model.onnx' ) ;
248+ const joblog = npath . join ( jobWorkDir , 'runlog.txt' ) ;
249+
250+ const command = [
251+ `${ viameConstants . setupScriptAbs } &&` ,
252+ `"${ viameConstants . kwiverExe } " runner ${ exportPipelinePath } ` ,
253+ `-s "onnx_convert:model_path=${ weightsPath } "` ,
254+ `-s "onnx_convert:onnx_model_prefix=${ converterOutput } "` ,
255+ ] ;
256+
257+ const job = observeChild ( spawn ( command . join ( ' ' ) , {
258+ shell : viameConstants . shell ,
259+ cwd : jobWorkDir ,
260+ } ) ) ;
261+
262+ const jobBase : DesktopJob = {
263+ key : `pipeline_${ job . pid } _${ jobWorkDir } ` ,
264+ command : command . join ( ' ' ) ,
265+ jobType : 'export' ,
266+ pid : job . pid ,
267+ args : exportTrainedPipelineArgs ,
268+ title : `${ exportTrainedPipelineArgs . pipeline . name } to ONNX` ,
269+ workingDir : jobWorkDir ,
270+ datasetIds : [ ] ,
271+ exitCode : job . exitCode ,
272+ startTime : new Date ( ) ,
273+ } ;
274+
275+ fs . writeFile ( npath . join ( jobWorkDir , DiveJobManifestName ) , JSON . stringify ( jobBase , null , 2 ) ) ;
276+
277+ updater ( {
278+ ...jobBase ,
279+ body : [ '' ] ,
280+ } ) ;
281+
282+ job . stdout . on ( 'data' , jobFileEchoMiddleware ( jobBase , updater , joblog ) ) ;
283+ job . stderr . on ( 'data' , jobFileEchoMiddleware ( jobBase , updater , joblog ) ) ;
284+
285+ job . on ( 'exit' , async ( code ) => {
286+ if ( code === 0 ) {
287+ if ( fs . existsSync ( converterOutput ) ) {
288+ if ( fs . existsSync ( path ) ) {
289+ fs . unlinkSync ( path ) ;
290+ }
291+ // We move instead of copying because .onnx files can be huge
292+ fs . moveSync ( converterOutput , path ) ;
293+ } else {
294+ console . error ( "An error occured while creating the ONNX file." ) ;
295+ }
296+ }
297+ updater ( {
298+ ...jobBase ,
299+ body : [ '' ] ,
300+ exitCode : code ,
301+ endTime : new Date ( ) ,
302+ } ) ;
303+ } ) ;
304+
305+ return jobBase ;
306+ }
307+
215308/**
216309 * a node.js implementation of dive_tasks.tasks.run_training
217310 */
@@ -356,5 +449,6 @@ async function train(
356449
357450export {
358451 runPipeline ,
452+ exportTrainedPipeline ,
359453 train ,
360454} ;
0 commit comments