1use crate::cannon::{
2 Hint, HostProgram, Preimage, HINT_CLIENT_READ_FD, HINT_CLIENT_WRITE_FD,
3 PREIMAGE_CLIENT_READ_FD, PREIMAGE_CLIENT_WRITE_FD,
4};
5use command_fds::{CommandFdExt, FdMapping};
6use log::debug;
7use os_pipe::{PipeReader, PipeWriter};
8use std::{
9 io::{Read, Write},
10 os::fd::{AsRawFd, FromRawFd, OwnedFd},
11 process::{Child, Command},
12};
13
14pub struct PreImageOracle {
15 pub cmd: Command,
16 pub oracle_client: RW,
17 pub oracle_server: RW,
18 pub hint_client: RW,
19 pub hint_server: RW,
20}
21
22pub trait PreImageOracleT {
23 fn get_preimage(&mut self, key: [u8; 32]) -> Preimage;
24
25 fn hint(&mut self, hint: Hint);
26}
27
28pub struct ReadWrite<R, W> {
29 pub reader: R,
30 pub writer: W,
31}
32
33pub struct RW(pub ReadWrite<PipeReader, PipeWriter>);
34
35#[cfg(not(any(target_os = "ios", target_os = "macos", target_os = "haiku", windows)))]
41fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> {
42 let mut fds: [libc::c_int; 2] = [0; 2];
43 let res = unsafe { libc::pipe2(fds.as_mut_ptr(), libc::O_DIRECT) };
44 if res != 0 {
45 return Err(std::io::Error::last_os_error());
46 }
47 unsafe {
48 Ok((
49 PipeReader::from_raw_fd(fds[0]),
50 PipeWriter::from_raw_fd(fds[1]),
51 ))
52 }
53}
54
55#[cfg(any(target_os = "ios", target_os = "macos", target_os = "haiku"))]
56pub fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> {
57 let mut fds: [libc::c_int; 2] = [0; 2];
58 let res = unsafe { libc::pipe(fds.as_mut_ptr()) };
59 if res != 0 {
60 return Err(std::io::Error::last_os_error());
61 }
62 let res = unsafe { libc::fcntl(fds[0], libc::F_SETFD, 0) };
64 if res != 0 {
65 return Err(std::io::Error::last_os_error());
66 }
67 let res = unsafe { libc::fcntl(fds[1], libc::F_SETFD, 0) };
68 if res != 0 {
69 return Err(std::io::Error::last_os_error());
70 }
71 unsafe {
72 Ok((
73 PipeReader::from_raw_fd(fds[0]),
74 PipeWriter::from_raw_fd(fds[1]),
75 ))
76 }
77}
78
79#[cfg(windows)]
80pub fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> {
81 os_pipe::pipe()
82}
83
84pub fn create_bidirectional_channel() -> Option<(RW, RW)> {
94 let (ar, bw) = create_pipe().ok()?;
95 let (br, aw) = create_pipe().ok()?;
96 Some((
97 RW(ReadWrite {
98 reader: ar,
99 writer: aw,
100 }),
101 RW(ReadWrite {
102 reader: br,
103 writer: bw,
104 }),
105 ))
106}
107
108impl PreImageOracle {
109 pub fn create(host_program: HostProgram) -> PreImageOracle {
110 let mut cmd = Command::new(&host_program.name);
111 cmd.args(&host_program.arguments);
112
113 let (oracle_client, oracle_server) =
114 create_bidirectional_channel().expect("Could not create bidirectional oracle channel");
115 let (hint_client, hint_server) =
116 create_bidirectional_channel().expect("Could not create bidirectional hint channel");
117
118 cmd.fd_mappings(vec![
122 FdMapping {
123 parent_fd: unsafe { OwnedFd::from_raw_fd(hint_server.0.writer.as_raw_fd()) },
124 child_fd: HINT_CLIENT_WRITE_FD,
125 },
126 FdMapping {
127 parent_fd: unsafe { OwnedFd::from_raw_fd(hint_server.0.reader.as_raw_fd()) },
128 child_fd: HINT_CLIENT_READ_FD,
129 },
130 FdMapping {
131 parent_fd: unsafe { OwnedFd::from_raw_fd(oracle_server.0.writer.as_raw_fd()) },
132 child_fd: PREIMAGE_CLIENT_WRITE_FD,
133 },
134 FdMapping {
135 parent_fd: unsafe { OwnedFd::from_raw_fd(oracle_server.0.reader.as_raw_fd()) },
136 child_fd: PREIMAGE_CLIENT_READ_FD,
137 },
138 ])
139 .unwrap_or_else(|_| panic!("Could not map file descriptors to preimage server process"));
140
141 PreImageOracle {
142 cmd,
143 oracle_client,
144 oracle_server,
145 hint_client,
146 hint_server,
147 }
148 }
149
150 pub fn start(&mut self) -> Child {
151 self.cmd
153 .spawn()
154 .expect("Could not spawn pre-image oracle process")
155 }
156}
157
158pub struct NullPreImageOracle;
159
160impl PreImageOracleT for NullPreImageOracle {
161 fn get_preimage(&mut self, _key: [u8; 32]) -> Preimage {
162 panic!("No preimage oracle specified for preimage retrieval");
163 }
164
165 fn hint(&mut self, _hint: Hint) {
166 panic!("No preimage oracle specified for hints");
167 }
168}
169
170impl PreImageOracleT for PreImageOracle {
171 fn get_preimage(&mut self, key: [u8; 32]) -> Preimage {
180 let RW(ReadWrite { reader, writer }) = &mut self.oracle_client;
181
182 let r = writer.write_all(&key);
183 assert!(r.is_ok());
184 let r = writer.flush();
185 assert!(r.is_ok());
186
187 debug!("Reading response");
188 let mut buf = [0_u8; 8];
189 let resp = reader.read_exact(&mut buf);
190 assert!(resp.is_ok());
191
192 debug!("Extracting contents");
193 let length = u64::from_be_bytes(buf);
194 let mut preimage = vec![0_u8; length as usize];
195 let resp = reader.read_exact(&mut preimage);
196
197 assert!(resp.is_ok());
198
199 debug!(
200 "Got preimage of length {}\n {}",
201 preimage.len(),
202 hex::encode(&preimage)
203 );
204 assert_eq!(preimage.len(), length as usize);
206
207 Preimage::create(preimage)
208 }
209
210 fn hint(&mut self, hint: Hint) {
218 let RW(ReadWrite { reader, writer }) = &mut self.hint_client;
219
220 let mut hint_bytes = hint.get();
222 let hint_length = hint_bytes.len();
223
224 let mut msg: Vec<u8> = vec![];
225 msg.append(&mut u64::to_be_bytes(hint_length as u64).to_vec());
226 msg.append(&mut hint_bytes);
227
228 let _ = writer.write(&msg);
229
230 let mut buf = [0_u8];
232 let _ = reader.read_exact(&mut buf);
234 }
235}
236
237impl PreImageOracleT for Box<dyn PreImageOracleT> {
238 fn get_preimage(&mut self, key: [u8; 32]) -> Preimage {
239 self.as_mut().get_preimage(key)
240 }
241
242 fn hint(&mut self, hint: Hint) {
243 self.as_mut().hint(hint)
244 }
245}
246
247#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
256 fn test_bidir_channels() {
257 let (mut c0, mut c1) = create_bidirectional_channel().unwrap();
258
259 let msg = [42_u8];
261 let mut buf = [0_u8; 1];
262
263 let writer_joiner = std::thread::spawn(move || {
264 let r = c0.0.writer.write(&msg);
265 assert!(r.is_ok());
266 });
267
268 let reader_joiner = std::thread::spawn(move || {
269 let r = c1.0.reader.read_exact(&mut buf);
270 assert!(r.is_ok());
271 buf
272 });
273
274 let buf = reader_joiner.join().unwrap();
276 writer_joiner.join().unwrap();
278
279 assert_eq!(msg, buf);
281
282 let msg2 = vec![42_u8, 43_u8];
288 let len = msg2.len() as u64;
289 let mut msg = u64::to_be_bytes(len).to_vec();
290 msg.extend_from_slice(&msg2);
291
292 let writer_joiner = std::thread::spawn(move || {
294 let r = c1.0.writer.write(&msg);
295 assert!(r.is_ok());
296 msg
297 });
298
299 let reader_joiner = std::thread::spawn(move || {
301 let mut response_vec = vec![];
302 let r = c0.0.reader.read_to_end(&mut response_vec);
304 assert!(r.is_ok());
305
306 let n = u64::from_be_bytes(response_vec[0..8].try_into().unwrap());
307
308 let data = response_vec[8..(n + 8) as usize].to_vec();
309 (n, data)
310 });
311
312 let (n, data) = reader_joiner.join().unwrap();
314
315 writer_joiner.join().unwrap();
317
318 assert_eq!(n, len);
320 assert_eq!(data, msg2);
321 }
322}