1use std::collections::HashMap;
29use std::io::Write;
30
31use juicebox_asm::insn::*;
32use juicebox_asm::Runtime;
33use juicebox_asm::{Asm, Imm64, Imm8, Label, Mem8, Reg64, Reg8};
34
35struct BrainfuckInterp {
38 pc: usize,
39 imem: Vec<char>,
40 dptr: usize,
41 dmem: [u8; 256],
42 branches: HashMap<usize, usize>,
43}
44
45impl BrainfuckInterp {
46 fn new(prog: &str) -> Result<Self, String> {
47 let (imem, branches) = {
51 let mut imem = Vec::new();
53 let mut lhs_brackets = Vec::new();
55 let mut branches = HashMap::new();
57
58 for (idx, token) in prog.chars().filter(|c| !c.is_whitespace()).enumerate() {
59 match token {
60 '<' | '>' | '+' | '-' | '.' | ',' => { }
61 '[' => lhs_brackets.push(idx),
62 ']' => {
63 if let Some(lhs) = lhs_brackets.pop() {
64 branches.insert(lhs, idx);
65 branches.insert(idx, lhs);
66 } else {
67 return Err(format!("encountered un-balanced brackets, found ']' at index {idx} without matching '['"));
68 }
69 }
70 _ => return Err(format!("invalid bf token '{token}'")),
71 }
72 imem.push(token)
73 }
74
75 if !lhs_brackets.is_empty() {
76 return Err(String::from(
77 "encountered un-balanced brackets, left-over '[' after parsing bf program",
78 ));
79 }
80
81 (imem, branches)
82 };
83
84 Ok(BrainfuckInterp {
85 pc: 0,
86 imem,
87 dptr: 0,
88 dmem: [0; 256],
89 branches,
90 })
91 }
92}
93
94fn run_interp(prog: &str) {
95 let mut vm = BrainfuckInterp::new(prog).unwrap();
96
97 loop {
98 let insn = match vm.imem.get(vm.pc) {
99 Some(insn) => insn,
100 None => break, };
102
103 let putchar = |val: u8| {
104 std::io::stdout()
105 .write(&[val])
106 .expect("Failed to write to stdout!");
107 };
108
109 match insn {
110 '>' => {
111 vm.dptr += 1;
112 assert!(vm.dptr < vm.dmem.len());
113 }
114 '<' => {
115 assert!(vm.dptr > 0);
116 vm.dptr -= 1;
117 }
118 '+' => {
119 vm.dmem[vm.dptr] += 1;
120 }
121 '-' => {
122 vm.dmem[vm.dptr] -= 1;
123 }
124 '.' => {
125 putchar(vm.dmem[vm.dptr]);
126 }
127 ',' => {
128 unimplemented!("getchar");
129 }
130 '[' => {
131 if vm.dmem[vm.dptr] == 0 {
132 vm.pc = *vm.branches.get(&vm.pc).unwrap();
133 }
134 }
135 ']' => {
136 if vm.dmem[vm.dptr] != 0 {
137 vm.pc = *vm.branches.get(&vm.pc).unwrap();
138 }
139 }
140 _ => unreachable!(),
141 }
142
143 vm.pc += 1;
144 }
145}
146
147#[cfg(not(any(target_arch = "x86_64", target_os = "linux")))]
150compile_error!("Only supported on x86_64 with SystemV abi");
151
152struct BrainfuckJit {
153 imem: Vec<char>,
154 dmem: [u8; 256],
155}
156
157impl BrainfuckJit {
158 fn new(prog: &str) -> Result<Self, String> {
159 let imem = prog
162 .chars()
163 .filter(|c| !c.is_whitespace())
164 .map(|c| match c {
165 '<' | '>' | '+' | '-' | '.' | ',' | '[' | ']' => Ok(c),
166 _ => Err(format!("invalid bf token '{c}'")),
167 })
168 .collect::<Result<Vec<char>, String>>()?;
169
170 Ok(BrainfuckJit {
171 imem,
172 dmem: [0; 256],
173 })
174 }
175}
176
177extern "C" fn putchar(c: u8) {
178 std::io::stdout()
179 .write(&[c])
180 .expect("Failed to write to stdout!");
181}
182
183fn run_jit(prog: &str) {
184 let mut vm = BrainfuckJit::new(prog).unwrap();
185
186 let dmem_base = Reg64::rbx;
189 let dmem_size = Reg64::r12;
190 let dmem_idx = Reg64::r13;
191
192 let mut asm = Asm::new();
193
194 asm.push(dmem_base);
196 asm.push(dmem_size);
197 asm.push(dmem_idx);
198
199 asm.mov(dmem_base, Reg64::rdi);
201 asm.mov(dmem_size, Imm64::from(vm.dmem.len()));
203 asm.xor(dmem_idx, dmem_idx);
205
206 let mut label_stack = Vec::new();
209
210 let mut oob_ov = Label::new();
212 let mut oob_uv = Label::new();
214
215 let mut pc = 0;
217 while pc < vm.imem.len() {
218 match vm.imem[pc] {
219 '>' => {
220 asm.inc(dmem_idx);
221
222 asm.cmp(dmem_idx, dmem_size);
224 asm.jz(&mut oob_ov);
225 }
226 '<' => {
227 asm.test(dmem_idx, dmem_idx);
229 asm.jz(&mut oob_uv);
230
231 asm.dec(dmem_idx);
232 }
233 '+' => {
234 match vm.imem[pc..].iter().take_while(|&&i| i.eq(&'+')).count() {
238 1 => {
239 asm.inc(Mem8::indirect_base_index(dmem_base, dmem_idx));
240 }
241 cnt if cnt <= u8::MAX as usize => {
242 asm.add(
243 Mem8::indirect_base_index(dmem_base, dmem_idx),
244 Imm8::from(cnt as u8),
245 );
246
247 pc += cnt - 1;
250 }
251 cnt @ _ => unimplemented!("cnt={cnt} oob, add with larger imm"),
252 }
253 }
254 '-' => {
255 match vm.imem[pc..].iter().take_while(|&&i| i.eq(&'-')).count() {
259 1 => {
260 asm.dec(Mem8::indirect_base_index(dmem_base, dmem_idx));
261 }
262 cnt if cnt <= u8::MAX as usize => {
263 asm.sub(
264 Mem8::indirect_base_index(dmem_base, dmem_idx),
265 Imm8::from(cnt as u8),
266 );
267
268 pc += cnt - 1;
271 }
272 cnt @ _ => unimplemented!("cnt={cnt} oob, sub with larger imm"),
273 }
274 }
275 '.' => {
276 asm.mov(Reg8::dil, Mem8::indirect_base_index(dmem_base, dmem_idx));
282 asm.mov(Reg64::rax, Imm64::from(putchar as usize));
283 asm.call(Reg64::rax);
284 }
285 ',' => {
286 unimplemented!("getchar");
287 }
288 '[' => {
289 label_stack.push((Label::new(), Label::new()));
291 let label_pair = label_stack.last_mut().unwrap();
293
294 asm.cmp(
297 Mem8::indirect_base_index(dmem_base, dmem_idx),
298 Imm8::from(0u8),
299 );
300 asm.jz(&mut label_pair.0);
301
302 asm.bind(&mut label_pair.1);
305 }
306 ']' => {
307 let mut label_pair = label_stack
308 .pop()
309 .expect("encountered un-balanced brackets, found ']' without matching '['");
310
311 asm.cmp(
314 Mem8::indirect_base_index(dmem_base, dmem_idx),
315 Imm8::from(0u8),
316 );
317 asm.jnz(&mut label_pair.1);
318
319 asm.bind(&mut label_pair.0);
322 }
323 _ => unreachable!(),
324 }
325
326 pc += 1;
328 }
329
330 let mut epilogue = Label::new();
331
332 asm.xor(Reg64::rax, Reg64::rax);
334 asm.bind(&mut epilogue);
335 asm.pop(dmem_idx);
337 asm.pop(dmem_size);
338 asm.pop(dmem_base);
339 asm.ret();
340
341 asm.bind(&mut oob_ov);
343 asm.mov(Reg64::rax, Imm64::from(1));
344 asm.jmp(&mut epilogue);
345
346 asm.bind(&mut oob_uv);
348 asm.mov(Reg64::rax, Imm64::from(2));
349 asm.jmp(&mut epilogue);
350
351 if !label_stack.is_empty() {
352 panic!("encountered un-balanced brackets, left-over '[' after jitting bf program")
353 }
354
355 let mut rt = Runtime::new();
357 let bf_entry = unsafe { rt.add_code::<extern "C" fn(*mut u8) -> u64>(asm.into_code()) };
358
359 match bf_entry(&mut vm.dmem as *mut u8) {
361 0 => { }
362 1 => panic!("oob: data pointer overflow"),
363 2 => panic!("oob: data pointer underflow"),
364 _ => unreachable!(),
365 }
366}
367
368fn main() {
371 let inp = "++++++++[>++++[>++>+++>+++>+<<<<-]>+>+>->>+[<]<-]>>.>---.+++++++..+++.>>.<-.<.+++.------.--------.>>+.>++.";
373 println!("hello-world (wikipedia.org) - interp");
374 run_interp(inp);
375 println!("hello-world (wikipedia.org) - jit");
376 run_jit(inp);
377
378 let inp = ">+++++++++[<++++++++>-]<.>+++++++[<++++>-]<+.+++++++..+++.[-]>++++++++[<++++>-] <.>+++++++++++[<++++++++>-]<-.--------.+++.------.--------.[-]>++++++++[<++++>- ]<+.[-]++++++++++.";
380 println!("hello-world (programmingwiki.de) - interp");
381 run_interp(inp);
382 println!("hello-world (programmingwiki.de) - jit");
383 run_jit(inp);
384}
385
386#[cfg(test)]
387mod test {
388 use super::*;
389
390 #[test]
391 fn data_ptr_no_overflow() {
392 let inp = std::iter::repeat('>').take(255).collect::<String>();
393 run_jit(&inp);
394 }
395
396 #[test]
397 #[should_panic]
398 fn data_ptr_overflow() {
399 let inp = std::iter::repeat('>').take(255 + 1).collect::<String>();
400 run_jit(&inp);
401 }
402
403 #[test]
404 fn data_ptr_no_underflow() {
405 let inp = ">><< ><";
406 run_jit(inp);
407 }
408
409 #[test]
410 #[should_panic]
411 fn data_ptr_underflow() {
412 let inp = ">><< >< <";
413 run_jit(&inp);
414 }
415}