o1vm/interpreters/riscv32im/
witness.rs

1// TODO: do we want to be more restrictive and refer to the number of accesses
2//       to the SAME register/memory addrss?
3use super::{
4    column::Column,
5    interpreter::{
6        self, IInstruction, Instruction, InterpreterEnv, MInstruction, RInstruction, SBInstruction,
7        SInstruction, SyscallInstruction, UInstruction, UJInstruction,
8    },
9    registers::Registers,
10    INSTRUCTION_SET_SIZE, SCRATCH_SIZE, SCRATCH_SIZE_INVERSE,
11};
12use crate::{
13    cannon::{State, PAGE_ADDRESS_MASK, PAGE_ADDRESS_SIZE, PAGE_SIZE},
14    lookups::Lookup,
15};
16use ark_ff::Field;
17use std::array;
18
19/// Maximum number of register accesses per instruction (based on demo)
20// FIXME: can be different
21pub const MAX_NB_REG_ACC: u64 = 7;
22/// Maximum number of memory accesses per instruction (based on demo)
23// FIXME: can be different
24pub const MAX_NB_MEM_ACC: u64 = 12;
25/// Maximum number of memory or register accesses per instruction
26pub const MAX_ACC: u64 = MAX_NB_REG_ACC + MAX_NB_MEM_ACC;
27
28pub const NUM_GLOBAL_LOOKUP_TERMS: usize = 1;
29pub const NUM_DECODING_LOOKUP_TERMS: usize = 2;
30pub const NUM_INSTRUCTION_LOOKUP_TERMS: usize = 5;
31pub const NUM_LOOKUP_TERMS: usize =
32    NUM_GLOBAL_LOOKUP_TERMS + NUM_DECODING_LOOKUP_TERMS + NUM_INSTRUCTION_LOOKUP_TERMS;
33
34/// This structure represents the environment the virtual machine state will use
35/// to transition. This environment will be used by the interpreter. The virtual
36/// machine has access to its internal state and some external memory. In
37/// addition to that, it has access to the environment of the Keccak interpreter
38/// that is used to verify the preimage requested during the execution.
39pub struct Env<Fp> {
40    pub instruction_counter: u64,
41    pub memory: Vec<(u32, Vec<u8>)>,
42    pub last_memory_accesses: [usize; 3],
43    pub memory_write_index: Vec<(u32, Vec<u64>)>,
44    pub last_memory_write_index_accesses: [usize; 3],
45    pub registers: Registers<u32>,
46    pub registers_write_index: Registers<u64>,
47    pub scratch_state_idx: usize,
48    pub scratch_state: [Fp; SCRATCH_SIZE],
49    pub scratch_state_inverse_idx: usize,
50    pub scratch_state_inverse: [Fp; SCRATCH_SIZE_INVERSE],
51    pub halt: bool,
52    pub selector: usize,
53}
54
55fn fresh_scratch_state<Fp: Field, const N: usize>() -> [Fp; N] {
56    array::from_fn(|_| Fp::zero())
57}
58
59impl<Fp: Field> InterpreterEnv for Env<Fp> {
60    type Position = Column;
61
62    fn alloc_scratch(&mut self) -> Self::Position {
63        let scratch_idx = self.scratch_state_idx;
64        self.scratch_state_idx += 1;
65        Column::ScratchState(scratch_idx)
66    }
67
68    fn alloc_scratch_inverse(&mut self) -> Self::Position {
69        let scratch_inverse_idx = self.scratch_state_inverse_idx;
70        self.scratch_state_inverse_idx += 1;
71        Column::ScratchStateInverse(scratch_inverse_idx)
72    }
73
74    type Variable = u64;
75
76    fn variable(&self, _column: Self::Position) -> Self::Variable {
77        todo!()
78    }
79
80    fn add_constraint(&mut self, _assert_equals_zero: Self::Variable) {
81        // No-op for witness
82        // Do not assert that _assert_equals_zero is zero here!
83        // Some variables may have placeholders that do not faithfully
84        // represent the underlying values.
85    }
86
87    fn activate_selector(&mut self, instruction: Instruction) {
88        self.selector = instruction.into();
89    }
90
91    fn check_is_zero(assert_equals_zero: &Self::Variable) {
92        assert_eq!(*assert_equals_zero, 0);
93    }
94
95    fn check_equal(x: &Self::Variable, y: &Self::Variable) {
96        assert_eq!(*x, *y);
97    }
98
99    fn assert_boolean(&mut self, x: &Self::Variable) {
100        if *x != 0 && *x != 1 {
101            panic!("The value {} is not a boolean", *x);
102        }
103    }
104
105    fn add_lookup(&mut self, _lookup: Lookup<Self::Variable>) {
106        // No-op, constraints only
107        // TODO: keep track of multiplicities of fixed tables here as in Keccak?
108    }
109
110    fn instruction_counter(&self) -> Self::Variable {
111        self.instruction_counter
112    }
113
114    fn increase_instruction_counter(&mut self) {
115        self.instruction_counter += 1;
116    }
117
118    unsafe fn fetch_register(
119        &mut self,
120        idx: &Self::Variable,
121        output: Self::Position,
122    ) -> Self::Variable {
123        let res = self.registers[*idx as usize] as u64;
124        self.write_column(output, res);
125        res
126    }
127
128    unsafe fn push_register_if(
129        &mut self,
130        idx: &Self::Variable,
131        value: Self::Variable,
132        if_is_true: &Self::Variable,
133    ) {
134        let value: u32 = value.try_into().unwrap();
135        if *if_is_true == 1 {
136            self.registers[*idx as usize] = value
137        } else if *if_is_true == 0 {
138            // No-op
139        } else {
140            panic!("Bad value for flag in push_register: {}", *if_is_true);
141        }
142    }
143
144    unsafe fn fetch_register_access(
145        &mut self,
146        idx: &Self::Variable,
147        output: Self::Position,
148    ) -> Self::Variable {
149        let res = self.registers_write_index[*idx as usize];
150        self.write_column(output, res);
151        res
152    }
153
154    unsafe fn push_register_access_if(
155        &mut self,
156        idx: &Self::Variable,
157        value: Self::Variable,
158        if_is_true: &Self::Variable,
159    ) {
160        if *if_is_true == 1 {
161            self.registers_write_index[*idx as usize] = value
162        } else if *if_is_true == 0 {
163            // No-op
164        } else {
165            panic!("Bad value for flag in push_register: {}", *if_is_true);
166        }
167    }
168
169    unsafe fn fetch_memory(
170        &mut self,
171        addr: &Self::Variable,
172        output: Self::Position,
173    ) -> Self::Variable {
174        let addr: u32 = (*addr).try_into().unwrap();
175        let page = addr >> PAGE_ADDRESS_SIZE;
176        let page_address = (addr & PAGE_ADDRESS_MASK) as usize;
177        let memory_page_idx = self.get_memory_page_index(page);
178        let value = self.memory[memory_page_idx].1[page_address];
179        self.write_column(output, value.into());
180        value.into()
181    }
182
183    unsafe fn push_memory(&mut self, addr: &Self::Variable, value: Self::Variable) {
184        let addr: u32 = (*addr).try_into().unwrap();
185        let page = addr >> PAGE_ADDRESS_SIZE;
186        let page_address = (addr & PAGE_ADDRESS_MASK) as usize;
187        let memory_page_idx = self.get_memory_page_index(page);
188        self.memory[memory_page_idx].1[page_address] =
189            value.try_into().expect("push_memory values fit in a u8");
190    }
191
192    unsafe fn fetch_memory_access(
193        &mut self,
194        addr: &Self::Variable,
195        output: Self::Position,
196    ) -> Self::Variable {
197        let addr: u32 = (*addr).try_into().unwrap();
198        let page = addr >> PAGE_ADDRESS_SIZE;
199        let page_address = (addr & PAGE_ADDRESS_MASK) as usize;
200        let memory_write_index_page_idx = self.get_memory_access_page_index(page);
201        let value = self.memory_write_index[memory_write_index_page_idx].1[page_address];
202        self.write_column(output, value);
203        value
204    }
205
206    unsafe fn push_memory_access(&mut self, addr: &Self::Variable, value: Self::Variable) {
207        let addr = *addr as u32;
208        let page = addr >> PAGE_ADDRESS_SIZE;
209        let page_address = (addr & PAGE_ADDRESS_MASK) as usize;
210        let memory_write_index_page_idx = self.get_memory_access_page_index(page);
211        self.memory_write_index[memory_write_index_page_idx].1[page_address] = value;
212    }
213
214    fn constant(x: u32) -> Self::Variable {
215        x as u64
216    }
217
218    unsafe fn bitmask(
219        &mut self,
220        x: &Self::Variable,
221        highest_bit: u32,
222        lowest_bit: u32,
223        position: Self::Position,
224    ) -> Self::Variable {
225        assert!(
226            lowest_bit < highest_bit,
227            "The lowest bit must be strictly lower than the highest bit"
228        );
229        assert!(
230            highest_bit <= 32,
231            "The interpreter is for a 32bits architecture"
232        );
233        let x: u32 = (*x).try_into().unwrap();
234        let res = (x >> lowest_bit) & ((1 << (highest_bit - lowest_bit)) - 1);
235        let res = res as u64;
236        self.write_column(position, res);
237        res
238    }
239
240    unsafe fn shift_left(
241        &mut self,
242        x: &Self::Variable,
243        by: &Self::Variable,
244        position: Self::Position,
245    ) -> Self::Variable {
246        let x: u32 = (*x).try_into().unwrap();
247        let by: u32 = (*by).try_into().unwrap();
248        let res = x << by;
249        let res = res as u64;
250        self.write_column(position, res);
251        res
252    }
253
254    unsafe fn shift_right(
255        &mut self,
256        x: &Self::Variable,
257        by: &Self::Variable,
258        position: Self::Position,
259    ) -> Self::Variable {
260        let x: u32 = (*x).try_into().unwrap();
261        let by: u32 = (*by).try_into().unwrap();
262        let res = x >> by;
263        let res = res as u64;
264        self.write_column(position, res);
265        res
266    }
267
268    unsafe fn shift_right_arithmetic(
269        &mut self,
270        x: &Self::Variable,
271        by: &Self::Variable,
272        position: Self::Position,
273    ) -> Self::Variable {
274        let x: u32 = (*x).try_into().unwrap();
275        let by: u32 = (*by).try_into().unwrap();
276        let res = ((x as i32) >> by) as u32;
277        let res = res as u64;
278        self.write_column(position, res);
279        res
280    }
281
282    unsafe fn test_zero(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable {
283        let res = if *x == 0 { 1 } else { 0 };
284        self.write_column(position, res);
285        res
286    }
287
288    fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable {
289        // write the result
290        let pos = self.alloc_scratch();
291        let res = if *x == 0 { 1 } else { 0 };
292        self.write_column(pos, res);
293        // write the non deterministic advice inv_or_zero
294        let pos = self.alloc_scratch_inverse();
295        if *x == 0 {
296            self.write_field_column(pos, Fp::zero());
297        } else {
298            self.write_field_column(pos, Fp::from(*x));
299        };
300        // return the result
301        res
302    }
303
304    fn equal(&mut self, x: &Self::Variable, y: &Self::Variable) -> Self::Variable {
305        // We replicate is_zero(x-y), but working on field elt,
306        // to avoid subtraction overflow in the witness interpreter for u32
307        let to_zero_test = Fp::from(*x) - Fp::from(*y);
308        let res = {
309            let pos = self.alloc_scratch();
310            let is_zero: u64 = if to_zero_test == Fp::zero() { 1 } else { 0 };
311            self.write_column(pos, is_zero);
312            is_zero
313        };
314        let pos = self.alloc_scratch_inverse();
315        if to_zero_test == Fp::zero() {
316            self.write_field_column(pos, Fp::zero());
317        } else {
318            self.write_field_column(pos, to_zero_test);
319        };
320        res
321    }
322
323    unsafe fn test_less_than(
324        &mut self,
325        x: &Self::Variable,
326        y: &Self::Variable,
327        position: Self::Position,
328    ) -> Self::Variable {
329        let x: u32 = (*x).try_into().unwrap();
330        let y: u32 = (*y).try_into().unwrap();
331        let res = if x < y { 1 } else { 0 };
332        let res = res as u64;
333        self.write_column(position, res);
334        res
335    }
336
337    unsafe fn test_less_than_signed(
338        &mut self,
339        x: &Self::Variable,
340        y: &Self::Variable,
341        position: Self::Position,
342    ) -> Self::Variable {
343        let x: u32 = (*x).try_into().unwrap();
344        let y: u32 = (*y).try_into().unwrap();
345        let res = if (x as i32) < (y as i32) { 1 } else { 0 };
346        let res = res as u64;
347        self.write_column(position, res);
348        res
349    }
350
351    unsafe fn and_witness(
352        &mut self,
353        x: &Self::Variable,
354        y: &Self::Variable,
355        position: Self::Position,
356    ) -> Self::Variable {
357        let x: u32 = (*x).try_into().unwrap();
358        let y: u32 = (*y).try_into().unwrap();
359        let res = x & y;
360        let res = res as u64;
361        self.write_column(position, res);
362        res
363    }
364
365    unsafe fn nor_witness(
366        &mut self,
367        x: &Self::Variable,
368        y: &Self::Variable,
369        position: Self::Position,
370    ) -> Self::Variable {
371        let x: u32 = (*x).try_into().unwrap();
372        let y: u32 = (*y).try_into().unwrap();
373        let res = !(x | y);
374        let res = res as u64;
375        self.write_column(position, res);
376        res
377    }
378
379    unsafe fn or_witness(
380        &mut self,
381        x: &Self::Variable,
382        y: &Self::Variable,
383        position: Self::Position,
384    ) -> Self::Variable {
385        let x: u32 = (*x).try_into().unwrap();
386        let y: u32 = (*y).try_into().unwrap();
387        let res = x | y;
388        let res = res as u64;
389        self.write_column(position, res);
390        res
391    }
392
393    unsafe fn xor_witness(
394        &mut self,
395        x: &Self::Variable,
396        y: &Self::Variable,
397        position: Self::Position,
398    ) -> Self::Variable {
399        let x: u32 = (*x).try_into().unwrap();
400        let y: u32 = (*y).try_into().unwrap();
401        let res = x ^ y;
402        let res = res as u64;
403        self.write_column(position, res);
404        res
405    }
406
407    unsafe fn add_witness(
408        &mut self,
409        x: &Self::Variable,
410        y: &Self::Variable,
411        out_position: Self::Position,
412        overflow_position: Self::Position,
413    ) -> (Self::Variable, Self::Variable) {
414        let x: u32 = (*x).try_into().unwrap();
415        let y: u32 = (*y).try_into().unwrap();
416        // https://doc.rust-lang.org/std/primitive.u32.html#method.overflowing_add
417        let res = x.overflowing_add(y);
418        let (res_, overflow) = (res.0 as u64, res.1 as u64);
419        self.write_column(out_position, res_);
420        self.write_column(overflow_position, overflow);
421        (res_, overflow)
422    }
423
424    unsafe fn sub_witness(
425        &mut self,
426        x: &Self::Variable,
427        y: &Self::Variable,
428        out_position: Self::Position,
429        underflow_position: Self::Position,
430    ) -> (Self::Variable, Self::Variable) {
431        let x: u32 = (*x).try_into().unwrap();
432        let y: u32 = (*y).try_into().unwrap();
433        // https://doc.rust-lang.org/std/primitive.u32.html#method.overflowing_sub
434        let res = x.overflowing_sub(y);
435        let (res_, underflow) = (res.0 as u64, res.1 as u64);
436        self.write_column(out_position, res_);
437        self.write_column(underflow_position, underflow);
438        (res_, underflow)
439    }
440
441    unsafe fn mul_signed_witness(
442        &mut self,
443        x: &Self::Variable,
444        y: &Self::Variable,
445        position: Self::Position,
446    ) -> Self::Variable {
447        let x: u32 = (*x).try_into().unwrap();
448        let y: u32 = (*y).try_into().unwrap();
449        let res = ((x as i32) * (y as i32)) as u32;
450        let res = res as u64;
451        self.write_column(position, res);
452        res
453    }
454
455    unsafe fn mul_hi_signed(
456        &mut self,
457        x: &Self::Variable,
458        y: &Self::Variable,
459        position: Self::Position,
460    ) -> Self::Variable {
461        let x: i32 = (*x).try_into().unwrap();
462        let y: i32 = (*y).try_into().unwrap();
463        let res = (x as i64) * (y as i64);
464        let res = (res >> 32) as i32;
465        let res = res as u64;
466        self.write_column(position, res);
467        res
468    }
469
470    unsafe fn mul_lo_signed(
471        &mut self,
472        x: &Self::Variable,
473        y: &Self::Variable,
474        position: Self::Position,
475    ) -> Self::Variable {
476        let x: i32 = (*x).try_into().unwrap();
477        let y: i32 = (*y).try_into().unwrap();
478        let res = ((x as i64) * (y as i64)) as u64;
479        let res = (res & ((1 << 32) - 1)) as u32;
480        let res = res as u64;
481        self.write_column(position, res);
482        res
483    }
484
485    unsafe fn mul_hi(
486        &mut self,
487        x: &Self::Variable,
488        y: &Self::Variable,
489        position: Self::Position,
490    ) -> Self::Variable {
491        let x: u32 = (*x).try_into().unwrap();
492        let y: u32 = (*y).try_into().unwrap();
493        let res = (x as u64) * (y as u64);
494        let res = (res >> 32) as u32;
495        let res = res as u64;
496        self.write_column(position, res);
497        res
498    }
499
500    unsafe fn mul_hi_signed_unsigned(
501        &mut self,
502        x: &Self::Variable,
503        y: &Self::Variable,
504        position: Self::Position,
505    ) -> Self::Variable {
506        let x: u32 = (*x).try_into().unwrap();
507        let y: u32 = (*y).try_into().unwrap();
508        let res = (((x as i32) as i64) * (y as i64)) as u64;
509        let res = (res >> 32) as u32;
510        let res = res as u64;
511        self.write_column(position, res);
512        res
513    }
514
515    unsafe fn div_signed(
516        &mut self,
517        x: &Self::Variable,
518        y: &Self::Variable,
519        position: Self::Position,
520    ) -> Self::Variable {
521        let x: i32 = (*x).try_into().unwrap();
522        let y: i32 = (*y).try_into().unwrap();
523        let res = (x / y) as u32;
524        let res = res as u64;
525        self.write_column(position, res);
526        res
527    }
528
529    unsafe fn mul_lo(
530        &mut self,
531        x: &Self::Variable,
532        y: &Self::Variable,
533        position: Self::Position,
534    ) -> Self::Variable {
535        let x: u32 = (*x).try_into().unwrap();
536        let y: u32 = (*y).try_into().unwrap();
537        let res = (x as u64) * (y as u64);
538        let res = (res & ((1 << 32) - 1)) as u32;
539        let res = res as u64;
540        self.write_column(position, res);
541        res
542    }
543
544    unsafe fn mod_signed(
545        &mut self,
546        x: &Self::Variable,
547        y: &Self::Variable,
548        position: Self::Position,
549    ) -> Self::Variable {
550        let x: i32 = (*x).try_into().unwrap();
551        let y: i32 = (*y).try_into().unwrap();
552        let res = (x % y) as u32;
553        let res = res as u64;
554        self.write_column(position, res);
555        res
556    }
557
558    unsafe fn div(
559        &mut self,
560        x: &Self::Variable,
561        y: &Self::Variable,
562        position: Self::Position,
563    ) -> Self::Variable {
564        let x: u32 = (*x).try_into().unwrap();
565        let y: u32 = (*y).try_into().unwrap();
566        let res = x / y;
567        let res = res as u64;
568        self.write_column(position, res);
569        res
570    }
571
572    unsafe fn mod_unsigned(
573        &mut self,
574        x: &Self::Variable,
575        y: &Self::Variable,
576        position: Self::Position,
577    ) -> Self::Variable {
578        let x: u32 = (*x).try_into().unwrap();
579        let y: u32 = (*y).try_into().unwrap();
580        let res = x % y;
581        let res = res as u64;
582        self.write_column(position, res);
583        res
584    }
585
586    unsafe fn count_leading_zeros(
587        &mut self,
588        x: &Self::Variable,
589        position: Self::Position,
590    ) -> Self::Variable {
591        let x: u32 = (*x).try_into().unwrap();
592        let res = x.leading_zeros();
593        let res = res as u64;
594        self.write_column(position, res);
595        res
596    }
597
598    unsafe fn count_leading_ones(
599        &mut self,
600        x: &Self::Variable,
601        position: Self::Position,
602    ) -> Self::Variable {
603        let x: u32 = (*x).try_into().unwrap();
604        let res = x.leading_ones();
605        let res = res as u64;
606        self.write_column(position, res);
607        res
608    }
609
610    fn copy(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable {
611        self.write_column(position, *x);
612        *x
613    }
614
615    fn set_halted(&mut self, flag: Self::Variable) {
616        if flag == 0 {
617            self.halt = false
618        } else if flag == 1 {
619            self.halt = true
620        } else {
621            panic!("Bad value for flag in set_halted: {}", flag);
622        }
623    }
624
625    fn report_exit(&mut self, exit_code: &Self::Variable) {
626        println!(
627            "Exited with code {} at step {}",
628            *exit_code,
629            self.normalized_instruction_counter()
630        );
631    }
632
633    fn reset(&mut self) {
634        self.scratch_state_idx = 0;
635        self.scratch_state = fresh_scratch_state();
636        self.selector = INSTRUCTION_SET_SIZE;
637    }
638}
639
640impl<Fp: Field> Env<Fp> {
641    pub fn create(page_size: usize, state: State) -> Self {
642        let initial_instruction_pointer = state.pc;
643        let next_instruction_pointer = state.next_pc;
644
645        let selector = INSTRUCTION_SET_SIZE;
646
647        let mut initial_memory: Vec<(u32, Vec<u8>)> = state
648            .memory
649            .into_iter()
650            // Check that the conversion from page data is correct
651            .map(|page| (page.index, page.data))
652            .collect();
653
654        for (_address, initial_memory) in initial_memory.iter_mut() {
655            initial_memory.extend((0..(page_size - initial_memory.len())).map(|_| 0u8));
656            assert_eq!(initial_memory.len(), page_size);
657        }
658
659        let memory_offsets = initial_memory
660            .iter()
661            .map(|(offset, _)| *offset)
662            .collect::<Vec<_>>();
663
664        let initial_registers = {
665            Registers {
666                general_purpose: state.registers,
667                current_instruction_pointer: initial_instruction_pointer,
668                next_instruction_pointer,
669                heap_pointer: state.heap,
670            }
671        };
672
673        let mut registers = initial_registers.clone();
674        registers[2] = 0x408004f0;
675        // set the stack pointer to the top of the stack
676
677        Env {
678            instruction_counter: state.step,
679            memory: initial_memory.clone(),
680            last_memory_accesses: [0usize; 3],
681            memory_write_index: memory_offsets
682                .iter()
683                .map(|offset| (*offset, vec![0u64; page_size]))
684                .collect(),
685            last_memory_write_index_accesses: [0usize; 3],
686            registers,
687            registers_write_index: Registers::default(),
688            scratch_state_idx: 0,
689            scratch_state: fresh_scratch_state(),
690            scratch_state_inverse_idx: 0,
691            scratch_state_inverse: fresh_scratch_state(),
692            halt: state.exited,
693            selector,
694        }
695    }
696
697    pub fn next_instruction_counter(&self) -> u64 {
698        (self.normalized_instruction_counter() + 1) * MAX_ACC
699    }
700
701    pub fn decode_instruction(&mut self) -> (Instruction, u32) {
702        /* https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf */
703        let instruction =
704            ((self.get_memory_direct(self.registers.current_instruction_pointer) as u32) << 24)
705                | ((self.get_memory_direct(self.registers.current_instruction_pointer + 1) as u32)
706                    << 16)
707                | ((self.get_memory_direct(self.registers.current_instruction_pointer + 2) as u32)
708                    << 8)
709                | (self.get_memory_direct(self.registers.current_instruction_pointer + 3) as u32);
710        let instruction = instruction.to_be(); // convert to big endian for more straightforward decoding
711        let opcode = {
712            match instruction & 0b1111111 // bits 0-6
713            {
714                0b0110111 => Instruction::UType(UInstruction::LoadUpperImmediate),
715                0b0010111 => Instruction::UType(UInstruction::AddUpperImmediate),
716                0b1101111 => Instruction::UJType(UJInstruction::JumpAndLink),
717                0b1100011 =>
718                match (instruction >> 12) & 0x7 // bits 12-14 for func3
719                {
720                    0b000 => Instruction::SBType(SBInstruction::BranchEq),
721                    0b001 => Instruction::SBType(SBInstruction::BranchNeq),
722                    0b100 => Instruction::SBType(SBInstruction::BranchLessThan),
723                    0b101 => Instruction::SBType(SBInstruction::BranchGreaterThanEqual),
724                    0b110 => Instruction::SBType(SBInstruction::BranchLessThanUnsigned),
725                    0b111 => Instruction::SBType(SBInstruction::BranchGreaterThanEqualUnsigned),
726                    _ => panic!("Unknown SBType instruction with full inst {}", instruction),
727                },
728                0b1100111 => Instruction::IType(IInstruction::JumpAndLinkRegister),
729                0b0000011 =>
730                match (instruction >> 12) & 0x7 // bits 12-14 for func3
731                {
732                    0b000 => Instruction::IType(IInstruction::LoadByte),
733                    0b001 => Instruction::IType(IInstruction::LoadHalf),
734                    0b010 => Instruction::IType(IInstruction::LoadWord),
735                    0b100 => Instruction::IType(IInstruction::LoadByteUnsigned),
736                    0b101 => Instruction::IType(IInstruction::LoadHalfUnsigned),
737                    _ => panic!("Unknown IType instruction with full inst {}", instruction),
738                },
739                0b0100011 =>
740                match (instruction >> 12) & 0x7 // bits 12-14 for func3
741                {
742                    0b000 => Instruction::SType(SInstruction::StoreByte),
743                    0b001 => Instruction::SType(SInstruction::StoreHalf),
744                    0b010 => Instruction::SType(SInstruction::StoreWord),
745                    _ => panic!("Unknown SType instruction with full inst {}", instruction),
746                },
747                0b0010011 =>
748                match (instruction >> 12) & 0x7 // bits 12-14 for func3
749                {
750                    0b000 => Instruction::IType(IInstruction::AddImmediate),
751                    0b010 => Instruction::IType(IInstruction::SetLessThanImmediate),
752                    0b011 => Instruction::IType(IInstruction::SetLessThanImmediateUnsigned),
753                    0b100 => Instruction::IType(IInstruction::XorImmediate),
754                    0b110 => Instruction::IType(IInstruction::OrImmediate),
755                    0b111 => Instruction::IType(IInstruction::AndImmediate),
756                    0b001 => Instruction::IType(IInstruction::ShiftLeftLogicalImmediate),
757                    0b101 =>
758                    match (instruction >> 30) & 0x1 // bit 30 in simm component of IType
759                    {
760                    0b0 => Instruction::IType(IInstruction::ShiftRightLogicalImmediate),
761                    0b1 => Instruction::IType(IInstruction::ShiftRightArithmeticImmediate),
762                    _ => panic!("Unknown IType in shift right instructions with full inst {}", instruction),
763                    },
764                    _ => panic!("Unknown IType instruction with full inst {}", instruction),
765                },
766                0b0110011 => {
767                    let funct5 = instruction >> 27 & 0x1F; // bits 27-31 for funct5
768                    let funct2 = instruction >> 25 & 0x3; // bits 25-26 for func2
769                    let funct3 = instruction >> 12 & 0x7; // bits 12-14 for func3
770                    match funct2 {
771                        // These are the instructions for the base integer set
772                        0b00 => {
773                            // The integer set have two sets of instructions
774                            // using a different funct5 value
775                            match funct5 {
776                                0b00000 => {
777                                    // Note: all possible values are handled here
778                                    match funct3 {
779                                        0b000 => Instruction::RType(RInstruction::Add),
780                                        0b001 => Instruction::RType(RInstruction::ShiftLeftLogical),
781                                        0b010 => Instruction::RType(RInstruction::SetLessThan),
782                                        0b011 => Instruction::RType(RInstruction::SetLessThanUnsigned),
783                                        0b100 => Instruction::RType(RInstruction::Xor),
784                                        0b101 => Instruction::RType(RInstruction::ShiftRightLogical),
785                                        0b110 => Instruction::RType(RInstruction::Or),
786                                        0b111 => Instruction::RType(RInstruction::And),
787                                        _ => panic!("This case should never happen as funct3 is 8 bits long and all possible case are implemented. However, we still have an unknown opcode 0110011 instruction with full inst {} (funct5 = {}, funct2 = {}, funct3 = {})", instruction, funct5, funct2, funct3),
788                                    }
789                                },
790                                // Note that there are still some values unhandled here.
791                                0b01000 => {
792                                    // Note that there are still 6 values unhandled here.
793                                    match funct3 {
794                                        0b000 => Instruction::RType(RInstruction::Sub),
795                                        0b101 => Instruction::RType(RInstruction::ShiftRightArithmetic),
796                                        _ => panic!("Unknown opcode 0110011 instruction with full inst {} (funct5 = {}, funct2 = {}, funct3 = {})", instruction, funct5, funct2, funct3),
797                                    }
798                                },
799                                // All the unhandled cases
800                                1_u32..=7_u32 | 9_u32..=u32::MAX =>
801                                    panic!("Unknown opcode 0110011 instruction with full inst {} (funct5 = {}, funct2 = {}, funct3 = {})", instruction, funct5, funct2, funct3),
802                            }
803                        },
804                        // These are the instructions for the M type
805                        0b01 => {
806                            match funct5 {
807                                // All instructions for the M type have the same
808                                // funct5 value. Still catching it here to be
809                                // sure we do not misinterpret an instruction
810                                0b00000 => {
811                                    match funct3 {
812                                        0b000 => Instruction::MType(MInstruction::Mul),
813                                        0b001 => Instruction::MType(MInstruction::Mulh),
814                                        0b010 => Instruction::MType(MInstruction::Mulhsu),
815                                        0b011 => Instruction::MType(MInstruction::Mulhu),
816                                        0b100 => Instruction::MType(MInstruction::Div),
817                                        0b101 => Instruction::MType(MInstruction::Divu),
818                                        0b110 => Instruction::MType(MInstruction::Rem),
819                                        0b111 => Instruction::MType(MInstruction::Remu),
820                                        _ => panic!("This case should never happen as funct3 is 8 bits long and all possible case are implemented. However, we still have an unknown opcode 0110011 instruction with full inst {} (funct5 = {}, funct2 = {}, funct3 = {})", instruction, funct5, funct2, funct3),
821                                    }
822                                },
823                                // Note that there are still some values unhandled here.
824                                1_u32..=u32::MAX => panic!("Unknown 0110011 instruction with full inst {} (funct5 = {}, funct2 = {}, funct3 = {})", instruction, funct5, funct2, funct3),
825                            }
826                        },
827                        _ => panic!("Unknown RType 0110011 instruction with full inst {} (funct5 = {}, funct2 = {}, funct3 = {})", instruction, funct5, funct2, funct3),
828                    }
829                }
830                0b0001111 =>
831                match (instruction >> 12) & 0x7 // bits 12-14 for func3
832                {
833                    0b000 => Instruction::RType(RInstruction::Fence),
834                    0b001 => Instruction::RType(RInstruction::FenceI),
835                    _ => panic!("Unknown RType 0001111 (Fence) instruction with full inst {}", instruction),
836                },
837                // FIXME: we should implement more syscalls here, and check the register state.
838                // Even better, only one constructor call ecall, and in the
839                // interpreter, we do the action depending on it
840                0b1110011 => Instruction::SyscallType(SyscallInstruction::SyscallSuccess),
841                _ => panic!("Unknown instruction with full inst {:b}, and opcode {:b}", instruction, instruction & 0b1111111),
842            }
843        };
844        (opcode, instruction)
845    }
846
847    /// Execute a single step in the RISCV32i program
848    pub fn step(&mut self) -> Instruction {
849        self.reset_scratch_state();
850        self.reset_scratch_state_inverse();
851        let (opcode, _instruction) = self.decode_instruction();
852
853        interpreter::interpret_instruction(self, opcode);
854
855        self.instruction_counter = self.next_instruction_counter();
856
857        // Integer division by MAX_ACC to obtain the actual instruction count
858        if self.halt {
859            println!(
860                "Halted at step={} instruction={:?}",
861                self.normalized_instruction_counter(),
862                opcode
863            );
864        }
865        opcode
866    }
867
868    pub fn reset_scratch_state(&mut self) {
869        self.scratch_state_idx = 0;
870        self.scratch_state = fresh_scratch_state();
871        self.selector = INSTRUCTION_SET_SIZE;
872    }
873
874    pub fn reset_scratch_state_inverse(&mut self) {
875        self.scratch_state_inverse_idx = 0;
876        self.scratch_state_inverse = fresh_scratch_state();
877    }
878
879    pub fn write_column(&mut self, column: Column, value: u64) {
880        self.write_field_column(column, value.into())
881    }
882
883    pub fn write_field_column(&mut self, column: Column, value: Fp) {
884        match column {
885            Column::ScratchState(idx) => self.scratch_state[idx] = value,
886            Column::ScratchStateInverse(idx) => self.scratch_state_inverse[idx] = value,
887            Column::InstructionCounter => panic!("Cannot overwrite the column {:?}", column),
888            Column::Selector(s) => self.selector = s,
889        }
890    }
891
892    pub fn update_last_memory_access(&mut self, i: usize) {
893        let [i_0, i_1, _] = self.last_memory_accesses;
894        self.last_memory_accesses = [i, i_0, i_1]
895    }
896
897    pub fn get_memory_page_index(&mut self, page: u32) -> usize {
898        for &i in self.last_memory_accesses.iter() {
899            if self.memory_write_index[i].0 == page {
900                return i;
901            }
902        }
903        for (i, (page_index, _memory)) in self.memory.iter_mut().enumerate() {
904            if *page_index == page {
905                self.update_last_memory_access(i);
906                return i;
907            }
908        }
909
910        // Memory not found; dynamically allocate
911        let memory = vec![0u8; PAGE_SIZE as usize];
912        self.memory.push((page, memory));
913        let i = self.memory.len() - 1;
914        self.update_last_memory_access(i);
915        i
916    }
917
918    pub fn update_last_memory_write_index_access(&mut self, i: usize) {
919        let [i_0, i_1, _] = self.last_memory_write_index_accesses;
920        self.last_memory_write_index_accesses = [i, i_0, i_1]
921    }
922
923    pub fn get_memory_access_page_index(&mut self, page: u32) -> usize {
924        for &i in self.last_memory_write_index_accesses.iter() {
925            if self.memory_write_index[i].0 == page {
926                return i;
927            }
928        }
929        for (i, (page_index, _memory_write_index)) in self.memory_write_index.iter_mut().enumerate()
930        {
931            if *page_index == page {
932                self.update_last_memory_write_index_access(i);
933                return i;
934            }
935        }
936
937        // Memory not found; dynamically allocate
938        let memory_write_index = vec![0u64; PAGE_SIZE as usize];
939        self.memory_write_index.push((page, memory_write_index));
940        let i = self.memory_write_index.len() - 1;
941        self.update_last_memory_write_index_access(i);
942        i
943    }
944
945    pub fn get_memory_direct(&mut self, addr: u32) -> u8 {
946        let page = addr >> PAGE_ADDRESS_SIZE;
947        let page_address = (addr & PAGE_ADDRESS_MASK) as usize;
948        let memory_idx = self.get_memory_page_index(page);
949        self.memory[memory_idx].1[page_address]
950    }
951
952    /// The actual number of instructions executed results from dividing the
953    /// instruction counter by MAX_ACC (floor).
954    ///
955    /// NOTE: actually, in practice it will be less than that, as there is no
956    ///       single instruction that performs all of them.
957    pub fn normalized_instruction_counter(&self) -> u64 {
958        self.instruction_counter / MAX_ACC
959    }
960}