o1vm/
preimage_oracle.rs

1use crate::cannon::{
2    Hint, HostProgram, Preimage, HINT_CLIENT_READ_FD, HINT_CLIENT_WRITE_FD,
3    PREIMAGE_CLIENT_READ_FD, PREIMAGE_CLIENT_WRITE_FD,
4};
5use command_fds::{CommandFdExt, FdMapping};
6use log::debug;
7use os_pipe::{PipeReader, PipeWriter};
8use std::{
9    io::{Read, Write},
10    os::fd::{AsRawFd, FromRawFd, OwnedFd},
11    process::{Child, Command},
12};
13
14pub struct PreImageOracle {
15    pub cmd: Command,
16    pub oracle_client: RW,
17    pub oracle_server: RW,
18    pub hint_client: RW,
19    pub hint_server: RW,
20}
21
22pub trait PreImageOracleT {
23    fn get_preimage(&mut self, key: [u8; 32]) -> Preimage;
24
25    fn hint(&mut self, hint: Hint);
26}
27
28pub struct ReadWrite<R, W> {
29    pub reader: R,
30    pub writer: W,
31}
32
33pub struct RW(pub ReadWrite<PipeReader, PipeWriter>);
34
35// Here, we implement `os_pipe::pipe` in a way that allows us to pass flags. In particular, we
36// don't pass the `CLOEXEC` flag, because we want these pipes to survive an exec, and we set
37// `DIRECT` to handle writes as single atomic operations (up to splitting at the buffer size).
38// This fixes the IPC hangs. This is bad, but the hang is worse.
39
40#[cfg(not(any(target_os = "ios", target_os = "macos", target_os = "haiku", windows)))]
41fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> {
42    let mut fds: [libc::c_int; 2] = [0; 2];
43    let res = unsafe { libc::pipe2(fds.as_mut_ptr(), libc::O_DIRECT) };
44    if res != 0 {
45        return Err(std::io::Error::last_os_error());
46    }
47    unsafe {
48        Ok((
49            PipeReader::from_raw_fd(fds[0]),
50            PipeWriter::from_raw_fd(fds[1]),
51        ))
52    }
53}
54
55#[cfg(any(target_os = "ios", target_os = "macos", target_os = "haiku"))]
56pub fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> {
57    let mut fds: [libc::c_int; 2] = [0; 2];
58    let res = unsafe { libc::pipe(fds.as_mut_ptr()) };
59    if res != 0 {
60        return Err(std::io::Error::last_os_error());
61    }
62    // It appears that Mac and friends don't have DIRECT. Oh well. Don't use a Mac.
63    let res = unsafe { libc::fcntl(fds[0], libc::F_SETFD, 0) };
64    if res != 0 {
65        return Err(std::io::Error::last_os_error());
66    }
67    let res = unsafe { libc::fcntl(fds[1], libc::F_SETFD, 0) };
68    if res != 0 {
69        return Err(std::io::Error::last_os_error());
70    }
71    unsafe {
72        Ok((
73            PipeReader::from_raw_fd(fds[0]),
74            PipeWriter::from_raw_fd(fds[1]),
75        ))
76    }
77}
78
79#[cfg(windows)]
80pub fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> {
81    os_pipe::pipe()
82}
83
84// Create bidirectional channel between A and B
85//
86// Schematically we create 2 unidirectional pipes and creates 2 structures made
87// by taking the writer from one and the reader from the other.
88//
89//     A                     B
90//     |     ar  <---- bw    |
91//     |     aw  ----> br    |
92//
93pub fn create_bidirectional_channel() -> Option<(RW, RW)> {
94    let (ar, bw) = create_pipe().ok()?;
95    let (br, aw) = create_pipe().ok()?;
96    Some((
97        RW(ReadWrite {
98            reader: ar,
99            writer: aw,
100        }),
101        RW(ReadWrite {
102            reader: br,
103            writer: bw,
104        }),
105    ))
106}
107
108impl PreImageOracle {
109    pub fn create(host_program: HostProgram) -> PreImageOracle {
110        let mut cmd = Command::new(&host_program.name);
111        cmd.args(&host_program.arguments);
112
113        let (oracle_client, oracle_server) =
114            create_bidirectional_channel().expect("Could not create bidirectional oracle channel");
115        let (hint_client, hint_server) =
116            create_bidirectional_channel().expect("Could not create bidirectional hint channel");
117
118        // file descriptors 0, 1, 2 respectively correspond to the inherited stdin,
119        // stdout, stderr.
120        // We need to map 3, 4, 5, 6 in the child process
121        cmd.fd_mappings(vec![
122            FdMapping {
123                parent_fd: unsafe { OwnedFd::from_raw_fd(hint_server.0.writer.as_raw_fd()) },
124                child_fd: HINT_CLIENT_WRITE_FD,
125            },
126            FdMapping {
127                parent_fd: unsafe { OwnedFd::from_raw_fd(hint_server.0.reader.as_raw_fd()) },
128                child_fd: HINT_CLIENT_READ_FD,
129            },
130            FdMapping {
131                parent_fd: unsafe { OwnedFd::from_raw_fd(oracle_server.0.writer.as_raw_fd()) },
132                child_fd: PREIMAGE_CLIENT_WRITE_FD,
133            },
134            FdMapping {
135                parent_fd: unsafe { OwnedFd::from_raw_fd(oracle_server.0.reader.as_raw_fd()) },
136                child_fd: PREIMAGE_CLIENT_READ_FD,
137            },
138        ])
139        .unwrap_or_else(|_| panic!("Could not map file descriptors to preimage server process"));
140
141        PreImageOracle {
142            cmd,
143            oracle_client,
144            oracle_server,
145            hint_client,
146            hint_server,
147        }
148    }
149
150    pub fn start(&mut self) -> Child {
151        // Spawning inherits the current process's stdin/stdout/stderr descriptors
152        self.cmd
153            .spawn()
154            .expect("Could not spawn pre-image oracle process")
155    }
156}
157
158pub struct NullPreImageOracle;
159
160impl PreImageOracleT for NullPreImageOracle {
161    fn get_preimage(&mut self, _key: [u8; 32]) -> Preimage {
162        panic!("No preimage oracle specified for preimage retrieval");
163    }
164
165    fn hint(&mut self, _hint: Hint) {
166        panic!("No preimage oracle specified for hints");
167    }
168}
169
170impl PreImageOracleT for PreImageOracle {
171    // The preimage protocol goes as follows
172    // 1. Ask for data through a key
173    // 2. Get the answers in the following format
174    //      +------------+--------------------+
175    //      | length <8> | pre-image <length> |
176    //      +---------------------------------+
177    //   a. a 64-bit integer indicating the length of the actual data
178    //   b. the preimage data, with a size of <length> bits
179    fn get_preimage(&mut self, key: [u8; 32]) -> Preimage {
180        let RW(ReadWrite { reader, writer }) = &mut self.oracle_client;
181
182        let r = writer.write_all(&key);
183        assert!(r.is_ok());
184        let r = writer.flush();
185        assert!(r.is_ok());
186
187        debug!("Reading response");
188        let mut buf = [0_u8; 8];
189        let resp = reader.read_exact(&mut buf);
190        assert!(resp.is_ok());
191
192        debug!("Extracting contents");
193        let length = u64::from_be_bytes(buf);
194        let mut preimage = vec![0_u8; length as usize];
195        let resp = reader.read_exact(&mut preimage);
196
197        assert!(resp.is_ok());
198
199        debug!(
200            "Got preimage of length {}\n {}",
201            preimage.len(),
202            hex::encode(&preimage)
203        );
204        // We should have read exactly <length> bytes
205        assert_eq!(preimage.len(), length as usize);
206
207        Preimage::create(preimage)
208    }
209
210    // The hint protocol goes as follows:
211    // 1. Write a hint request with the following byte-stream format
212    //       +------------+---------------+
213    //       | length <8> | hint <length> |
214    //       +----------------------------+
215    //
216    // 2. Get back a single ack byte informing the hint has been processed.
217    fn hint(&mut self, hint: Hint) {
218        let RW(ReadWrite { reader, writer }) = &mut self.hint_client;
219
220        // Write hint request
221        let mut hint_bytes = hint.get();
222        let hint_length = hint_bytes.len();
223
224        let mut msg: Vec<u8> = vec![];
225        msg.append(&mut u64::to_be_bytes(hint_length as u64).to_vec());
226        msg.append(&mut hint_bytes);
227
228        let _ = writer.write(&msg);
229
230        // Read single byte acknowledgment response
231        let mut buf = [0_u8];
232        // And do nothing with it anyway
233        let _ = reader.read_exact(&mut buf);
234    }
235}
236
237impl PreImageOracleT for Box<dyn PreImageOracleT> {
238    fn get_preimage(&mut self, key: [u8; 32]) -> Preimage {
239        self.as_mut().get_preimage(key)
240    }
241
242    fn hint(&mut self, hint: Hint) {
243        self.as_mut().hint(hint)
244    }
245}
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    // Test that bidirectional channels work as expected
252    // That is, after creating a pair (c0, c1)
253    // 1. c1's reader can read what c0's writer produces
254    // 2. c0's reader can read what c1's writer produces
255    #[test]
256    fn test_bidir_channels() {
257        let (mut c0, mut c1) = create_bidirectional_channel().unwrap();
258
259        // Send a single byte message
260        let msg = [42_u8];
261        let mut buf = [0_u8; 1];
262
263        let writer_joiner = std::thread::spawn(move || {
264            let r = c0.0.writer.write(&msg);
265            assert!(r.is_ok());
266        });
267
268        let reader_joiner = std::thread::spawn(move || {
269            let r = c1.0.reader.read_exact(&mut buf);
270            assert!(r.is_ok());
271            buf
272        });
273
274        // Retrieve the buffer from the reader
275        let buf = reader_joiner.join().unwrap();
276        // Ensure that the writer has completed
277        writer_joiner.join().unwrap();
278
279        // Check that we correctly read the single byte message
280        assert_eq!(msg, buf);
281
282        // Create a more structured message with the preimage format
283        //      +------------+--------------------+
284        //      | length <8> | pre-image <length> |
285        //      +---------------------------------+
286        //   Here we'll use a length of 2
287        let msg2 = vec![42_u8, 43_u8];
288        let len = msg2.len() as u64;
289        let mut msg = u64::to_be_bytes(len).to_vec();
290        msg.extend_from_slice(&msg2);
291
292        // Write the message
293        let writer_joiner = std::thread::spawn(move || {
294            let r = c1.0.writer.write(&msg);
295            assert!(r.is_ok());
296            msg
297        });
298
299        // Read back the length from the other end of the bidirectionnal channel
300        let reader_joiner = std::thread::spawn(move || {
301            let mut response_vec = vec![];
302            // We do a single read to mirror go, even though we should *really* do 2. 'Go' figure.
303            let r = c0.0.reader.read_to_end(&mut response_vec);
304            assert!(r.is_ok());
305
306            let n = u64::from_be_bytes(response_vec[0..8].try_into().unwrap());
307
308            let data = response_vec[8..(n + 8) as usize].to_vec();
309            (n, data)
310        });
311
312        // Retrieve the data from the reader
313        let (n, data) = reader_joiner.join().unwrap();
314
315        // Ensure that the writer has completed
316        writer_joiner.join().unwrap();
317
318        // Check that the responses are equal
319        assert_eq!(n, len);
320        assert_eq!(data, msg2);
321    }
322}