use crate::cannon::{
Hint, HostProgram, Preimage, HINT_CLIENT_READ_FD, HINT_CLIENT_WRITE_FD,
PREIMAGE_CLIENT_READ_FD, PREIMAGE_CLIENT_WRITE_FD,
};
use command_fds::{CommandFdExt, FdMapping};
use log::debug;
use os_pipe::{PipeReader, PipeWriter};
use std::{
io::{Read, Write},
os::fd::{AsRawFd, FromRawFd, OwnedFd},
process::{Child, Command},
};
pub struct PreImageOracle {
pub cmd: Command,
pub oracle_client: RW,
pub oracle_server: RW,
pub hint_client: RW,
pub hint_server: RW,
}
pub trait PreImageOracleT {
fn get_preimage(&mut self, key: [u8; 32]) -> Preimage;
fn hint(&mut self, hint: Hint);
}
pub struct ReadWrite<R, W> {
pub reader: R,
pub writer: W,
}
pub struct RW(pub ReadWrite<PipeReader, PipeWriter>);
#[cfg(not(any(target_os = "ios", target_os = "macos", target_os = "haiku", windows)))]
fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> {
let mut fds: [libc::c_int; 2] = [0; 2];
let res = unsafe { libc::pipe2(fds.as_mut_ptr(), libc::O_DIRECT) };
if res != 0 {
return Err(std::io::Error::last_os_error());
}
unsafe {
Ok((
PipeReader::from_raw_fd(fds[0]),
PipeWriter::from_raw_fd(fds[1]),
))
}
}
#[cfg(any(target_os = "ios", target_os = "macos", target_os = "haiku"))]
pub fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> {
let mut fds: [libc::c_int; 2] = [0; 2];
let res = unsafe { libc::pipe(fds.as_mut_ptr()) };
if res != 0 {
return Err(std::io::Error::last_os_error());
}
let res = unsafe { libc::fcntl(fds[0], libc::F_SETFD, 0) };
if res != 0 {
return Err(std::io::Error::last_os_error());
}
let res = unsafe { libc::fcntl(fds[1], libc::F_SETFD, 0) };
if res != 0 {
return Err(std::io::Error::last_os_error());
}
unsafe {
Ok((
PipeReader::from_raw_fd(fds[0]),
PipeWriter::from_raw_fd(fds[1]),
))
}
}
#[cfg(windows)]
pub fn create_pipe() -> std::io::Result<(PipeReader, PipeWriter)> {
os_pipe::pipe()
}
pub fn create_bidirectional_channel() -> Option<(RW, RW)> {
let (ar, bw) = create_pipe().ok()?;
let (br, aw) = create_pipe().ok()?;
Some((
RW(ReadWrite {
reader: ar,
writer: aw,
}),
RW(ReadWrite {
reader: br,
writer: bw,
}),
))
}
impl PreImageOracle {
pub fn create(hp_opt: &Option<HostProgram>) -> PreImageOracle {
let host_program = hp_opt.as_ref().expect("No host program given");
let mut cmd = Command::new(&host_program.name);
cmd.args(&host_program.arguments);
let (oracle_client, oracle_server) =
create_bidirectional_channel().expect("Could not create bidirectional oracle channel");
let (hint_client, hint_server) =
create_bidirectional_channel().expect("Could not create bidirectional hint channel");
cmd.fd_mappings(vec![
FdMapping {
parent_fd: unsafe { OwnedFd::from_raw_fd(hint_server.0.writer.as_raw_fd()) },
child_fd: HINT_CLIENT_WRITE_FD,
},
FdMapping {
parent_fd: unsafe { OwnedFd::from_raw_fd(hint_server.0.reader.as_raw_fd()) },
child_fd: HINT_CLIENT_READ_FD,
},
FdMapping {
parent_fd: unsafe { OwnedFd::from_raw_fd(oracle_server.0.writer.as_raw_fd()) },
child_fd: PREIMAGE_CLIENT_WRITE_FD,
},
FdMapping {
parent_fd: unsafe { OwnedFd::from_raw_fd(oracle_server.0.reader.as_raw_fd()) },
child_fd: PREIMAGE_CLIENT_READ_FD,
},
])
.unwrap_or_else(|_| panic!("Could not map file descriptors to preimage server process"));
PreImageOracle {
cmd,
oracle_client,
oracle_server,
hint_client,
hint_server,
}
}
pub fn start(&mut self) -> Child {
self.cmd
.spawn()
.expect("Could not spawn pre-image oracle process")
}
}
impl PreImageOracleT for PreImageOracle {
fn get_preimage(&mut self, key: [u8; 32]) -> Preimage {
let RW(ReadWrite { reader, writer }) = &mut self.oracle_client;
let r = writer.write_all(&key);
assert!(r.is_ok());
let r = writer.flush();
assert!(r.is_ok());
debug!("Reading response");
let mut buf = [0_u8; 8];
let resp = reader.read_exact(&mut buf);
assert!(resp.is_ok());
debug!("Extracting contents");
let length = u64::from_be_bytes(buf);
let mut preimage = vec![0_u8; length as usize];
let resp = reader.read_exact(&mut preimage);
assert!(resp.is_ok());
debug!(
"Got preimage of length {}\n {}",
preimage.len(),
hex::encode(&preimage)
);
assert_eq!(preimage.len(), length as usize);
Preimage::create(preimage)
}
fn hint(&mut self, hint: Hint) {
let RW(ReadWrite { reader, writer }) = &mut self.hint_client;
let mut hint_bytes = hint.get();
let hint_length = hint_bytes.len();
let mut msg: Vec<u8> = vec![];
msg.append(&mut u64::to_be_bytes(hint_length as u64).to_vec());
msg.append(&mut hint_bytes);
let _ = writer.write(&msg);
let mut buf = [0_u8];
let _ = reader.read_exact(&mut buf);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bidir_channels() {
let (mut c0, mut c1) = create_bidirectional_channel().unwrap();
let msg = [42_u8];
let mut buf = [0_u8; 1];
let writer_joiner = std::thread::spawn(move || {
let r = c0.0.writer.write(&msg);
assert!(r.is_ok());
});
let reader_joiner = std::thread::spawn(move || {
let r = c1.0.reader.read_exact(&mut buf);
assert!(r.is_ok());
buf
});
let buf = reader_joiner.join().unwrap();
writer_joiner.join().unwrap();
assert_eq!(msg, buf);
let msg2 = vec![42_u8, 43_u8];
let len = msg2.len() as u64;
let mut msg = u64::to_be_bytes(len).to_vec();
msg.extend_from_slice(&msg2);
let writer_joiner = std::thread::spawn(move || {
let r = c1.0.writer.write(&msg);
assert!(r.is_ok());
msg
});
let reader_joiner = std::thread::spawn(move || {
let mut response_vec = vec![];
let r = c0.0.reader.read_to_end(&mut response_vec);
assert!(r.is_ok());
let n = u64::from_be_bytes(response_vec[0..8].try_into().unwrap());
let data = response_vec[8..(n + 8) as usize].to_vec();
(n, data)
});
let (n, data) = reader_joiner.join().unwrap();
writer_joiner.join().unwrap();
assert_eq!(n, len);
assert_eq!(data, msg2);
}
}