bf/
bf.rs

1// SPDX-License-Identifier: MIT
2//
3// Copyright (c) 2024, Johannes Stoelp <dev@memzero.de>
4
5//! Brainfuck VM.
6//!
7//! This example implements a simple [brainfuck][bf] interpreter
8//! [`BrainfuckInterp`] and a jit compiler [`BrainfuckJit`].
9//!
10//! Brainfuck is an esoteric programming languge existing of 8 commands.
11//! - `>` increment data pointer.
12//! - `<` decrement data pointer.
13//! - `+` increment data at current data pointer.
14//! - `-` decrement data at current data pointer.
15//! - `.` output data at current data pointer.
16//! - `,` read input and store at current data pointer.
17//! - `[` jump behind matching `]` if data at data pointer is zero.
18//! - `]` jump behind matching `[` if data at data pointer is non-zero.
19//!
20//! The following is the `hello-world` program from [wikipedia][hw].
21//! ```
22//! ++++++++[>++++[>++>+++>+++>+<<<<-]>+>+>->>+[<]<-]>>.>---.+++++++..+++.>>.<-.<.+++.------.--------.>>+.>++.
23//! ```
24//!
25//! [bf]: https://en.wikipedia.org/wiki/Brainfuck
26//! [hw]: https://en.wikipedia.org/wiki/Brainfuck#Hello_World!
27
28use std::collections::HashMap;
29use std::io::Write;
30
31use juicebox_asm::insn::*;
32use juicebox_asm::Runtime;
33use juicebox_asm::{Asm, Imm64, Imm8, Label, Mem8, Reg64, Reg8};
34
35// -- BRAINFUCK INTERPRETER ----------------------------------------------------
36
37struct BrainfuckInterp {
38    pc: usize,
39    imem: Vec<char>,
40    dptr: usize,
41    dmem: [u8; 256],
42    branches: HashMap<usize, usize>,
43}
44
45impl BrainfuckInterp {
46    fn new(prog: &str) -> Result<Self, String> {
47        // Do a first pass over the bf program to filter whitespace and detect
48        // invalid tokens. Additionally validate all conditional branches, and
49        // compute their branch targets.
50        let (imem, branches) = {
51            // Instruction memory holding the final bf program.
52            let mut imem = Vec::new();
53            // Helper to track index of open brackets.
54            let mut lhs_brackets = Vec::new();
55            // Mapping from branch instruction to branch target.
56            let mut branches = HashMap::new();
57
58            for (idx, token) in prog.chars().filter(|c| !c.is_whitespace()).enumerate() {
59                match token {
60                    '<' | '>' | '+' | '-' | '.' | ',' => { /* ignore valid bf tokens */ }
61                    '[' => lhs_brackets.push(idx),
62                    ']' => {
63                        if let Some(lhs) = lhs_brackets.pop() {
64                            branches.insert(lhs, idx);
65                            branches.insert(idx, lhs);
66                        } else {
67                            return Err(format!("encountered un-balanced brackets, found ']' at index {idx} without matching '['"));
68                        }
69                    }
70                    _ => return Err(format!("invalid bf token '{token}'")),
71                }
72                imem.push(token)
73            }
74
75            if !lhs_brackets.is_empty() {
76                return Err(String::from(
77                    "encountered un-balanced brackets, left-over '[' after parsing bf program",
78                ));
79            }
80
81            (imem, branches)
82        };
83
84        Ok(BrainfuckInterp {
85            pc: 0,
86            imem,
87            dptr: 0,
88            dmem: [0; 256],
89            branches,
90        })
91    }
92}
93
94fn run_interp(prog: &str) {
95    let mut vm = BrainfuckInterp::new(prog).unwrap();
96
97    loop {
98        let insn = match vm.imem.get(vm.pc) {
99            Some(insn) => insn,
100            None => break, // End of bf program.
101        };
102
103        let putchar = |val: u8| {
104            std::io::stdout()
105                .write(&[val])
106                .expect("Failed to write to stdout!");
107        };
108
109        match insn {
110            '>' => {
111                vm.dptr += 1;
112                assert!(vm.dptr < vm.dmem.len());
113            }
114            '<' => {
115                assert!(vm.dptr > 0);
116                vm.dptr -= 1;
117            }
118            '+' => {
119                vm.dmem[vm.dptr] += 1;
120            }
121            '-' => {
122                vm.dmem[vm.dptr] -= 1;
123            }
124            '.' => {
125                putchar(vm.dmem[vm.dptr]);
126            }
127            ',' => {
128                unimplemented!("getchar");
129            }
130            '[' => {
131                if vm.dmem[vm.dptr] == 0 {
132                    vm.pc = *vm.branches.get(&vm.pc).unwrap();
133                }
134            }
135            ']' => {
136                if vm.dmem[vm.dptr] != 0 {
137                    vm.pc = *vm.branches.get(&vm.pc).unwrap();
138                }
139            }
140            _ => unreachable!(),
141        }
142
143        vm.pc += 1;
144    }
145}
146
147// -- BRAINFUCK JIT ------------------------------------------------------------
148
149#[cfg(not(any(target_arch = "x86_64", target_os = "linux")))]
150compile_error!("Only supported on x86_64 with SystemV abi");
151
152struct BrainfuckJit {
153    imem: Vec<char>,
154    dmem: [u8; 256],
155}
156
157impl BrainfuckJit {
158    fn new(prog: &str) -> Result<Self, String> {
159        // Do a first pass over the bf program to filter whitespace and detect
160        // invalid tokens.
161        let imem = prog
162            .chars()
163            .filter(|c| !c.is_whitespace())
164            .map(|c| match c {
165                '<' | '>' | '+' | '-' | '.' | ',' | '[' | ']' => Ok(c),
166                _ => Err(format!("invalid bf token '{c}'")),
167            })
168            .collect::<Result<Vec<char>, String>>()?;
169
170        Ok(BrainfuckJit {
171            imem,
172            dmem: [0; 256],
173        })
174    }
175}
176
177extern "C" fn putchar(c: u8) {
178    std::io::stdout()
179        .write(&[c])
180        .expect("Failed to write to stdout!");
181}
182
183fn run_jit(prog: &str) {
184    let mut vm = BrainfuckJit::new(prog).unwrap();
185
186    // Use callee saved registers to hold vm state, such that we don't need to
187    // save any state before calling out to putchar.
188    let dmem_base = Reg64::rbx;
189    let dmem_size = Reg64::r12;
190    let dmem_idx = Reg64::r13;
191
192    let mut asm = Asm::new();
193
194    // Save callee saved registers before we tamper them.
195    asm.push(dmem_base);
196    asm.push(dmem_size);
197    asm.push(dmem_idx);
198
199    // Move data memory pointer (argument on jit entry) into correct register.
200    asm.mov(dmem_base, Reg64::rdi);
201    // Move data memory size (compile time constant) into correct register.
202    asm.mov(dmem_size, Imm64::from(vm.dmem.len()));
203    // Clear data memory index.
204    asm.xor(dmem_idx, dmem_idx);
205
206    // A stack of label pairs, used to link up forward and backward jumps for a
207    // given '[]' pair.
208    let mut label_stack = Vec::new();
209
210    // Label to jump to when a data pointer overflow is detected.
211    let mut oob_ov = Label::new();
212    // Label to jump to when a data pointer underflow is detected.
213    let mut oob_uv = Label::new();
214
215    // Generate code for each instruction in the bf program.
216    let mut pc = 0;
217    while pc < vm.imem.len() {
218        match vm.imem[pc] {
219            '>' => {
220                asm.inc(dmem_idx);
221
222                // Check for data pointer overflow and jump to error handler if needed.
223                asm.cmp(dmem_idx, dmem_size);
224                asm.jz(&mut oob_ov);
225            }
226            '<' => {
227                // Check for data pointer underflow and jump to error handler if needed.
228                asm.test(dmem_idx, dmem_idx);
229                asm.jz(&mut oob_uv);
230
231                asm.dec(dmem_idx);
232            }
233            '+' => {
234                // Apply optimization to fold consecutive '+' instructions to a
235                // single add instruction during compile time.
236
237                match vm.imem[pc..].iter().take_while(|&&i| i.eq(&'+')).count() {
238                    1 => {
239                        asm.inc(Mem8::indirect_base_index(dmem_base, dmem_idx));
240                    }
241                    cnt if cnt <= u8::MAX as usize => {
242                        asm.add(
243                            Mem8::indirect_base_index(dmem_base, dmem_idx),
244                            Imm8::from(cnt as u8),
245                        );
246
247                        // Advance pc, but account for pc increment at the end
248                        // of the loop.
249                        pc += cnt - 1;
250                    }
251                    cnt @ _ => unimplemented!("cnt={cnt} oob, add with larger imm"),
252                }
253            }
254            '-' => {
255                // Apply optimization to fold consecutive '-' instructions to a
256                // single sub instruction during compile time.
257
258                match vm.imem[pc..].iter().take_while(|&&i| i.eq(&'-')).count() {
259                    1 => {
260                        asm.dec(Mem8::indirect_base_index(dmem_base, dmem_idx));
261                    }
262                    cnt if cnt <= u8::MAX as usize => {
263                        asm.sub(
264                            Mem8::indirect_base_index(dmem_base, dmem_idx),
265                            Imm8::from(cnt as u8),
266                        );
267
268                        // Advance pc, but account for pc increment at the end
269                        // of the loop.
270                        pc += cnt - 1;
271                    }
272                    cnt @ _ => unimplemented!("cnt={cnt} oob, sub with larger imm"),
273                }
274            }
275            '.' => {
276                // Load data memory from active cell into di register, which is
277                // the first argument register according to the SystemV abi,
278                // then call into putchar. Since we stored all out vm state in
279                // callee saved registers we don't need to save any registers
280                // before the call.
281                asm.mov(Reg8::dil, Mem8::indirect_base_index(dmem_base, dmem_idx));
282                asm.mov(Reg64::rax, Imm64::from(putchar as usize));
283                asm.call(Reg64::rax);
284            }
285            ',' => {
286                unimplemented!("getchar");
287            }
288            '[' => {
289                // Create new label pair.
290                label_stack.push((Label::new(), Label::new()));
291                // UNWRAP: We just pushed a new entry on the stack.
292                let label_pair = label_stack.last_mut().unwrap();
293
294                // Goto label_pair.0 if data memory at active cell is 0.
295                //   if vm.dmem[vm.dptr] == 0 goto label_pair.0
296                asm.cmp(
297                    Mem8::indirect_base_index(dmem_base, dmem_idx),
298                    Imm8::from(0u8),
299                );
300                asm.jz(&mut label_pair.0);
301
302                // Bind label_pair.1 after the jump instruction, which will be
303                // the branch target for the matching ']'.
304                asm.bind(&mut label_pair.1);
305            }
306            ']' => {
307                let mut label_pair = label_stack
308                    .pop()
309                    .expect("encountered un-balanced brackets, found ']' without matching '['");
310
311                // Goto label_pair.1 if data memory at active cell is not 0.
312                //   if vm.dmem[vm.dptr] != 0 goto label_pair.1
313                asm.cmp(
314                    Mem8::indirect_base_index(dmem_base, dmem_idx),
315                    Imm8::from(0u8),
316                );
317                asm.jnz(&mut label_pair.1);
318
319                // Bind label_pair.0 after the jump instruction, which is the
320                // branch target for the matching '['.
321                asm.bind(&mut label_pair.0);
322            }
323            _ => unreachable!(),
324        }
325
326        // Increment pc to next instruction.
327        pc += 1;
328    }
329
330    let mut epilogue = Label::new();
331
332    // Successful return from bf program.
333    asm.xor(Reg64::rax, Reg64::rax);
334    asm.bind(&mut epilogue);
335    // Restore callee saved registers before returning from jit.
336    asm.pop(dmem_idx);
337    asm.pop(dmem_size);
338    asm.pop(dmem_base);
339    asm.ret();
340
341    // Return because of data pointer overflow.
342    asm.bind(&mut oob_ov);
343    asm.mov(Reg64::rax, Imm64::from(1));
344    asm.jmp(&mut epilogue);
345
346    // Return because of data pointer underflow.
347    asm.bind(&mut oob_uv);
348    asm.mov(Reg64::rax, Imm64::from(2));
349    asm.jmp(&mut epilogue);
350
351    if !label_stack.is_empty() {
352        panic!("encountered un-balanced brackets, left-over '[' after jitting bf program")
353    }
354
355    // Get function pointer to jitted bf program.
356    let mut rt = Runtime::new();
357    let bf_entry = unsafe { rt.add_code::<extern "C" fn(*mut u8) -> u64>(asm.into_code()) };
358
359    // Execute jitted bf program.
360    match bf_entry(&mut vm.dmem as *mut u8) {
361        0 => { /* success */ }
362        1 => panic!("oob: data pointer overflow"),
363        2 => panic!("oob: data pointer underflow"),
364        _ => unreachable!(),
365    }
366}
367
368// -- MAIN ---------------------------------------------------------------------
369
370fn main() {
371    // https://en.wikipedia.org/wiki/Brainfuck#Hello_World!
372    let inp = "++++++++[>++++[>++>+++>+++>+<<<<-]>+>+>->>+[<]<-]>>.>---.+++++++..+++.>>.<-.<.+++.------.--------.>>+.>++.";
373    println!("hello-world (wikipedia.org) - interp");
374    run_interp(inp);
375    println!("hello-world (wikipedia.org) - jit");
376    run_jit(inp);
377
378    // https://programmingwiki.de/Brainfuck
379    let inp = ">+++++++++[<++++++++>-]<.>+++++++[<++++>-]<+.+++++++..+++.[-]>++++++++[<++++>-] <.>+++++++++++[<++++++++>-]<-.--------.+++.------.--------.[-]>++++++++[<++++>- ]<+.[-]++++++++++.";
380    println!("hello-world (programmingwiki.de) - interp");
381    run_interp(inp);
382    println!("hello-world (programmingwiki.de) - jit");
383    run_jit(inp);
384}
385
386#[cfg(test)]
387mod test {
388    use super::*;
389
390    #[test]
391    fn data_ptr_no_overflow() {
392        let inp = std::iter::repeat('>').take(255).collect::<String>();
393        run_jit(&inp);
394    }
395
396    #[test]
397    #[should_panic]
398    fn data_ptr_overflow() {
399        let inp = std::iter::repeat('>').take(255 + 1).collect::<String>();
400        run_jit(&inp);
401    }
402
403    #[test]
404    fn data_ptr_no_underflow() {
405        let inp = ">><< ><";
406        run_jit(inp);
407    }
408
409    #[test]
410    #[should_panic]
411    fn data_ptr_underflow() {
412        let inp = ">><< >< <";
413        run_jit(&inp);
414    }
415}