@@ -4,6 +4,7 @@ use std::ffi::CString;
4
4
use std:: marker;
5
5
use std:: path:: Path ;
6
6
use std:: ptr;
7
+ use super :: { Buffer , BufferTrait } ;
7
8
use super :: Code ;
8
9
use super :: DataType ;
9
10
use super :: Graph ;
@@ -16,6 +17,15 @@ use super::Status;
16
17
use super :: Tensor ;
17
18
use super :: TensorType ;
18
19
20
+ /// Aggregation type for a saved model bundle.
21
+ #[ derive( Debug ) ]
22
+ pub struct SavedModelBundle {
23
+ /// The loaded session.
24
+ pub session : Session ,
25
+ /// A meta graph defition as raw protocol buffer.
26
+ pub meta_graph_def : Vec < u8 > ,
27
+ }
28
+
19
29
/// Manages a single graph and execution.
20
30
#[ derive( Debug ) ]
21
31
pub struct Session {
@@ -73,6 +83,51 @@ impl Session {
73
83
}
74
84
}
75
85
86
+ /// Loads a session from an exported model, creating a bundle
87
+ pub fn from_saved_model_to_bundle < P : AsRef < Path > , Tag : AsRef < str > , Tags : IntoIterator < Item = Tag > >
88
+ ( options : & SessionOptions ,
89
+ tags : Tags ,
90
+ graph : & mut Graph ,
91
+ export_dir : P )
92
+ -> Result < SavedModelBundle > {
93
+ let mut status = Status :: new ( ) ;
94
+
95
+ let export_dir_cstr =
96
+ try!( export_dir. as_ref ( )
97
+ . to_str ( )
98
+ . and_then ( |s| CString :: new ( s. as_bytes ( ) ) . ok ( ) )
99
+ . ok_or_else ( || invalid_arg ! ( "Invalid export directory path" ) ) ) ;
100
+
101
+ let tags_cstr: Vec < _ > = try!( tags. into_iter ( )
102
+ . map ( |t| CString :: new ( t. as_ref ( ) ) )
103
+ . collect :: < :: std:: result:: Result < _ , _ > > ( )
104
+ . map_err ( |_| invalid_arg ! ( "Invalid tag name" ) ) ) ;
105
+ let tags_ptr: Vec < * const c_char > = tags_cstr. iter ( ) . map ( |t| t. as_ptr ( ) ) . collect ( ) ;
106
+
107
+ // The empty TF_Buffer will be filled by LoadSessionFromSavedModel
108
+ let mut meta = unsafe { Buffer :: < u8 > :: from_ptr ( ptr:: null_mut ( ) , 0 ) } ;
109
+
110
+ let inner = unsafe {
111
+ tf:: TF_LoadSessionFromSavedModel ( options. inner ,
112
+ ptr:: null ( ) ,
113
+ export_dir_cstr. as_ptr ( ) ,
114
+ tags_ptr. as_ptr ( ) ,
115
+ tags_ptr. len ( ) as c_int ,
116
+ graph. inner ( ) ,
117
+ meta. inner_mut ( ) ,
118
+ status. inner ( ) )
119
+ } ;
120
+ if inner. is_null ( ) {
121
+ Err ( status)
122
+ } else {
123
+ let session = Session { inner : inner } ;
124
+ Ok ( SavedModelBundle {
125
+ session : session,
126
+ meta_graph_def : Vec :: from ( meta. as_ref ( ) )
127
+ } )
128
+ }
129
+ }
130
+
76
131
/// Closes the session.
77
132
pub fn close ( & mut self ) -> Result < ( ) > {
78
133
let mut status = Status :: new ( ) ;
0 commit comments