From 0be2e644e46231a2657aa693e5ea86e82eee6f34 Mon Sep 17 00:00:00 2001 From: Devin Bidwell Date: Fri, 12 Dec 2025 17:23:04 -0700 Subject: [PATCH] WIP optimization code --- rust_compiler/Cargo.lock | 7 + rust_compiler/Cargo.toml | 10 +- rust_compiler/libs/compiler/src/v1.rs | 5 + rust_compiler/libs/optimizer/Cargo.toml | 4 + .../libs/optimizer/src/leaf_function.rs | 44 + rust_compiler/libs/optimizer/src/lib.rs | 818 +++++++++++++++++- rust_compiler/src/main.rs | 16 +- 7 files changed, 889 insertions(+), 15 deletions(-) create mode 100644 rust_compiler/libs/optimizer/src/leaf_function.rs diff --git a/rust_compiler/Cargo.lock b/rust_compiler/Cargo.lock index 30b4270..20fd557 100644 --- a/rust_compiler/Cargo.lock +++ b/rust_compiler/Cargo.lock @@ -576,6 +576,12 @@ checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" [[package]] name = "optimizer" version = "0.1.0" +dependencies = [ + "anyhow", + "helpers", + "il", + "rust_decimal", +] [[package]] name = "parser" @@ -931,6 +937,7 @@ dependencies = [ "compiler", "helpers", "lsp-types", + "optimizer", "parser", "rust_decimal", "safer-ffi", diff --git a/rust_compiler/Cargo.toml b/rust_compiler/Cargo.toml index fdf339b..2bf6da4 100644 --- a/rust_compiler/Cargo.toml +++ b/rust_compiler/Cargo.toml @@ -9,9 +9,10 @@ members = ["libs/*"] [workspace.dependencies] thiserror = "2" rust_decimal = "1" -safer-ffi = { version = "0.1" } # Safely share structs in memory between C# and Rust -lsp-types = { version = "0.97" } # Allows for LSP style reporting to the frontend -crc32fast = "1.5" # This is for `HASH(..)` calls to be optimized away +safer-ffi = { version = "0.1" } # Safely share structs in memory between C# and Rust +lsp-types = { version = "0.97" } # Allows for LSP style reporting to the frontend +crc32fast = "1.5" # This is for `HASH(..)` calls to be optimized away +anyhow = { version = "^1.0", features = ["backtrace"] } [features] headers = ["safer-ffi/headers"] @@ -42,7 +43,8 @@ tokenizer = { path = "libs/tokenizer" } parser = { path = "libs/parser" } compiler = { path = "libs/compiler" } helpers = { path = "libs/helpers" } +optimizer = { path = "libs/optimizer" } safer-ffi = { workspace = true } -anyhow = { version = "^1.0", features = ["backtrace"] } +anyhow = { workspace = true } [dev-dependencies] diff --git a/rust_compiler/libs/compiler/src/v1.rs b/rust_compiler/libs/compiler/src/v1.rs index 5ed8986..c54fac0 100644 --- a/rust_compiler/libs/compiler/src/v1.rs +++ b/rust_compiler/libs/compiler/src/v1.rs @@ -140,6 +140,7 @@ struct CompileLocation<'a> { pub struct CompilationResult<'a> { pub errors: Vec>, pub source_map: HashMap>, + pub instructions: Vec>, } pub struct Compiler<'a, 'w, W: std::io::Write> { @@ -206,6 +207,7 @@ impl<'a, 'w, W: std::io::Write> Compiler<'a, 'w, W> { return CompilationResult { source_map: self.source_map, errors: self.errors, + instructions: self.instructions, }; } Err(e) => { @@ -214,6 +216,7 @@ impl<'a, 'w, W: std::io::Write> Compiler<'a, 'w, W> { return CompilationResult { errors: self.errors, source_map: self.source_map, + instructions: self.instructions, }; } }; @@ -239,6 +242,7 @@ impl<'a, 'w, W: std::io::Write> Compiler<'a, 'w, W> { return CompilationResult { errors: self.errors, source_map: self.source_map, + instructions: self.instructions, }; } @@ -252,6 +256,7 @@ impl<'a, 'w, W: std::io::Write> Compiler<'a, 'w, W> { CompilationResult { errors: self.errors, source_map: self.source_map, + instructions: self.instructions, } } diff --git a/rust_compiler/libs/optimizer/Cargo.toml b/rust_compiler/libs/optimizer/Cargo.toml index 29f308d..8194092 100644 --- a/rust_compiler/libs/optimizer/Cargo.toml +++ b/rust_compiler/libs/optimizer/Cargo.toml @@ -4,3 +4,7 @@ version = "0.1.0" edition = "2024" [dependencies] +il = { path = "../il" } +helpers = { path = "../helpers" } +rust_decimal = { workspace = true } +anyhow = { workspace = true } diff --git a/rust_compiler/libs/optimizer/src/leaf_function.rs b/rust_compiler/libs/optimizer/src/leaf_function.rs new file mode 100644 index 0000000..f8783f1 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/leaf_function.rs @@ -0,0 +1,44 @@ +use il::{Instruction, InstructionNode}; +use std::collections::HashSet; + +/// Scans the instruction set to identify "leaf functions". +/// A leaf function is defined as a function (delimited by LabelDefs) that does not +/// contain any `jal` (JumpAndLink) instructions. +/// +/// Returns a Set containing the names of all identified leaf functions. +pub fn find_leaf_functions(instructions: &[InstructionNode]) -> HashSet { + let mut leaf_functions = HashSet::new(); + let mut current_label: Option = None; + let mut is_current_leaf = true; + + for node in instructions { + match &node.instruction { + Instruction::LabelDef(label) => { + // If we were tracking a function, and it remained a leaf until now, save it. + if let Some(name) = current_label.take() + && is_current_leaf + { + leaf_functions.insert(name); + } + + // Start tracking the new function + current_label = Some(label.to_string()); + is_current_leaf = true; + } + Instruction::JumpAndLink(_) => { + // If we see a JAL, this function is NOT a leaf. + is_current_leaf = false; + } + _ => {} + } + } + + // Handle the final function in the file + if let Some(name) = current_label + && is_current_leaf + { + leaf_functions.insert(name); + } + + leaf_functions +} diff --git a/rust_compiler/libs/optimizer/src/lib.rs b/rust_compiler/libs/optimizer/src/lib.rs index b93cf3f..88a016d 100644 --- a/rust_compiler/libs/optimizer/src/lib.rs +++ b/rust_compiler/libs/optimizer/src/lib.rs @@ -1,14 +1,814 @@ -pub fn add(left: u64, right: u64) -> u64 { - left + right +use il::{Instruction, InstructionNode, Operand}; +use rust_decimal::Decimal; +use std::collections::{HashMap, HashSet}; + +mod leaf_function; +use leaf_function::find_leaf_functions; + +/// Entry point for the optimizer. +pub fn optimize<'a>(mut instructions: Vec>) -> Vec> { + let mut changed = true; + let mut pass_count = 0; + const MAX_PASSES: usize = 10; + + // Iterative passes for code simplification + while changed && pass_count < MAX_PASSES { + changed = false; + pass_count += 1; + + // Pass 1: Constant Propagation + let (new_inst, c1) = constant_propagation(instructions); + instructions = new_inst; + changed |= c1; + + // Pass 2: Register Forwarding (Intermediate Move Elimination) + let (new_inst, c2) = register_forwarding(instructions); + instructions = new_inst; + changed |= c2; + + // Pass 3: Function Call Optimization (Remove unused push/pop around calls) + let (new_inst, c3) = optimize_function_calls(instructions); + instructions = new_inst; + changed |= c3; + + // Pass 4: Leaf Function Optimization (Remove RA save/restore for leaf functions) + // This is separate from pass 3 as it deals with the function *definition*, not the call site. + let (new_inst, c4) = optimize_leaf_functions(instructions); + instructions = new_inst; + changed |= c4; + + // Pass 5: Redundant Move Elimination + let (new_inst, c5) = remove_redundant_moves(instructions); + instructions = new_inst; + changed |= c5; + + // Pass 6: Dead Code Elimination + let (new_inst, c6) = remove_unreachable_code(instructions); + instructions = new_inst; + changed |= c6; + } + + // Final Pass: Resolve Labels to Line Numbers + resolve_labels(instructions) } -#[cfg(test)] -mod tests { - use super::*; +/// Pass: Leaf Function Optimization +/// If a function makes no calls (is a leaf), it doesn't need to save/restore `ra`. +fn optimize_leaf_functions<'a>( + mut input: Vec>, +) -> (Vec>, bool) { + let leaves = find_leaf_functions(&input); + if leaves.is_empty() { + return (input, false); + } - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); + let mut changed = false; + let mut to_remove = HashSet::new(); + let mut current_function: Option = None; + + // Map of FunctionName -> The stack offset where RA was stored. + // We need this to adjust other stack accesses (arguments vs locals). + let mut func_ra_offsets = HashMap::new(); + + // First scan: Identify instructions to remove and capture RA offsets + for (i, node) in input.iter().enumerate() { + match &node.instruction { + Instruction::LabelDef(label) => { + current_function = Some(label.to_string()); + } + Instruction::Push(Operand::ReturnAddress) => { + if let Some(func) = ¤t_function { + if leaves.contains(func) { + to_remove.insert(i); + changed = true; + } + } + } + Instruction::Get(Operand::ReturnAddress, _, Operand::Register(_)) => { + // This is the restore instruction: `get ra db r0` + if let Some(func) = ¤t_function { + if leaves.contains(func) { + to_remove.insert(i); + // Look back for the address calc: `sub r0 sp OFFSET` + if i > 0 { + if let Instruction::Sub(_, Operand::StackPointer, Operand::Number(n)) = + &input[i - 1].instruction + { + func_ra_offsets.insert(func.clone(), *n); + to_remove.insert(i - 1); + } + } + } + } + } + _ => {} + } + } + + if !changed { + return (input, false); + } + + // Second scan: Rebuild with adjustments + let mut output = Vec::with_capacity(input.len()); + let mut processing_function: Option = None; + + for (i, mut node) in input.into_iter().enumerate() { + if to_remove.contains(&i) { + continue; + } + + if let Instruction::LabelDef(l) = &node.instruction { + processing_function = Some(l.to_string()); + } + + // Apply Stack Adjustments if we are inside a leaf function that we optimized + if let Some(func) = &processing_function { + if let Some(ra_offset) = func_ra_offsets.get(func) { + // If this is the stack cleanup `sub sp sp N`, decrement N by 1 (since we removed push ra) + if let Instruction::Sub( + Operand::StackPointer, + Operand::StackPointer, + Operand::Number(n), + ) = &mut node.instruction + { + let new_n = *n - Decimal::from(1); + if new_n.is_zero() { + continue; // Remove instruction if 0 + } + *n = new_n; + } + + // Adjust stack variable accesses relative to the removed RA. + // Compiler layout: [Args] [RA] [Locals/Temps] + // Stack grows up (increment sp on push). + // Access is `sp - offset`. + // Deeper items (Args) have LARGER offsets than RA. + // Shallower items (Locals) have SMALLER offsets than RA. + // Since RA is gone, items deeper than RA (Args) effectively shift "down" (index - 1). + if let Instruction::Sub(_, Operand::StackPointer, Operand::Number(n)) = + &mut node.instruction + { + if *n > *ra_offset { + *n -= Decimal::from(1); + } + } + } + } + + output.push(node); + } + + (output, true) +} + +/// Analyzes which registers are written to by each function label. +fn analyze_clobbers(instructions: &[InstructionNode]) -> HashMap> { + let mut clobbers = HashMap::new(); + let mut current_label = None; + + for node in instructions { + if let Instruction::LabelDef(label) = &node.instruction { + current_label = Some(label.to_string()); + clobbers.insert(label.to_string(), HashSet::new()); + } + + if let Some(label) = ¤t_label { + if let Some(reg) = get_destination_reg(&node.instruction) { + if let Some(set) = clobbers.get_mut(label) { + set.insert(reg); + } + } + + // Note: If we call another function, we technically clobber whatever IT clobbers + // (unless we save/restore it, which counts as a write anyway). + // This simple pass relies on the fact that any register modification (including restore) is a 'write'. + } + } + clobbers +} + +/// Pass: Function Call Optimization +/// Removes Push/Restore pairs surrounding a JAL if the target function does not clobber that register. +fn optimize_function_calls<'a>( + mut input: Vec>, +) -> (Vec>, bool) { + let clobbers = analyze_clobbers(&input); + let mut changed = false; + let mut to_remove = HashSet::new(); + let mut stack_adjustments = HashMap::new(); // Index of `sub sp sp N` -> amount to subtract + + let mut i = 0; + while i < input.len() { + if let Instruction::JumpAndLink(Operand::Label(target)) = &input[i].instruction { + let target_key = target.to_string(); + // If we don't have info on the function (e.g. extern or complex), skip + if let Some(func_clobbers) = clobbers.get(&target_key) { + // 1. Identify Pushes immediately preceding the JAL + let mut pushes = Vec::new(); // (index, register) + let mut scan_back = i.saturating_sub(1); + while scan_back > 0 { + if to_remove.contains(&scan_back) { + scan_back -= 1; + continue; + } + if let Instruction::Push(Operand::Register(r)) = &input[scan_back].instruction { + pushes.push((scan_back, *r)); + scan_back -= 1; + } else { + break; + } + } + + // 2. Identify Restores immediately following the JAL + // Compiler emits: sub r0 sp Offset, get Reg db r0. + let mut restores = Vec::new(); // (index_of_get, register, index_of_sub) + let mut scan_fwd = i + 1; + while scan_fwd < input.len() { + // Skip over the 'sub r0 sp X' address calculation lines + if let Instruction::Sub(Operand::Register(0), Operand::StackPointer, _) = + &input[scan_fwd].instruction + { + // Check next instruction for the Get + if scan_fwd + 1 < input.len() { + if let Instruction::Get(Operand::Register(r), _, Operand::Register(0)) = + &input[scan_fwd + 1].instruction + { + restores.push((scan_fwd + 1, *r, scan_fwd)); + scan_fwd += 2; + continue; + } + } + } + break; + } + + // 3. Check for Stack Cleanup `sub sp sp N` + let cleanup_idx = scan_fwd; + + let has_cleanup = if cleanup_idx < input.len() { + if let Instruction::Sub( + Operand::StackPointer, + Operand::StackPointer, + Operand::Number(_), + ) = &input[cleanup_idx].instruction + { + true + } else { + false + } + } else { + false + }; + + // "All or Nothing" strategy for the safe subset: + let all_pushes_safe = pushes.iter().all(|(_, r)| !func_clobbers.contains(r)); + + let push_set: HashSet = pushes.iter().map(|(_, r)| *r).collect(); + let restore_set: HashSet = restores.iter().map(|(_, r, _)| *r).collect(); + + if all_pushes_safe && has_cleanup && push_set == restore_set { + // We can remove ALL saves/restores for this call! + for (p_idx, _) in pushes { + to_remove.insert(p_idx); + } + for (g_idx, _, s_idx) in restores { + to_remove.insert(g_idx); + to_remove.insert(s_idx); + } + + // Reduce stack cleanup amount + let num_removed = push_set.len() as i64; + stack_adjustments.insert(cleanup_idx, num_removed); + changed = true; + } + } + } + i += 1; + } + + if changed { + let mut clean = Vec::with_capacity(input.len()); + for (idx, mut node) in input.into_iter().enumerate() { + if to_remove.contains(&idx) { + continue; + } + + // Apply stack adjustment + if let Some(reduction) = stack_adjustments.get(&idx) { + if let Instruction::Sub(dst, a, Operand::Number(n)) = &node.instruction { + let new_n = n - Decimal::from(*reduction); + if new_n.is_zero() { + continue; // Remove the sub entirely if 0 + } + node.instruction = + Instruction::Sub(dst.clone(), a.clone(), Operand::Number(new_n)); + } + } + + clean.push(node); + } + return (clean, changed); + } + + (input, false) +} + +/// Pass: Register Forwarding +/// Eliminates intermediate moves by writing directly to the final destination. +/// Example: `l r1 d0 T` + `move r9 r1` -> `l r9 d0 T` +fn register_forwarding<'a>( + mut input: Vec>, +) -> (Vec>, bool) { + let mut changed = false; + let mut i = 0; + + // We use a while loop to manually control index so we can peek ahead + while i < input.len().saturating_sub(1) { + let next_idx = i + 1; + + // Check if current instruction defines a register + // and the NEXT instruction is a move from that register. + let forward_candidate = if let Some(def_reg) = get_destination_reg(&input[i].instruction) { + if let Instruction::Move(Operand::Register(dest_reg), Operand::Register(src_reg)) = + &input[next_idx].instruction + { + if *src_reg == def_reg { + // Candidate found: Instruction `i` defines `src_reg`, Instruction `i+1` moves `src_reg` to `dest_reg`. + // We can optimize if `src_reg` (the temp) is NOT used after this move. + Some((def_reg, *dest_reg)) + } else { + None + } + } else { + None + } + } else { + None + }; + + if let Some((temp_reg, final_reg)) = forward_candidate { + // Check liveness: Is temp_reg used after i+1? + // We scan from i+2 onwards. + let mut temp_is_dead = true; + for node in input.iter().skip(i + 2) { + if reg_is_read(&node.instruction, temp_reg) { + temp_is_dead = false; + break; + } + // If the temp is redefined, then the old value is dead, so we are safe. + if let Some(redef) = get_destination_reg(&node.instruction) { + if redef == temp_reg { + break; + } + } + // If we hit a label/jump, we assume liveness might leak (conservative safety) + if matches!( + node.instruction, + Instruction::LabelDef(_) | Instruction::Jump(_) | Instruction::JumpAndLink(_) + ) { + temp_is_dead = false; + break; + } + } + + if temp_is_dead { + // Perform the swap + // 1. Rewrite input[i] to write to final_reg + if let Some(new_instr) = set_destination_reg(&input[i].instruction, final_reg) { + input[i].instruction = new_instr; + // 2. Remove input[i+1] (The Move) + input.remove(next_idx); + changed = true; + // Don't increment i, re-evaluate current index (which is now a new neighbor) + continue; + } + } + } + + i += 1; + } + + (input, changed) +} + +/// Pass: Resolve Labels +/// Converts all Jump/Branch labels to absolute line numbers and removes LabelDefs. +fn resolve_labels<'a>(input: Vec>) -> Vec> { + let mut label_map: HashMap = HashMap::new(); + let mut line_number = 0; + + // 1. Build Label Map (filtering out LabelDefs from the count) + for node in &input { + if let Instruction::LabelDef(name) = &node.instruction { + label_map.insert(name.to_string(), line_number); + } else { + line_number += 1; + } + } + + let mut output = Vec::with_capacity(input.len()); + + // 2. Rewrite Jumps and Filter Labels + for mut node in input { + // Helper to get line number as Decimal operand + let get_line = |lbl: &Operand| -> Option> { + if let Operand::Label(name) = lbl { + label_map + .get(name.as_ref()) + .map(|&l| Operand::Number(Decimal::from(l))) + } else { + None + } + }; + + match &mut node.instruction { + Instruction::LabelDef(_) => continue, // Strip labels + + // Jumps + Instruction::Jump(op) => { + if let Some(num) = get_line(op) { + *op = num; + } + } + Instruction::JumpAndLink(op) => { + if let Some(num) = get_line(op) { + *op = num; + } + } + Instruction::BranchEq(a, b, op) + | Instruction::BranchNe(a, b, op) + | Instruction::BranchGt(a, b, op) + | Instruction::BranchLt(a, b, op) + | Instruction::BranchGe(a, b, op) + | Instruction::BranchLe(a, b, op) => { + if let Some(num) = get_line(op) { + *op = num; + } + } + Instruction::BranchEqZero(a, op) | Instruction::BranchNeZero(a, op) => { + if let Some(num) = get_line(op) { + *op = num; + } + } + _ => {} + } + output.push(node); + } + + output +} + +// --- Helpers for Register Analysis --- + +fn get_destination_reg(instr: &Instruction) -> Option { + match instr { + Instruction::Move(Operand::Register(r), _) + | Instruction::Add(Operand::Register(r), _, _) + | Instruction::Sub(Operand::Register(r), _, _) + | Instruction::Mul(Operand::Register(r), _, _) + | Instruction::Div(Operand::Register(r), _, _) + | Instruction::Mod(Operand::Register(r), _, _) + | Instruction::Pow(Operand::Register(r), _, _) + | Instruction::Load(Operand::Register(r), _, _) + | Instruction::LoadSlot(Operand::Register(r), _, _, _) + | Instruction::LoadBatch(Operand::Register(r), _, _, _) + | Instruction::LoadBatchNamed(Operand::Register(r), _, _, _, _) + | Instruction::SetEq(Operand::Register(r), _, _) + | Instruction::SetNe(Operand::Register(r), _, _) + | Instruction::SetGt(Operand::Register(r), _, _) + | Instruction::SetLt(Operand::Register(r), _, _) + | Instruction::SetGe(Operand::Register(r), _, _) + | Instruction::SetLe(Operand::Register(r), _, _) + | Instruction::And(Operand::Register(r), _, _) + | Instruction::Or(Operand::Register(r), _, _) + | Instruction::Xor(Operand::Register(r), _, _) + | Instruction::Peek(Operand::Register(r)) + | Instruction::Get(Operand::Register(r), _, _) + | Instruction::Select(Operand::Register(r), _, _, _) + | Instruction::Rand(Operand::Register(r)) + | Instruction::Acos(Operand::Register(r), _) + | Instruction::Asin(Operand::Register(r), _) + | Instruction::Atan(Operand::Register(r), _) + | Instruction::Atan2(Operand::Register(r), _, _) + | Instruction::Abs(Operand::Register(r), _) + | Instruction::Ceil(Operand::Register(r), _) + | Instruction::Cos(Operand::Register(r), _) + | Instruction::Floor(Operand::Register(r), _) + | Instruction::Log(Operand::Register(r), _) + | Instruction::Max(Operand::Register(r), _, _) + | Instruction::Min(Operand::Register(r), _, _) + | Instruction::Sin(Operand::Register(r), _) + | Instruction::Sqrt(Operand::Register(r), _) + | Instruction::Tan(Operand::Register(r), _) + | Instruction::Trunc(Operand::Register(r), _) + | Instruction::Pop(Operand::Register(r)) => Some(*r), + _ => None, } } + +fn set_destination_reg<'a>(instr: &Instruction<'a>, new_reg: u8) -> Option> { + // Helper to easily recreate instruction with new dest + let r = Operand::Register(new_reg); + match instr { + Instruction::Move(_, b) => Some(Instruction::Move(r, b.clone())), + Instruction::Add(_, a, b) => Some(Instruction::Add(r, a.clone(), b.clone())), + Instruction::Sub(_, a, b) => Some(Instruction::Sub(r, a.clone(), b.clone())), + Instruction::Mul(_, a, b) => Some(Instruction::Mul(r, a.clone(), b.clone())), + Instruction::Div(_, a, b) => Some(Instruction::Div(r, a.clone(), b.clone())), + Instruction::Mod(_, a, b) => Some(Instruction::Mod(r, a.clone(), b.clone())), + Instruction::Pow(_, a, b) => Some(Instruction::Pow(r, a.clone(), b.clone())), + Instruction::Load(_, a, b) => Some(Instruction::Load(r, a.clone(), b.clone())), + Instruction::LoadSlot(_, a, b, c) => { + Some(Instruction::LoadSlot(r, a.clone(), b.clone(), c.clone())) + } + Instruction::LoadBatch(_, a, b, c) => { + Some(Instruction::LoadBatch(r, a.clone(), b.clone(), c.clone())) + } + Instruction::LoadBatchNamed(_, a, b, c, d) => Some(Instruction::LoadBatchNamed( + r, + a.clone(), + b.clone(), + c.clone(), + d.clone(), + )), + Instruction::SetEq(_, a, b) => Some(Instruction::SetEq(r, a.clone(), b.clone())), + Instruction::SetNe(_, a, b) => Some(Instruction::SetNe(r, a.clone(), b.clone())), + Instruction::SetGt(_, a, b) => Some(Instruction::SetGt(r, a.clone(), b.clone())), + Instruction::SetLt(_, a, b) => Some(Instruction::SetLt(r, a.clone(), b.clone())), + Instruction::SetGe(_, a, b) => Some(Instruction::SetGe(r, a.clone(), b.clone())), + Instruction::SetLe(_, a, b) => Some(Instruction::SetLe(r, a.clone(), b.clone())), + Instruction::And(_, a, b) => Some(Instruction::And(r, a.clone(), b.clone())), + Instruction::Or(_, a, b) => Some(Instruction::Or(r, a.clone(), b.clone())), + Instruction::Xor(_, a, b) => Some(Instruction::Xor(r, a.clone(), b.clone())), + Instruction::Peek(_) => Some(Instruction::Peek(r)), + Instruction::Get(_, a, b) => Some(Instruction::Get(r, a.clone(), b.clone())), + Instruction::Select(_, a, b, c) => { + Some(Instruction::Select(r, a.clone(), b.clone(), c.clone())) + } + Instruction::Rand(_) => Some(Instruction::Rand(r)), + Instruction::Pop(_) => Some(Instruction::Pop(r)), + + // Math funcs + Instruction::Acos(_, a) => Some(Instruction::Acos(r, a.clone())), + Instruction::Asin(_, a) => Some(Instruction::Asin(r, a.clone())), + Instruction::Atan(_, a) => Some(Instruction::Atan(r, a.clone())), + Instruction::Atan2(_, a, b) => Some(Instruction::Atan2(r, a.clone(), b.clone())), + Instruction::Abs(_, a) => Some(Instruction::Abs(r, a.clone())), + Instruction::Ceil(_, a) => Some(Instruction::Ceil(r, a.clone())), + Instruction::Cos(_, a) => Some(Instruction::Cos(r, a.clone())), + Instruction::Floor(_, a) => Some(Instruction::Floor(r, a.clone())), + Instruction::Log(_, a) => Some(Instruction::Log(r, a.clone())), + Instruction::Max(_, a, b) => Some(Instruction::Max(r, a.clone(), b.clone())), + Instruction::Min(_, a, b) => Some(Instruction::Min(r, a.clone(), b.clone())), + Instruction::Sin(_, a) => Some(Instruction::Sin(r, a.clone())), + Instruction::Sqrt(_, a) => Some(Instruction::Sqrt(r, a.clone())), + Instruction::Tan(_, a) => Some(Instruction::Tan(r, a.clone())), + Instruction::Trunc(_, a) => Some(Instruction::Trunc(r, a.clone())), + + _ => None, + } +} + +fn reg_is_read(instr: &Instruction, reg: u8) -> bool { + let check = |op: &Operand| matches!(op, Operand::Register(r) if *r == reg); + + match instr { + Instruction::Move(_, a) => check(a), + Instruction::Add(_, a, b) + | Instruction::Sub(_, a, b) + | Instruction::Mul(_, a, b) + | Instruction::Div(_, a, b) + | Instruction::Mod(_, a, b) + | Instruction::Pow(_, a, b) => check(a) || check(b), + + Instruction::Load(_, a, _) => check(a), // Load reads device? Device can be reg? Yes. + Instruction::Store(a, _, b) => check(a) || check(b), + + Instruction::BranchEq(a, b, _) + | Instruction::BranchNe(a, b, _) + | Instruction::BranchGt(a, b, _) + | Instruction::BranchLt(a, b, _) + | Instruction::BranchGe(a, b, _) + | Instruction::BranchLe(a, b, _) => check(a) || check(b), + + Instruction::BranchEqZero(a, _) | Instruction::BranchNeZero(a, _) => check(a), + + Instruction::SetEq(_, a, b) + | Instruction::SetNe(_, a, b) + | Instruction::SetGt(_, a, b) + | Instruction::SetLt(_, a, b) + | Instruction::SetGe(_, a, b) + | Instruction::SetLe(_, a, b) + | Instruction::And(_, a, b) + | Instruction::Or(_, a, b) + | Instruction::Xor(_, a, b) => check(a) || check(b), + + Instruction::Push(a) => check(a), + Instruction::Get(_, a, b) => check(a) || check(b), + Instruction::Put(a, b, c) => check(a) || check(b) || check(c), + + Instruction::Select(_, a, b, c) => check(a) || check(b) || check(c), + Instruction::Sleep(a) => check(a), + + // Math single arg + Instruction::Acos(_, a) + | Instruction::Asin(_, a) + | Instruction::Atan(_, a) + | Instruction::Abs(_, a) + | Instruction::Ceil(_, a) + | Instruction::Cos(_, a) + | Instruction::Floor(_, a) + | Instruction::Log(_, a) + | Instruction::Sin(_, a) + | Instruction::Sqrt(_, a) + | Instruction::Tan(_, a) + | Instruction::Trunc(_, a) => check(a), + + // Math double arg + Instruction::Atan2(_, a, b) | Instruction::Max(_, a, b) | Instruction::Min(_, a, b) => { + check(a) || check(b) + } + + _ => false, + } +} + +// --- Constant Propagation & Dead Code (Same as before) --- + +fn constant_propagation<'a>(input: Vec>) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + let mut registers: [Option; 16] = [None; 16]; + + for mut node in input { + match &node.instruction { + Instruction::LabelDef(_) | Instruction::JumpAndLink(_) => registers = [None; 16], + _ => {} + } + + let simplified = match &node.instruction { + Instruction::Move(dst, src) => { + if let Some(val) = resolve_value(src, ®isters) { + Some(Instruction::Move(dst.clone(), Operand::Number(val))) + } else { + None + } + } + Instruction::Add(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x + y), + Instruction::Sub(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x - y), + Instruction::Mul(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x * y), + Instruction::Div(dst, a, b) => { + try_fold_math( + dst, + a, + b, + ®isters, + |x, y| if y.is_zero() { x } else { x / y }, + ) + } + Instruction::Mod(dst, a, b) => { + try_fold_math( + dst, + a, + b, + ®isters, + |x, y| if y.is_zero() { x } else { x % y }, + ) + } + Instruction::BranchEq(a, b, l) => { + try_resolve_branch(a, b, l, ®isters, |x, y| x == y) + } + Instruction::BranchNe(a, b, l) => { + try_resolve_branch(a, b, l, ®isters, |x, y| x != y) + } + Instruction::BranchGt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x > y), + Instruction::BranchLt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x < y), + Instruction::BranchGe(a, b, l) => { + try_resolve_branch(a, b, l, ®isters, |x, y| x >= y) + } + Instruction::BranchLe(a, b, l) => { + try_resolve_branch(a, b, l, ®isters, |x, y| x <= y) + } + Instruction::BranchEqZero(a, l) => { + try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x == y) + } + Instruction::BranchNeZero(a, l) => { + try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x != y) + } + _ => None, + }; + + if let Some(new) = simplified { + node.instruction = new; + changed = true; + } + + // Update tracking + match &node.instruction { + Instruction::Move(Operand::Register(r), src) => { + registers[*r as usize] = resolve_value(src, ®isters) + } + // Invalidate if destination is register + _ => { + if let Some(r) = get_destination_reg(&node.instruction) { + registers[r as usize] = None; + } + } + } + + // Filter out NOPs (Empty LabelDefs from branch resolution) + if let Instruction::LabelDef(l) = &node.instruction { + if l.is_empty() { + changed = true; + continue; + } + } + + output.push(node); + } + (output, changed) +} + +fn resolve_value(op: &Operand, regs: &[Option; 16]) -> Option { + match op { + Operand::Number(n) => Some(*n), + Operand::Register(r) => regs[*r as usize], + _ => None, + } +} + +fn try_fold_math<'a, F>( + dst: &Operand<'a>, + a: &Operand<'a>, + b: &Operand<'a>, + regs: &[Option; 16], + op: F, +) -> Option> +where + F: Fn(Decimal, Decimal) -> Decimal, +{ + let val_a = resolve_value(a, regs)?; + let val_b = resolve_value(b, regs)?; + Some(Instruction::Move( + dst.clone(), + Operand::Number(op(val_a, val_b)), + )) +} + +fn try_resolve_branch<'a, F>( + a: &Operand<'a>, + b: &Operand<'a>, + label: &Operand<'a>, + regs: &[Option; 16], + check: F, +) -> Option> +where + F: Fn(Decimal, Decimal) -> bool, +{ + let val_a = resolve_value(a, regs)?; + let val_b = resolve_value(b, regs)?; + if check(val_a, val_b) { + Some(Instruction::Jump(label.clone())) + } else { + Some(Instruction::LabelDef("".into())) // NOP + } +} + +fn remove_redundant_moves<'a>(input: Vec>) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + for node in input { + if let Instruction::Move(dst, src) = &node.instruction { + if dst == src { + changed = true; + continue; + } + } + output.push(node); + } + (output, changed) +} + +fn remove_unreachable_code<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + let mut dead = false; + for node in input { + if let Instruction::LabelDef(_) = node.instruction { + dead = false; + } + if dead { + changed = true; + continue; + } + match node.instruction { + Instruction::Jump(_) | Instruction::Jump(Operand::ReturnAddress) => dead = true, + _ => {} + } + output.push(node); + } + (output, changed) +} diff --git a/rust_compiler/src/main.rs b/rust_compiler/src/main.rs index cdd7bc0..b6ddb9d 100644 --- a/rust_compiler/src/main.rs +++ b/rust_compiler/src/main.rs @@ -88,9 +88,15 @@ fn run_logic<'a>() -> Result<(), Error<'a>> { None => BufWriter::new(Box::new(std::io::stdout())), }; - let compiler = Compiler::new(parser, &mut writer, None); + let mut tmp = BufWriter::new(vec![]); - let CompilationResult { errors, .. } = compiler.compile(); + let compiler = Compiler::new(parser, &mut tmp, None); + + let CompilationResult { + errors, + instructions, + .. + } = compiler.compile(); if !errors.is_empty() { let mut std_error = stderr(); @@ -103,6 +109,12 @@ fn run_logic<'a>() -> Result<(), Error<'a>> { } } + let instructions = optimizer::optimize(instructions); + + for instruction in instructions { + writer.write_all(format!("{}\n", instruction.instruction).as_bytes())?; + } + writer.flush()?; Ok(())