o1vm/interpreters/riscv32im/
interpreter.rs

1//! This module implement an interpreter for the RISCV32 IM instruction set
2//! architecture.
3//!
4//! The implementation mostly follows (and copy) code from the MIPS interpreter
5//! available [here](../mips/interpreter.rs).
6//!
7//! ## Credits
8//!
9//! We would like to thank the authors of the following documentations:
10//! - <https://msyksphinz-self.github.io/riscv-isadoc/html/rvm.html> ([CC BY
11//!   4.0](https://creativecommons.org/licenses/by/4.0/)) from
12//!   [msyksphinz-self](https://github.com/msyksphinz-self/riscv-isadoc)
13//! - <https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf>
14//!   from the course [CS 3410: Computer System Organization and
15//!   Programming](https://www.cs.cornell.edu/courses/cs3410/2024fa/home.html) at
16//!   Cornell University.
17//!
18//! The format and description of each instruction is taken from these sources,
19//! and copied in this file for offline reference.
20//! If you are the author of the above documentations and would like to add or
21//! modify the credits, please open a pull request.
22//!
23//! For each instruction, we provide the format, description, and the
24//! semantic in pseudo-code of the instruction.
25//! When `signed` is mentioned in the pseudo-code, it means that the
26//! operation is performed as a signed operation (i.e. signed(v) where `v` is a
27//! 32 bits value means that `v` must be interpreted as a i32 value in Rust, the
28//! most significant bit being the sign - 1 for negative, 0 for positive).
29//! By default, unsigned operations are performed.
30
31use super::registers::{REGISTER_CURRENT_IP, REGISTER_HEAP_POINTER, REGISTER_NEXT_IP};
32use crate::lookups::{Lookup, LookupTableIDs};
33use ark_ff::{One, Zero};
34use strum::{EnumCount, IntoEnumIterator};
35use strum_macros::{EnumCount, EnumIter};
36
37#[derive(Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Hash, Ord, PartialOrd)]
38pub enum Instruction {
39    RType(RInstruction),
40    IType(IInstruction),
41    SType(SInstruction),
42    SBType(SBInstruction),
43    UType(UInstruction),
44    UJType(UJInstruction),
45    SyscallType(SyscallInstruction),
46    MType(MInstruction),
47}
48
49// See
50// https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf
51// for the order
52#[derive(
53    Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd,
54)]
55pub enum RInstruction {
56    #[default]
57    /// Format: `add rd, rs1, rs2`
58    ///
59    /// Description: Adds the registers rs1 and rs2 and stores the result in rd.
60    /// Arithmetic overflow is ignored and the result is simply the low 32
61    /// bits of the result.
62    Add, // add
63    /// Format: `sub rd, rs1, rs2`
64    ///
65    /// Description: Subs the register rs2 from rs1 and stores the result in rd.
66    /// Arithmetic overflow is ignored and the result is simply the low 32
67    /// bits of the result.
68    Sub, // sub
69    /// Format: `sll rd, rs1, rs2`
70    ///
71    /// Description: Performs logical left shift on the value in register rs1 by
72    /// the shift amount held in the lower 5 bits of register rs2.
73    ShiftLeftLogical, // sll
74    /// Format: `slt rd, rs1, rs2`
75    ///
76    /// Description: Place the value 1 in register rd if register rs1 is less
77    /// than register rs2 when both are treated as signed numbers, else 0 is
78    /// written to rd.
79    SetLessThan, // slt
80    /// Format: `sltu rd, rs1, rs2`
81    ///
82    /// Description: Place the value 1 in register rd if register rs1 is less
83    /// than register rs2 when both are treated as unsigned numbers, else 0 is
84    /// written to rd.
85    SetLessThanUnsigned, // sltu
86    /// Format: `xor rd, rs1, rs2`
87    ///
88    /// Description: Performs bitwise XOR on registers rs1 and rs2 and place the
89    /// result in rd
90    Xor, // xor
91    /// Format: `srl rd, rs1, rs2`
92    ///
93    /// Description: Logical right shift on the value in register rs1 by the
94    /// shift amount held in the lower 5 bits of register rs2
95    ShiftRightLogical, // srl
96    /// Format: `sra rd, rs1, rs2`
97    ///
98    /// Description: Performs arithmetic right shift on the value in register
99    /// rs1 by the shift amount held in the lower 5 bits of register rs2
100    ShiftRightArithmetic, // sra
101    /// Format: `or rd, rs1, rs2`
102    ///
103    /// Description: Performs bitwise OR on registers rs1 and rs2 and place the
104    /// result in rd
105    Or, // or
106    /// Format: `and rd, rs1, rs2`
107    ///
108    /// Description: Performs bitwise AND on registers rs1 and rs2 and place the
109    /// result in rd
110    And, // and
111    /// Format: `fence`
112    ///
113    /// Description: Used to order device I/O and memory accesses as viewed by
114    /// other RISC-V harts and external devices or coprocessors.
115    /// Any combination of device input (I), device output (O), memory reads
116    /// (R), and memory writes (W) may be ordered with respect to any
117    /// combination of the same. Informally, no other RISC-V hart or external
118    /// device can observe any operation in the successor set following a FENCE
119    /// before any operation in the predecessor set preceding the FENCE.
120    Fence, // fence
121    /// Format: `fence.i`
122    ///
123    /// Description: Provides explicit synchronization between writes to
124    /// instruction memory and instruction fetches on the same hart.
125    FenceI, // fence.i
126}
127
128#[derive(
129    Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd,
130)]
131pub enum IInstruction {
132    #[default]
133    /// Format: `lb rd, offset(rs1)`
134    ///
135    /// Description: Loads a 8-bit value from memory and sign-extends this to
136    /// 32 bits before storing it in register rd.
137    LoadByte, // lb
138    /// Format: `lh rd, offset(rs1)`
139    ///
140    /// Description: Loads a 16-bit value from memory and sign-extends this to
141    /// 32 bits before storing it in register rd.
142    LoadHalf, // lh
143    /// Format: `lw rd, offset(rs1)`
144    ///
145    /// Description: Loads a 32-bit value from memory and sign-extends this to
146    /// 32 bits before storing it in register rd.
147    LoadWord, // lw
148    /// Format: `lbu rd, offset(rs1)`
149    ///
150    /// Description: Loads a 8-bit value from memory and zero-extends this to
151    /// 32 bits before storing it in register rd.
152    LoadByteUnsigned, // lbu
153    /// Format: `lhu rd, offset(rs1)`
154    ///
155    /// Description: Loads a 16-bit value from memory and zero-extends this to
156    /// 32 bits before storing it in register rd.
157    LoadHalfUnsigned, // lhu
158
159    /// Format: `slli rd, rs1, shamt`
160    ///
161    /// Description: Performs logical left shift on the value in register rs1 by
162    /// the shift amount held in the lower 5 bits of the immediate
163    ShiftLeftLogicalImmediate, // slli
164    /// Format: `srli rd, rs1, shamt`
165    ///
166    /// Description: Performs logical right shift on the value in register rs1
167    /// by the shift amount held in the lower 5 bits of the immediate
168    ShiftRightLogicalImmediate, // srli
169    /// Format: `srai rd, rs1, shamt`
170    ///
171    /// Description: Performs arithmetic right shift on the value in register
172    /// rs1 by the shift amount held in the lower 5 bits of the immediate
173    ShiftRightArithmeticImmediate, // srai
174    /// Format: `slti rd, rs1, imm`
175    ///
176    /// Description: Place the value 1 in register rd if register rs1 is less
177    /// than the signextended immediate when both are treated as signed numbers,
178    /// else 0 is written to rd.
179    SetLessThanImmediate, // slti
180    /// Format: `sltiu rd, rs1, imm`
181    ///
182    /// Description: Place the value 1 in register rd if register rs1 is less
183    /// than the immediate when both are treated as unsigned numbers, else 0 is
184    /// written to rd.
185    SetLessThanImmediateUnsigned, // sltiu
186
187    /// Format: `addi rd, rs1, imm`
188    ///
189    /// Description: Adds the sign-extended 12-bit immediate to register rs1.
190    /// Arithmetic overflow is ignored and the result is simply the low 32
191    /// bits of the result. ADDI rd, rs1, 0 is used to implement the MV rd, rs1
192    /// assembler pseudo-instruction.
193    AddImmediate, // addi
194    /// Format: `xori rd, rs1, imm`
195    ///
196    /// Description: Performs bitwise XOR on register rs1 and the sign-extended
197    /// 12-bit immediate and place the result in rd Note, “XORI rd, rs1, -1”
198    /// performs a bitwise logical inversion of register rs1(assembler
199    /// pseudo-instruction NOT rd, rs)
200    XorImmediate, // xori
201    /// Format: `ori rd, rs1, imm`
202    ///
203    /// Description: Performs bitwise OR on register rs1 and the sign-extended
204    /// 12-bit immediate and place the result in rd
205    OrImmediate, // ori
206    /// Format: `andi rd, rs1, imm`
207    ///
208    /// Description: Performs bitwise AND on register rs1 and the sign-extended
209    /// 12-bit immediate and place the result in rd
210    AndImmediate, // andi
211
212    /// Format: `jalr rd, rs1, imm`
213    ///
214    /// Description: Jump to address and place return address in rd.
215    JumpAndLinkRegister, // jalr
216}
217
218#[derive(
219    Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd,
220)]
221pub enum SInstruction {
222    #[default]
223    /// Format: `sb rs2, offset(rs1)`
224    ///
225    /// Description: Store 8-bit, values from the low bits of register rs2 to
226    /// memory.
227    StoreByte, // sb
228    /// Format: `sh rs2, offset(rs1)`
229    ///
230    /// Description: Store 16-bit, values from the low bits of register rs2 to
231    /// memory.
232    StoreHalf, // sh
233    /// Format: `sw rs2, offset(rs1)`
234    ///
235    /// Description: Store 32-bit, values from the low bits of register rs2 to
236    /// memory.
237    StoreWord, // sw
238}
239
240#[derive(
241    Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd,
242)]
243pub enum SBInstruction {
244    #[default]
245    /// Format: `beq rs1, rs2, offset`
246    ///
247    /// Description: Take the branch if registers rs1 and rs2 are equal.
248    BranchEq, // beq
249    /// Format: `bne rs1, rs2, offset`
250    ///
251    /// Description: Take the branch if registers rs1 and rs2 are not equal.
252    BranchNeq, // bne
253    /// Format: `blt rs1, rs2, offset`
254    ///
255    /// Description: Take the branch if registers rs1 is less than rs2, using
256    /// signed comparison.
257    BranchLessThan, // blt
258    /// Format: `bge rs1, rs2, offset`
259    ///
260    /// Description: Take the branch if registers rs1 is greater than or equal
261    /// to rs2, using signed comparison.
262    BranchGreaterThanEqual, // bge
263    /// Format: `bltu rs1, rs2, offset`
264    ///
265    /// Description: Take the branch if registers rs1 is less than rs2, using
266    /// unsigned comparison.
267    BranchLessThanUnsigned, // bltu
268    /// Format: `bgeu rs1, rs2, offset`
269    ///
270    /// Description: Take the branch if registers rs1 is greater than or equal
271    /// to rs2, using unsigned comparison.
272    BranchGreaterThanEqualUnsigned, // bgeu
273}
274
275#[derive(
276    Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd,
277)]
278pub enum UInstruction {
279    #[default]
280    /// Format: `lui rd,imm`
281    ///
282    /// Description: Build 32-bit constants and uses the U-type format. LUI
283    /// places the U-immediate value in the top 20 bits of the destination
284    /// register rd, filling in the lowest 12 bits with zeros.
285    LoadUpperImmediate, // lui
286    /// Format: `auipc rd,imm`
287    ///
288    /// Description: Build pc-relative addresses and uses the U-type format.
289    /// AUIPC (Add upper immediate to PC) forms a 32-bit offset from the 20-bit
290    /// U-immediate, filling in the lowest 12 bits with zeros, adds this offset
291    /// to the pc, then places the result in register rd.
292    AddUpperImmediate, // auipc
293}
294
295#[derive(
296    Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd,
297)]
298pub enum UJInstruction {
299    #[default]
300    /// Format: `jal rd,imm`
301    ///
302    /// Description: Jump to address and place return address in rd.
303    JumpAndLink, // jal
304}
305
306#[derive(
307    Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd,
308)]
309pub enum SyscallInstruction {
310    #[default]
311    SyscallSuccess,
312}
313
314/// M extension instructions
315/// Following <https://msyksphinz-self.github.io/riscv-isadoc/html/rvm.html>
316#[derive(
317    Debug, Clone, Copy, Eq, PartialEq, EnumCount, EnumIter, Default, Hash, Ord, PartialOrd,
318)]
319pub enum MInstruction {
320    /// Format: `mul rd, rs1, rs2`
321    ///
322    /// Description: performs an 32-bit 32-bit multiplication of signed rs1
323    /// by signed rs2 and places the lower 32 bits in the destination register.
324    /// Implementation: `x[rd] = x[rs1] * x[rs2]`
325    #[default]
326    Mul, // mul
327    /// Format: `mulh rd, rs1, rs2`
328    ///
329    /// Description: performs an 32-bit 32-bit multiplication of signed rs1 by
330    /// signed rs2 and places the upper 32 bits in the destination register.
331    /// Implementation: `x[rd] = (x[rs1] * x[rs2]) >> 32`
332    Mulh, // mulh
333    /// Format: `mulhsu rd, rs1, rs2`
334    ///
335    /// Description: performs an 32-bit 32-bit multiplication of signed rs1 by
336    /// unsigned rs2 and places the upper 32 bits in the destination register.
337    /// Implementation: `x[rd] = (x[rs1] * x[rs2]) >> 32`
338    Mulhsu, // mulhsu
339    /// Format: `mulhu rd, rs1, rs2`
340    ///
341    /// Description: performs an 32-bit 32-bit multiplication of unsigned rs1 by
342    /// unsigned rs2 and places the upper 32 bits in the destination register.
343    /// Implementation: `x[rd] = (x[rs1] * x[rs2]) >> 32`
344    Mulhu, // mulhu
345    /// Format: `div rd, rs1, rs2`
346    ///
347    /// Description: perform an 32 bits by 32 bits signed integer division of
348    /// rs1 by rs2, rounding towards zero
349    /// Implementation: `x[rd] = x[rs1] /s x[rs2]`
350    Div, // div
351    /// Format: `divu rd, rs1, rs2`
352    ///
353    /// Description: performs an 32 bits by 32 bits unsigned integer division of
354    /// rs1 by rs2, rounding towards zero.
355    /// Implementation: `x[rd] = x[rs1] /u x[rs2]`
356    Divu, // divu
357    /// Format: `rem rd, rs1, rs2`
358    ///
359    /// Description: performs an 32 bits by 32 bits signed integer reminder of
360    /// rs1 by rs2.
361    /// Implementation: `x[rd] = x[rs1] %s x[rs2]`
362    Rem, // rem
363    /// Format: `remu rd, rs1, rs2`
364    ///
365    /// Description: performs an 32 bits by 32 bits unsigned integer reminder of
366    /// rs1 by rs2.
367    /// Implementation: `x[rd] = x[rs1] %u x[rs2]`
368    Remu, // remu
369}
370
371impl IntoIterator for Instruction {
372    type Item = Instruction;
373    type IntoIter = std::vec::IntoIter<Instruction>;
374
375    fn into_iter(self) -> Self::IntoIter {
376        match self {
377            Instruction::RType(_) => {
378                let mut iter_contents = Vec::with_capacity(RInstruction::COUNT);
379                for rtype in RInstruction::iter() {
380                    iter_contents.push(Instruction::RType(rtype));
381                }
382                iter_contents.into_iter()
383            }
384            Instruction::IType(_) => {
385                let mut iter_contents = Vec::with_capacity(IInstruction::COUNT);
386                for itype in IInstruction::iter() {
387                    iter_contents.push(Instruction::IType(itype));
388                }
389                iter_contents.into_iter()
390            }
391            Instruction::SType(_) => {
392                let mut iter_contents = Vec::with_capacity(SInstruction::COUNT);
393                for stype in SInstruction::iter() {
394                    iter_contents.push(Instruction::SType(stype));
395                }
396                iter_contents.into_iter()
397            }
398            Instruction::SBType(_) => {
399                let mut iter_contents = Vec::with_capacity(SBInstruction::COUNT);
400                for sbtype in SBInstruction::iter() {
401                    iter_contents.push(Instruction::SBType(sbtype));
402                }
403                iter_contents.into_iter()
404            }
405            Instruction::UType(_) => {
406                let mut iter_contents = Vec::with_capacity(UInstruction::COUNT);
407                for utype in UInstruction::iter() {
408                    iter_contents.push(Instruction::UType(utype));
409                }
410                iter_contents.into_iter()
411            }
412            Instruction::UJType(_) => {
413                let mut iter_contents = Vec::with_capacity(UJInstruction::COUNT);
414                for ujtype in UJInstruction::iter() {
415                    iter_contents.push(Instruction::UJType(ujtype));
416                }
417                iter_contents.into_iter()
418            }
419            Instruction::SyscallType(_) => {
420                let mut iter_contents = Vec::with_capacity(SyscallInstruction::COUNT);
421                for syscall in SyscallInstruction::iter() {
422                    iter_contents.push(Instruction::SyscallType(syscall));
423                }
424                iter_contents.into_iter()
425            }
426            Instruction::MType(_) => {
427                let mut iter_contents = Vec::with_capacity(MInstruction::COUNT);
428                for mtype in MInstruction::iter() {
429                    iter_contents.push(Instruction::MType(mtype));
430                }
431                iter_contents.into_iter()
432            }
433        }
434    }
435}
436
437impl std::fmt::Display for Instruction {
438    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
439        match self {
440            Instruction::RType(rtype) => write!(f, "{}", rtype),
441            Instruction::IType(itype) => write!(f, "{}", itype),
442            Instruction::SType(stype) => write!(f, "{}", stype),
443            Instruction::SBType(sbtype) => write!(f, "{}", sbtype),
444            Instruction::UType(utype) => write!(f, "{}", utype),
445            Instruction::UJType(ujtype) => write!(f, "{}", ujtype),
446            Instruction::SyscallType(_syscall) => write!(f, "ecall"),
447            Instruction::MType(mtype) => write!(f, "{}", mtype),
448        }
449    }
450}
451
452impl std::fmt::Display for RInstruction {
453    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
454        match self {
455            RInstruction::Add => write!(f, "add"),
456            RInstruction::Sub => write!(f, "sub"),
457            RInstruction::ShiftLeftLogical => write!(f, "sll"),
458            RInstruction::SetLessThan => write!(f, "slt"),
459            RInstruction::SetLessThanUnsigned => write!(f, "sltu"),
460            RInstruction::Xor => write!(f, "xor"),
461            RInstruction::ShiftRightLogical => write!(f, "srl"),
462            RInstruction::ShiftRightArithmetic => write!(f, "sra"),
463            RInstruction::Or => write!(f, "or"),
464            RInstruction::And => write!(f, "and"),
465            RInstruction::Fence => write!(f, "fence"),
466            RInstruction::FenceI => write!(f, "fence.i"),
467        }
468    }
469}
470
471impl std::fmt::Display for IInstruction {
472    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
473        match self {
474            IInstruction::LoadByte => write!(f, "lb"),
475            IInstruction::LoadHalf => write!(f, "lh"),
476            IInstruction::LoadWord => write!(f, "lw"),
477            IInstruction::LoadByteUnsigned => write!(f, "lbu"),
478            IInstruction::LoadHalfUnsigned => write!(f, "lhu"),
479            IInstruction::ShiftLeftLogicalImmediate => write!(f, "slli"),
480            IInstruction::ShiftRightLogicalImmediate => write!(f, "srli"),
481            IInstruction::ShiftRightArithmeticImmediate => write!(f, "srai"),
482            IInstruction::SetLessThanImmediate => write!(f, "slti"),
483            IInstruction::SetLessThanImmediateUnsigned => write!(f, "sltiu"),
484            IInstruction::AddImmediate => write!(f, "addi"),
485            IInstruction::XorImmediate => write!(f, "xori"),
486            IInstruction::OrImmediate => write!(f, "ori"),
487            IInstruction::AndImmediate => write!(f, "andi"),
488            IInstruction::JumpAndLinkRegister => write!(f, "jalr"),
489        }
490    }
491}
492
493impl std::fmt::Display for SInstruction {
494    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
495        match self {
496            SInstruction::StoreByte => write!(f, "sb"),
497            SInstruction::StoreHalf => write!(f, "sh"),
498            SInstruction::StoreWord => write!(f, "sw"),
499        }
500    }
501}
502
503impl std::fmt::Display for SBInstruction {
504    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
505        match self {
506            SBInstruction::BranchEq => write!(f, "beq"),
507            SBInstruction::BranchNeq => write!(f, "bne"),
508            SBInstruction::BranchLessThan => write!(f, "blt"),
509            SBInstruction::BranchGreaterThanEqual => write!(f, "bge"),
510            SBInstruction::BranchLessThanUnsigned => write!(f, "bltu"),
511            SBInstruction::BranchGreaterThanEqualUnsigned => write!(f, "bgeu"),
512        }
513    }
514}
515
516impl std::fmt::Display for UInstruction {
517    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
518        match self {
519            UInstruction::LoadUpperImmediate => write!(f, "lui"),
520            UInstruction::AddUpperImmediate => write!(f, "auipc"),
521        }
522    }
523}
524
525impl std::fmt::Display for UJInstruction {
526    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
527        match self {
528            UJInstruction::JumpAndLink => write!(f, "jal"),
529        }
530    }
531}
532
533impl std::fmt::Display for MInstruction {
534    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
535        match self {
536            MInstruction::Mul => write!(f, "mul"),
537            MInstruction::Mulh => write!(f, "mulh"),
538            MInstruction::Mulhsu => write!(f, "mulhsu"),
539            MInstruction::Mulhu => write!(f, "mulhu"),
540            MInstruction::Div => write!(f, "div"),
541            MInstruction::Divu => write!(f, "divu"),
542            MInstruction::Rem => write!(f, "rem"),
543            MInstruction::Remu => write!(f, "remu"),
544        }
545    }
546}
547
548pub trait InterpreterEnv {
549    /// A position can be seen as an indexed variable
550    type Position;
551
552    /// Allocate a new abstract variable for the current step.
553    /// The variable can be used to store temporary values.
554    /// The variables are "freed" after each step/instruction.
555    /// The variable allocation can be seen as an allocation on a stack that is
556    /// popped after each step execution.
557    /// At the moment, [crate::interpreters::riscv32im::SCRATCH_SIZE]
558    /// elements can be allocated. If more temporary variables are required for
559    /// an instruction, increase the value
560    /// [crate::interpreters::riscv32im::SCRATCH_SIZE]
561    fn alloc_scratch(&mut self) -> Self::Position;
562
563    fn alloc_scratch_inverse(&mut self) -> Self::Position;
564
565    type Variable: Clone
566        + std::ops::Add<Self::Variable, Output = Self::Variable>
567        + std::ops::Sub<Self::Variable, Output = Self::Variable>
568        + std::ops::Mul<Self::Variable, Output = Self::Variable>
569        + std::fmt::Debug
570        + Zero
571        + One;
572
573    // Returns the variable in the current row corresponding to a given column alias.
574    fn variable(&self, column: Self::Position) -> Self::Variable;
575
576    /// Add a constraint to the proof system, asserting that
577    /// `assert_equals_zero` is 0.
578    fn add_constraint(&mut self, assert_equals_zero: Self::Variable);
579
580    /// Activate the selector for the given instruction.
581    fn activate_selector(&mut self, selector: Instruction);
582
583    /// Check that the witness value in `assert_equals_zero` is 0; otherwise abort.
584    fn check_is_zero(assert_equals_zero: &Self::Variable);
585
586    /// Assert that the value `assert_equals_zero` is 0, and add a constraint in the proof system.
587    fn assert_is_zero(&mut self, assert_equals_zero: Self::Variable) {
588        Self::check_is_zero(&assert_equals_zero);
589        self.add_constraint(assert_equals_zero);
590    }
591
592    /// Check that the witness values in `x` and `y` are equal; otherwise abort.
593    fn check_equal(x: &Self::Variable, y: &Self::Variable);
594
595    /// Assert that the values `x` and `y` are equal, and add a constraint in the proof system.
596    fn assert_equal(&mut self, x: Self::Variable, y: Self::Variable) {
597        // NB: We use a different function to give a better error message for debugging.
598        Self::check_equal(&x, &y);
599        self.add_constraint(x - y);
600    }
601
602    /// Assert that the value `x` is boolean, and add a constraint in the proof system.
603    fn assert_boolean(&mut self, x: &Self::Variable);
604
605    fn add_lookup(&mut self, lookup: Lookup<Self::Variable>);
606
607    fn instruction_counter(&self) -> Self::Variable;
608
609    fn increase_instruction_counter(&mut self);
610
611    /// Fetch the value of the general purpose register with index `idx` and store it in local
612    /// position `output`.
613    ///
614    /// # Safety
615    ///
616    /// No lookups or other constraints are added as part of this operation. The caller must
617    /// manually add the lookups for this operation.
618    unsafe fn fetch_register(
619        &mut self,
620        idx: &Self::Variable,
621        output: Self::Position,
622    ) -> Self::Variable;
623
624    /// Set the general purpose register with index `idx` to `value` if `if_is_true` is true.
625    ///
626    /// # Safety
627    ///
628    /// No lookups or other constraints are added as part of this operation. The caller must
629    /// manually add the lookups for this operation.
630    unsafe fn push_register_if(
631        &mut self,
632        idx: &Self::Variable,
633        value: Self::Variable,
634        if_is_true: &Self::Variable,
635    );
636
637    /// Set the general purpose register with index `idx` to `value`.
638    ///
639    /// # Safety
640    ///
641    /// No lookups or other constraints are added as part of this operation. The caller must
642    /// manually add the lookups for this operation.
643    unsafe fn push_register(&mut self, idx: &Self::Variable, value: Self::Variable) {
644        self.push_register_if(idx, value, &Self::constant(1))
645    }
646
647    /// Fetch the last 'access index' for the general purpose register with index `idx`, and store
648    /// it in local position `output`.
649    ///
650    /// # Safety
651    ///
652    /// No lookups or other constraints are added as part of this operation. The caller must
653    /// manually add the lookups for this operation.
654    unsafe fn fetch_register_access(
655        &mut self,
656        idx: &Self::Variable,
657        output: Self::Position,
658    ) -> Self::Variable;
659
660    /// Set the last 'access index' for the general purpose register with index `idx` to `value` if
661    /// `if_is_true` is true.
662    ///
663    /// # Safety
664    ///
665    /// No lookups or other constraints are added as part of this operation. The caller must
666    /// manually add the lookups for this operation.
667    unsafe fn push_register_access_if(
668        &mut self,
669        idx: &Self::Variable,
670        value: Self::Variable,
671        if_is_true: &Self::Variable,
672    );
673
674    /// Set the last 'access index' for the general purpose register with index `idx` to `value`.
675    ///
676    /// # Safety
677    ///
678    /// No lookups or other constraints are added as part of this operation. The caller must
679    /// manually add the lookups for this operation.
680    unsafe fn push_register_access(&mut self, idx: &Self::Variable, value: Self::Variable) {
681        self.push_register_access_if(idx, value, &Self::constant(1))
682    }
683
684    /// Access the general purpose register with index `idx`, adding constraints asserting that the
685    /// old value was `old_value` and that the new value will be `new_value`, if `if_is_true` is
686    /// true.
687    ///
688    /// # Safety
689    ///
690    /// Callers of this function must manually update the registers if required, this function will
691    /// only update the access counter.
692    unsafe fn access_register_if(
693        &mut self,
694        idx: &Self::Variable,
695        old_value: &Self::Variable,
696        new_value: &Self::Variable,
697        if_is_true: &Self::Variable,
698    ) {
699        let last_accessed = {
700            let last_accessed_location = self.alloc_scratch();
701            unsafe { self.fetch_register_access(idx, last_accessed_location) }
702        };
703        let instruction_counter = self.instruction_counter();
704        let elapsed_time = instruction_counter.clone() - last_accessed.clone();
705        let new_accessed = {
706            // Here, we write as if the register had been written *at the start of the next
707            // instruction*. This ensures that we can't 'time travel' within this
708            // instruction, and claim to read the value that we're about to write!
709            instruction_counter + Self::constant(1)
710            // A register should allow multiple accesses to the same register within the same instruction.
711            // In order to allow this, we always increase the instruction counter by 1.
712        };
713        unsafe { self.push_register_access_if(idx, new_accessed.clone(), if_is_true) };
714        self.add_lookup(Lookup::write_if(
715            if_is_true.clone(),
716            LookupTableIDs::RegisterLookup,
717            vec![idx.clone(), last_accessed, old_value.clone()],
718        ));
719        self.add_lookup(Lookup::read_if(
720            if_is_true.clone(),
721            LookupTableIDs::RegisterLookup,
722            vec![idx.clone(), new_accessed, new_value.clone()],
723        ));
724        self.range_check64(&elapsed_time);
725
726        // Update instruction counter after accessing a register.
727        self.increase_instruction_counter();
728    }
729
730    fn read_register(&mut self, idx: &Self::Variable) -> Self::Variable {
731        let value = {
732            let value_location = self.alloc_scratch();
733            unsafe { self.fetch_register(idx, value_location) }
734        };
735        unsafe {
736            self.access_register(idx, &value, &value);
737        };
738        value
739    }
740
741    /// Access the general purpose register with index `idx`, adding constraints asserting that the
742    /// old value was `old_value` and that the new value will be `new_value`.
743    ///
744    /// # Safety
745    ///
746    /// Callers of this function must manually update the registers if required, this function will
747    /// only update the access counter.
748    unsafe fn access_register(
749        &mut self,
750        idx: &Self::Variable,
751        old_value: &Self::Variable,
752        new_value: &Self::Variable,
753    ) {
754        self.access_register_if(idx, old_value, new_value, &Self::constant(1))
755    }
756
757    fn write_register_if(
758        &mut self,
759        idx: &Self::Variable,
760        new_value: Self::Variable,
761        if_is_true: &Self::Variable,
762    ) {
763        let old_value = {
764            let value_location = self.alloc_scratch();
765            unsafe { self.fetch_register(idx, value_location) }
766        };
767        // Ensure that we only write 0 to the 0 register.
768        let actual_new_value = {
769            let idx_is_zero = self.is_zero(idx);
770            let pos = self.alloc_scratch();
771            self.copy(&((Self::constant(1) - idx_is_zero) * new_value), pos)
772        };
773        unsafe {
774            self.access_register_if(idx, &old_value, &actual_new_value, if_is_true);
775        };
776        unsafe {
777            self.push_register_if(idx, actual_new_value, if_is_true);
778        };
779    }
780
781    fn write_register(&mut self, idx: &Self::Variable, new_value: Self::Variable) {
782        self.write_register_if(idx, new_value, &Self::constant(1))
783    }
784
785    /// Fetch the memory value at address `addr` and store it in local position `output`.
786    ///
787    /// # Safety
788    ///
789    /// No lookups or other constraints are added as part of this operation. The caller must
790    /// manually add the lookups for this memory operation.
791    unsafe fn fetch_memory(
792        &mut self,
793        addr: &Self::Variable,
794        output: Self::Position,
795    ) -> Self::Variable;
796
797    /// Set the memory value at address `addr` to `value`.
798    ///
799    /// # Safety
800    ///
801    /// No lookups or other constraints are added as part of this operation. The caller must
802    /// manually add the lookups for this memory operation.
803    unsafe fn push_memory(&mut self, addr: &Self::Variable, value: Self::Variable);
804
805    /// Fetch the last 'access index' that the memory at address `addr` was written at, and store
806    /// it in local position `output`.
807    ///
808    /// # Safety
809    ///
810    /// No lookups or other constraints are added as part of this operation. The caller must
811    /// manually add the lookups for this memory operation.
812    unsafe fn fetch_memory_access(
813        &mut self,
814        addr: &Self::Variable,
815        output: Self::Position,
816    ) -> Self::Variable;
817
818    /// Set the last 'access index' for the memory at address `addr` to `value`.
819    ///
820    /// # Safety
821    ///
822    /// No lookups or other constraints are added as part of this operation. The caller must
823    /// manually add the lookups for this memory operation.
824    unsafe fn push_memory_access(&mut self, addr: &Self::Variable, value: Self::Variable);
825
826    /// Access the memory address `addr`, adding constraints asserting that the old value was
827    /// `old_value` and that the new value will be `new_value`.
828    ///
829    /// # Safety
830    ///
831    /// Callers of this function must manually update the memory if required, this function will
832    /// only update the access counter.
833    unsafe fn access_memory(
834        &mut self,
835        addr: &Self::Variable,
836        old_value: &Self::Variable,
837        new_value: &Self::Variable,
838    ) {
839        let last_accessed = {
840            let last_accessed_location = self.alloc_scratch();
841            unsafe { self.fetch_memory_access(addr, last_accessed_location) }
842        };
843        let instruction_counter = self.instruction_counter();
844        let elapsed_time = instruction_counter.clone() - last_accessed.clone();
845        let new_accessed = {
846            // Here, we write as if the memory had been written *at the start of the next
847            // instruction*. This ensures that we can't 'time travel' within this
848            // instruction, and claim to read the value that we're about to write!
849            instruction_counter + Self::constant(1)
850        };
851        unsafe { self.push_memory_access(addr, new_accessed.clone()) };
852        self.add_lookup(Lookup::write_one(
853            LookupTableIDs::MemoryLookup,
854            vec![addr.clone(), last_accessed, old_value.clone()],
855        ));
856        self.add_lookup(Lookup::read_one(
857            LookupTableIDs::MemoryLookup,
858            vec![addr.clone(), new_accessed, new_value.clone()],
859        ));
860        self.range_check64(&elapsed_time);
861
862        // Update instruction counter after accessing a memory address.
863        self.increase_instruction_counter();
864    }
865
866    fn read_memory(&mut self, addr: &Self::Variable) -> Self::Variable {
867        let value = {
868            let value_location = self.alloc_scratch();
869            unsafe { self.fetch_memory(addr, value_location) }
870        };
871        unsafe {
872            self.access_memory(addr, &value, &value);
873        };
874        value
875    }
876
877    fn write_memory(&mut self, addr: &Self::Variable, new_value: Self::Variable) {
878        let old_value = {
879            let value_location = self.alloc_scratch();
880            unsafe { self.fetch_memory(addr, value_location) }
881        };
882        unsafe {
883            self.access_memory(addr, &old_value, &new_value);
884        };
885        unsafe {
886            self.push_memory(addr, new_value);
887        };
888    }
889
890    /// Adds a lookup to the RangeCheck16Lookup table
891    fn lookup_16bits(&mut self, value: &Self::Variable) {
892        self.add_lookup(Lookup::read_one(
893            LookupTableIDs::RangeCheck16Lookup,
894            vec![value.clone()],
895        ));
896    }
897
898    /// Range checks with 2 lookups to the RangeCheck16Lookup table that a value
899    /// is at most 2^`bits`-1  (bits <= 16).
900    fn range_check16(&mut self, value: &Self::Variable, bits: u32) {
901        assert!(bits <= 16);
902        // 0 <= value < 2^bits
903        // First, check lowerbound: 0 <= value < 2^16
904        self.lookup_16bits(value);
905        // Second, check upperbound: value + 2^16 - 2^bits < 2^16
906        self.lookup_16bits(&(value.clone() + Self::constant(1 << 16) - Self::constant(1 << bits)));
907    }
908
909    /// Adds a lookup to the ByteLookup table
910    fn lookup_8bits(&mut self, value: &Self::Variable) {
911        self.add_lookup(Lookup::read_one(
912            LookupTableIDs::ByteLookup,
913            vec![value.clone()],
914        ));
915    }
916
917    /// Range checks with 2 lookups to the ByteLookup table that a value
918    /// is at most 2^`bits`-1  (bits <= 8).
919    fn range_check8(&mut self, value: &Self::Variable, bits: u32) {
920        assert!(bits <= 8);
921        // 0 <= value < 2^bits
922        // First, check lowerbound: 0 <= value < 2^8
923        self.lookup_8bits(value);
924        // Second, check upperbound: value + 2^8 - 2^bits < 2^8
925        self.lookup_8bits(&(value.clone() + Self::constant(1 << 8) - Self::constant(1 << bits)));
926    }
927
928    /// Adds a lookup to the AtMost4Lookup table
929    fn lookup_2bits(&mut self, value: &Self::Variable) {
930        self.add_lookup(Lookup::read_one(
931            LookupTableIDs::AtMost4Lookup,
932            vec![value.clone()],
933        ));
934    }
935
936    fn range_check64(&mut self, _value: &Self::Variable) {
937        // TODO
938    }
939
940    fn set_instruction_pointer(&mut self, ip: Self::Variable) {
941        let idx = Self::constant(REGISTER_CURRENT_IP as u32);
942        let new_accessed = self.instruction_counter() + Self::constant(1);
943        unsafe {
944            self.push_register_access(&idx, new_accessed.clone());
945        }
946        unsafe {
947            self.push_register(&idx, ip.clone());
948        }
949        self.add_lookup(Lookup::read_one(
950            LookupTableIDs::RegisterLookup,
951            vec![idx, new_accessed, ip],
952        ));
953    }
954
955    fn get_instruction_pointer(&mut self) -> Self::Variable {
956        let idx = Self::constant(REGISTER_CURRENT_IP as u32);
957        let ip = {
958            let value_location = self.alloc_scratch();
959            unsafe { self.fetch_register(&idx, value_location) }
960        };
961        self.add_lookup(Lookup::write_one(
962            LookupTableIDs::RegisterLookup,
963            vec![idx, self.instruction_counter(), ip.clone()],
964        ));
965        ip
966    }
967
968    fn set_next_instruction_pointer(&mut self, ip: Self::Variable) {
969        let idx = Self::constant(REGISTER_NEXT_IP as u32);
970        let new_accessed = self.instruction_counter() + Self::constant(1);
971        unsafe {
972            self.push_register_access(&idx, new_accessed.clone());
973        }
974        unsafe {
975            self.push_register(&idx, ip.clone());
976        }
977        self.add_lookup(Lookup::read_one(
978            LookupTableIDs::RegisterLookup,
979            vec![idx, new_accessed, ip],
980        ));
981    }
982
983    fn get_next_instruction_pointer(&mut self) -> Self::Variable {
984        let idx = Self::constant(REGISTER_NEXT_IP as u32);
985        let ip = {
986            let value_location = self.alloc_scratch();
987            unsafe { self.fetch_register(&idx, value_location) }
988        };
989        self.add_lookup(Lookup::write_one(
990            LookupTableIDs::RegisterLookup,
991            vec![idx, self.instruction_counter(), ip.clone()],
992        ));
993        ip
994    }
995
996    fn constant(x: u32) -> Self::Variable;
997
998    /// Extract the bits from the variable `x` between `highest_bit` and `lowest_bit`, and store
999    /// the result in `position`.
1000    /// `lowest_bit` becomes the least-significant bit of the resulting value.
1001    ///
1002    /// # Safety
1003    ///
1004    /// There are no constraints on the returned value; callers must assert the relationship with
1005    /// the source variable `x` and that the returned value fits in `highest_bit - lowest_bit`
1006    /// bits.
1007    ///
1008    /// Do not call this function with highest_bit - lowest_bit >= 32.
1009    // TODO: embed the range check in the function when highest_bit - lowest_bit <= 16?
1010    unsafe fn bitmask(
1011        &mut self,
1012        x: &Self::Variable,
1013        highest_bit: u32,
1014        lowest_bit: u32,
1015        position: Self::Position,
1016    ) -> Self::Variable;
1017
1018    /// Return the result of shifting `x` by `by`, storing the result in `position`.
1019    ///
1020    /// # Safety
1021    ///
1022    /// There are no constraints on the returned value; callers must assert the relationship with
1023    /// the source variable `x` and the shift amount `by`.
1024    unsafe fn shift_left(
1025        &mut self,
1026        x: &Self::Variable,
1027        by: &Self::Variable,
1028        position: Self::Position,
1029    ) -> Self::Variable;
1030
1031    /// Return the result of shifting `x` by `by`, storing the result in `position`.
1032    ///
1033    /// # Safety
1034    ///
1035    /// There are no constraints on the returned value; callers must assert the relationship with
1036    /// the source variable `x` and the shift amount `by`.
1037    unsafe fn shift_right(
1038        &mut self,
1039        x: &Self::Variable,
1040        by: &Self::Variable,
1041        position: Self::Position,
1042    ) -> Self::Variable;
1043
1044    /// Return the result of shifting `x` by `by`, storing the result in `position`.
1045    ///
1046    /// # Safety
1047    ///
1048    /// There are no constraints on the returned value; callers must assert the relationship with
1049    /// the source variable `x` and the shift amount `by`.
1050    unsafe fn shift_right_arithmetic(
1051        &mut self,
1052        x: &Self::Variable,
1053        by: &Self::Variable,
1054        position: Self::Position,
1055    ) -> Self::Variable;
1056
1057    /// Returns 1 if `x` is 0, or 0 otherwise, storing the result in `position`.
1058    ///
1059    /// # Safety
1060    ///
1061    /// There are no constraints on the returned value; callers must assert the relationship with
1062    /// `x`.
1063    unsafe fn test_zero(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable;
1064
1065    fn is_zero(&mut self, x: &Self::Variable) -> Self::Variable;
1066
1067    /// Returns 1 if `x` is equal to `y`, or 0 otherwise, storing the result in `position`.
1068    fn equal(&mut self, x: &Self::Variable, y: &Self::Variable) -> Self::Variable;
1069
1070    /// Returns 1 if `x < y` as unsigned integers, or 0 otherwise, storing the result in
1071    /// `position`.
1072    ///
1073    /// # Safety
1074    ///
1075    /// There are no constraints on the returned value; callers must assert that the value
1076    /// correctly represents the relationship between `x` and `y`
1077    unsafe fn test_less_than(
1078        &mut self,
1079        x: &Self::Variable,
1080        y: &Self::Variable,
1081        position: Self::Position,
1082    ) -> Self::Variable;
1083
1084    /// Returns 1 if `x < y` as signed integers, or 0 otherwise, storing the result in `position`.
1085    ///
1086    /// # Safety
1087    ///
1088    /// There are no constraints on the returned value; callers must assert that the value
1089    /// correctly represents the relationship between `x` and `y`
1090    unsafe fn test_less_than_signed(
1091        &mut self,
1092        x: &Self::Variable,
1093        y: &Self::Variable,
1094        position: Self::Position,
1095    ) -> Self::Variable;
1096
1097    /// Returns `x or y`, storing the result in `position`.
1098    ///
1099    /// # Safety
1100    ///
1101    /// There are no constraints on the returned value; callers must manually add constraints to
1102    /// ensure that it is correctly constructed.
1103    unsafe fn and_witness(
1104        &mut self,
1105        x: &Self::Variable,
1106        y: &Self::Variable,
1107        position: Self::Position,
1108    ) -> Self::Variable;
1109
1110    /// Returns `x or y`, storing the result in `position`.
1111    ///
1112    /// # Safety
1113    ///
1114    /// There are no constraints on the returned value; callers must manually add constraints to
1115    /// ensure that it is correctly constructed.
1116    unsafe fn or_witness(
1117        &mut self,
1118        x: &Self::Variable,
1119        y: &Self::Variable,
1120        position: Self::Position,
1121    ) -> Self::Variable;
1122
1123    /// Returns `x nor y`, storing the result in `position`.
1124    ///
1125    /// # Safety
1126    ///
1127    /// There are no constraints on the returned value; callers must manually add constraints to
1128    /// ensure that it is correctly constructed.
1129    unsafe fn nor_witness(
1130        &mut self,
1131        x: &Self::Variable,
1132        y: &Self::Variable,
1133        position: Self::Position,
1134    ) -> Self::Variable;
1135
1136    /// Returns `x xor y`, storing the result in `position`.
1137    ///
1138    /// # Safety
1139    ///
1140    /// There are no constraints on the returned value; callers must manually add constraints to
1141    /// ensure that it is correctly constructed.
1142    unsafe fn xor_witness(
1143        &mut self,
1144        x: &Self::Variable,
1145        y: &Self::Variable,
1146        position: Self::Position,
1147    ) -> Self::Variable;
1148
1149    /// Returns `x + y` and the overflow bit, storing the results in `position_out` and
1150    /// `position_overflow` respectively.
1151    ///
1152    /// # Safety
1153    ///
1154    /// There are no constraints on the returned values; callers must manually add constraints to
1155    /// ensure that they are correctly constructed.
1156    unsafe fn add_witness(
1157        &mut self,
1158        y: &Self::Variable,
1159        x: &Self::Variable,
1160        out_position: Self::Position,
1161        overflow_position: Self::Position,
1162    ) -> (Self::Variable, Self::Variable);
1163
1164    /// Returns `x + y` and the underflow bit, storing the results in `position_out` and
1165    /// `position_underflow` respectively.
1166    ///
1167    /// # Safety
1168    ///
1169    /// There are no constraints on the returned values; callers must manually add constraints to
1170    /// ensure that they are correctly constructed.
1171    unsafe fn sub_witness(
1172        &mut self,
1173        y: &Self::Variable,
1174        x: &Self::Variable,
1175        out_position: Self::Position,
1176        underflow_position: Self::Position,
1177    ) -> (Self::Variable, Self::Variable);
1178
1179    /// Returns `x * y`, where `x` and `y` are treated as integers, storing the result in `position`.
1180    ///
1181    /// # Safety
1182    ///
1183    /// There are no constraints on the returned value; callers must manually add constraints to
1184    /// ensure that it is correctly constructed.
1185    unsafe fn mul_signed_witness(
1186        &mut self,
1187        x: &Self::Variable,
1188        y: &Self::Variable,
1189        position: Self::Position,
1190    ) -> Self::Variable;
1191
1192    /// Returns `((x * y) >> 32`, storing the results in `position`.
1193    ///
1194    /// # Safety
1195    ///
1196    /// There are no constraints on the returned values; callers must manually add constraints to
1197    /// ensure that the pair of returned values correspond to the given values `x` and `y`, and
1198    /// that they fall within the desired range.
1199    unsafe fn mul_hi_signed(
1200        &mut self,
1201        x: &Self::Variable,
1202        y: &Self::Variable,
1203        position: Self::Position,
1204    ) -> Self::Variable;
1205
1206    /// Returns `(x * y) & ((1 << 32) - 1))`, storing the results in `position`
1207    ///
1208    /// # Safety
1209    ///
1210    /// There are no constraints on the returned values; callers must manually add constraints to
1211    /// ensure that the pair of returned values correspond to the given values `x` and `y`, and
1212    /// that they fall within the desired range.
1213    unsafe fn mul_lo_signed(
1214        &mut self,
1215        x: &Self::Variable,
1216        y: &Self::Variable,
1217        position: Self::Position,
1218    ) -> Self::Variable;
1219
1220    /// Returns `((x * y) >> 32`, storing the results in `position`.
1221    ///
1222    /// # Safety
1223    ///
1224    /// There are no constraints on the returned values; callers must manually add constraints to
1225    /// ensure that the pair of returned values correspond to the given values `x` and `y`, and
1226    /// that they fall within the desired range.
1227    unsafe fn mul_hi(
1228        &mut self,
1229        x: &Self::Variable,
1230        y: &Self::Variable,
1231        position: Self::Position,
1232    ) -> Self::Variable;
1233
1234    /// Returns `(x * y) & ((1 << 32) - 1))`, storing the results in `position`.
1235    ///
1236    /// # Safety
1237    ///
1238    /// There are no constraints on the returned values; callers must manually add constraints to
1239    /// ensure that the pair of returned values correspond to the given values `x` and `y`, and
1240    /// that they fall within the desired range.
1241    unsafe fn mul_lo(
1242        &mut self,
1243        x: &Self::Variable,
1244        y: &Self::Variable,
1245        position: Self::Position,
1246    ) -> Self::Variable;
1247
1248    /// Returns `((x * y) >> 32`, storing the results in `position`.
1249    ///
1250    /// # Safety
1251    ///
1252    /// There are no constraints on the returned values; callers must manually add constraints to
1253    /// ensure that the pair of returned values correspond to the given values `x` and `y`, and
1254    /// that they fall within the desired range.
1255    unsafe fn mul_hi_signed_unsigned(
1256        &mut self,
1257        x: &Self::Variable,
1258        y: &Self::Variable,
1259        position: Self::Position,
1260    ) -> Self::Variable;
1261
1262    /// Returns `x / y`, storing the results in `position`.
1263    ///
1264    /// # Safety
1265    ///
1266    /// There are no constraints on the returned values; callers must manually add constraints to
1267    /// ensure that the pair of returned values correspond to the given values `x` and `y`, and
1268    /// that they fall within the desired range.
1269    ///
1270    /// Division by zero will create a panic! exception. The RISC-V
1271    /// specification leaves the case unspecified, and therefore we prefer to
1272    /// forbid this case while building the witness.
1273    unsafe fn div_signed(
1274        &mut self,
1275        x: &Self::Variable,
1276        y: &Self::Variable,
1277        position: Self::Position,
1278    ) -> Self::Variable;
1279
1280    /// Returns `x % y`, storing the results in `position`.
1281    ///
1282    /// # Safety
1283    ///
1284    /// There are no constraints on the returned values; callers must manually add constraints to
1285    /// ensure that the pair of returned values correspond to the given values `x` and `y`, and
1286    /// that they fall within the desired range.
1287    unsafe fn mod_signed(
1288        &mut self,
1289        x: &Self::Variable,
1290        y: &Self::Variable,
1291        position: Self::Position,
1292    ) -> Self::Variable;
1293
1294    /// Returns `x / y`, storing the results in `position`.
1295    ///
1296    /// # Safety
1297    ///
1298    /// There are no constraints on the returned values; callers must manually add constraints to
1299    /// ensure that the pair of returned values correspond to the given values `x` and `y`, and
1300    /// that they fall within the desired range.
1301    ///
1302    /// Division by zero will create a panic! exception. The RISC-V
1303    /// specification leaves the case unspecified, and therefore we prefer to
1304    /// forbid this case while building the witness.
1305    unsafe fn div(
1306        &mut self,
1307        x: &Self::Variable,
1308        y: &Self::Variable,
1309        position: Self::Position,
1310    ) -> Self::Variable;
1311
1312    /// Returns `x % y`, storing the results in `position`.
1313    ///
1314    /// # Safety
1315    ///
1316    /// There are no constraints on the returned values; callers must manually add constraints to
1317    /// ensure that the pair of returned values correspond to the given values `x` and `y`, and
1318    /// that they fall within the desired range.
1319    unsafe fn mod_unsigned(
1320        &mut self,
1321        x: &Self::Variable,
1322        y: &Self::Variable,
1323        position: Self::Position,
1324    ) -> Self::Variable;
1325
1326    /// Returns the number of leading 0s in `x`, storing the result in `position`.
1327    ///
1328    /// # Safety
1329    ///
1330    /// There are no constraints on the returned value; callers must manually add constraints to
1331    /// ensure that it is correctly constructed.
1332    unsafe fn count_leading_zeros(
1333        &mut self,
1334        x: &Self::Variable,
1335        position: Self::Position,
1336    ) -> Self::Variable;
1337
1338    /// Returns the number of leading 1s in `x`, storing the result in `position`.
1339    ///
1340    /// # Safety
1341    ///
1342    /// There are no constraints on the returned value; callers must manually add constraints to
1343    /// ensure that it is correctly constructed.
1344    unsafe fn count_leading_ones(
1345        &mut self,
1346        x: &Self::Variable,
1347        position: Self::Position,
1348    ) -> Self::Variable;
1349
1350    fn copy(&mut self, x: &Self::Variable, position: Self::Position) -> Self::Variable;
1351
1352    /// Increases the heap pointer by `by_amount` if `if_is_true` is `1`, and returns the previous
1353    /// value of the heap pointer.
1354    fn increase_heap_pointer(
1355        &mut self,
1356        by_amount: &Self::Variable,
1357        if_is_true: &Self::Variable,
1358    ) -> Self::Variable {
1359        let idx = Self::constant(REGISTER_HEAP_POINTER as u32);
1360        let old_ptr = {
1361            let value_location = self.alloc_scratch();
1362            unsafe { self.fetch_register(&idx, value_location) }
1363        };
1364        let new_ptr = old_ptr.clone() + by_amount.clone();
1365        unsafe {
1366            self.access_register_if(&idx, &old_ptr, &new_ptr, if_is_true);
1367        };
1368        unsafe {
1369            self.push_register_if(&idx, new_ptr, if_is_true);
1370        };
1371        old_ptr
1372    }
1373
1374    fn set_halted(&mut self, flag: Self::Variable);
1375
1376    /// Given a variable `x`, this function extends it to a signed integer of
1377    /// `bitlength` bits.
1378    fn sign_extend(&mut self, x: &Self::Variable, bitlength: u32) -> Self::Variable {
1379        assert!(bitlength <= 32);
1380        // FIXME: Constrain `high_bit`
1381        let high_bit = {
1382            let pos = self.alloc_scratch();
1383            unsafe { self.bitmask(x, bitlength, bitlength - 1, pos) }
1384        };
1385        // Casting in u64 for special case of bitlength = 0 to avoid overflow.
1386        // No condition for constant time execution.
1387        // Decomposing the steps for readability.
1388        let v: u64 = (1u64 << (32 - bitlength)) - 1;
1389        let v: u64 = v << bitlength;
1390        let v: u32 = v as u32;
1391        high_bit * Self::constant(v) + x.clone()
1392    }
1393
1394    fn report_exit(&mut self, exit_code: &Self::Variable);
1395
1396    fn reset(&mut self);
1397}
1398
1399pub fn interpret_instruction<Env: InterpreterEnv>(env: &mut Env, instr: Instruction) {
1400    env.activate_selector(instr);
1401    match instr {
1402        Instruction::RType(rtype) => interpret_rtype(env, rtype),
1403        Instruction::IType(itype) => interpret_itype(env, itype),
1404        Instruction::SType(stype) => interpret_stype(env, stype),
1405        Instruction::SBType(sbtype) => interpret_sbtype(env, sbtype),
1406        Instruction::UType(utype) => interpret_utype(env, utype),
1407        Instruction::UJType(ujtype) => interpret_ujtype(env, ujtype),
1408        Instruction::SyscallType(syscall) => interpret_syscall(env, syscall),
1409        Instruction::MType(mtype) => interpret_mtype(env, mtype),
1410    }
1411}
1412
1413/// Interpret an R-type instruction.
1414/// The encoding of an R-type instruction is as follows:
1415/// ```text
1416/// | 31               25 | 24      20 | 19     15 | 14        12 | 11    7 | 6      0 |
1417/// | funct5 & funct 2    |     rs2    |    rs1    |    funct3    |    rd   |  opcode  |
1418/// ```
1419/// Following the documentation found
1420/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf)
1421pub fn interpret_rtype<Env: InterpreterEnv>(env: &mut Env, instr: RInstruction) {
1422    let instruction_pointer = env.get_instruction_pointer();
1423    let next_instruction_pointer = env.get_next_instruction_pointer();
1424
1425    let instruction = {
1426        let v0 = env.read_memory(&instruction_pointer);
1427        let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1)));
1428        let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2)));
1429        let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3)));
1430        (v3 * Env::constant(1 << 24))
1431            + (v2 * Env::constant(1 << 16))
1432            + (v1 * Env::constant(1 << 8))
1433            + v0
1434    };
1435
1436    // FIXME: constrain the opcode to match the instruction given as a parameter
1437    let opcode = {
1438        let pos = env.alloc_scratch();
1439        unsafe { env.bitmask(&instruction, 7, 0, pos) }
1440    };
1441    env.range_check8(&opcode, 7);
1442
1443    let rd = {
1444        let pos = env.alloc_scratch();
1445        unsafe { env.bitmask(&instruction, 12, 7, pos) }
1446    };
1447    env.range_check8(&rd, 5);
1448
1449    let funct3 = {
1450        let pos = env.alloc_scratch();
1451        unsafe { env.bitmask(&instruction, 15, 12, pos) }
1452    };
1453    env.range_check8(&funct3, 3);
1454
1455    let rs1 = {
1456        let pos = env.alloc_scratch();
1457        unsafe { env.bitmask(&instruction, 20, 15, pos) }
1458    };
1459    env.range_check8(&rs1, 5);
1460
1461    let rs2 = {
1462        let pos = env.alloc_scratch();
1463        unsafe { env.bitmask(&instruction, 25, 20, pos) }
1464    };
1465    env.range_check8(&rs2, 5);
1466
1467    let funct2 = {
1468        let pos = env.alloc_scratch();
1469        unsafe { env.bitmask(&instruction, 27, 25, pos) }
1470    };
1471    env.range_check8(&funct2, 2);
1472
1473    let funct5 = {
1474        let pos = env.alloc_scratch();
1475        unsafe { env.bitmask(&instruction, 32, 27, pos) }
1476    };
1477    env.range_check8(&funct5, 5);
1478
1479    // Check correctness of decomposition
1480    env.add_constraint(
1481        instruction
1482    - (opcode.clone() * Env::constant(1 << 0))    // opcode at bits 0-6
1483    - (rd.clone() * Env::constant(1 << 7))        // rd at bits 7-11
1484    - (funct3.clone() * Env::constant(1 << 12))   // funct3 at bits 12-14
1485    - (rs1.clone() * Env::constant(1 << 15))      // rs1 at bits 15-19
1486    - (rs2.clone() * Env::constant(1 << 20))      // rs2 at bits 20-24
1487    - (funct2.clone() * Env::constant(1 << 25))   // funct2 at bits 25-26
1488    - (funct5.clone() * Env::constant(1 << 27)), // funct5 at bits 27-31
1489    );
1490
1491    match instr {
1492        RInstruction::Add => {
1493            // add: x[rd] = x[rs1] + x[rs2]
1494            let local_rs1 = env.read_register(&rs1);
1495            let local_rs2 = env.read_register(&rs2);
1496            let local_rd = {
1497                let overflow_scratch = env.alloc_scratch();
1498                let rd_scratch = env.alloc_scratch();
1499                let (local_rd, _overflow) = unsafe {
1500                    env.add_witness(&local_rs1, &local_rs2, rd_scratch, overflow_scratch)
1501                };
1502                local_rd
1503            };
1504            env.write_register(&rd, local_rd);
1505
1506            env.set_instruction_pointer(next_instruction_pointer.clone());
1507            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1508        }
1509        RInstruction::Sub => {
1510            /* sub: x[rd] = x[rs1] - x[rs2] */
1511            let local_rs1 = env.read_register(&rs1);
1512            let local_rs2 = env.read_register(&rs2);
1513            let local_rd = {
1514                let underflow_scratch = env.alloc_scratch();
1515                let rd_scratch = env.alloc_scratch();
1516                let (local_rd, _underflow) = unsafe {
1517                    env.sub_witness(&local_rs1, &local_rs2, rd_scratch, underflow_scratch)
1518                };
1519                local_rd
1520            };
1521            env.write_register(&rd, local_rd);
1522
1523            env.set_instruction_pointer(next_instruction_pointer.clone());
1524            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1525        }
1526        RInstruction::ShiftLeftLogical => {
1527            /* sll: x[rd] = x[rs1] << x[rs2] */
1528            let local_rs1 = env.read_register(&rs1);
1529            let local_rs2 = env.read_register(&rs2);
1530            let local_rd = {
1531                let rd_scratch = env.alloc_scratch();
1532                unsafe { env.shift_left(&local_rs1, &local_rs2, rd_scratch) }
1533            };
1534            env.write_register(&rd, local_rd);
1535
1536            env.set_instruction_pointer(next_instruction_pointer.clone());
1537            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1538        }
1539        RInstruction::SetLessThan => {
1540            /* slt: x[rd] = (x[rs1] < x[rs2]) ? 1 : 0 */
1541            let local_rs1 = env.read_register(&rs1);
1542            let local_rs2 = env.read_register(&rs2);
1543            let local_rd = {
1544                let rd_scratch = env.alloc_scratch();
1545                unsafe { env.test_less_than_signed(&local_rs1, &local_rs2, rd_scratch) }
1546            };
1547            env.write_register(&rd, local_rd);
1548
1549            env.set_instruction_pointer(next_instruction_pointer.clone());
1550            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1551        }
1552        RInstruction::SetLessThanUnsigned => {
1553            /* sltu: x[rd] = (x[rs1] < (u)x[rs2]) ? 1 : 0 */
1554            let local_rs1 = env.read_register(&rs1);
1555            let local_rs2 = env.read_register(&rs2);
1556            let local_rd = {
1557                let pos = env.alloc_scratch();
1558                unsafe { env.test_less_than(&local_rs1, &local_rs2, pos) }
1559            };
1560            env.write_register(&rd, local_rd);
1561
1562            env.set_instruction_pointer(next_instruction_pointer.clone());
1563            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1564        }
1565        RInstruction::Xor => {
1566            /* xor: x[rd] = x[rs1] ^ x[rs2] */
1567            let local_rs1 = env.read_register(&rs1);
1568            let local_rs2 = env.read_register(&rs2);
1569            let local_rd = {
1570                let pos = env.alloc_scratch();
1571                unsafe { env.xor_witness(&local_rs1, &local_rs2, pos) }
1572            };
1573            env.write_register(&rd, local_rd);
1574
1575            env.set_instruction_pointer(next_instruction_pointer.clone());
1576            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1577        }
1578        RInstruction::ShiftRightLogical => {
1579            /* srl: x[rd] = x[rs1] >> x[rs2] */
1580            let local_rs1 = env.read_register(&rs1);
1581            let local_rs2 = env.read_register(&rs2);
1582            let local_rd = {
1583                let pos = env.alloc_scratch();
1584                unsafe { env.shift_right(&local_rs1, &local_rs2, pos) }
1585            };
1586            env.write_register(&rd, local_rd);
1587
1588            env.set_instruction_pointer(next_instruction_pointer.clone());
1589            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1590        }
1591        RInstruction::ShiftRightArithmetic => {
1592            /* sra: x[rd] = x[rs1] >> x[rs2] */
1593            let local_rs1 = env.read_register(&rs1);
1594            let local_rs2 = env.read_register(&rs2);
1595            let local_rd = {
1596                let pos = env.alloc_scratch();
1597                unsafe { env.shift_right_arithmetic(&local_rs1, &local_rs2, pos) }
1598            };
1599            env.write_register(&rd, local_rd);
1600
1601            env.set_instruction_pointer(next_instruction_pointer.clone());
1602            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1603        }
1604        RInstruction::Or => {
1605            /* or: x[rd] = x[rs1] | x[rs2] */
1606            let local_rs1 = env.read_register(&rs1);
1607            let local_rs2 = env.read_register(&rs2);
1608            let local_rd = {
1609                let pos = env.alloc_scratch();
1610                unsafe { env.or_witness(&local_rs1, &local_rs2, pos) }
1611            };
1612            env.write_register(&rd, local_rd);
1613
1614            env.set_instruction_pointer(next_instruction_pointer.clone());
1615            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1616        }
1617        RInstruction::And => {
1618            /* and: x[rd] = x[rs1] & x[rs2] */
1619            let local_rs1 = env.read_register(&rs1);
1620            let local_rs2 = env.read_register(&rs2);
1621            let local_rd = {
1622                let pos = env.alloc_scratch();
1623                unsafe { env.and_witness(&local_rs1, &local_rs2, pos) }
1624            };
1625            env.write_register(&rd, local_rd);
1626
1627            env.set_instruction_pointer(next_instruction_pointer.clone());
1628            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1629        }
1630        RInstruction::Fence => {
1631            unimplemented!("Fence")
1632        }
1633        RInstruction::FenceI => {
1634            unimplemented!("FenceI")
1635        }
1636    };
1637}
1638
1639/// Interpret an I-type instruction.
1640/// The encoding of an I-type instruction is as follows:
1641/// ```text
1642/// | 31     20 | 19     15 | 14    12 | 11    7 | 6      0 |
1643/// | immediate |    rs1    |  funct3  |    rd   |  opcode  |
1644/// ```
1645/// Following the documentation found
1646/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf)
1647pub fn interpret_itype<Env: InterpreterEnv>(env: &mut Env, instr: IInstruction) {
1648    let instruction_pointer = env.get_instruction_pointer();
1649    let next_instruction_pointer = env.get_next_instruction_pointer();
1650
1651    let instruction = {
1652        let v0 = env.read_memory(&instruction_pointer);
1653        let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1)));
1654        let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2)));
1655        let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3)));
1656        (v3 * Env::constant(1 << 24))
1657            + (v2 * Env::constant(1 << 16))
1658            + (v1 * Env::constant(1 << 8))
1659            + v0
1660    };
1661
1662    let opcode = {
1663        let pos = env.alloc_scratch();
1664        unsafe { env.bitmask(&instruction, 7, 0, pos) }
1665    };
1666    env.range_check8(&opcode, 7);
1667
1668    let rd = {
1669        let pos = env.alloc_scratch();
1670        unsafe { env.bitmask(&instruction, 12, 7, pos) }
1671    };
1672    env.range_check8(&rd, 5);
1673
1674    let funct3 = {
1675        let pos = env.alloc_scratch();
1676        unsafe { env.bitmask(&instruction, 15, 12, pos) }
1677    };
1678    env.range_check8(&funct3, 3);
1679
1680    let rs1 = {
1681        let pos = env.alloc_scratch();
1682        unsafe { env.bitmask(&instruction, 20, 15, pos) }
1683    };
1684    env.range_check8(&rs1, 5);
1685
1686    let imm = {
1687        let pos = env.alloc_scratch();
1688        unsafe { env.bitmask(&instruction, 32, 20, pos) }
1689    };
1690
1691    env.range_check16(&imm, 12);
1692
1693    let shamt = {
1694        let pos = env.alloc_scratch();
1695        unsafe { env.bitmask(&imm, 5, 0, pos) }
1696    };
1697    env.range_check8(&shamt, 5);
1698
1699    let imm_header = {
1700        let pos = env.alloc_scratch();
1701        unsafe { env.bitmask(&imm, 12, 5, pos) }
1702    };
1703    env.range_check8(&imm_header, 7);
1704
1705    // check the correctness of the immediate and shamt
1706    env.add_constraint(imm.clone() - (imm_header.clone() * Env::constant(1 << 5)) - shamt.clone());
1707
1708    // check correctness of decomposition
1709    env.add_constraint(
1710        instruction
1711            - (opcode.clone() * Env::constant(1 << 0))    // opcode at bits 0-6
1712            - (rd.clone() * Env::constant(1 << 7))        // rd at bits 7-11
1713            - (funct3.clone() * Env::constant(1 << 12))   // funct3 at bits 12-14
1714            - (rs1.clone() * Env::constant(1 << 15))      // rs1 at bits 15-19
1715            - (imm.clone() * Env::constant(1 << 20)), // imm at bits 20-32
1716    );
1717
1718    match instr {
1719        IInstruction::LoadByte => {
1720            // lb:  x[rd] = sext(M[x[rs1] + sext(offset)][7:0])
1721            let local_rs1 = env.read_register(&rs1);
1722            let local_imm = env.sign_extend(&imm, 12);
1723            let address = {
1724                let address_scratch = env.alloc_scratch();
1725                let overflow_scratch = env.alloc_scratch();
1726                let (address, _overflow) = unsafe {
1727                    env.add_witness(&local_rs1, &local_imm, address_scratch, overflow_scratch)
1728                };
1729                address
1730            };
1731            // Add a range check here for address
1732            let value = env.read_memory(&address);
1733            let value = env.sign_extend(&value, 8);
1734            env.write_register(&rd, value);
1735            env.set_instruction_pointer(next_instruction_pointer.clone());
1736            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1737        }
1738        IInstruction::LoadHalf => {
1739            // lh:  x[rd] = sext(M[x[rs1] + sext(offset)][15:0])
1740            let local_rs1 = env.read_register(&rs1);
1741            let local_imm = env.sign_extend(&imm, 12);
1742            let address = {
1743                let address_scratch = env.alloc_scratch();
1744                let overflow_scratch = env.alloc_scratch();
1745                let (address, _overflow) = unsafe {
1746                    env.add_witness(&local_rs1, &local_imm, address_scratch, overflow_scratch)
1747                };
1748                address
1749            };
1750            // Add a range check here for address
1751            let v0 = env.read_memory(&address);
1752            let v1 = env.read_memory(&(address.clone() + Env::constant(1)));
1753            let value = (v0 * Env::constant(1 << 8)) + v1;
1754            let value = env.sign_extend(&value, 16);
1755            env.write_register(&rd, value);
1756            env.set_instruction_pointer(next_instruction_pointer.clone());
1757            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1758        }
1759        IInstruction::LoadWord => {
1760            // lw:  x[rd] = sext(M[x[rs1] + sext(offset)][31:0])
1761            let base = env.read_register(&rs1);
1762            let offset = env.sign_extend(&imm, 12);
1763            let address = {
1764                let address_scratch = env.alloc_scratch();
1765                let overflow_scratch = env.alloc_scratch();
1766                let (address, _overflow) =
1767                    unsafe { env.add_witness(&base, &offset, address_scratch, overflow_scratch) };
1768                address
1769            };
1770            // Add a range check here for address
1771            let v0 = env.read_memory(&address);
1772            let v1 = env.read_memory(&(address.clone() + Env::constant(1)));
1773            let v2 = env.read_memory(&(address.clone() + Env::constant(2)));
1774            let v3 = env.read_memory(&(address.clone() + Env::constant(3)));
1775            let value = (v0 * Env::constant(1 << 24))
1776                + (v1 * Env::constant(1 << 16))
1777                + (v2 * Env::constant(1 << 8))
1778                + v3;
1779            let value = env.sign_extend(&value, 32);
1780            env.write_register(&rd, value);
1781            env.set_instruction_pointer(next_instruction_pointer.clone());
1782            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1783        }
1784        IInstruction::LoadByteUnsigned => {
1785            //lbu: x[rd] = M[x[rs1] + sext(offset)][7:0]
1786            let local_rs1 = env.read_register(&rs1);
1787            let local_imm = env.sign_extend(&imm, 12);
1788            let address = {
1789                let address_scratch = env.alloc_scratch();
1790                let overflow_scratch = env.alloc_scratch();
1791                let (address, _overflow) = unsafe {
1792                    env.add_witness(&local_rs1, &local_imm, address_scratch, overflow_scratch)
1793                };
1794                address
1795            };
1796            // lhu: Add a range check here for address
1797            let value = env.read_memory(&address);
1798            env.write_register(&rd, value);
1799            env.set_instruction_pointer(next_instruction_pointer.clone());
1800            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1801        }
1802        IInstruction::LoadHalfUnsigned => {
1803            // lhu: x[rd] = M[x[rs1] + sext(offset)][15:0]
1804            let local_rs1 = env.read_register(&rs1);
1805            let local_imm = env.sign_extend(&imm, 12);
1806            let address = {
1807                let address_scratch = env.alloc_scratch();
1808                let overflow_scratch = env.alloc_scratch();
1809                let (address, _overflow) = unsafe {
1810                    env.add_witness(&local_rs1, &local_imm, address_scratch, overflow_scratch)
1811                };
1812                address
1813            };
1814            // Add a range check here for address
1815            let v0 = env.read_memory(&address);
1816            let v1 = env.read_memory(&(address.clone() + Env::constant(1)));
1817            let value = (v0 * Env::constant(1 << 8)) + v1;
1818            env.write_register(&rd, value);
1819            env.set_instruction_pointer(next_instruction_pointer.clone());
1820            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1821        }
1822        IInstruction::ShiftLeftLogicalImmediate => {
1823            // slli: x[rd] = x[rs1] << shamt
1824            let local_rs1 = env.read_register(&rs1);
1825
1826            let local_rd = {
1827                let pos = env.alloc_scratch();
1828                unsafe { env.shift_left(&local_rs1, &shamt.clone(), pos) }
1829            };
1830
1831            env.write_register(&rd, local_rd);
1832            env.set_instruction_pointer(next_instruction_pointer.clone());
1833            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1834        }
1835        IInstruction::ShiftRightLogicalImmediate => {
1836            // srli: x[rd] = x[rs1] >> shamt
1837            let local_rs1 = env.read_register(&rs1);
1838            let local_rd = {
1839                let pos = env.alloc_scratch();
1840                unsafe { env.shift_right(&local_rs1, &shamt, pos) }
1841            };
1842            env.write_register(&rd, local_rd);
1843            env.set_instruction_pointer(next_instruction_pointer.clone());
1844            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1845        }
1846        IInstruction::ShiftRightArithmeticImmediate => {
1847            // srai: x[rd] = x[rs1] >> shamt
1848            let local_rs1 = env.read_register(&rs1);
1849
1850            let local_rd = {
1851                let pos = env.alloc_scratch();
1852                unsafe { env.shift_right_arithmetic(&local_rs1, &shamt, pos) }
1853            };
1854            env.write_register(&rd, local_rd);
1855            env.set_instruction_pointer(next_instruction_pointer.clone());
1856            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1857        }
1858        IInstruction::SetLessThanImmediate => {
1859            // slti: x[rd] = (x[rs1] < sext(immediate)) ? 1 : 0
1860            let local_rs1 = env.read_register(&rs1);
1861            let local_imm = env.sign_extend(&imm, 12);
1862            let local_rd = {
1863                let pos = env.alloc_scratch();
1864                unsafe { env.test_less_than_signed(&local_rs1, &local_imm, pos) }
1865            };
1866            env.write_register(&rd, local_rd);
1867            env.set_instruction_pointer(next_instruction_pointer.clone());
1868            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1869        }
1870        IInstruction::SetLessThanImmediateUnsigned => {
1871            // sltiu: x[rd] = (x[rs1] < (u)sext(immediate)) ? 1 : 0
1872            let local_rs1 = env.read_register(&rs1);
1873            let local_imm = env.sign_extend(&imm, 12);
1874            let local_rd = {
1875                let pos = env.alloc_scratch();
1876                unsafe { env.test_less_than(&local_rs1, &local_imm, pos) }
1877            };
1878            env.write_register(&rd, local_rd);
1879            env.set_instruction_pointer(next_instruction_pointer.clone());
1880            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1881        }
1882        IInstruction::AddImmediate => {
1883            // addi: x[rd] = x[rs1] + sext(immediate)
1884            let local_rs1 = env.read_register(&(rs1.clone()));
1885            let local_imm = env.sign_extend(&imm, 12);
1886            let local_rd = {
1887                let overflow_scratch = env.alloc_scratch();
1888                let rd_scratch = env.alloc_scratch();
1889                let (local_rd, _overflow) = unsafe {
1890                    env.add_witness(&local_rs1, &local_imm, rd_scratch, overflow_scratch)
1891                };
1892                local_rd
1893            };
1894            env.write_register(&rd, local_rd);
1895            env.set_instruction_pointer(next_instruction_pointer.clone());
1896            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1897        }
1898        IInstruction::XorImmediate => {
1899            // xori: x[rd] = x[rs1] ^ sext(immediate)
1900            let local_rs1 = env.read_register(&rs1);
1901            let local_imm = env.sign_extend(&imm, 12);
1902            let local_rd = {
1903                let rd_scratch = env.alloc_scratch();
1904                unsafe { env.xor_witness(&local_rs1, &local_imm, rd_scratch) }
1905            };
1906            env.write_register(&rd, local_rd);
1907            env.set_instruction_pointer(next_instruction_pointer.clone());
1908            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1909        }
1910        IInstruction::OrImmediate => {
1911            // ori: x[rd] = x[rs1] | sext(immediate)
1912            let local_rs1 = env.read_register(&rs1);
1913            let local_imm = env.sign_extend(&imm, 12);
1914            let local_rd = {
1915                let rd_scratch = env.alloc_scratch();
1916                unsafe { env.or_witness(&local_rs1, &local_imm, rd_scratch) }
1917            };
1918            env.write_register(&rd, local_rd);
1919            env.set_instruction_pointer(next_instruction_pointer.clone());
1920            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1921        }
1922        IInstruction::AndImmediate => {
1923            // andi: x[rd] = x[rs1] & sext(immediate)
1924            let local_rs1 = env.read_register(&rs1);
1925            let local_imm = env.sign_extend(&imm, 12);
1926            let local_rd = {
1927                let rd_scratch = env.alloc_scratch();
1928                unsafe { env.and_witness(&local_rs1, &local_imm, rd_scratch) }
1929            };
1930            env.write_register(&rd, local_rd);
1931            env.set_instruction_pointer(next_instruction_pointer.clone());
1932            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
1933        }
1934        IInstruction::JumpAndLinkRegister => {
1935            let addr = env.read_register(&rs1);
1936            // jalr:
1937            //  t  = pc+4;
1938            //  pc = (x[rs1] + sext(offset)) & ∼1; <- NOT NOW
1939            //  pc = (x[rs1] + sext(offset)); <- PLEASE FIXME
1940            //  x[rd] = t
1941            let offset = env.sign_extend(&imm, 12);
1942            let new_addr = {
1943                let res_scratch = env.alloc_scratch();
1944                let overflow_scratch = env.alloc_scratch();
1945                let (res, _overflow) =
1946                    unsafe { env.add_witness(&addr, &offset, res_scratch, overflow_scratch) };
1947                res
1948            };
1949            env.write_register(&rd, next_instruction_pointer.clone());
1950            env.set_instruction_pointer(new_addr.clone());
1951            env.set_next_instruction_pointer(new_addr.clone() + Env::constant(4u32));
1952        }
1953    };
1954}
1955
1956/// Interpret an S-type instruction.
1957/// The encoding of an S-type instruction is as follows:
1958/// ```text
1959/// | 31     25 | 24      20 | 19     15 | 14        12 | 11    7 | 6      0 |
1960/// | immediate |     rs2    |    rs1    |    funct3    |    imm  |  opcode  |
1961/// ```
1962/// Following the documentation found
1963/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf)
1964pub fn interpret_stype<Env: InterpreterEnv>(env: &mut Env, instr: SInstruction) {
1965    /* fetch instruction pointer from the program state */
1966    let instruction_pointer = env.get_instruction_pointer();
1967    /* compute the next instruction ptr and add one, as well record raml lookup */
1968    let next_instruction_pointer = env.get_next_instruction_pointer();
1969    /* read instruction from ip address */
1970    let instruction = {
1971        let v0 = env.read_memory(&instruction_pointer);
1972        let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1)));
1973        let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2)));
1974        let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3)));
1975        (v3 * Env::constant(1 << 24))
1976            + (v2 * Env::constant(1 << 16))
1977            + (v1 * Env::constant(1 << 8))
1978            + v0
1979    };
1980
1981    /* fetch opcode from instruction bit 0 - 6 for a total len of 7 */
1982    let opcode = {
1983        let pos = env.alloc_scratch();
1984        unsafe { env.bitmask(&instruction, 7, 0, pos) }
1985    };
1986    /* verify opcode is 7 bits */
1987    env.range_check8(&opcode, 7);
1988
1989    let imm0_4 = {
1990        let pos = env.alloc_scratch();
1991        unsafe { env.bitmask(&instruction, 12, 7, pos) }
1992        // bytes 7-11
1993    };
1994    env.range_check8(&imm0_4, 5);
1995    let funct3 = {
1996        let pos = env.alloc_scratch();
1997        unsafe { env.bitmask(&instruction, 15, 12, pos) }
1998    };
1999    env.range_check8(&funct3, 3);
2000
2001    let rs1 = {
2002        let pos = env.alloc_scratch();
2003        unsafe { env.bitmask(&instruction, 20, 15, pos) }
2004    };
2005    env.range_check8(&rs1, 5);
2006    let rs2 = {
2007        let pos = env.alloc_scratch();
2008        unsafe { env.bitmask(&instruction, 25, 20, pos) }
2009    };
2010    env.range_check8(&rs2, 5);
2011
2012    let imm5_11 = {
2013        let pos = env.alloc_scratch();
2014        unsafe { env.bitmask(&instruction, 32, 25, pos) }
2015        // bytes 25-31
2016    };
2017    env.range_check8(&imm5_11, 7);
2018
2019    // check correctness of decomposition of S type function
2020    env.add_constraint(
2021        instruction
2022         - (opcode.clone() * Env::constant(1 << 0))    // opcode at bits 0-6
2023         - (imm0_4.clone() * Env::constant(1 << 7))    // imm0_4 at bits 7-11
2024         - (funct3.clone() * Env::constant(1 << 12))   // funct3 at bits 12-14
2025         - (rs1.clone() * Env::constant(1 << 15))      // rs1 at bits 15-19
2026         - (rs2.clone() * Env::constant(1 << 20))      // rs2 at bits 20-24
2027         - (imm5_11.clone() * Env::constant(1 << 25)), // imm5_11 at bits 25-31
2028    );
2029
2030    let local_rs1 = env.read_register(&rs1);
2031    let local_imm0_4 = env.sign_extend(&imm0_4, 5);
2032    let local_imm5_11 = env.sign_extend(&imm5_11, 7);
2033    let local_imm0_11 = {
2034        let pos = env.alloc_scratch();
2035        let shift_pos = env.alloc_scratch();
2036        let shifted_imm5_11 =
2037            unsafe { env.shift_left(&local_imm5_11, &Env::constant(5), shift_pos) };
2038        let local_imm0_11 = unsafe { env.or_witness(&shifted_imm5_11, &local_imm0_4, pos) };
2039        env.sign_extend(&local_imm0_11, 11)
2040    };
2041    let address = {
2042        let address_scratch = env.alloc_scratch();
2043        let overflow_scratch = env.alloc_scratch();
2044        let (address, _overflow) = unsafe {
2045            env.add_witness(
2046                &local_rs1,
2047                &local_imm0_11,
2048                address_scratch,
2049                overflow_scratch,
2050            )
2051        };
2052        address
2053    };
2054    let local_rs2 = env.read_register(&rs2);
2055
2056    match instr {
2057        SInstruction::StoreByte => {
2058            // sb: M[x[rs1] + sext(offset)] = x[rs2][7:0]
2059            let v0 = {
2060                let value_scratch = env.alloc_scratch();
2061                unsafe { env.bitmask(&local_rs2, 8, 0, value_scratch) }
2062            };
2063
2064            env.lookup_8bits(&v0);
2065            env.write_memory(&address, v0);
2066
2067            env.set_instruction_pointer(next_instruction_pointer.clone());
2068            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
2069        }
2070        SInstruction::StoreHalf => {
2071            // sh: M[x[rs1] + sext(offset)] = x[rs2][15:0]
2072            let [v0, v1] = [
2073                {
2074                    let value_scratch = env.alloc_scratch();
2075                    unsafe { env.bitmask(&local_rs2, 8, 0, value_scratch) }
2076                },
2077                {
2078                    let value_scratch = env.alloc_scratch();
2079                    unsafe { env.bitmask(&local_rs2, 16, 8, value_scratch) }
2080                },
2081            ];
2082
2083            env.lookup_8bits(&v0);
2084            env.lookup_8bits(&v1);
2085
2086            env.write_memory(&address, v0);
2087            env.write_memory(&(address.clone() + Env::constant(1u32)), v1);
2088
2089            env.set_instruction_pointer(next_instruction_pointer.clone());
2090            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
2091        }
2092        SInstruction::StoreWord => {
2093            // sw: M[x[rs1] + sext(offset)] = x[rs2][31:0]
2094            let [v0, v1, v2, v3] = [
2095                {
2096                    let value_scratch = env.alloc_scratch();
2097                    unsafe { env.bitmask(&local_rs2, 32, 24, value_scratch) }
2098                },
2099                {
2100                    let value_scratch = env.alloc_scratch();
2101                    unsafe { env.bitmask(&local_rs2, 24, 16, value_scratch) }
2102                },
2103                {
2104                    let value_scratch = env.alloc_scratch();
2105                    unsafe { env.bitmask(&local_rs2, 16, 8, value_scratch) }
2106                },
2107                {
2108                    let value_scratch = env.alloc_scratch();
2109                    unsafe { env.bitmask(&local_rs2, 8, 0, value_scratch) }
2110                },
2111            ];
2112
2113            env.lookup_8bits(&v0);
2114            env.lookup_8bits(&v1);
2115            env.lookup_8bits(&v2);
2116            env.lookup_8bits(&v3);
2117
2118            env.write_memory(&address, v0);
2119            env.write_memory(&(address.clone() + Env::constant(1u32)), v1);
2120            env.write_memory(&(address.clone() + Env::constant(2u32)), v2);
2121            env.write_memory(&(address.clone() + Env::constant(3u32)), v3);
2122
2123            env.set_instruction_pointer(next_instruction_pointer.clone());
2124            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
2125        }
2126    };
2127}
2128
2129/// Interpret an SB-type instruction.
2130/// The encoding of an SB-type instruction is as follows:
2131/// ```text
2132/// | 31     25 | 24     20 | 19     15 | 14        12 | 11      7 | 6      0 |
2133/// |   imm2    |    rs2    |    rs1    |    funct3    |    imm1   |  opcode  |
2134/// ```
2135/// Following the documentation found
2136/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf)
2137pub fn interpret_sbtype<Env: InterpreterEnv>(env: &mut Env, instr: SBInstruction) {
2138    let instruction_pointer = env.get_instruction_pointer();
2139    let next_instruction_pointer = env.get_next_instruction_pointer();
2140    let instruction = {
2141        let v0 = env.read_memory(&instruction_pointer);
2142        let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1)));
2143        let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2)));
2144        let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3)));
2145        (v3 * Env::constant(1 << 24))
2146            + (v2 * Env::constant(1 << 16))
2147            + (v1 * Env::constant(1 << 8))
2148            + v0
2149    };
2150    let opcode = {
2151        let pos = env.alloc_scratch();
2152        unsafe { env.bitmask(&instruction, 7, 0, pos) }
2153    };
2154
2155    env.range_check8(&opcode, 7);
2156
2157    let funct3 = {
2158        let pos = env.alloc_scratch();
2159        unsafe { env.bitmask(&instruction, 15, 12, pos) }
2160    };
2161    env.range_check8(&funct3, 3);
2162
2163    let rs1 = {
2164        let pos = env.alloc_scratch();
2165        unsafe { env.bitmask(&instruction, 20, 15, pos) }
2166    };
2167    env.range_check8(&rs1, 5);
2168
2169    let rs2 = {
2170        let pos = env.alloc_scratch();
2171        unsafe { env.bitmask(&instruction, 25, 20, pos) }
2172    };
2173    env.range_check8(&rs2, 5);
2174
2175    let imm0_12 = {
2176        let imm11 = {
2177            let pos = env.alloc_scratch();
2178            unsafe { env.bitmask(&instruction, 8, 7, pos) }
2179        };
2180
2181        env.assert_boolean(&imm11);
2182
2183        let imm1_4 = {
2184            let pos = env.alloc_scratch();
2185            unsafe { env.bitmask(&instruction, 12, 8, pos) }
2186        };
2187        env.range_check8(&imm1_4, 4);
2188
2189        let imm5_10 = {
2190            let pos = env.alloc_scratch();
2191            unsafe { env.bitmask(&instruction, 31, 25, pos) }
2192        };
2193        env.range_check8(&imm5_10, 6);
2194
2195        let imm12 = {
2196            let pos = env.alloc_scratch();
2197            unsafe { env.bitmask(&instruction, 32, 31, pos) }
2198        };
2199        env.assert_boolean(&imm12);
2200
2201        // check correctness of decomposition of SB type function
2202        env.add_constraint(
2203            instruction.clone()
2204                - (opcode * Env::constant(1 << 0))    // opcode at bits 0-7
2205                - (imm11.clone() * Env::constant(1 << 7))     // imm11 at bits 8
2206                - (imm1_4.clone() * Env::constant(1 << 8))    // imm1_4 at bits 9-11
2207                - (funct3 * Env::constant(1 << 11))   // funct3 at bits 11-14
2208                - (rs1.clone() * Env::constant(1 << 14))      // rs1 at bits 15-20
2209                - (rs2.clone() * Env::constant(1 << 19))      // rs2 at bits 20-24
2210                - (imm5_10.clone() * Env::constant(1 << 24))  // imm5_10 at bits 25-30
2211                - (imm12.clone() * Env::constant(1 << 31)), // imm12 at bits 31
2212        );
2213
2214        (imm12 * Env::constant(1 << 12))
2215            + (imm11 * Env::constant(1 << 11))
2216            + (imm5_10 * Env::constant(1 << 5))
2217            + (imm1_4 * Env::constant(1 << 1))
2218    };
2219    // extra bit is because the 0th bit in the immediate is always 0 i.e you cannot jump to an odd address
2220    let imm0_12 = env.sign_extend(&imm0_12, 13);
2221
2222    match instr {
2223        SBInstruction::BranchEq => {
2224            // beq: if (x[rs1] == x[rs2]) pc += sext(offset)
2225            let local_rs1 = env.read_register(&rs1);
2226            let local_rs2 = env.read_register(&rs2);
2227
2228            let equals = env.equal(&local_rs1, &local_rs2);
2229            let offset = (Env::constant(1) - equals.clone()) * Env::constant(4) + equals * imm0_12;
2230            let offset = env.sign_extend(&offset, 12);
2231            let addr = {
2232                let res_scratch = env.alloc_scratch();
2233                let overflow_scratch = env.alloc_scratch();
2234                let (res, _overflow) = unsafe {
2235                    env.add_witness(
2236                        &next_instruction_pointer,
2237                        &offset,
2238                        res_scratch,
2239                        overflow_scratch,
2240                    )
2241                };
2242                // FIXME: Requires a range check
2243                res
2244            };
2245            env.set_instruction_pointer(next_instruction_pointer);
2246            env.set_next_instruction_pointer(addr);
2247        }
2248        SBInstruction::BranchNeq => {
2249            // bne: if (x[rs1] != x[rs2]) pc += sext(offset)
2250            let local_rs1 = env.read_register(&rs1);
2251            let local_rs2 = env.read_register(&rs2);
2252
2253            let equals = env.equal(&local_rs1, &local_rs2);
2254            let offset = equals.clone() * Env::constant(4) + (Env::constant(1) - equals) * imm0_12;
2255            let addr = {
2256                let res_scratch = env.alloc_scratch();
2257                let overflow_scratch = env.alloc_scratch();
2258                let (res, _overflow) = unsafe {
2259                    env.add_witness(
2260                        &next_instruction_pointer,
2261                        &offset,
2262                        res_scratch,
2263                        overflow_scratch,
2264                    )
2265                };
2266                // FIXME: Requires a range check
2267                res
2268            };
2269            env.set_instruction_pointer(next_instruction_pointer);
2270            env.set_next_instruction_pointer(addr);
2271        }
2272        SBInstruction::BranchLessThan => {
2273            // blt: if (x[rs1] < x[rs2]) pc += sext(offset)
2274            let local_rs1 = env.read_register(&rs1);
2275            let local_rs2 = env.read_register(&rs2);
2276
2277            let less_than = {
2278                let rd_scratch = env.alloc_scratch();
2279                unsafe { env.test_less_than_signed(&local_rs1, &local_rs2, rd_scratch) }
2280            };
2281            let offset = (less_than.clone()) * imm0_12
2282                + (Env::constant(1) - less_than.clone()) * Env::constant(4);
2283
2284            let addr = {
2285                let res_scratch = env.alloc_scratch();
2286                let overflow_scratch = env.alloc_scratch();
2287                let (res, _overflow) = unsafe {
2288                    env.add_witness(
2289                        &next_instruction_pointer,
2290                        &offset,
2291                        res_scratch,
2292                        overflow_scratch,
2293                    )
2294                };
2295                // FIXME: Requires a range check
2296                res
2297            };
2298            env.set_instruction_pointer(next_instruction_pointer);
2299            env.set_next_instruction_pointer(addr);
2300        }
2301        SBInstruction::BranchGreaterThanEqual => {
2302            // bge: if (x[rs1] >= x[rs2]) pc += sext(offset)
2303            let local_rs1 = env.read_register(&rs1);
2304            let local_rs2 = env.read_register(&rs2);
2305
2306            let less_than = {
2307                let rd_scratch = env.alloc_scratch();
2308                unsafe { env.test_less_than_signed(&local_rs1, &local_rs2, rd_scratch) }
2309            };
2310
2311            let offset =
2312                less_than.clone() * Env::constant(4) + (Env::constant(1) - less_than) * imm0_12;
2313            // greater than equal is the negation of less than
2314            let addr = {
2315                let res_scratch = env.alloc_scratch();
2316                let overflow_scratch = env.alloc_scratch();
2317                let (res, _overflow) = unsafe {
2318                    env.add_witness(
2319                        &next_instruction_pointer,
2320                        &offset,
2321                        res_scratch,
2322                        overflow_scratch,
2323                    )
2324                };
2325                // FIXME: Requires a range check
2326                res
2327            };
2328            env.set_instruction_pointer(next_instruction_pointer);
2329            env.set_next_instruction_pointer(addr);
2330        }
2331        SBInstruction::BranchLessThanUnsigned => {
2332            // bltu: if (x[rs1] <u x[rs2]) pc += sext(offset)
2333            let local_rs1 = env.read_register(&rs1);
2334            let local_rs2 = env.read_register(&rs2);
2335
2336            let less_than = {
2337                let rd_scratch = env.alloc_scratch();
2338                unsafe { env.test_less_than(&local_rs1, &local_rs2, rd_scratch) }
2339            };
2340
2341            let offset = (Env::constant(1) - less_than.clone()) * Env::constant(4)
2342                + less_than.clone() * imm0_12;
2343
2344            let addr = {
2345                let res_scratch = env.alloc_scratch();
2346                let overflow_scratch = env.alloc_scratch();
2347                let (res, _overflow) = unsafe {
2348                    env.add_witness(&instruction_pointer, &offset, res_scratch, overflow_scratch)
2349                };
2350                // FIXME: Requires a range check
2351                res
2352            };
2353
2354            env.set_instruction_pointer(next_instruction_pointer);
2355            env.set_next_instruction_pointer(addr);
2356        }
2357        SBInstruction::BranchGreaterThanEqualUnsigned => {
2358            // bgeu: if (x[rs1] >=u x[rs2]) pc += sext(offset)
2359            let local_rs1 = env.read_register(&rs1);
2360            let local_rs2 = env.read_register(&rs2);
2361
2362            let rd_scratch = env.alloc_scratch();
2363            let less_than = unsafe { env.test_less_than(&local_rs1, &local_rs2, rd_scratch) };
2364            let offset =
2365                less_than.clone() * Env::constant(4) + (Env::constant(1) - less_than) * imm0_12;
2366
2367            // greater than equal is the negation of less than
2368            let addr = {
2369                let res_scratch = env.alloc_scratch();
2370                let overflow_scratch = env.alloc_scratch();
2371                let (res, _overflow) = unsafe {
2372                    env.add_witness(&instruction_pointer, &offset, res_scratch, overflow_scratch)
2373                };
2374                res
2375            };
2376
2377            env.set_instruction_pointer(next_instruction_pointer);
2378            env.set_next_instruction_pointer(addr);
2379        }
2380    };
2381}
2382
2383/// Interpret an U-type instruction.
2384/// The encoding of an U-type instruction is as follows:
2385/// ```text
2386/// | 31     12 | 11    7 | 6      0 |
2387/// | immediate |    rd   |  opcode  |
2388/// ```
2389/// Following the documentation found
2390/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf)
2391pub fn interpret_utype<Env: InterpreterEnv>(env: &mut Env, instr: UInstruction) {
2392    let instruction_pointer = env.get_instruction_pointer();
2393    let next_instruction_pointer = env.get_next_instruction_pointer();
2394
2395    let instruction = {
2396        let v0 = env.read_memory(&instruction_pointer);
2397        let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1)));
2398        let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2)));
2399        let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3)));
2400        (v3 * Env::constant(1 << 24))
2401            + (v2 * Env::constant(1 << 16))
2402            + (v1 * Env::constant(1 << 8))
2403            + v0
2404    };
2405
2406    let opcode = {
2407        let pos = env.alloc_scratch();
2408        unsafe { env.bitmask(&instruction, 7, 0, pos) }
2409    };
2410    env.range_check8(&opcode, 7);
2411
2412    let rd = {
2413        let pos = env.alloc_scratch();
2414        unsafe { env.bitmask(&instruction, 12, 7, pos) }
2415    };
2416    env.range_check8(&rd, 5);
2417
2418    let imm = {
2419        let pos = env.alloc_scratch();
2420        unsafe { env.bitmask(&instruction, 32, 12, pos) }
2421    };
2422    // FIXME: rangecheck
2423
2424    // check correctness of decomposition of U type function
2425    env.add_constraint(
2426        instruction
2427            - (opcode.clone() * Env::constant(1 << 0))    // opcode at bits 0-6
2428            - (rd.clone() * Env::constant(1 << 7))        // rd at bits 7-11
2429            - (imm.clone() * Env::constant(1 << 12)), // imm at bits 12-31
2430    );
2431
2432    match instr {
2433        UInstruction::LoadUpperImmediate => {
2434            // lui: x[rd] = sext(immediate[31:12] << 12)
2435            let local_imm = {
2436                let shifted_imm = {
2437                    let pos = env.alloc_scratch();
2438                    unsafe { env.shift_left(&imm, &Env::constant(12), pos) }
2439                };
2440                env.sign_extend(&shifted_imm, 32)
2441            };
2442            env.write_register(&rd, local_imm);
2443
2444            env.set_instruction_pointer(next_instruction_pointer.clone());
2445            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
2446        }
2447        UInstruction::AddUpperImmediate => {
2448            // auipc: x[rd] = pc + sext(immediate[31:12] << 12)
2449            let local_imm = {
2450                let pos = env.alloc_scratch();
2451                let shifted_imm = unsafe { env.shift_left(&imm, &Env::constant(12), pos) };
2452                env.sign_extend(&shifted_imm, 32)
2453            };
2454            let local_pc = instruction_pointer.clone();
2455            let (local_rd, _) = {
2456                let pos = env.alloc_scratch();
2457                let overflow_pos = env.alloc_scratch();
2458                unsafe { env.add_witness(&local_pc, &local_imm, pos, overflow_pos) }
2459            };
2460            env.write_register(&rd, local_rd);
2461
2462            env.set_instruction_pointer(next_instruction_pointer.clone());
2463            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
2464        }
2465    };
2466}
2467
2468/// Interpret an UJ-type instruction.
2469/// The encoding of an UJ-type instruction is as follows:
2470/// ```text
2471/// | 31                12  | 11    7 | 6      0 |
2472/// | imm[20|10:1|11|19:12] |    rd   |  opcode  |
2473/// ```
2474/// Following the documentation found
2475/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf)
2476///
2477/// The interpretation of the immediate is as follow:
2478/// ```text
2479/// imm_20    = instruction[31]
2480/// imm_10_1  = instruction[30..21]
2481/// imm_11    = instruction[20]
2482/// imm_19_12 = instruction[19..12]
2483///
2484/// imm = imm_20    << 19   +
2485///       imm_19_12 << 11   +
2486///       imm_11    << 10   +
2487///       imm_10_1
2488///
2489/// # The immediate is then sign-extended. The sign-extension is in the bit imm20
2490/// imm = imm << 1
2491/// ```
2492pub fn interpret_ujtype<Env: InterpreterEnv>(env: &mut Env, instr: UJInstruction) {
2493    let instruction_pointer = env.get_instruction_pointer();
2494    let next_instruction_pointer = env.get_next_instruction_pointer();
2495
2496    let instruction = {
2497        let v0 = env.read_memory(&instruction_pointer);
2498        let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1)));
2499        let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2)));
2500        let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3)));
2501        (v3 * Env::constant(1 << 24))
2502            + (v2 * Env::constant(1 << 16))
2503            + (v1 * Env::constant(1 << 8))
2504            + v0
2505    };
2506
2507    let opcode = {
2508        let pos = env.alloc_scratch();
2509        unsafe { env.bitmask(&instruction, 7, 0, pos) }
2510    };
2511    env.range_check8(&opcode, 7);
2512
2513    let rd = {
2514        let pos = env.alloc_scratch();
2515        unsafe { env.bitmask(&instruction, 12, 7, pos) }
2516    };
2517    env.range_check8(&rd, 5);
2518
2519    let imm20 = {
2520        let pos = env.alloc_scratch();
2521        unsafe { env.bitmask(&instruction, 32, 31, pos) }
2522    };
2523    env.assert_boolean(&imm20);
2524
2525    let imm10_1 = {
2526        let pos = env.alloc_scratch();
2527        unsafe { env.bitmask(&instruction, 31, 21, pos) }
2528    };
2529    env.range_check16(&imm10_1, 10);
2530
2531    let imm11 = {
2532        let pos = env.alloc_scratch();
2533        unsafe { env.bitmask(&instruction, 21, 20, pos) }
2534    };
2535    env.assert_boolean(&imm11);
2536
2537    let imm19_12 = {
2538        let pos = env.alloc_scratch();
2539        unsafe { env.bitmask(&instruction, 20, 12, pos) }
2540    };
2541    env.range_check8(&imm19_12, 8);
2542
2543    let offset = {
2544        imm10_1.clone() * Env::constant(1 << 1)
2545            + imm11.clone() * Env::constant(1 << 11)
2546            + imm19_12.clone() * Env::constant(1 << 12)
2547            + imm20.clone() * Env::constant(1 << 20)
2548    };
2549
2550    // FIXME: check correctness of decomposition
2551
2552    match instr {
2553        UJInstruction::JumpAndLink => {
2554            let offset = env.sign_extend(&offset, 21);
2555            let new_addr = {
2556                let res_scratch = env.alloc_scratch();
2557                let overflow_scratch = env.alloc_scratch();
2558                let (res, _overflow) = unsafe {
2559                    env.add_witness(&instruction_pointer, &offset, res_scratch, overflow_scratch)
2560                };
2561                res
2562            };
2563            env.write_register(&rd, next_instruction_pointer.clone());
2564            env.set_instruction_pointer(new_addr.clone());
2565            env.set_next_instruction_pointer(new_addr + Env::constant(4u32));
2566        }
2567    }
2568}
2569
2570pub fn interpret_syscall<Env: InterpreterEnv>(env: &mut Env, _instr: SyscallInstruction) {
2571    // FIXME: check if it is syscall success. There is only one syscall atm
2572    env.set_halted(Env::constant(1));
2573}
2574
2575/// Interpret an M-type instruction.
2576/// The encoding of an M-type instruction is as follows:
2577/// ```text
2578/// | 31     27 | 26    25 | 24     20 | 19     15 | 14        12 | 11    7 | 6      0 |
2579/// |   00000   |    01    |    rs2    |    rs1    |    funct3    |    rd   |  opcode  |
2580/// ```
2581/// Following the documentation found
2582/// [here](https://www.cs.cornell.edu/courses/cs3410/2024fa/assignments/cpusim/riscv-instructions.pdf)
2583pub fn interpret_mtype<Env: InterpreterEnv>(env: &mut Env, instr: MInstruction) {
2584    let instruction_pointer = env.get_instruction_pointer();
2585    let next_instruction_pointer = env.get_next_instruction_pointer();
2586
2587    let instruction = {
2588        let v0 = env.read_memory(&instruction_pointer);
2589        let v1 = env.read_memory(&(instruction_pointer.clone() + Env::constant(1)));
2590        let v2 = env.read_memory(&(instruction_pointer.clone() + Env::constant(2)));
2591        let v3 = env.read_memory(&(instruction_pointer.clone() + Env::constant(3)));
2592        (v3 * Env::constant(1 << 24))
2593            + (v2 * Env::constant(1 << 16))
2594            + (v1 * Env::constant(1 << 8))
2595            + v0
2596    };
2597
2598    let opcode = {
2599        let pos = env.alloc_scratch();
2600        unsafe { env.bitmask(&instruction, 7, 0, pos) }
2601    };
2602    env.range_check8(&opcode, 7);
2603
2604    let rd = {
2605        let pos = env.alloc_scratch();
2606        unsafe { env.bitmask(&instruction, 12, 7, pos) }
2607    };
2608    env.range_check8(&rd, 5);
2609
2610    let funct3 = {
2611        let pos = env.alloc_scratch();
2612        unsafe { env.bitmask(&instruction, 15, 12, pos) }
2613    };
2614    env.range_check8(&funct3, 3);
2615
2616    let rs1 = {
2617        let pos = env.alloc_scratch();
2618        unsafe { env.bitmask(&instruction, 20, 15, pos) }
2619    };
2620    env.range_check8(&rs1, 5);
2621
2622    let rs2 = {
2623        let pos = env.alloc_scratch();
2624        unsafe { env.bitmask(&instruction, 25, 20, pos) }
2625    };
2626    env.range_check8(&rs2, 5);
2627
2628    let funct2 = {
2629        let pos = env.alloc_scratch();
2630        unsafe { env.bitmask(&instruction, 27, 25, pos) }
2631    };
2632    // FIXME: check it is equal to 01?
2633    env.range_check8(&funct2, 2);
2634
2635    let funct5 = {
2636        let pos = env.alloc_scratch();
2637        unsafe { env.bitmask(&instruction, 32, 27, pos) }
2638    };
2639    // FIXME: check it is equal to 00000?
2640    env.range_check8(&funct5, 5);
2641
2642    // Check decomposition of M type instruction
2643    env.add_constraint(
2644        instruction
2645            - (opcode.clone() * Env::constant(1 << 0))    // opcode at bits 0-6
2646            - (rd.clone() * Env::constant(1 << 7))        // rd at bits 7-11
2647            - (funct3.clone() * Env::constant(1 << 12))   // funct3 at bits 12-14
2648            - (rs1.clone() * Env::constant(1 << 15))      // rs1 at bits 15-19
2649            - (rs2.clone() * Env::constant(1 << 20))      // rs2 at bits 20-24
2650            - (funct2.clone() * Env::constant(1 << 25))   // funct2 at bits 25-26
2651            - (funct5.clone() * Env::constant(1 << 27)), // funct5 at bits 27-31
2652    );
2653
2654    match instr {
2655        MInstruction::Mul => {
2656            // x[rd] = x[rs1] * x[rs2]
2657            let rs1 = env.read_register(&rs1);
2658            let rs2 = env.read_register(&rs2);
2659            // FIXME: constrain
2660            let res = {
2661                let pos = env.alloc_scratch();
2662                unsafe { env.mul_lo_signed(&rs1, &rs2, pos) }
2663            };
2664            env.write_register(&rd, res);
2665
2666            env.set_instruction_pointer(next_instruction_pointer.clone());
2667            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
2668        }
2669        MInstruction::Mulh => {
2670            // x[rd] = (signed(x[rs1]) * signed(x[rs2])) >> 32
2671            let rs1 = env.read_register(&rs1);
2672            let rs2 = env.read_register(&rs2);
2673            // FIXME: constrain
2674            let res = {
2675                let pos = env.alloc_scratch();
2676                unsafe { env.mul_hi_signed(&rs1, &rs2, pos) }
2677            };
2678            env.write_register(&rd, res);
2679
2680            env.set_instruction_pointer(next_instruction_pointer.clone());
2681            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
2682        }
2683        MInstruction::Mulhsu => {
2684            // x[rd] = (signed(x[rs1]) * x[rs2]) >> 32
2685            let rs1 = env.read_register(&rs1);
2686            let rs2 = env.read_register(&rs2);
2687            // FIXME: constrain
2688            let res = {
2689                let pos = env.alloc_scratch();
2690                unsafe { env.mul_hi_signed_unsigned(&rs1, &rs2, pos) }
2691            };
2692            env.write_register(&rd, res);
2693
2694            env.set_instruction_pointer(next_instruction_pointer.clone());
2695            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
2696        }
2697        MInstruction::Mulhu => {
2698            // x[rd] = (x[rs1] * x[rs2]) >> 32
2699            let rs1 = env.read_register(&rs1);
2700            let rs2 = env.read_register(&rs2);
2701            // FIXME: constrain
2702            let res = {
2703                let pos = env.alloc_scratch();
2704                unsafe { env.mul_hi(&rs1, &rs2, pos) }
2705            };
2706            env.write_register(&rd, res);
2707
2708            env.set_instruction_pointer(next_instruction_pointer.clone());
2709            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
2710        }
2711        MInstruction::Div => {
2712            // x[rd] = signed(x[rs1]) / signed(x[rs2])
2713            let rs1 = env.read_register(&rs1);
2714            let rs2 = env.read_register(&rs2);
2715            // FIXME: constrain
2716            let res = {
2717                let pos = env.alloc_scratch();
2718                unsafe { env.div_signed(&rs1, &rs2, pos) }
2719            };
2720            env.write_register(&rd, res);
2721
2722            env.set_instruction_pointer(next_instruction_pointer.clone());
2723            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
2724        }
2725        MInstruction::Divu => {
2726            // x[rd] = x[rs1] / x[rs2]
2727            let rs1 = env.read_register(&rs1);
2728            let rs2 = env.read_register(&rs2);
2729            // FIXME: constrain
2730            let res = {
2731                let pos = env.alloc_scratch();
2732                unsafe { env.div(&rs1, &rs2, pos) }
2733            };
2734            env.write_register(&rd, res);
2735
2736            env.set_instruction_pointer(next_instruction_pointer.clone());
2737            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
2738        }
2739        MInstruction::Rem => {
2740            // x[rd] = signed(x[rs1]) % signed(x[rs2])
2741            let rs1 = env.read_register(&rs1);
2742            let rs2 = env.read_register(&rs2);
2743            // FIXME: constrain
2744            let res = {
2745                let pos = env.alloc_scratch();
2746                unsafe { env.mod_signed(&rs1, &rs2, pos) }
2747            };
2748            env.write_register(&rd, res);
2749
2750            env.set_instruction_pointer(next_instruction_pointer.clone());
2751            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
2752        }
2753        MInstruction::Remu => {
2754            // x[rd] = x[rs1] % x[rs2]
2755            let rs1 = env.read_register(&rs1);
2756            let rs2 = env.read_register(&rs2);
2757            // FIXME: constrain
2758            let res = {
2759                let pos = env.alloc_scratch();
2760                unsafe { env.mod_unsigned(&rs1, &rs2, pos) }
2761            };
2762            env.write_register(&rd, res);
2763
2764            env.set_instruction_pointer(next_instruction_pointer.clone());
2765            env.set_next_instruction_pointer(next_instruction_pointer + Env::constant(4u32));
2766        }
2767    }
2768}