o1vm/
cannon.rs

1// Data structure and stuff for compatibility with Cannon
2
3use base64::{engine::general_purpose, Engine as _};
4
5use core::{
6    fmt,
7    fmt::{Display, Formatter},
8};
9use libflate::zlib::{Decoder, Encoder};
10use regex::Regex;
11use serde::{Deserialize, Deserializer, Serialize, Serializer};
12use std::io::{Read, Write};
13
14pub const PAGE_ADDRESS_SIZE: u32 = 12;
15pub const PAGE_SIZE: u32 = 1 << PAGE_ADDRESS_SIZE;
16pub const PAGE_ADDRESS_MASK: u32 = PAGE_SIZE - 1;
17
18#[derive(Serialize, Deserialize, Debug)]
19pub struct Page {
20    pub index: u32,
21    #[serde(deserialize_with = "from_base64", serialize_with = "to_base64")]
22    pub data: Vec<u8>,
23}
24
25fn from_base64<'de, D>(deserializer: D) -> Result<Vec<u8>, D::Error>
26where
27    D: Deserializer<'de>,
28{
29    let s: String = Deserialize::deserialize(deserializer)?;
30    let b64_decoded = general_purpose::STANDARD.decode(s).unwrap();
31    let mut decoder = Decoder::new(&b64_decoded[..]).unwrap();
32    let mut data = Vec::new();
33    decoder.read_to_end(&mut data).unwrap();
34    assert_eq!(data.len(), PAGE_SIZE as usize);
35    Ok(data)
36}
37
38fn to_base64<S>(v: &[u8], serializer: S) -> Result<S::Ok, S::Error>
39where
40    S: Serializer,
41{
42    let encoded_v = Vec::new();
43    let mut encoder = Encoder::new(encoded_v).unwrap();
44    encoder.write_all(v).unwrap();
45    let res = encoder.finish().into_result().unwrap();
46    let b64_encoded = general_purpose::STANDARD.encode(res);
47    serializer.serialize_str(&b64_encoded)
48}
49
50// The renaming below keeps compatibility with OP Cannon's state format
51#[derive(Serialize, Deserialize, Debug)]
52pub struct State {
53    pub memory: Vec<Page>,
54    #[serde(
55        rename = "preimageKey",
56        deserialize_with = "deserialize_preimage_key",
57        serialize_with = "serialize_preimage_key"
58    )]
59    pub preimage_key: [u8; 32],
60    #[serde(rename = "preimageOffset")]
61    pub preimage_offset: u32,
62    pub pc: u32,
63    #[serde(rename = "nextPC")]
64    pub next_pc: u32,
65    pub lo: u32,
66    pub hi: u32,
67    pub heap: u32,
68    pub exit: u8,
69    pub exited: bool,
70    pub step: u64,
71    pub registers: [u32; 32],
72    pub last_hint: Option<Vec<u8>>,
73    pub preimage: Option<Vec<u8>>,
74}
75
76#[derive(Debug, PartialEq, Eq)]
77pub struct ParsePreimageKeyError(String);
78
79#[derive(Debug, PartialEq)]
80pub struct PreimageKey(pub [u8; 32]);
81
82use std::str::FromStr;
83
84impl FromStr for PreimageKey {
85    type Err = ParsePreimageKeyError;
86
87    fn from_str(s: &str) -> Result<Self, Self::Err> {
88        let parts = s.split('x').collect::<Vec<&str>>();
89        let hex_value: &str = if parts.len() == 1 {
90            parts[0]
91        } else {
92            if parts.len() != 2 {
93                return Err(ParsePreimageKeyError(
94                    format!("Badly structured value to convert {s}").to_string(),
95                ));
96            };
97            parts[1]
98        };
99        // We only handle a hexadecimal representations of exactly 32 bytes (no auto-padding)
100        if hex_value.len() == 64 {
101            hex::decode(hex_value).map_or_else(
102                |_| {
103                    Err(ParsePreimageKeyError(
104                        format!("Could not hex decode {hex_value}").to_string(),
105                    ))
106                },
107                |h| {
108                    h.clone().try_into().map_or_else(
109                        |_| {
110                            Err(ParsePreimageKeyError(
111                                format!("Could not cast vector {:#?} into 32 bytes array", h)
112                                    .to_string(),
113                            ))
114                        },
115                        |res| Ok(PreimageKey(res)),
116                    )
117                },
118            )
119        } else {
120            Err(ParsePreimageKeyError(
121                format!("{hex_value} is not 32-bytes long").to_string(),
122            ))
123        }
124    }
125}
126
127fn deserialize_preimage_key<'de, D>(deserializer: D) -> Result<[u8; 32], D::Error>
128where
129    D: Deserializer<'de>,
130{
131    let s: String = Deserialize::deserialize(deserializer)?;
132    let p = PreimageKey::from_str(s.as_str())
133        .unwrap_or_else(|_| panic!("Parsing {s} as preimage key failed"));
134    Ok(p.0)
135}
136
137fn serialize_preimage_key<S>(v: &[u8], serializer: S) -> Result<S::Ok, S::Error>
138where
139    S: Serializer,
140{
141    let s: String = format!("0x{}", hex::encode(v));
142    serializer.serialize_str(&s)
143}
144
145#[derive(Clone, Debug, PartialEq)]
146pub enum StepFrequency {
147    Never,
148    Always,
149    Exactly(u64),
150    Every(u64),
151    Range(u64, Option<u64>),
152}
153
154impl FromStr for StepFrequency {
155    type Err = String;
156    // Simple parser for Cannon's "frequency format"
157    // A frequency input is either
158    // - never/always
159    // - =<n> (only at step n)
160    // - %<n> (every steps multiple of n)
161    // - n..[m] (from n on, until m excluded if specified, until the end otherwise)
162    fn from_str(s: &str) -> std::result::Result<StepFrequency, String> {
163        use StepFrequency::*;
164
165        let mod_re = Regex::new(r"^%([0-9]+)").unwrap();
166        let eq_re = Regex::new(r"^=([0-9]+)").unwrap();
167        let ival_re = Regex::new(r"^([0-9]+)..([0-9]+)?").unwrap();
168
169        match s {
170            "never" => Ok(Never),
171            "always" => Ok(Always),
172            s => {
173                if let Some(m) = mod_re.captures(s) {
174                    Ok(Every(m[1].parse::<u64>().unwrap()))
175                } else if let Some(m) = eq_re.captures(s) {
176                    Ok(Exactly(m[1].parse::<u64>().unwrap()))
177                } else if let Some(m) = ival_re.captures(s) {
178                    let lo = m[1].parse::<u64>().unwrap();
179                    let hi_opt = m.get(2).map(|x| x.as_str().parse::<u64>().unwrap());
180                    Ok(Range(lo, hi_opt))
181                } else {
182                    Err(format!("Unknown frequency format {}", s))
183                }
184            }
185        }
186    }
187}
188
189impl Display for State {
190    // A very debatable and incomplete, but serviceable, `to_string` implementation.
191    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
192        write!(f,
193            "memory_size (length): {}\nfirst page size: {}\npreimage key: {:#?}\npreimage offset:{}\npc: {}\nlo: {}\nhi: {}\nregisters:{:#?} ",
194            self.memory.len(),
195            self.memory[0].data.len(),
196            self.preimage_key,
197            self.preimage_offset,
198            self.pc,
199            self.lo,
200            self.hi,
201            self.registers
202        )
203    }
204}
205
206#[derive(Debug, Clone)]
207pub struct HostProgram {
208    pub name: String,
209    pub arguments: Vec<String>,
210}
211
212#[derive(Debug, Clone)]
213pub struct VmConfiguration {
214    pub input_state_file: String,
215    pub output_state_file: String,
216    pub metadata_file: Option<String>,
217    pub proof_at: StepFrequency,
218    pub stop_at: StepFrequency,
219    pub snapshot_state_at: StepFrequency,
220    pub info_at: StepFrequency,
221    pub proof_fmt: String,
222    pub snapshot_fmt: String,
223    pub pprof_cpu: bool,
224    pub halt_address: Option<u32>,
225    pub host: Option<HostProgram>,
226}
227
228impl Default for VmConfiguration {
229    fn default() -> Self {
230        VmConfiguration {
231            input_state_file: "state.json".to_string(),
232            output_state_file: "out.json".to_string(),
233            metadata_file: None,
234            proof_at: StepFrequency::Never,
235            stop_at: StepFrequency::Never,
236            snapshot_state_at: StepFrequency::Never,
237            info_at: StepFrequency::Never,
238            proof_fmt: "proof-%d.json".to_string(),
239            snapshot_fmt: "state-%d.json".to_string(),
240            pprof_cpu: false,
241            halt_address: None,
242            host: None,
243        }
244    }
245}
246
247#[derive(Debug, Clone)]
248pub struct Start {
249    pub time: std::time::Instant,
250    pub step: usize,
251}
252
253impl Start {
254    pub fn create(step: usize) -> Start {
255        Start {
256            time: std::time::Instant::now(),
257            step,
258        }
259    }
260}
261
262#[derive(Debug, PartialEq, Clone, Deserialize)]
263pub struct Symbol {
264    pub name: String,
265    pub start: u32,
266    pub size: usize,
267}
268
269#[derive(Debug, PartialEq, Clone, Deserialize)]
270pub struct Meta {
271    #[serde(deserialize_with = "filtered_ordered")]
272    pub symbols: Vec<Symbol>, // Needs to be in ascending order w.r.t start address
273}
274
275// Make sure that deserialized data are ordered in ascending order and that we
276// have removed 0-size symbols
277fn filtered_ordered<'de, D>(deserializer: D) -> Result<Vec<Symbol>, D::Error>
278where
279    D: Deserializer<'de>,
280{
281    let v: Vec<Symbol> = Deserialize::deserialize(deserializer)?;
282    let mut filtered: Vec<Symbol> = v.into_iter().filter(|e| e.size != 0).collect();
283    filtered.sort_by(|a, b| a.start.cmp(&b.start));
284    Ok(filtered)
285}
286
287impl Meta {
288    pub fn find_address_symbol(&self, address: u32) -> Option<String> {
289        use std::cmp::Ordering;
290
291        self.symbols
292            .binary_search_by(
293                |Symbol {
294                     start,
295                     size,
296                     name: _,
297                 }| {
298                    if address < *start {
299                        Ordering::Greater
300                    } else {
301                        let end = *start + *size as u32;
302                        if address >= end {
303                            Ordering::Less
304                        } else {
305                            Ordering::Equal
306                        }
307                    }
308                },
309            )
310            .map_or_else(|_| None, |idx| Some(self.symbols[idx].name.to_string()))
311    }
312}
313
314pub const HINT_CLIENT_READ_FD: i32 = 3;
315pub const HINT_CLIENT_WRITE_FD: i32 = 4;
316pub const PREIMAGE_CLIENT_READ_FD: i32 = 5;
317pub const PREIMAGE_CLIENT_WRITE_FD: i32 = 6;
318
319pub struct Preimage(Vec<u8>);
320
321impl Preimage {
322    pub fn create(v: Vec<u8>) -> Self {
323        Preimage(v)
324    }
325
326    pub fn get(self) -> Vec<u8> {
327        self.0
328    }
329}
330
331pub struct Hint(Vec<u8>);
332
333impl Hint {
334    pub fn create(v: Vec<u8>) -> Self {
335        Hint(v)
336    }
337
338    pub fn get(self) -> Vec<u8> {
339        self.0
340    }
341}
342
343#[cfg(test)]
344mod tests {
345
346    use super::*;
347    use std::{
348        fs::File,
349        io::{BufReader, Write},
350    };
351
352    #[test]
353    fn sp_parser() {
354        use StepFrequency::*;
355        assert_eq!(StepFrequency::from_str("never"), Ok(Never));
356        assert_eq!(StepFrequency::from_str("always"), Ok(Always));
357        assert_eq!(StepFrequency::from_str("=123"), Ok(Exactly(123)));
358        assert_eq!(StepFrequency::from_str("%123"), Ok(Every(123)));
359        assert_eq!(StepFrequency::from_str("1..3"), Ok(Range(1, Some(3))));
360        assert_eq!(StepFrequency::from_str("1.."), Ok(Range(1, None)));
361        assert!(StepFrequency::from_str("@123").is_err());
362    }
363
364    // This sample is a subset taken from a Cannon-generated "meta.json" file
365    // Interestingly, it contains 0-size symbols - there are removed by
366    // deserialization.
367    const META_SAMPLE: &str = r#"{
368  "symbols": [
369    {
370      "name": "go.go",
371      "start": 0,
372      "size": 0
373    },
374    {
375      "name": "internal/cpu.processOptions",
376      "start": 69632,
377      "size": 1872
378    },
379    {
380      "name": "runtime.text",
381      "start": 69632,
382      "size": 0
383    },  
384    {
385      "name": "runtime/internal/atomic.(*Uint8).Load",
386      "start": 71504,
387      "size": 28
388    },
389    {
390      "name": "runtime/internal/atomic.(*Uint8).Store",
391      "start": 71532,
392      "size": 28
393    },
394    {
395      "name": "runtime/internal/atomic.(*Uint8).And",
396      "start": 71560,
397      "size": 88
398    },
399    {
400      "name": "runtime/internal/atomic.(*Uint8).Or",
401      "start": 71648,
402      "size": 72
403    }]}"#;
404
405    fn deserialize_meta_sample() -> Meta {
406        serde_json::from_str::<Meta>(META_SAMPLE).unwrap()
407    }
408
409    #[test]
410    fn test_serialize_deserialize_page() {
411        let value: &str = r#"{"index":16,"data":"eJztlkFoE0EUht8k21ZEtFYFg1FCTW0qSGoTS6pFJU3TFlNI07TEQJHE1kJMmhwi1ihaRJCqiAdBKR5Ez4IXvQk5eBaP4iEWpAchV0Hoof5vd14SoQcvve0H/5s3O//OzuzMLHtvNBZVDkUNHLQLUdHugSTKINJgnDoNZB60+MhFBq63Q0G4LCFYQptZoKR9r0hpEc1r4bopy8WRtdptmCJqM+t89RHiY60Xc39M8b26XXUjHLdEbf4qdTyMIWvn9vnyxhTy7eBxGwvGoRWU23ASIqNE5MT4H2DslogOa/EY+f38LxiNKYyrEwW02sV9CJLfgdjnMOfLc0+6biMKHohJFLe2fqO0qLl4Hui0AfcB1H0EzEFTc73GtSfIBO0jnhvnDvpx5CLVIJoKoS7Ic59C2pdfoRpEe+KoC+J7CWnf8leqQf/CbcwbiHP2rcO3TuENfr+C9HcGYp+T15nXnMjdOl/JOyDtc3tUt9tDzto31AXprwuyfCc2SfVsohZ8j7ogPh4Lr7NT+fxV1Yv9pXJ11AXxHYUsX99aVfnWqkT11vcsvk8QnstWJD4EUr0Igt4HqodD0wdP59kIUkH76DvU9IXOXSfnr0tIBe1T5zlAJmrY+xHFICRIG+8p5Lq/YW+djt1tfX/S314ODV/67Wc6eOEZUkF8CxwavqWfSWo/9QWpoH2UhXjtHDhn+E6wzO+EIL4RnEk+nOzDnmWZayRYDyJ6BzkgE3Vjv5faYrjV9F6DuD/eMx+gxvlQlbnndMDdh1TA2G1sbGxsbGxsbGx2Co9Sqvk/2gL/r05DxlgRP8bZK0O50cJQPjMxO5HKhCOlQr8/sVy5uRTuD5RGKuXFaDgYSQ+E/LOlsZlEIZ8NBqKlcmby8mIpPOjPpWYmxwPF06lI+mpqPB+O35ou0l+FGHpe"}"#;
412        let decoded_page: Page = serde_json::from_str(value).unwrap();
413        let res = serde_json::to_string(&decoded_page).unwrap();
414        assert_eq!(res, value);
415    }
416
417    #[test]
418    fn test_preimage_key_serialisation() {
419        #[derive(Serialize, Deserialize)]
420        struct TestPreimageKeyStruct {
421            #[serde(
422                rename = "preimageKey",
423                deserialize_with = "deserialize_preimage_key",
424                serialize_with = "serialize_preimage_key"
425            )]
426            pub preimage_key: [u8; 32],
427        }
428
429        let preimage_key: &str = r#"{"preimageKey":"0x0000000000000000000000000000000000000000000000000000000000000000"}"#;
430        let s: TestPreimageKeyStruct = serde_json::from_str(preimage_key).unwrap();
431        let res = serde_json::to_string(&s).unwrap();
432        assert_eq!(preimage_key, res);
433    }
434
435    #[test]
436    fn test_meta_deserialize_from_file() {
437        let path = "meta_test.json";
438        let mut output =
439            File::create(path).unwrap_or_else(|_| panic!("Could not create file {path}"));
440        write!(output, "{}", META_SAMPLE)
441            .unwrap_or_else(|_| panic!("Could not write to file {path}"));
442
443        let input = File::open(path).unwrap_or_else(|_| panic!("Could not open file {path}"));
444        let buffered = BufReader::new(input);
445        let read: Meta = serde_json::from_reader(buffered)
446            .unwrap_or_else(|_| panic!("Failed to deserialize metadata from file {path}"));
447
448        let expected = Meta {
449            symbols: vec![
450                Symbol {
451                    name: "internal/cpu.processOptions".to_string(),
452                    start: 69632,
453                    size: 1872,
454                },
455                Symbol {
456                    name: "runtime/internal/atomic.(*Uint8).Load".to_string(),
457                    start: 71504,
458                    size: 28,
459                },
460                Symbol {
461                    name: "runtime/internal/atomic.(*Uint8).Store".to_string(),
462                    start: 71532,
463                    size: 28,
464                },
465                Symbol {
466                    name: "runtime/internal/atomic.(*Uint8).And".to_string(),
467                    start: 71560,
468                    size: 88,
469                },
470                Symbol {
471                    name: "runtime/internal/atomic.(*Uint8).Or".to_string(),
472                    start: 71648,
473                    size: 72,
474                },
475            ],
476        };
477
478        assert_eq!(read, expected);
479    }
480
481    #[test]
482    fn test_find_address_symbol() {
483        let meta = deserialize_meta_sample();
484
485        assert_eq!(
486            meta.find_address_symbol(69633),
487            Some("internal/cpu.processOptions".to_string())
488        );
489        assert_eq!(
490            meta.find_address_symbol(69632),
491            Some("internal/cpu.processOptions".to_string())
492        );
493        assert_eq!(meta.find_address_symbol(42), None);
494    }
495
496    #[test]
497    fn test_parse_preimagekey() {
498        assert_eq!(
499            PreimageKey::from_str(
500                "0x0000000000000000000000000000000000000000000000000000000000000000"
501            ),
502            Ok(PreimageKey([0; 32]))
503        );
504        assert_eq!(
505            PreimageKey::from_str(
506                "0x0000000000000000000000000000000000000000000000000000000000000001"
507            ),
508            Ok(PreimageKey([
509                0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
510                0, 0, 0, 1
511            ]))
512        );
513        assert!(PreimageKey::from_str("0x01").is_err());
514    }
515}