1use crate::cannon::{
2    Hint, HostProgram, Preimage, HINT_CLIENT_READ_FD, HINT_CLIENT_WRITE_FD,
3    PREIMAGE_CLIENT_READ_FD, PREIMAGE_CLIENT_WRITE_FD,
4};
5use command_fds::{CommandFdExt, FdMapping};
6use log::debug;
7use os_pipe::{PipeReader, PipeWriter};
8use std::{
9    io::{Read, Write},
10    os::fd::{AsRawFd, FromRawFd, OwnedFd},
11    process::{Child, Command},
12};
13
14pub struct PreImageOracle {
15    pub cmd: Command,
16    pub oracle_client: RW,
17    pub oracle_server: RW,
18    pub hint_client: RW,
19    pub hint_server: RW,
20}
21
22pub trait PreImageOracleT {
23    fn get_preimage(&mut self, key: [u8; 32]) -> Preimage;
24
25    fn hint(&mut self, hint: Hint);
26}
27
28pub struct ReadWrite<R, W> {
29    pub reader: R,
30    pub writer: W,
31}
32
33pub struct RW(pub ReadWrite<PipeReader, PipeWriter>);
34
35#[cfg(not(any(target_os = "ios", target_os = "macos", target_os = "haiku", windows)))]
41fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> {
42    let mut fds: [libc::c_int; 2] = [0; 2];
43    let res = unsafe { libc::pipe2(fds.as_mut_ptr(), libc::O_DIRECT) };
44    if res != 0 {
45        return Err(std::io::Error::last_os_error());
46    }
47    unsafe {
48        Ok((
49            PipeReader::from_raw_fd(fds[0]),
50            PipeWriter::from_raw_fd(fds[1]),
51        ))
52    }
53}
54
55#[cfg(any(target_os = "ios", target_os = "macos", target_os = "haiku"))]
56pub fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> {
57    let mut fds: [libc::c_int; 2] = [0; 2];
58    let res = unsafe { libc::pipe(fds.as_mut_ptr()) };
59    if res != 0 {
60        return Err(std::io::Error::last_os_error());
61    }
62    let res = unsafe { libc::fcntl(fds[0], libc::F_SETFD, 0) };
64    if res != 0 {
65        return Err(std::io::Error::last_os_error());
66    }
67    let res = unsafe { libc::fcntl(fds[1], libc::F_SETFD, 0) };
68    if res != 0 {
69        return Err(std::io::Error::last_os_error());
70    }
71    unsafe {
72        Ok((
73            PipeReader::from_raw_fd(fds[0]),
74            PipeWriter::from_raw_fd(fds[1]),
75        ))
76    }
77}
78
79#[cfg(windows)]
80pub fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> {
81    os_pipe::pipe()
82}
83
84pub fn create_bidirectional_channel() -> Option<(RW, RW)> {
94    let (ar, bw) = create_pipe().ok()?;
95    let (br, aw) = create_pipe().ok()?;
96    Some((
97        RW(ReadWrite {
98            reader: ar,
99            writer: aw,
100        }),
101        RW(ReadWrite {
102            reader: br,
103            writer: bw,
104        }),
105    ))
106}
107
108impl PreImageOracle {
109    pub fn create(host_program: HostProgram) -> PreImageOracle {
110        let mut cmd = Command::new(&host_program.name);
111        cmd.args(&host_program.arguments);
112
113        let (oracle_client, oracle_server) =
114            create_bidirectional_channel().expect("Could not create bidirectional oracle channel");
115        let (hint_client, hint_server) =
116            create_bidirectional_channel().expect("Could not create bidirectional hint channel");
117
118        cmd.fd_mappings(vec![
122            FdMapping {
123                parent_fd: unsafe { OwnedFd::from_raw_fd(hint_server.0.writer.as_raw_fd()) },
124                child_fd: HINT_CLIENT_WRITE_FD,
125            },
126            FdMapping {
127                parent_fd: unsafe { OwnedFd::from_raw_fd(hint_server.0.reader.as_raw_fd()) },
128                child_fd: HINT_CLIENT_READ_FD,
129            },
130            FdMapping {
131                parent_fd: unsafe { OwnedFd::from_raw_fd(oracle_server.0.writer.as_raw_fd()) },
132                child_fd: PREIMAGE_CLIENT_WRITE_FD,
133            },
134            FdMapping {
135                parent_fd: unsafe { OwnedFd::from_raw_fd(oracle_server.0.reader.as_raw_fd()) },
136                child_fd: PREIMAGE_CLIENT_READ_FD,
137            },
138        ])
139        .unwrap_or_else(|_| panic!("Could not map file descriptors to preimage server process"));
140
141        PreImageOracle {
142            cmd,
143            oracle_client,
144            oracle_server,
145            hint_client,
146            hint_server,
147        }
148    }
149
150    pub fn start(&mut self) -> Child {
151        self.cmd
153            .spawn()
154            .expect("Could not spawn pre-image oracle process")
155    }
156}
157
158pub struct NullPreImageOracle;
159
160impl PreImageOracleT for NullPreImageOracle {
161    fn get_preimage(&mut self, _key: [u8; 32]) -> Preimage {
162        panic!("No preimage oracle specified for preimage retrieval");
163    }
164
165    fn hint(&mut self, _hint: Hint) {
166        panic!("No preimage oracle specified for hints");
167    }
168}
169
170impl PreImageOracleT for PreImageOracle {
171    fn get_preimage(&mut self, key: [u8; 32]) -> Preimage {
180        let RW(ReadWrite { reader, writer }) = &mut self.oracle_client;
181
182        let r = writer.write_all(&key);
183        assert!(r.is_ok());
184        let r = writer.flush();
185        assert!(r.is_ok());
186
187        debug!("Reading response");
188        let mut buf = [0_u8; 8];
189        let resp = reader.read_exact(&mut buf);
190        assert!(resp.is_ok());
191
192        debug!("Extracting contents");
193        let length = u64::from_be_bytes(buf);
194        let mut preimage = vec![0_u8; length as usize];
195        let resp = reader.read_exact(&mut preimage);
196
197        assert!(resp.is_ok());
198
199        debug!(
200            "Got preimage of length {}\n {}",
201            preimage.len(),
202            hex::encode(&preimage)
203        );
204        assert_eq!(preimage.len(), length as usize);
206
207        Preimage::create(preimage)
208    }
209
210    fn hint(&mut self, hint: Hint) {
218        let RW(ReadWrite { reader, writer }) = &mut self.hint_client;
219
220        let mut hint_bytes = hint.get();
222        let hint_length = hint_bytes.len();
223
224        let mut msg: Vec<u8> = vec![];
225        msg.append(&mut u64::to_be_bytes(hint_length as u64).to_vec());
226        msg.append(&mut hint_bytes);
227
228        let _ = writer.write(&msg);
229
230        let mut buf = [0_u8];
232        let _ = reader.read_exact(&mut buf);
234    }
235}
236
237impl PreImageOracleT for Box<dyn PreImageOracleT> {
238    fn get_preimage(&mut self, key: [u8; 32]) -> Preimage {
239        self.as_mut().get_preimage(key)
240    }
241
242    fn hint(&mut self, hint: Hint) {
243        self.as_mut().hint(hint)
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
256    fn test_bidir_channels() {
257        let (mut c0, mut c1) = create_bidirectional_channel().unwrap();
258
259        let msg = [42_u8];
261        let mut buf = [0_u8; 1];
262
263        let writer_joiner = std::thread::spawn(move || {
264            let r = c0.0.writer.write(&msg);
265            assert!(r.is_ok());
266        });
267
268        let reader_joiner = std::thread::spawn(move || {
269            let r = c1.0.reader.read_exact(&mut buf);
270            assert!(r.is_ok());
271            buf
272        });
273
274        let buf = reader_joiner.join().unwrap();
276        writer_joiner.join().unwrap();
278
279        assert_eq!(msg, buf);
281
282        let msg2 = vec![42_u8, 43_u8];
288        let len = msg2.len() as u64;
289        let mut msg = u64::to_be_bytes(len).to_vec();
290        msg.extend_from_slice(&msg2);
291
292        let writer_joiner = std::thread::spawn(move || {
294            let r = c1.0.writer.write(&msg);
295            assert!(r.is_ok());
296            msg
297        });
298
299        let reader_joiner = std::thread::spawn(move || {
301            let mut response_vec = vec![];
302            let r = c0.0.reader.read_to_end(&mut response_vec);
304            assert!(r.is_ok());
305
306            let n = u64::from_be_bytes(response_vec[0..8].try_into().unwrap());
307
308            let data = response_vec[8..(n + 8) as usize].to_vec();
309            (n, data)
310        });
311
312        let (n, data) = reader_joiner.join().unwrap();
314
315        writer_joiner.join().unwrap();
317
318        assert_eq!(n, len);
320        assert_eq!(data, msg2);
321    }
322}