Skip to content

Commit dd2e334

Browse files
committed
Add Session::from_saved_model_to_bundle
1 parent 73ed289 commit dd2e334

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

src/session.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::ffi::CString;
44
use std::marker;
55
use std::path::Path;
66
use std::ptr;
7+
use super::{Buffer, BufferTrait};
78
use super::Code;
89
use super::DataType;
910
use super::Graph;
@@ -16,6 +17,15 @@ use super::Status;
1617
use super::Tensor;
1718
use super::TensorType;
1819

20+
/// Aggregation type for a saved model bundle.
21+
#[derive(Debug)]
22+
pub struct SavedModelBundle {
23+
/// The loaded session.
24+
pub session: Session,
25+
/// A meta graph defition as raw protocol buffer.
26+
pub meta_graph_def: Vec<u8>,
27+
}
28+
1929
/// Manages a single graph and execution.
2030
#[derive(Debug)]
2131
pub struct Session {
@@ -73,6 +83,51 @@ impl Session {
7383
}
7484
}
7585

86+
/// Loads a session from an exported model, creating a bundle
87+
pub fn from_saved_model_to_bundle<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>
88+
(options: &SessionOptions,
89+
tags: Tags,
90+
graph: &mut Graph,
91+
export_dir: P)
92+
-> Result<SavedModelBundle> {
93+
let mut status = Status::new();
94+
95+
let export_dir_cstr =
96+
try!(export_dir.as_ref()
97+
.to_str()
98+
.and_then(|s| CString::new(s.as_bytes()).ok())
99+
.ok_or_else(|| invalid_arg!("Invalid export directory path")));
100+
101+
let tags_cstr: Vec<_> = try!(tags.into_iter()
102+
.map(|t| CString::new(t.as_ref()))
103+
.collect::<::std::result::Result<_, _>>()
104+
.map_err(|_| invalid_arg!("Invalid tag name")));
105+
let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect();
106+
107+
// The empty TF_Buffer will be filled by LoadSessionFromSavedModel
108+
let mut meta = unsafe { Buffer::<u8>::from_ptr(ptr::null_mut(), 0) };
109+
110+
let inner = unsafe {
111+
tf::TF_LoadSessionFromSavedModel(options.inner,
112+
ptr::null(),
113+
export_dir_cstr.as_ptr(),
114+
tags_ptr.as_ptr(),
115+
tags_ptr.len() as c_int,
116+
graph.inner(),
117+
meta.inner_mut(),
118+
status.inner())
119+
};
120+
if inner.is_null() {
121+
Err(status)
122+
} else {
123+
let session = Session { inner: inner };
124+
Ok(SavedModelBundle {
125+
session: session,
126+
meta_graph_def: Vec::from(meta.as_ref())
127+
})
128+
}
129+
}
130+
76131
/// Closes the session.
77132
pub fn close(&mut self) -> Result<()> {
78133
let mut status = Status::new();

0 commit comments

Comments
 (0)