diff --git a/rust_compiler/Cargo.lock b/rust_compiler/Cargo.lock index 94fcc35..aed6d37 100644 --- a/rust_compiler/Cargo.lock +++ b/rust_compiler/Cargo.lock @@ -23,7 +23,7 @@ version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" dependencies = [ - "getrandom", + "getrandom 0.2.16", "once_cell", "version_check", ] @@ -73,7 +73,7 @@ version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -84,7 +84,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -135,6 +135,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + [[package]] name = "bitvec" version = "1.0.1" @@ -278,6 +284,18 @@ dependencies = [ "tokenizer", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "windows-sys 0.59.0", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -293,12 +311,28 @@ version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "equivalent" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + [[package]] name = "ext-trait" version = "1.0.1" @@ -334,13 +368,19 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320bea982e85d42441eb25c49b41218e7eaa2657e8f90bc4eca7437376751e23" +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "fluent-uri" version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17c704e9dbe1ddd863da1e6ff3567795087b1eb201ce80d8fa81162e1516500d" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -366,6 +406,18 @@ dependencies = [ "wasi", ] +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + [[package]] name = "gimli" version = "0.32.3" @@ -428,6 +480,32 @@ dependencies = [ "rustversion", ] +[[package]] +name = "insta" +version = "1.45.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "983e3b24350c84ab8a65151f537d67afbbf7153bb9f1110e03e9fa9b07f67a5c" +dependencies = [ + "console", + "once_cell", + "similar", + "tempfile", +] + +[[package]] +name = "integration_tests" +version = "0.1.0" +dependencies = [ + "anyhow", + "compiler", + "il", + "indoc", + "insta", + "optimizer", + "parser", + "tokenizer", +] + [[package]] name = "inventory" version = "0.3.21" @@ -465,6 +543,12 @@ version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "logos" version = "0.16.0" @@ -505,7 +589,7 @@ version = "0.97.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53353550a17c04ac46c585feb189c2db82154fc84b79c7a66c96c2c644f66071" dependencies = [ - "bitflags", + "bitflags 1.3.2", "fluent-uri", "serde", "serde_json", @@ -678,6 +762,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "radium" version = "0.7.0" @@ -711,7 +801,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.16", ] [[package]] @@ -800,6 +890,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags 2.10.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -928,6 +1031,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + [[package]] name = "slang" version = "0.5.0" @@ -1014,6 +1123,19 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tempfile" +version = "3.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "thiserror" version = "2.0.17" @@ -1140,6 +1262,15 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "wasm-bindgen" version = "0.2.106" @@ -1191,6 +1322,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.61.2" @@ -1200,6 +1340,70 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "winnow" version = "0.7.14" @@ -1209,6 +1413,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + [[package]] name = "with_builtin_macros" version = "0.0.3" diff --git a/rust_compiler/libs/il/src/lib.rs b/rust_compiler/libs/il/src/lib.rs index 1a3fdc3..54aadfb 100644 --- a/rust_compiler/libs/il/src/lib.rs +++ b/rust_compiler/libs/il/src/lib.rs @@ -61,6 +61,7 @@ impl<'a> std::fmt::Display for Instructions<'a> { } } +#[derive(Clone)] pub struct InstructionNode<'a> { pub instruction: Instruction<'a>, pub span: Option, diff --git a/rust_compiler/libs/integration_tests/.gitattributes b/rust_compiler/libs/integration_tests/.gitattributes new file mode 100644 index 0000000..d5f2352 --- /dev/null +++ b/rust_compiler/libs/integration_tests/.gitattributes @@ -0,0 +1,2 @@ +# Treat snapshot files as text +*.snap text diff --git a/rust_compiler/libs/integration_tests/Cargo.toml b/rust_compiler/libs/integration_tests/Cargo.toml new file mode 100644 index 0000000..bc05679 --- /dev/null +++ b/rust_compiler/libs/integration_tests/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "integration_tests" +version = "0.1.0" +edition = "2024" +publish = false + +[dependencies] +compiler = { path = "../compiler" } +parser = { path = "../parser" } +tokenizer = { path = "../tokenizer" } +optimizer = { path = "../optimizer" } +il = { path = "../il" } +anyhow = { workspace = true } +indoc = "2" +insta = "1.40" + +[lib] +# This is a test-only crate +path = "src/lib.rs" diff --git a/rust_compiler/libs/integration_tests/README.md b/rust_compiler/libs/integration_tests/README.md new file mode 100644 index 0000000..00dce8a --- /dev/null +++ b/rust_compiler/libs/integration_tests/README.md @@ -0,0 +1,92 @@ +# Integration Tests for Slang Compiler with Optimizer + +This crate contains end-to-end integration tests for the Slang compiler that verify the complete compilation pipeline including all optimization passes. + +## Snapshot Testing with Insta + +These tests use [insta](https://insta.rs/) for snapshot testing, which captures the entire compiled output and stores it in snapshot files for comparison. + +### Running Tests + +```bash +# Run all integration tests +cargo test --package integration_tests + +# Run a specific test +cargo test --package integration_tests test_simple_leaf_function +``` + +### Updating Snapshots + +When you make changes to the compiler or optimizer that affect the output: + +```bash +# Update all snapshots automatically +INSTA_UPDATE=always cargo test --package integration_tests + +# Or use cargo-insta for interactive review (install first: cargo install cargo-insta) +cargo insta test --package integration_tests +cargo insta review --package integration_tests +``` + +### Understanding Snapshots + +Snapshot files are stored in `src/snapshots/` and contain: + +- The full IC10 assembly output from compiling Slang source code +- Metadata about which test generated them +- The expression that produced the output + +Example snapshot structure: + +``` +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +j main +move r8 10 +j ra +``` + +### What We Test + +1. **Leaf Function Optimization** - Removal of unnecessary `push sp/ra` and `pop ra/sp` +2. **Function Calls** - Preservation of stack frame when calling functions +3. **Constant Folding** - Compile-time evaluation of constant expressions +4. **Algebraic Simplification** - Identity operations like `x * 1` → `x` +5. **Strength Reduction** - Converting expensive operations like `x * 2` → `x + x` +6. **Dead Code Elimination** - Removal of unused variables +7. **Peephole Comparison Fusion** - Combining comparison + branch instructions +8. **Select Optimization** - Converting if/else to single `select` instruction +9. **Complex Arithmetic** - Multiple optimizations working together +10. **Nested Function Calls** - Full program optimization + +### Adding New Tests + +To add a new integration test: + +1. Add a new `#[test]` function in `src/lib.rs` +2. Call `compile_optimized()` with your Slang source code +3. Use `insta::assert_snapshot!(output)` to capture the output +4. Run with `INSTA_UPDATE=always` to create the initial snapshot +5. Review the snapshot file to ensure it looks correct + +Example: + +```rust +#[test] +fn test_my_optimization() { + let source = "fn foo(x) { return x + 1; }"; + let output = compile_optimized(source); + insta::assert_snapshot!(output); +} +``` + +### Benefits of Snapshot Testing + +- **Full Output Verification**: Tests the entire compiled output, not just snippets +- **Easy to Review**: Visual diffs show exactly what changed in the output +- **Regression Detection**: Any change to output is immediately visible +- **Living Documentation**: Snapshots serve as examples of compiler output +- **Less Brittle**: No need to manually update expected strings when making intentional changes diff --git a/rust_compiler/libs/integration_tests/src/lib.rs b/rust_compiler/libs/integration_tests/src/lib.rs new file mode 100644 index 0000000..e59086a --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/lib.rs @@ -0,0 +1,175 @@ +//! Integration tests for the Slang compiler with optimizer +//! +//! These tests compile Slang source code and verify both the compilation +//! and optimization passes work correctly together using snapshot testing. + +#[cfg(test)] +mod tests { + use compiler::Compiler; + use indoc::indoc; + use parser::Parser; + use tokenizer::Tokenizer; + + /// Compile Slang source code and return both unoptimized and optimized output + fn compile_with_and_without_optimization(source: &str) -> String { + // Compile for unoptimized output + let tokenizer = Tokenizer::from(source); + let parser = Parser::new(tokenizer); + let compiler = Compiler::new(parser, None); + let result = compiler.compile(); + + // Get unoptimized output + let mut unoptimized_writer = std::io::BufWriter::new(Vec::new()); + result + .instructions + .write(&mut unoptimized_writer) + .expect("Failed to write unoptimized output"); + let unoptimized_bytes = unoptimized_writer + .into_inner() + .expect("Failed to get bytes"); + let unoptimized = String::from_utf8(unoptimized_bytes).expect("Invalid UTF-8"); + + // Compile again for optimized output + let tokenizer2 = Tokenizer::from(source); + let parser2 = Parser::new(tokenizer2); + let compiler2 = Compiler::new(parser2, None); + let result2 = compiler2.compile(); + + // Apply optimizations + let optimized_instructions = optimizer::optimize(result2.instructions); + + // Get optimized output + let mut optimized_writer = std::io::BufWriter::new(Vec::new()); + optimized_instructions + .write(&mut optimized_writer) + .expect("Failed to write optimized output"); + let optimized_bytes = optimized_writer.into_inner().expect("Failed to get bytes"); + let optimized = String::from_utf8(optimized_bytes).expect("Invalid UTF-8"); + + // Combine both outputs with clear separators + format!( + "## Unoptimized Output\n\n{}\n## Optimized Output\n\n{}", + unoptimized, optimized + ) + } + + #[test] + fn test_simple_leaf_function() { + let source = "fn test() { let x = 10; }"; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_function_with_call() { + let source = indoc! {" + fn add(a, b) { return a + b; } + fn main() { let x = add(5, 10); } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_constant_folding() { + let source = "let x = 5 + 10;"; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_algebraic_simplification() { + let source = "let x = 5; let y = x * 1;"; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_strength_reduction() { + let source = "fn double(x) { return x * 2; }"; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_dead_code_elimination() { + let source = indoc! {" + fn compute(x) { + let unused = 20; + return x + 1; + } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_peephole_comparison_fusion() { + let source = indoc! {" + fn compare(x, y) { + if (x > y) { + let z = 1; + } + } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_select_optimization() { + let source = indoc! {" + fn ternary(cond) { + let result = 0; + if (cond) { + result = 10; + } else { + result = 20; + } + return result; + } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_leaf_function_no_stack_frame() { + let source = indoc! {" + fn increment(x) { + x = x + 1; + } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_complex_arithmetic() { + let source = indoc! {" + fn compute(a, b, c) { + let x = a * 2; + let y = b + 0; + let z = c * 1; + return x + y + z; + } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_nested_function_calls() { + let source = indoc! {" + fn add(a, b) { return a + b; } + fn multiply(x, y) { return x * 2; } + fn complex(a, b) { + let sum = add(a, b); + let doubled = multiply(sum, 2); + return doubled; + } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } +} diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__algebraic_simplification.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__algebraic_simplification.snap new file mode 100644 index 0000000..aa56204 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__algebraic_simplification.snap @@ -0,0 +1,18 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +main: +move r8 5 +mul r1 r8 1 +move r9 r1 + +## Optimized Output + +j 1 +move r8 5 +move r1 5 +move r9 5 diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__complex_arithmetic.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__complex_arithmetic.snap new file mode 100644 index 0000000..724ef55 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__complex_arithmetic.snap @@ -0,0 +1,49 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +compute: +pop r8 +pop r9 +pop r10 +push sp +push ra +mul r1 r10 2 +move r11 r1 +add r2 r9 0 +move r12 r2 +mul r3 r8 1 +move r13 r3 +add r4 r11 r12 +add r5 r4 r13 +move r15 r5 +j __internal_L1 +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +pop r9 +pop r10 +push sp +push ra +add r1 r10 r10 +move r11 r1 +move r2 r9 +move r12 r2 +move r3 r8 +move r13 r3 +add r4 r11 r12 +add r5 r4 r13 +move r15 r5 +j 16 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__constant_folding.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__constant_folding.snap new file mode 100644 index 0000000..2f4fcf3 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__constant_folding.snap @@ -0,0 +1,14 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +main: +move r8 15 + +## Optimized Output + +j 1 +move r8 15 diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__dead_code_elimination.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__dead_code_elimination.snap new file mode 100644 index 0000000..9f6fefe --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__dead_code_elimination.snap @@ -0,0 +1,33 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +compute: +pop r8 +push sp +push ra +move r9 20 +add r1 r8 1 +move r15 r1 +j __internal_L1 +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +push sp +push ra +move r9 20 +add r1 r8 1 +move r15 r1 +j 8 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__function_with_call.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__function_with_call.snap new file mode 100644 index 0000000..093032f --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__function_with_call.snap @@ -0,0 +1,53 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +add: +pop r8 +pop r9 +push sp +push ra +add r1 r9 r8 +move r15 r1 +j __internal_L1 +__internal_L1: +pop ra +pop sp +j ra +main: +push sp +push ra +push 5 +push 10 +jal add +move r8 r15 +__internal_L2: +pop ra +pop sp +j ra + +## Optimized Output + +j 11 +pop r8 +pop r9 +push sp +push ra +add r1 r9 r8 +move r15 r1 +j 8 +pop ra +pop sp +j ra +push sp +push ra +push 5 +push 10 +jal 1 +move r8 r15 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__leaf_function_no_stack_frame.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__leaf_function_no_stack_frame.snap new file mode 100644 index 0000000..86f9929 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__leaf_function_no_stack_frame.snap @@ -0,0 +1,27 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +increment: +pop r8 +push sp +push ra +add r1 r8 1 +move r8 r1 +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +j main +pop r8 +add r1 r8 1 +move r8 r1 +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__nested_function_calls.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__nested_function_calls.snap new file mode 100644 index 0000000..3cad2e5 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__nested_function_calls.snap @@ -0,0 +1,111 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +add: +pop r8 +pop r9 +push sp +push ra +add r1 r9 r8 +move r15 r1 +j __internal_L1 +__internal_L1: +pop ra +pop sp +j ra +multiply: +pop r8 +pop r9 +push sp +push ra +mul r1 r9 2 +move r15 r1 +j __internal_L2 +__internal_L2: +pop ra +pop sp +j ra +complex: +pop r8 +pop r9 +push sp +push ra +push r8 +push r9 +push r9 +push r8 +jal add +pop r9 +pop r8 +move r10 r15 +push r8 +push r9 +push r10 +push r10 +push 2 +jal multiply +pop r10 +pop r9 +pop r8 +move r11 r15 +move r15 r11 +j __internal_L3 +__internal_L3: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +pop r9 +push sp +push ra +add r1 r9 r8 +move r15 r1 +j 8 +pop ra +pop sp +j ra +pop r8 +pop r9 +push sp +push ra +add r1 r9 r9 +move r15 r1 +j 18 +pop ra +pop sp +j ra +pop r8 +pop r9 +push sp +push ra +push r8 +push r9 +push r9 +push r8 +jal 1 +pop r9 +pop r8 +move r10 r15 +push r8 +push r9 +push r10 +push r10 +push 2 +jal 11 +pop r10 +pop r9 +pop r8 +move r11 r15 +move r15 r11 +j 45 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__peephole_comparison_fusion.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__peephole_comparison_fusion.snap new file mode 100644 index 0000000..f66f066 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__peephole_comparison_fusion.snap @@ -0,0 +1,32 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +compare: +pop r8 +pop r9 +push sp +push ra +sgt r1 r9 r8 +beqz r1 __internal_L2 +move r10 1 +__internal_L2: +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +pop r9 +j main +pop r8 +pop r9 +ble r9 r8 8 +move r10 1 +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__select_optimization.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__select_optimization.snap new file mode 100644 index 0000000..941527a --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__select_optimization.snap @@ -0,0 +1,37 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +ternary: +pop r8 +push sp +push ra +move r9 0 +beqz r8 __internal_L3 +move r9 10 +j __internal_L2 +__internal_L3: +move r9 20 +__internal_L2: +move r15 r9 +j __internal_L1 +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +push sp +push ra +select r9 r8 10 20 +move r15 r9 +j 7 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__simple_leaf_function.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__simple_leaf_function.snap new file mode 100644 index 0000000..621b3ac --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__simple_leaf_function.snap @@ -0,0 +1,22 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +test: +push sp +push ra +move r8 10 +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +j main +move r8 10 +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__strength_reduction.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__strength_reduction.snap new file mode 100644 index 0000000..73093d4 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__strength_reduction.snap @@ -0,0 +1,31 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +double: +pop r8 +push sp +push ra +mul r1 r8 2 +move r15 r1 +j __internal_L1 +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +push sp +push ra +add r1 r8 r8 +move r15 r1 +j 7 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/optimizer/OPTIMIZATION_IDEAS.md b/rust_compiler/libs/optimizer/OPTIMIZATION_IDEAS.md new file mode 100644 index 0000000..8d970ab --- /dev/null +++ b/rust_compiler/libs/optimizer/OPTIMIZATION_IDEAS.md @@ -0,0 +1,212 @@ +# Additional Optimization Opportunities for Slang IL Optimizer + +## Currently Implemented ✓ + +1. Constant Propagation - Folds math operations with known values +2. Register Forwarding - Eliminates intermediate moves +3. Function Call Optimization - Removes unnecessary push/pop around calls +4. Leaf Function Optimization - Removes RA save/restore for non-calling functions +5. Redundant Move Elimination - Removes `move rx rx` +6. Dead Code Elimination - Removes unreachable code after jumps + +## Proposed Additional Optimizations + +### 1. **Algebraic Simplification** 🔥 HIGH IMPACT + +Simplify mathematical identities: + +- `x + 0` → `x` (move) +- `x - 0` → `x` (move) +- `x * 1` → `x` (move) +- `x * 0` → `0` (move to constant) +- `x / 1` → `x` (move) +- `x - x` → `0` (move to constant) +- `x % 1` → `0` (move to constant) + +**Example:** + +``` +add r1 r2 0 → move r1 r2 +mul r3 r4 1 → move r3 r4 +mul r5 r6 0 → move r5 0 +``` + +### 2. **Strength Reduction** 🔥 HIGH IMPACT + +Replace expensive operations with cheaper ones: + +- `x * 2` → `add x x x` (addition is cheaper than multiplication) +- `x * power_of_2` → bit shifts (if IC10 supports) +- `x / 2` → bit shifts (if IC10 supports) + +**Example:** + +``` +mul r1 r2 2 → add r1 r2 r2 +``` + +### 3. **Peephole Optimization - Instruction Sequences** 🔥 MEDIUM-HIGH IMPACT + +Recognize and optimize common instruction patterns: + +#### Pattern: Conditional Branch Simplification + +``` +seq r1 ra rb → beq ra rb label +beqz r1 label (remove the seq entirely) + +sne r1 ra rb → bne ra rb label +beqz r1 label (remove the sne entirely) +``` + +#### Pattern: Double Move Elimination + +``` +move r1 r2 → move r1 r3 +move r1 r3 (remove first move if r1 not used between) +``` + +#### Pattern: Redundant Load Elimination + +If a register's value is already loaded and hasn't been clobbered: + +``` +l r1 d0 Temperature +... (no writes to r1) +l r1 d0 Temperature → (remove second load) +``` + +### 4. **Copy Propagation Enhancement** 🔥 MEDIUM IMPACT + +Current register forwarding is good, but we can extend it: + +- Track `move` chains: if `r1 = r2` and `r2 = 5`, propagate the `5` directly +- Eliminate the intermediate register if possible + +### 5. **Dead Store Elimination** 🔥 MEDIUM IMPACT + +Remove writes to registers that are never read before being overwritten: + +``` +move r1 5 +move r1 10 → move r1 10 + (first write is dead) +``` + +### 6. **Common Subexpression Elimination (CSE)** 🔥 MEDIUM-HIGH IMPACT + +Recognize when the same computation is done multiple times: + +``` +add r1 r8 r9 +add r2 r8 r9 → add r1 r8 r9 + move r2 r1 +``` + +This is especially valuable for expensive operations like: + +- Device loads (`l`) +- Math functions (sqrt, sin, cos, etc.) + +### 7. **Jump Threading** 🔥 LOW-MEDIUM IMPACT + +Optimize jump-to-jump sequences: + +``` +j label1 +... +label1: +j label2 → j label2 (rewrite first jump) +``` + +### 8. **Branch Folding** 🔥 LOW-MEDIUM IMPACT + +Merge consecutive branches to the same target: + +``` +bgt r1 r2 label +bgt r3 r4 label → Could potentially be optimized based on conditions +``` + +### 9. **Loop Invariant Code Motion** 🔥 MEDIUM-HIGH IMPACT + +Move calculations out of loops if they don't change: + +``` +loop: + mul r2 5 10 → mul r2 5 10 (hoisted before loop) + add r1 r1 r2 loop: + ... add r1 r1 r2 + j loop ... + j loop +``` + +### 10. **Select Instruction Optimization** 🔥 LOW-MEDIUM IMPACT + +The `select` instruction can sometimes replace branch patterns: + +``` +beq r1 r2 else +move r3 r4 +j end +else: +move r3 r5 → seq r6 r1 r2 +end: select r3 r6 r5 r4 +``` + +### 11. **Stack Access Pattern Optimization** 🔥 LOW IMPACT + +If we see repeated `sub r0 sp N` + `get`, we might be able to optimize by: + +- Caching the stack address in a register if used multiple times +- Combining sequential gets from adjacent stack slots + +### 12. **Inline Small Functions** 🔥 HIGH IMPACT (Complex to implement) + +For very small leaf functions (1-2 instructions), inline them at the call site: + +``` +calculateSum: + add r15 r8 r9 + j ra + +main: + push 5 → main: + push 10 add r15 5 10 + jal calculateSum +``` + +### 13. **Branch Prediction Hints** 🔥 LOW IMPACT + +Reorganize code to put likely branches inline (fall-through) and unlikely branches as jumps. + +### 14. **Register Coalescing** 🔥 MEDIUM IMPACT + +Reduce register pressure by reusing registers that have non-overlapping lifetimes. + +## Priority Implementation Order + +### Phase 1 (Quick Wins): + +1. Algebraic Simplification (easy, high impact) +2. Strength Reduction (easy, high impact) +3. Dead Store Elimination (medium complexity, good impact) + +### Phase 2 (Medium Effort): + +4. Peephole Optimizations - seq/beq pattern (medium, high impact) +5. Common Subexpression Elimination (medium, high impact) +6. Copy Propagation Enhancement (medium, medium impact) + +### Phase 3 (Advanced): + +7. Loop Invariant Code Motion (complex, high impact for loop-heavy code) +8. Function Inlining (complex, high impact) +9. Register Coalescing (complex, medium impact) + +## Testing Strategy + +- Add test cases for each optimization +- Ensure optimization preserves semantics (run existing tests after each) +- Measure code size reduction +- Consider adding benchmarks to measure game performance impact diff --git a/rust_compiler/libs/optimizer/src/algebraic_simplification.rs b/rust_compiler/libs/optimizer/src/algebraic_simplification.rs new file mode 100644 index 0000000..82bd39d --- /dev/null +++ b/rust_compiler/libs/optimizer/src/algebraic_simplification.rs @@ -0,0 +1,161 @@ +use il::{Instruction, InstructionNode, Operand}; +use rust_decimal::Decimal; + +/// Pass: Algebraic Simplification +/// Simplifies mathematical identities like x+0, x*1, x*0, etc. +pub fn algebraic_simplification<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + + for mut node in input { + let simplified = match &node.instruction { + // x + 0 = x + Instruction::Add(dst, a, Operand::Number(n)) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), a.clone())) + } + Instruction::Add(dst, Operand::Number(n), b) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), b.clone())) + } + + // x - 0 = x + Instruction::Sub(dst, a, Operand::Number(n)) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), a.clone())) + } + + // x * 1 = x + Instruction::Mul(dst, a, Operand::Number(n)) if *n == Decimal::from(1) => { + Some(Instruction::Move(dst.clone(), a.clone())) + } + Instruction::Mul(dst, Operand::Number(n), b) if *n == Decimal::from(1) => { + Some(Instruction::Move(dst.clone(), b.clone())) + } + + // x * 0 = 0 + Instruction::Mul(dst, _, Operand::Number(n)) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + Instruction::Mul(dst, Operand::Number(n), _) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + + // x / 1 = x + Instruction::Div(dst, a, Operand::Number(n)) if *n == Decimal::from(1) => { + Some(Instruction::Move(dst.clone(), a.clone())) + } + + // 0 / x = 0 (if x != 0, but we can't check at compile time for non-literals) + Instruction::Div(dst, Operand::Number(n), _) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + + // x % 1 = 0 + Instruction::Mod(dst, _, Operand::Number(n)) if *n == Decimal::from(1) => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + + // 0 % x = 0 + Instruction::Mod(dst, Operand::Number(n), _) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + + // x AND 0 = 0 + Instruction::And(dst, _, Operand::Number(n)) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + Instruction::And(dst, Operand::Number(n), _) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + + // x OR 0 = x + Instruction::Or(dst, a, Operand::Number(n)) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), a.clone())) + } + Instruction::Or(dst, Operand::Number(n), b) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), b.clone())) + } + + // x XOR 0 = x + Instruction::Xor(dst, a, Operand::Number(n)) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), a.clone())) + } + Instruction::Xor(dst, Operand::Number(n), b) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), b.clone())) + } + + _ => None, + }; + + if let Some(new) = simplified { + node.instruction = new; + changed = true; + } + + output.push(node); + } + + (output, changed) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_add_zero() { + let input = vec![InstructionNode::new( + Instruction::Add( + Operand::Register(1), + Operand::Register(2), + Operand::Number(Decimal::ZERO), + ), + None, + )]; + + let (output, changed) = algebraic_simplification(input); + assert!(changed); + assert!(matches!( + output[0].instruction, + Instruction::Move(Operand::Register(1), Operand::Register(2)) + )); + } + + #[test] + fn test_mul_one() { + let input = vec![InstructionNode::new( + Instruction::Mul( + Operand::Register(3), + Operand::Register(4), + Operand::Number(Decimal::ONE), + ), + None, + )]; + + let (output, changed) = algebraic_simplification(input); + assert!(changed); + assert!(matches!( + output[0].instruction, + Instruction::Move(Operand::Register(3), Operand::Register(4)) + )); + } + + #[test] + fn test_mul_zero() { + let input = vec![InstructionNode::new( + Instruction::Mul( + Operand::Register(5), + Operand::Register(6), + Operand::Number(Decimal::ZERO), + ), + None, + )]; + + let (output, changed) = algebraic_simplification(input); + assert!(changed); + assert!(matches!( + output[0].instruction, + Instruction::Move(Operand::Register(5), Operand::Number(_)) + )); + } +} diff --git a/rust_compiler/libs/optimizer/src/constant_propagation.rs b/rust_compiler/libs/optimizer/src/constant_propagation.rs new file mode 100644 index 0000000..6765637 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/constant_propagation.rs @@ -0,0 +1,168 @@ +use crate::helpers::get_destination_reg; +use il::{Instruction, InstructionNode, Operand}; +use rust_decimal::Decimal; + +/// Pass: Constant Propagation +/// Folds arithmetic operations when both operands are constant. +/// Also tracks register values and propagates them forward. +pub fn constant_propagation<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + let mut registers: [Option; 16] = [None; 16]; + + for mut node in input { + // Invalidate register tracking on label/call boundaries + match &node.instruction { + Instruction::LabelDef(_) | Instruction::JumpAndLink(_) => registers = [None; 16], + _ => {} + } + + let simplified = match &node.instruction { + Instruction::Move(dst, src) => resolve_value(src, ®isters) + .map(|val| Instruction::Move(dst.clone(), Operand::Number(val))), + Instruction::Add(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x + y), + Instruction::Sub(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x - y), + Instruction::Mul(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x * y), + Instruction::Div(dst, a, b) => { + try_fold_math( + dst, + a, + b, + ®isters, + |x, y| if y.is_zero() { x } else { x / y }, + ) + } + Instruction::Mod(dst, a, b) => { + try_fold_math( + dst, + a, + b, + ®isters, + |x, y| if y.is_zero() { x } else { x % y }, + ) + } + Instruction::BranchEq(a, b, l) => { + try_resolve_branch(a, b, l, ®isters, |x, y| x == y) + } + Instruction::BranchNe(a, b, l) => { + try_resolve_branch(a, b, l, ®isters, |x, y| x != y) + } + Instruction::BranchGt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x > y), + Instruction::BranchLt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x < y), + Instruction::BranchGe(a, b, l) => { + try_resolve_branch(a, b, l, ®isters, |x, y| x >= y) + } + Instruction::BranchLe(a, b, l) => { + try_resolve_branch(a, b, l, ®isters, |x, y| x <= y) + } + Instruction::BranchEqZero(a, l) => { + try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x == y) + } + Instruction::BranchNeZero(a, l) => { + try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x != y) + } + _ => None, + }; + + if let Some(new) = simplified { + node.instruction = new; + changed = true; + } + + // Update register tracking + match &node.instruction { + Instruction::Move(Operand::Register(r), src) => { + registers[*r as usize] = resolve_value(src, ®isters) + } + _ => { + if let Some(r) = get_destination_reg(&node.instruction) { + registers[r as usize] = None; + } + } + } + + // Filter out NOPs (empty labels from branch resolution) + if let Instruction::LabelDef(l) = &node.instruction + && l.is_empty() + { + changed = true; + continue; + } + + output.push(node); + } + (output, changed) +} + +fn resolve_value(op: &Operand, regs: &[Option; 16]) -> Option { + match op { + Operand::Number(n) => Some(*n), + Operand::Register(r) => regs[*r as usize], + _ => None, + } +} + +fn try_fold_math<'a, F>( + dst: &Operand<'a>, + a: &Operand<'a>, + b: &Operand<'a>, + regs: &[Option; 16], + op: F, +) -> Option> +where + F: Fn(Decimal, Decimal) -> Decimal, +{ + let val_a = resolve_value(a, regs)?; + let val_b = resolve_value(b, regs)?; + Some(Instruction::Move( + dst.clone(), + Operand::Number(op(val_a, val_b)), + )) +} + +fn try_resolve_branch<'a, F>( + a: &Operand<'a>, + b: &Operand<'a>, + label: &Operand<'a>, + regs: &[Option; 16], + check: F, +) -> Option> +where + F: Fn(Decimal, Decimal) -> bool, +{ + let val_a = resolve_value(a, regs)?; + let val_b = resolve_value(b, regs)?; + if check(val_a, val_b) { + Some(Instruction::Jump(label.clone())) + } else { + Some(Instruction::LabelDef("".into())) // NOP + } +} + +#[cfg(test)] +mod tests { + use super::*; + use il::InstructionNode; + + #[test] + fn test_fold_add() { + let input = vec![InstructionNode::new( + Instruction::Add( + Operand::Register(1), + Operand::Number(5.into()), + Operand::Number(3.into()), + ), + None, + )]; + + let (output, changed) = constant_propagation(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::Move(Operand::Register(1), Operand::Number(_)) + )); + } +} diff --git a/rust_compiler/libs/optimizer/src/dead_code.rs b/rust_compiler/libs/optimizer/src/dead_code.rs new file mode 100644 index 0000000..feb1e73 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/dead_code.rs @@ -0,0 +1,82 @@ +use il::{Instruction, InstructionNode}; + +/// Pass: Redundant Move Elimination +/// Removes moves where source and destination are the same: `move rx rx` +pub fn remove_redundant_moves<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + for node in input { + if let Instruction::Move(dst, src) = &node.instruction + && dst == src + { + changed = true; + continue; + } + output.push(node); + } + (output, changed) +} + +/// Pass: Dead Code Elimination +/// Removes unreachable code after unconditional jumps. +pub fn remove_unreachable_code<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + let mut dead = false; + for node in input { + if let Instruction::LabelDef(_) = node.instruction { + dead = false; + } + if dead { + changed = true; + continue; + } + if let Instruction::Jump(_) = node.instruction { + dead = true + } + output.push(node); + } + (output, changed) +} + +#[cfg(test)] +mod tests { + use super::*; + use il::{Instruction, InstructionNode, Operand}; + + #[test] + fn test_remove_redundant_move() { + let input = vec![InstructionNode::new( + Instruction::Move(Operand::Register(1), Operand::Register(1)), + None, + )]; + + let (output, changed) = remove_redundant_moves(input); + assert!(changed); + assert_eq!(output.len(), 0); + } + + #[test] + fn test_remove_unreachable() { + let input = vec![ + InstructionNode::new(Instruction::Jump(Operand::Label("main".into())), None), + InstructionNode::new( + Instruction::Add( + Operand::Register(1), + Operand::Number(1.into()), + Operand::Number(2.into()), + ), + None, + ), + InstructionNode::new(Instruction::LabelDef("main".into()), None), + ]; + + let (output, changed) = remove_unreachable_code(input); + assert!(changed); + assert_eq!(output.len(), 2); + } +} diff --git a/rust_compiler/libs/optimizer/src/dead_store_elimination.rs b/rust_compiler/libs/optimizer/src/dead_store_elimination.rs new file mode 100644 index 0000000..91191ad --- /dev/null +++ b/rust_compiler/libs/optimizer/src/dead_store_elimination.rs @@ -0,0 +1,109 @@ +use crate::helpers::get_destination_reg; +use il::{Instruction, InstructionNode}; +use std::collections::HashMap; + +/// Pass: Dead Store Elimination +/// Removes writes to registers that are never read before being overwritten. +pub fn dead_store_elimination<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut changed = false; + let mut last_write: HashMap = HashMap::new(); + let mut to_remove = Vec::new(); + + // Scan for dead writes + for (i, node) in input.iter().enumerate() { + if let Some(dest_reg) = get_destination_reg(&node.instruction) { + // If this register was written before and hasn't been read, previous write is dead + if let Some(&prev_idx) = last_write.get(&dest_reg) { + // Check if the value was ever used between prev_idx and current + let was_used = input[prev_idx + 1..i] + .iter() + .any(|n| reg_is_read_or_affects_control(&n.instruction, dest_reg)); + + if !was_used { + // Previous write was dead + to_remove.push(prev_idx); + changed = true; + } + } + + // Update last write location + last_write.insert(dest_reg, i); + } + + // On labels/jumps, conservatively clear tracking (value might be used elsewhere) + if matches!( + node.instruction, + Instruction::LabelDef(_) | Instruction::Jump(_) | Instruction::JumpAndLink(_) + ) { + last_write.clear(); + } + } + + if changed { + let output = input + .into_iter() + .enumerate() + .filter_map(|(i, node)| { + if to_remove.contains(&i) { + None + } else { + Some(node) + } + }) + .collect(); + (output, true) + } else { + (input, false) + } +} + +/// Simplified check: Does this instruction read the register or affect control flow? +fn reg_is_read_or_affects_control(instr: &Instruction, reg: u8) -> bool { + use crate::helpers::reg_is_read; + + // If it reads the register, it's used + if reg_is_read(instr, reg) { + return true; + } + + // Conservatively assume register might be used if there's control flow + matches!( + instr, + Instruction::Jump(_) + | Instruction::JumpAndLink(_) + | Instruction::BranchEq(_, _, _) + | Instruction::BranchNe(_, _, _) + | Instruction::BranchGt(_, _, _) + | Instruction::BranchLt(_, _, _) + | Instruction::BranchGe(_, _, _) + | Instruction::BranchLe(_, _, _) + | Instruction::BranchEqZero(_, _) + | Instruction::BranchNeZero(_, _) + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use il::Operand; + + #[test] + fn test_dead_store() { + let input = vec![ + InstructionNode::new( + Instruction::Move(Operand::Register(1), Operand::Number(5.into())), + None, + ), + InstructionNode::new( + Instruction::Move(Operand::Register(1), Operand::Number(10.into())), + None, + ), + ]; + + let (output, changed) = dead_store_elimination(input); + assert!(changed); + assert_eq!(output.len(), 1); + } +} diff --git a/rust_compiler/libs/optimizer/src/function_call_optimization.rs b/rust_compiler/libs/optimizer/src/function_call_optimization.rs new file mode 100644 index 0000000..3233ad4 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/function_call_optimization.rs @@ -0,0 +1,160 @@ +use crate::helpers::get_destination_reg; +use il::{Instruction, InstructionNode, Operand}; +use rust_decimal::Decimal; +use std::collections::{HashMap, HashSet}; + +/// Analyzes which registers are written to by each function label. +fn analyze_clobbers(instructions: &[InstructionNode]) -> HashMap> { + let mut clobbers = HashMap::new(); + let mut current_label = None; + + for node in instructions { + if let Instruction::LabelDef(label) = &node.instruction { + current_label = Some(label.to_string()); + clobbers.insert(label.to_string(), HashSet::new()); + } + + if let Some(label) = ¤t_label + && let Some(reg) = get_destination_reg(&node.instruction) + && let Some(set) = clobbers.get_mut(label) + { + set.insert(reg); + } + } + clobbers +} + +/// Pass: Function Call Optimization +/// Removes Push/Restore pairs surrounding a JAL if the target function does not clobber that register. +pub fn optimize_function_calls<'a>( + input: Vec>, +) -> (Vec>, bool) { + let clobbers = analyze_clobbers(&input); + let mut changed = false; + let mut to_remove = HashSet::new(); + let mut stack_adjustments = HashMap::new(); + + let mut i = 0; + while i < input.len() { + if let Instruction::JumpAndLink(Operand::Label(target)) = &input[i].instruction { + let target_key = target.to_string(); + + if let Some(func_clobbers) = clobbers.get(&target_key) { + // 1. Identify Pushes immediately preceding the JAL + let mut pushes = Vec::new(); // (index, register) + let mut scan_back = i.saturating_sub(1); + while scan_back > 0 { + if to_remove.contains(&scan_back) { + scan_back -= 1; + continue; + } + if let Instruction::Push(Operand::Register(r)) = &input[scan_back].instruction { + pushes.push((scan_back, *r)); + scan_back -= 1; + } else { + break; + } + } + + // 2. Identify Restores immediately following the JAL + let mut restores = Vec::new(); // (index_of_get, register, index_of_sub) + let mut scan_fwd = i + 1; + while scan_fwd < input.len() { + // Skip 'sub r0 sp X' + if let Instruction::Sub(Operand::Register(0), Operand::StackPointer, _) = + &input[scan_fwd].instruction + { + // Check next instruction for the Get + if scan_fwd + 1 < input.len() + && let Instruction::Get(Operand::Register(r), _, Operand::Register(0)) = + &input[scan_fwd + 1].instruction + { + restores.push((scan_fwd + 1, *r, scan_fwd)); + scan_fwd += 2; + continue; + } + } + break; + } + + // 3. Stack Cleanup + let cleanup_idx = scan_fwd; + let has_cleanup = if cleanup_idx < input.len() { + matches!( + input[cleanup_idx].instruction, + Instruction::Sub( + Operand::StackPointer, + Operand::StackPointer, + Operand::Number(_) + ) + ) + } else { + false + }; + + // SAFEGUARD: Check Counts! + let mut push_counts = HashMap::new(); + for (_, r) in &pushes { + *push_counts.entry(*r).or_insert(0) += 1; + } + + let mut restore_counts = HashMap::new(); + for (_, r, _) in &restores { + *restore_counts.entry(*r).or_insert(0) += 1; + } + + let counts_match = push_counts + .iter() + .all(|(reg, count)| restore_counts.get(reg).unwrap_or(&0) == count); + let counts_match_reverse = restore_counts + .iter() + .all(|(reg, count)| push_counts.get(reg).unwrap_or(&0) == count); + + // Clobber Check + let all_pushes_safe = pushes.iter().all(|(_, r)| !func_clobbers.contains(r)); + + if all_pushes_safe && has_cleanup && counts_match && counts_match_reverse { + // Remove all pushes/restores + for (p_idx, _) in pushes { + to_remove.insert(p_idx); + } + for (g_idx, _, s_idx) in restores { + to_remove.insert(g_idx); + to_remove.insert(s_idx); + } + + // Reduce stack cleanup amount + let num_removed = push_counts.values().sum::() as i64; + stack_adjustments.insert(cleanup_idx, num_removed); + changed = true; + } + } + } + i += 1; + } + + if changed { + let mut clean = Vec::with_capacity(input.len()); + for (idx, mut node) in input.into_iter().enumerate() { + if to_remove.contains(&idx) { + continue; + } + + // Apply stack adjustment + if let Some(reduction) = stack_adjustments.get(&idx) + && let Instruction::Sub(dst, a, Operand::Number(n)) = &node.instruction + { + let new_n = n - Decimal::from(*reduction); + if new_n.is_zero() { + continue; + } + node.instruction = Instruction::Sub(dst.clone(), a.clone(), Operand::Number(new_n)); + } + + clean.push(node); + } + return (clean, changed); + } + + (input, false) +} diff --git a/rust_compiler/libs/optimizer/src/helpers.rs b/rust_compiler/libs/optimizer/src/helpers.rs new file mode 100644 index 0000000..f396969 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/helpers.rs @@ -0,0 +1,172 @@ +use il::{Instruction, Operand}; + +/// Returns the register number written to by an instruction, if any. +pub fn get_destination_reg(instr: &Instruction) -> Option { + match instr { + Instruction::Move(Operand::Register(r), _) + | Instruction::Add(Operand::Register(r), _, _) + | Instruction::Sub(Operand::Register(r), _, _) + | Instruction::Mul(Operand::Register(r), _, _) + | Instruction::Div(Operand::Register(r), _, _) + | Instruction::Mod(Operand::Register(r), _, _) + | Instruction::Pow(Operand::Register(r), _, _) + | Instruction::Load(Operand::Register(r), _, _) + | Instruction::LoadSlot(Operand::Register(r), _, _, _) + | Instruction::LoadBatch(Operand::Register(r), _, _, _) + | Instruction::LoadBatchNamed(Operand::Register(r), _, _, _, _) + | Instruction::SetEq(Operand::Register(r), _, _) + | Instruction::SetNe(Operand::Register(r), _, _) + | Instruction::SetGt(Operand::Register(r), _, _) + | Instruction::SetLt(Operand::Register(r), _, _) + | Instruction::SetGe(Operand::Register(r), _, _) + | Instruction::SetLe(Operand::Register(r), _, _) + | Instruction::And(Operand::Register(r), _, _) + | Instruction::Or(Operand::Register(r), _, _) + | Instruction::Xor(Operand::Register(r), _, _) + | Instruction::Peek(Operand::Register(r)) + | Instruction::Get(Operand::Register(r), _, _) + | Instruction::Select(Operand::Register(r), _, _, _) + | Instruction::Rand(Operand::Register(r)) + | Instruction::Acos(Operand::Register(r), _) + | Instruction::Asin(Operand::Register(r), _) + | Instruction::Atan(Operand::Register(r), _) + | Instruction::Atan2(Operand::Register(r), _, _) + | Instruction::Abs(Operand::Register(r), _) + | Instruction::Ceil(Operand::Register(r), _) + | Instruction::Cos(Operand::Register(r), _) + | Instruction::Floor(Operand::Register(r), _) + | Instruction::Log(Operand::Register(r), _) + | Instruction::Max(Operand::Register(r), _, _) + | Instruction::Min(Operand::Register(r), _, _) + | Instruction::Sin(Operand::Register(r), _) + | Instruction::Sqrt(Operand::Register(r), _) + | Instruction::Tan(Operand::Register(r), _) + | Instruction::Trunc(Operand::Register(r), _) + | Instruction::LoadReagent(Operand::Register(r), _, _, _) + | Instruction::Pop(Operand::Register(r)) => Some(*r), + _ => None, + } +} + +/// Creates a new instruction with the destination register changed. +pub fn set_destination_reg<'a>(instr: &Instruction<'a>, new_reg: u8) -> Option> { + let r = Operand::Register(new_reg); + match instr { + Instruction::Move(_, b) => Some(Instruction::Move(r, b.clone())), + Instruction::Add(_, a, b) => Some(Instruction::Add(r, a.clone(), b.clone())), + Instruction::Sub(_, a, b) => Some(Instruction::Sub(r, a.clone(), b.clone())), + Instruction::Mul(_, a, b) => Some(Instruction::Mul(r, a.clone(), b.clone())), + Instruction::Div(_, a, b) => Some(Instruction::Div(r, a.clone(), b.clone())), + Instruction::Mod(_, a, b) => Some(Instruction::Mod(r, a.clone(), b.clone())), + Instruction::Pow(_, a, b) => Some(Instruction::Pow(r, a.clone(), b.clone())), + Instruction::Load(_, a, b) => Some(Instruction::Load(r, a.clone(), b.clone())), + Instruction::LoadSlot(_, a, b, c) => { + Some(Instruction::LoadSlot(r, a.clone(), b.clone(), c.clone())) + } + Instruction::LoadBatch(_, a, b, c) => { + Some(Instruction::LoadBatch(r, a.clone(), b.clone(), c.clone())) + } + Instruction::LoadBatchNamed(_, a, b, c, d) => Some(Instruction::LoadBatchNamed( + r, + a.clone(), + b.clone(), + c.clone(), + d.clone(), + )), + Instruction::LoadReagent(_, b, c, d) => { + Some(Instruction::LoadReagent(r, b.clone(), c.clone(), d.clone())) + } + Instruction::SetEq(_, a, b) => Some(Instruction::SetEq(r, a.clone(), b.clone())), + Instruction::SetNe(_, a, b) => Some(Instruction::SetNe(r, a.clone(), b.clone())), + Instruction::SetGt(_, a, b) => Some(Instruction::SetGt(r, a.clone(), b.clone())), + Instruction::SetLt(_, a, b) => Some(Instruction::SetLt(r, a.clone(), b.clone())), + Instruction::SetGe(_, a, b) => Some(Instruction::SetGe(r, a.clone(), b.clone())), + Instruction::SetLe(_, a, b) => Some(Instruction::SetLe(r, a.clone(), b.clone())), + Instruction::And(_, a, b) => Some(Instruction::And(r, a.clone(), b.clone())), + Instruction::Or(_, a, b) => Some(Instruction::Or(r, a.clone(), b.clone())), + Instruction::Xor(_, a, b) => Some(Instruction::Xor(r, a.clone(), b.clone())), + Instruction::Peek(_) => Some(Instruction::Peek(r)), + Instruction::Get(_, a, b) => Some(Instruction::Get(r, a.clone(), b.clone())), + Instruction::Select(_, a, b, c) => { + Some(Instruction::Select(r, a.clone(), b.clone(), c.clone())) + } + Instruction::Rand(_) => Some(Instruction::Rand(r)), + Instruction::Pop(_) => Some(Instruction::Pop(r)), + Instruction::Acos(_, a) => Some(Instruction::Acos(r, a.clone())), + Instruction::Asin(_, a) => Some(Instruction::Asin(r, a.clone())), + Instruction::Atan(_, a) => Some(Instruction::Atan(r, a.clone())), + Instruction::Atan2(_, a, b) => Some(Instruction::Atan2(r, a.clone(), b.clone())), + Instruction::Abs(_, a) => Some(Instruction::Abs(r, a.clone())), + Instruction::Ceil(_, a) => Some(Instruction::Ceil(r, a.clone())), + Instruction::Cos(_, a) => Some(Instruction::Cos(r, a.clone())), + Instruction::Floor(_, a) => Some(Instruction::Floor(r, a.clone())), + Instruction::Log(_, a) => Some(Instruction::Log(r, a.clone())), + Instruction::Max(_, a, b) => Some(Instruction::Max(r, a.clone(), b.clone())), + Instruction::Min(_, a, b) => Some(Instruction::Min(r, a.clone(), b.clone())), + Instruction::Sin(_, a) => Some(Instruction::Sin(r, a.clone())), + Instruction::Sqrt(_, a) => Some(Instruction::Sqrt(r, a.clone())), + Instruction::Tan(_, a) => Some(Instruction::Tan(r, a.clone())), + Instruction::Trunc(_, a) => Some(Instruction::Trunc(r, a.clone())), + _ => None, + } +} + +/// Checks if a register is read by an instruction. +pub fn reg_is_read(instr: &Instruction, reg: u8) -> bool { + let check = |op: &Operand| matches!(op, Operand::Register(r) if *r == reg); + + match instr { + Instruction::Move(_, a) => check(a), + Instruction::Add(_, a, b) + | Instruction::Sub(_, a, b) + | Instruction::Mul(_, a, b) + | Instruction::Div(_, a, b) + | Instruction::Mod(_, a, b) + | Instruction::Pow(_, a, b) => check(a) || check(b), + Instruction::Load(_, a, _) => check(a), + Instruction::Store(a, _, b) => check(a) || check(b), + Instruction::BranchEq(a, b, _) + | Instruction::BranchNe(a, b, _) + | Instruction::BranchGt(a, b, _) + | Instruction::BranchLt(a, b, _) + | Instruction::BranchGe(a, b, _) + | Instruction::BranchLe(a, b, _) => check(a) || check(b), + Instruction::BranchEqZero(a, _) | Instruction::BranchNeZero(a, _) => check(a), + Instruction::LoadReagent(_, device, _, item_hash) => check(device) || check(item_hash), + Instruction::LoadSlot(_, dev, slot, _) => check(dev) || check(slot), + Instruction::LoadBatch(_, dev, _, mode) => check(dev) || check(mode), + Instruction::LoadBatchNamed(_, d_hash, n_hash, _, mode) => { + check(d_hash) || check(n_hash) || check(mode) + } + Instruction::SetEq(_, a, b) + | Instruction::SetNe(_, a, b) + | Instruction::SetGt(_, a, b) + | Instruction::SetLt(_, a, b) + | Instruction::SetGe(_, a, b) + | Instruction::SetLe(_, a, b) + | Instruction::And(_, a, b) + | Instruction::Or(_, a, b) + | Instruction::Xor(_, a, b) => check(a) || check(b), + Instruction::Push(a) => check(a), + Instruction::Get(_, a, b) => check(a) || check(b), + Instruction::Put(a, b, c) => check(a) || check(b) || check(c), + Instruction::Select(_, a, b, c) => check(a) || check(b) || check(c), + Instruction::Sleep(a) => check(a), + Instruction::Acos(_, a) + | Instruction::Asin(_, a) + | Instruction::Atan(_, a) + | Instruction::Abs(_, a) + | Instruction::Ceil(_, a) + | Instruction::Cos(_, a) + | Instruction::Floor(_, a) + | Instruction::Log(_, a) + | Instruction::Sin(_, a) + | Instruction::Sqrt(_, a) + | Instruction::Tan(_, a) + | Instruction::Trunc(_, a) => check(a), + Instruction::Atan2(_, a, b) | Instruction::Max(_, a, b) | Instruction::Min(_, a, b) => { + check(a) || check(b) + } + _ => false, + } +} diff --git a/rust_compiler/libs/optimizer/src/label_resolution.rs b/rust_compiler/libs/optimizer/src/label_resolution.rs new file mode 100644 index 0000000..a801166 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/label_resolution.rs @@ -0,0 +1,70 @@ +use il::{Instruction, InstructionNode, Operand}; +use rust_decimal::Decimal; +use std::collections::HashMap; + +/// Pass: Resolve Labels +/// Converts all Jump/Branch labels to absolute line numbers and removes LabelDefs. +pub fn resolve_labels<'a>(input: Vec>) -> Vec> { + let mut label_map: HashMap = HashMap::new(); + let mut line_number = 0; + + // Build Label Map (filtering out LabelDefs from the count) + for node in &input { + if let Instruction::LabelDef(name) = &node.instruction { + label_map.insert(name.to_string(), line_number); + } else { + line_number += 1; + } + } + + let mut output = Vec::with_capacity(input.len()); + + // Rewrite Jumps and Filter Labels + for mut node in input { + // Helper to get line number as Decimal operand + let get_line = |lbl: &Operand| -> Option> { + if let Operand::Label(name) = lbl { + label_map + .get(name.as_ref()) + .map(|&l| Operand::Number(Decimal::from(l))) + } else { + None + } + }; + + match &mut node.instruction { + Instruction::LabelDef(_) => continue, // Strip labels + + // Jumps + Instruction::Jump(op) => { + if let Some(num) = get_line(op) { + *op = num; + } + } + Instruction::JumpAndLink(op) => { + if let Some(num) = get_line(op) { + *op = num; + } + } + Instruction::BranchEq(_, _, op) + | Instruction::BranchNe(_, _, op) + | Instruction::BranchGt(_, _, op) + | Instruction::BranchLt(_, _, op) + | Instruction::BranchGe(_, _, op) + | Instruction::BranchLe(_, _, op) => { + if let Some(num) = get_line(op) { + *op = num; + } + } + Instruction::BranchEqZero(_, op) | Instruction::BranchNeZero(_, op) => { + if let Some(num) = get_line(op) { + *op = num; + } + } + _ => {} + } + output.push(node); + } + + output +} diff --git a/rust_compiler/libs/optimizer/src/leaf_function_optimization.rs b/rust_compiler/libs/optimizer/src/leaf_function_optimization.rs new file mode 100644 index 0000000..8447e95 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/leaf_function_optimization.rs @@ -0,0 +1,150 @@ +use crate::leaf_function::find_leaf_functions; +use il::{Instruction, InstructionNode, Operand}; +use rust_decimal::Decimal; +use std::collections::{HashMap, HashSet}; + +/// Helper: Check if a function body contains unsafe stack manipulation. +fn function_has_complex_stack_ops( + instructions: &[InstructionNode], + start_idx: usize, + end_idx: usize, +) -> bool { + for instruction in instructions.iter().take(end_idx).skip(start_idx) { + match instruction.instruction { + Instruction::Push(_) | Instruction::Pop(_) => return true, + Instruction::Add(Operand::StackPointer, _, _) + | Instruction::Sub(Operand::StackPointer, _, _) + | Instruction::Mul(Operand::StackPointer, _, _) + | Instruction::Div(Operand::StackPointer, _, _) + | Instruction::Move(Operand::StackPointer, _) => return true, + _ => {} + } + } + false +} + +/// Pass: Leaf Function Optimization +/// If a function makes no calls (is a leaf), it doesn't need to save/restore `ra`. +pub fn optimize_leaf_functions<'a>( + input: Vec>, +) -> (Vec>, bool) { + let leaves = find_leaf_functions(&input); + if leaves.is_empty() { + return (input, false); + } + + let mut changed = false; + let mut to_remove = HashSet::new(); + let mut func_restore_indices = HashMap::new(); + let mut func_ra_offsets = HashMap::new(); + let mut current_function: Option = None; + let mut function_start_indices = HashMap::new(); + + // First scan: Identify instructions to remove and capture RA offsets + for (i, node) in input.iter().enumerate() { + match &node.instruction { + Instruction::LabelDef(label) if !label.starts_with("__internal_L") => { + current_function = Some(label.to_string()); + function_start_indices.insert(label.to_string(), i); + } + Instruction::Push(Operand::ReturnAddress) => { + if let Some(func) = ¤t_function + && leaves.contains(func) + { + to_remove.insert(i); + } + } + Instruction::Get(Operand::ReturnAddress, _, Operand::Register(_)) => { + if let Some(func) = ¤t_function + && leaves.contains(func) + { + to_remove.insert(i); + func_restore_indices.insert(func.clone(), i); + + // Look back for the address calc: `sub r0 sp OFFSET` + if i > 0 + && let Instruction::Sub(_, Operand::StackPointer, Operand::Number(n)) = + &input[i - 1].instruction + { + func_ra_offsets.insert(func.clone(), *n); + to_remove.insert(i - 1); + } + } + } + _ => {} + } + } + + // Safety Check: Verify functions don't have complex stack ops + let mut safe_functions = HashSet::new(); + + for (func, start_idx) in &function_start_indices { + if let Some(restore_idx) = func_restore_indices.get(func) { + let check_start = if to_remove.contains(&(start_idx + 1)) { + start_idx + 2 + } else { + start_idx + 1 + }; + + if !function_has_complex_stack_ops(&input, check_start, *restore_idx) { + safe_functions.insert(func.clone()); + changed = true; + } + } + } + + if !changed { + return (input, false); + } + + // Second scan: Rebuild with adjustments + let mut output = Vec::with_capacity(input.len()); + let mut processing_function: Option = None; + + for (i, mut node) in input.into_iter().enumerate() { + if to_remove.contains(&i) + && let Some(func) = &processing_function + && safe_functions.contains(func) + { + continue; + } + + if let Instruction::LabelDef(l) = &node.instruction + && !l.starts_with("__internal_L") + { + processing_function = Some(l.to_string()); + } + + // Apply Stack Adjustments + if let Some(func) = &processing_function + && safe_functions.contains(func) + && let Some(ra_offset) = func_ra_offsets.get(func) + { + // Stack Cleanup Adjustment + if let Instruction::Sub( + Operand::StackPointer, + Operand::StackPointer, + Operand::Number(n), + ) = &mut node.instruction + { + let new_n = *n - Decimal::from(1); + if new_n.is_zero() { + continue; + } + *n = new_n; + } + + // Stack Variable Offset Adjustment + if let Instruction::Sub(_, Operand::StackPointer, Operand::Number(n)) = + &mut node.instruction + && *n > *ra_offset + { + *n -= Decimal::from(1); + } + } + + output.push(node); + } + + (output, true) +} diff --git a/rust_compiler/libs/optimizer/src/lib.rs b/rust_compiler/libs/optimizer/src/lib.rs index 855f913..204e78d 100644 --- a/rust_compiler/libs/optimizer/src/lib.rs +++ b/rust_compiler/libs/optimizer/src/lib.rs @@ -1,9 +1,30 @@ -use il::{Instruction, InstructionNode, Instructions, Operand}; -use rust_decimal::Decimal; -use std::collections::{HashMap, HashSet}; +use il::Instructions; +// Optimization pass modules +mod helpers; mod leaf_function; -use leaf_function::find_leaf_functions; + +mod algebraic_simplification; +mod constant_propagation; +mod dead_code; +mod dead_store_elimination; +mod function_call_optimization; +mod label_resolution; +mod leaf_function_optimization; +mod peephole_optimization; +mod register_forwarding; +mod strength_reduction; + +use algebraic_simplification::algebraic_simplification; +use constant_propagation::constant_propagation; +use dead_code::{remove_redundant_moves, remove_unreachable_code}; +use dead_store_elimination::dead_store_elimination; +use function_call_optimization::optimize_function_calls; +use label_resolution::resolve_labels; +use leaf_function_optimization::optimize_leaf_functions; +use peephole_optimization::peephole_optimization; +use register_forwarding::register_forwarding; +use strength_reduction::strength_reduction; /// Entry point for the optimizer. pub fn optimize<'a>(instructions: Instructions<'a>) -> Instructions<'a> { @@ -38,845 +59,37 @@ pub fn optimize<'a>(instructions: Instructions<'a>) -> Instructions<'a> { instructions = new_inst; changed |= c4; - // Pass 5: Redundant Move Elimination - let (new_inst, c5) = remove_redundant_moves(instructions); + // Pass 5: Algebraic Simplification (Identity operations) + let (new_inst, c5) = algebraic_simplification(instructions); instructions = new_inst; changed |= c5; - // Pass 6: Dead Code Elimination - let (new_inst, c6) = remove_unreachable_code(instructions); + // Pass 6: Strength Reduction (Replace expensive ops with cheaper ones) + let (new_inst, c6) = strength_reduction(instructions); instructions = new_inst; changed |= c6; + + // Pass 7: Peephole Optimizations (Common patterns) + let (new_inst, c7) = peephole_optimization(instructions); + instructions = new_inst; + changed |= c7; + + // Pass 8: Dead Store Elimination + let (new_inst, c8) = dead_store_elimination(instructions); + instructions = new_inst; + changed |= c8; + + // Pass 9: Redundant Move Elimination + let (new_inst, c9) = remove_redundant_moves(instructions); + instructions = new_inst; + changed |= c9; + + // Pass 10: Dead Code Elimination + let (new_inst, c10) = remove_unreachable_code(instructions); + instructions = new_inst; + changed |= c10; } // Final Pass: Resolve Labels to Line Numbers Instructions::new(resolve_labels(instructions)) } - -/// Helper: Check if a function body contains unsafe stack manipulation. -/// Returns true if the function modifies SP in a way that makes static RA offset analysis unsafe. -fn function_has_complex_stack_ops( - instructions: &[InstructionNode], - start_idx: usize, - end_idx: usize, -) -> bool { - for instruction in instructions.iter().take(end_idx).skip(start_idx) { - match instruction.instruction { - Instruction::Push(_) | Instruction::Pop(_) => return true, - // Check for explicit SP modification - Instruction::Add(Operand::StackPointer, _, _) - | Instruction::Sub(Operand::StackPointer, _, _) - | Instruction::Mul(Operand::StackPointer, _, _) - | Instruction::Div(Operand::StackPointer, _, _) - | Instruction::Move(Operand::StackPointer, _) => return true, - _ => {} - } - } - false -} - -/// Pass: Leaf Function Optimization -/// If a function makes no calls (is a leaf), it doesn't need to save/restore `ra`. -fn optimize_leaf_functions<'a>( - input: Vec>, -) -> (Vec>, bool) { - let leaves = find_leaf_functions(&input); - if leaves.is_empty() { - return (input, false); - } - - let mut changed = false; - let mut to_remove = HashSet::new(); - - // We map function names to the INDEX of the instruction that restores RA. - // We use this to validate the function body later. - let mut func_restore_indices = HashMap::new(); - let mut func_ra_offsets = HashMap::new(); - - let mut current_function: Option = None; - let mut function_start_indices = HashMap::new(); - - // First scan: Identify instructions to remove and capture RA offsets - for (i, node) in input.iter().enumerate() { - match &node.instruction { - Instruction::LabelDef(label) if !label.starts_with("__internal_L") => { - current_function = Some(label.to_string()); - function_start_indices.insert(label.to_string(), i); - } - Instruction::Push(Operand::ReturnAddress) => { - if let Some(func) = ¤t_function - && leaves.contains(func) - { - to_remove.insert(i); - } - } - Instruction::Get(Operand::ReturnAddress, _, Operand::Register(_)) => { - // This is the restore instruction: `get ra db r0` - if let Some(func) = ¤t_function - && leaves.contains(func) - { - to_remove.insert(i); - func_restore_indices.insert(func.clone(), i); - - // Look back for the address calc: `sub r0 sp OFFSET` - if i > 0 - && let Instruction::Sub(_, Operand::StackPointer, Operand::Number(n)) = - &input[i - 1].instruction - { - func_ra_offsets.insert(func.clone(), *n); - to_remove.insert(i - 1); - } - } - } - _ => {} - } - } - - // Safety Check: Verify that functions marked for optimization don't have complex stack ops. - // If they do, unmark them. - let mut safe_functions = HashSet::new(); - - for (func, start_idx) in &function_start_indices { - if let Some(restore_idx) = func_restore_indices.get(func) { - // Check instructions between start and restore using the helper function. - // We need to skip the `push ra` we just marked for removal, otherwise the helper - // will flag it as a complex op (Push). - // `start_idx` is the LabelDef. `start_idx + 1` is typically `push ra`. - - let check_start = if to_remove.contains(&(start_idx + 1)) { - start_idx + 2 - } else { - start_idx + 1 - }; - - // `restore_idx` points to the `get ra` instruction. The helper scans up to `end_idx` exclusive, - // so we don't need to worry about the restore instruction itself. - if !function_has_complex_stack_ops(&input, check_start, *restore_idx) { - safe_functions.insert(func.clone()); - changed = true; - } - } - } - - if !changed { - return (input, false); - } - - // Second scan: Rebuild with adjustments, but only for SAFE functions - let mut output = Vec::with_capacity(input.len()); - let mut processing_function: Option = None; - - for (i, mut node) in input.into_iter().enumerate() { - if to_remove.contains(&i) - && let Some(func) = &processing_function - && safe_functions.contains(func) - { - continue; // SKIP (Remove) - } - - if let Instruction::LabelDef(l) = &node.instruction - && !l.starts_with("__internal_L") - { - processing_function = Some(l.to_string()); - } - - // Apply Stack Adjustments - if let Some(func) = &processing_function - && safe_functions.contains(func) - && let Some(ra_offset) = func_ra_offsets.get(func) - { - // 1. Stack Cleanup Adjustment - if let Instruction::Sub( - Operand::StackPointer, - Operand::StackPointer, - Operand::Number(n), - ) = &mut node.instruction - { - // Decrease cleanup amount by 1 (for the removed RA) - let new_n = *n - Decimal::from(1); - if new_n.is_zero() { - continue; - } - *n = new_n; - } - - // 2. Stack Variable Offset Adjustment - // Since we verified the function is "Simple" (no nested stack mods), - // we can safely assume offsets > ra_offset need shifting. - if let Instruction::Sub(_, Operand::StackPointer, Operand::Number(n)) = - &mut node.instruction - && *n > *ra_offset - { - *n -= Decimal::from(1); - } - } - - output.push(node); - } - - (output, true) -} - -/// Analyzes which registers are written to by each function label. -fn analyze_clobbers(instructions: &[InstructionNode]) -> HashMap> { - let mut clobbers = HashMap::new(); - let mut current_label = None; - - for node in instructions { - if let Instruction::LabelDef(label) = &node.instruction { - current_label = Some(label.to_string()); - clobbers.insert(label.to_string(), HashSet::new()); - } - - if let Some(label) = ¤t_label - && let Some(reg) = get_destination_reg(&node.instruction) - && let Some(set) = clobbers.get_mut(label) - { - set.insert(reg); - } - } - clobbers -} - -/// Pass: Function Call Optimization -/// Removes Push/Restore pairs surrounding a JAL if the target function does not clobber that register. -fn optimize_function_calls<'a>( - input: Vec>, -) -> (Vec>, bool) { - let clobbers = analyze_clobbers(&input); - let mut changed = false; - let mut to_remove = HashSet::new(); - let mut stack_adjustments = HashMap::new(); - - let mut i = 0; - while i < input.len() { - if let Instruction::JumpAndLink(Operand::Label(target)) = &input[i].instruction { - let target_key = target.to_string(); - - if let Some(func_clobbers) = clobbers.get(&target_key) { - // 1. Identify Pushes immediately preceding the JAL - let mut pushes = Vec::new(); // (index, register) - let mut scan_back = i.saturating_sub(1); - while scan_back > 0 { - if to_remove.contains(&scan_back) { - scan_back -= 1; - continue; - } - if let Instruction::Push(Operand::Register(r)) = &input[scan_back].instruction { - pushes.push((scan_back, *r)); - scan_back -= 1; - } else { - break; - } - } - - // 2. Identify Restores immediately following the JAL - let mut restores = Vec::new(); // (index_of_get, register, index_of_sub) - let mut scan_fwd = i + 1; - while scan_fwd < input.len() { - // Skip 'sub r0 sp X' - if let Instruction::Sub(Operand::Register(0), Operand::StackPointer, _) = - &input[scan_fwd].instruction - { - // Check next instruction for the Get - if scan_fwd + 1 < input.len() - && let Instruction::Get(Operand::Register(r), _, Operand::Register(0)) = - &input[scan_fwd + 1].instruction - { - restores.push((scan_fwd + 1, *r, scan_fwd)); - scan_fwd += 2; - continue; - } - } - break; - } - - // 3. Stack Cleanup - let cleanup_idx = scan_fwd; - let has_cleanup = if cleanup_idx < input.len() { - matches!( - input[cleanup_idx].instruction, - Instruction::Sub( - Operand::StackPointer, - Operand::StackPointer, - Operand::Number(_) - ) - ) - } else { - false - }; - - // SAFEGUARD: Check Counts! - // If we pushed r8 twice but only restored it once, we have an argument. - // We must ensure the number of pushes for each register MATCHES the number of restores. - let mut push_counts = HashMap::new(); - for (_, r) in &pushes { - *push_counts.entry(*r).or_insert(0) += 1; - } - - let mut restore_counts = HashMap::new(); - for (_, r, _) in &restores { - *restore_counts.entry(*r).or_insert(0) += 1; - } - - let counts_match = push_counts - .iter() - .all(|(reg, count)| restore_counts.get(reg).unwrap_or(&0) == count); - // Also check reverse to ensure we didn't restore something we didn't push (unlikely but possible) - let counts_match_reverse = restore_counts - .iter() - .all(|(reg, count)| push_counts.get(reg).unwrap_or(&0) == count); - - // Clobber Check - let all_pushes_safe = pushes.iter().all(|(_, r)| !func_clobbers.contains(r)); - - if all_pushes_safe && has_cleanup && counts_match && counts_match_reverse { - // We can remove ALL found pushes/restores safely - for (p_idx, _) in pushes { - to_remove.insert(p_idx); - } - for (g_idx, _, s_idx) in restores { - to_remove.insert(g_idx); - to_remove.insert(s_idx); - } - - // Reduce stack cleanup amount - let num_removed = push_counts.values().sum::() as i64; - stack_adjustments.insert(cleanup_idx, num_removed); - changed = true; - } - } - } - i += 1; - } - - if changed { - let mut clean = Vec::with_capacity(input.len()); - for (idx, mut node) in input.into_iter().enumerate() { - if to_remove.contains(&idx) { - continue; - } - - // Apply stack adjustment - if let Some(reduction) = stack_adjustments.get(&idx) - && let Instruction::Sub(dst, a, Operand::Number(n)) = &node.instruction - { - let new_n = n - Decimal::from(*reduction); - if new_n.is_zero() { - continue; // Remove the sub entirely if 0 - } - node.instruction = Instruction::Sub(dst.clone(), a.clone(), Operand::Number(new_n)); - } - - clean.push(node); - } - return (clean, changed); - } - - (input, false) -} - -/// Pass: Register Forwarding -/// Eliminates intermediate moves by writing directly to the final destination. -/// Example: `l r1 d0 T` + `move r9 r1` -> `l r9 d0 T` -fn register_forwarding<'a>( - mut input: Vec>, -) -> (Vec>, bool) { - let mut changed = false; - let mut i = 0; - - // We use a while loop to manually control index so we can peek ahead - while i < input.len().saturating_sub(1) { - let next_idx = i + 1; - - // Check if current instruction defines a register - // and the NEXT instruction is a move from that register. - let forward_candidate = if let Some(def_reg) = get_destination_reg(&input[i].instruction) { - if let Instruction::Move(Operand::Register(dest_reg), Operand::Register(src_reg)) = - &input[next_idx].instruction - { - if *src_reg == def_reg { - // Candidate found: Instruction `i` defines `src_reg`, Instruction `i+1` moves `src_reg` to `dest_reg`. - // We can optimize if `src_reg` (the temp) is NOT used after this move. - Some((def_reg, *dest_reg)) - } else { - None - } - } else { - None - } - } else { - None - }; - - if let Some((temp_reg, final_reg)) = forward_candidate { - // Check liveness: Is temp_reg used after i+1? - // We scan from i+2 onwards. - let mut temp_is_dead = true; - for node in input.iter().skip(i + 2) { - if reg_is_read(&node.instruction, temp_reg) { - temp_is_dead = false; - break; - } - // If the temp is redefined, then the old value is dead, so we are safe. - if let Some(redef) = get_destination_reg(&node.instruction) - && redef == temp_reg - { - break; - } - - // If we hit a label/jump, we assume liveness might leak (conservative safety) - if matches!( - node.instruction, - Instruction::LabelDef(_) | Instruction::Jump(_) | Instruction::JumpAndLink(_) - ) { - temp_is_dead = false; - break; - } - } - - if temp_is_dead { - // Perform the swap - // 1. Rewrite input[i] to write to final_reg - if let Some(new_instr) = set_destination_reg(&input[i].instruction, final_reg) { - input[i].instruction = new_instr; - // 2. Remove input[i+1] (The Move) - input.remove(next_idx); - changed = true; - // Don't increment i, re-evaluate current index (which is now a new neighbor) - continue; - } - } - } - - i += 1; - } - - (input, changed) -} - -/// Pass: Resolve Labels -/// Converts all Jump/Branch labels to absolute line numbers and removes LabelDefs. -fn resolve_labels<'a>(input: Vec>) -> Vec> { - let mut label_map: HashMap = HashMap::new(); - let mut line_number = 0; - - // 1. Build Label Map (filtering out LabelDefs from the count) - for node in &input { - if let Instruction::LabelDef(name) = &node.instruction { - label_map.insert(name.to_string(), line_number); - } else { - line_number += 1; - } - } - - let mut output = Vec::with_capacity(input.len()); - - // 2. Rewrite Jumps and Filter Labels - for mut node in input { - // Helper to get line number as Decimal operand - let get_line = |lbl: &Operand| -> Option> { - if let Operand::Label(name) = lbl { - label_map - .get(name.as_ref()) - .map(|&l| Operand::Number(Decimal::from(l))) - } else { - None - } - }; - - match &mut node.instruction { - Instruction::LabelDef(_) => continue, // Strip labels - - // Jumps - Instruction::Jump(op) => { - if let Some(num) = get_line(op) { - *op = num; - } - } - Instruction::JumpAndLink(op) => { - if let Some(num) = get_line(op) { - *op = num; - } - } - Instruction::BranchEq(_, _, op) - | Instruction::BranchNe(_, _, op) - | Instruction::BranchGt(_, _, op) - | Instruction::BranchLt(_, _, op) - | Instruction::BranchGe(_, _, op) - | Instruction::BranchLe(_, _, op) => { - if let Some(num) = get_line(op) { - *op = num; - } - } - Instruction::BranchEqZero(_, op) | Instruction::BranchNeZero(_, op) => { - if let Some(num) = get_line(op) { - *op = num; - } - } - _ => {} - } - output.push(node); - } - - output -} - -// --- Helpers for Register Analysis --- - -fn get_destination_reg(instr: &Instruction) -> Option { - match instr { - Instruction::Move(Operand::Register(r), _) - | Instruction::Add(Operand::Register(r), _, _) - | Instruction::Sub(Operand::Register(r), _, _) - | Instruction::Mul(Operand::Register(r), _, _) - | Instruction::Div(Operand::Register(r), _, _) - | Instruction::Mod(Operand::Register(r), _, _) - | Instruction::Pow(Operand::Register(r), _, _) - | Instruction::Load(Operand::Register(r), _, _) - | Instruction::LoadSlot(Operand::Register(r), _, _, _) - | Instruction::LoadBatch(Operand::Register(r), _, _, _) - | Instruction::LoadBatchNamed(Operand::Register(r), _, _, _, _) - | Instruction::SetEq(Operand::Register(r), _, _) - | Instruction::SetNe(Operand::Register(r), _, _) - | Instruction::SetGt(Operand::Register(r), _, _) - | Instruction::SetLt(Operand::Register(r), _, _) - | Instruction::SetGe(Operand::Register(r), _, _) - | Instruction::SetLe(Operand::Register(r), _, _) - | Instruction::And(Operand::Register(r), _, _) - | Instruction::Or(Operand::Register(r), _, _) - | Instruction::Xor(Operand::Register(r), _, _) - | Instruction::Peek(Operand::Register(r)) - | Instruction::Get(Operand::Register(r), _, _) - | Instruction::Select(Operand::Register(r), _, _, _) - | Instruction::Rand(Operand::Register(r)) - | Instruction::Acos(Operand::Register(r), _) - | Instruction::Asin(Operand::Register(r), _) - | Instruction::Atan(Operand::Register(r), _) - | Instruction::Atan2(Operand::Register(r), _, _) - | Instruction::Abs(Operand::Register(r), _) - | Instruction::Ceil(Operand::Register(r), _) - | Instruction::Cos(Operand::Register(r), _) - | Instruction::Floor(Operand::Register(r), _) - | Instruction::Log(Operand::Register(r), _) - | Instruction::Max(Operand::Register(r), _, _) - | Instruction::Min(Operand::Register(r), _, _) - | Instruction::Sin(Operand::Register(r), _) - | Instruction::Sqrt(Operand::Register(r), _) - | Instruction::Tan(Operand::Register(r), _) - | Instruction::Trunc(Operand::Register(r), _) - | Instruction::LoadReagent(Operand::Register(r), _, _, _) - | Instruction::Pop(Operand::Register(r)) => Some(*r), - _ => None, - } -} - -fn set_destination_reg<'a>(instr: &Instruction<'a>, new_reg: u8) -> Option> { - // Helper to easily recreate instruction with new dest - let r = Operand::Register(new_reg); - match instr { - Instruction::Move(_, b) => Some(Instruction::Move(r, b.clone())), - Instruction::Add(_, a, b) => Some(Instruction::Add(r, a.clone(), b.clone())), - Instruction::Sub(_, a, b) => Some(Instruction::Sub(r, a.clone(), b.clone())), - Instruction::Mul(_, a, b) => Some(Instruction::Mul(r, a.clone(), b.clone())), - Instruction::Div(_, a, b) => Some(Instruction::Div(r, a.clone(), b.clone())), - Instruction::Mod(_, a, b) => Some(Instruction::Mod(r, a.clone(), b.clone())), - Instruction::Pow(_, a, b) => Some(Instruction::Pow(r, a.clone(), b.clone())), - Instruction::Load(_, a, b) => Some(Instruction::Load(r, a.clone(), b.clone())), - Instruction::LoadSlot(_, a, b, c) => { - Some(Instruction::LoadSlot(r, a.clone(), b.clone(), c.clone())) - } - Instruction::LoadBatch(_, a, b, c) => { - Some(Instruction::LoadBatch(r, a.clone(), b.clone(), c.clone())) - } - Instruction::LoadBatchNamed(_, a, b, c, d) => Some(Instruction::LoadBatchNamed( - r, - a.clone(), - b.clone(), - c.clone(), - d.clone(), - )), - Instruction::LoadReagent(_, b, c, d) => { - Some(Instruction::LoadReagent(r, b.clone(), c.clone(), d.clone())) - } - Instruction::SetEq(_, a, b) => Some(Instruction::SetEq(r, a.clone(), b.clone())), - Instruction::SetNe(_, a, b) => Some(Instruction::SetNe(r, a.clone(), b.clone())), - Instruction::SetGt(_, a, b) => Some(Instruction::SetGt(r, a.clone(), b.clone())), - Instruction::SetLt(_, a, b) => Some(Instruction::SetLt(r, a.clone(), b.clone())), - Instruction::SetGe(_, a, b) => Some(Instruction::SetGe(r, a.clone(), b.clone())), - Instruction::SetLe(_, a, b) => Some(Instruction::SetLe(r, a.clone(), b.clone())), - Instruction::And(_, a, b) => Some(Instruction::And(r, a.clone(), b.clone())), - Instruction::Or(_, a, b) => Some(Instruction::Or(r, a.clone(), b.clone())), - Instruction::Xor(_, a, b) => Some(Instruction::Xor(r, a.clone(), b.clone())), - Instruction::Peek(_) => Some(Instruction::Peek(r)), - Instruction::Get(_, a, b) => Some(Instruction::Get(r, a.clone(), b.clone())), - Instruction::Select(_, a, b, c) => { - Some(Instruction::Select(r, a.clone(), b.clone(), c.clone())) - } - Instruction::Rand(_) => Some(Instruction::Rand(r)), - Instruction::Pop(_) => Some(Instruction::Pop(r)), - - // Math funcs - Instruction::Acos(_, a) => Some(Instruction::Acos(r, a.clone())), - Instruction::Asin(_, a) => Some(Instruction::Asin(r, a.clone())), - Instruction::Atan(_, a) => Some(Instruction::Atan(r, a.clone())), - Instruction::Atan2(_, a, b) => Some(Instruction::Atan2(r, a.clone(), b.clone())), - Instruction::Abs(_, a) => Some(Instruction::Abs(r, a.clone())), - Instruction::Ceil(_, a) => Some(Instruction::Ceil(r, a.clone())), - Instruction::Cos(_, a) => Some(Instruction::Cos(r, a.clone())), - Instruction::Floor(_, a) => Some(Instruction::Floor(r, a.clone())), - Instruction::Log(_, a) => Some(Instruction::Log(r, a.clone())), - Instruction::Max(_, a, b) => Some(Instruction::Max(r, a.clone(), b.clone())), - Instruction::Min(_, a, b) => Some(Instruction::Min(r, a.clone(), b.clone())), - Instruction::Sin(_, a) => Some(Instruction::Sin(r, a.clone())), - Instruction::Sqrt(_, a) => Some(Instruction::Sqrt(r, a.clone())), - Instruction::Tan(_, a) => Some(Instruction::Tan(r, a.clone())), - Instruction::Trunc(_, a) => Some(Instruction::Trunc(r, a.clone())), - - _ => None, - } -} - -fn reg_is_read(instr: &Instruction, reg: u8) -> bool { - let check = |op: &Operand| matches!(op, Operand::Register(r) if *r == reg); - - match instr { - Instruction::Move(_, a) => check(a), - Instruction::Add(_, a, b) - | Instruction::Sub(_, a, b) - | Instruction::Mul(_, a, b) - | Instruction::Div(_, a, b) - | Instruction::Mod(_, a, b) - | Instruction::Pow(_, a, b) => check(a) || check(b), - - Instruction::Load(_, a, _) => check(a), // Load reads device? Device can be reg? Yes. - Instruction::Store(a, _, b) => check(a) || check(b), - - Instruction::BranchEq(a, b, _) - | Instruction::BranchNe(a, b, _) - | Instruction::BranchGt(a, b, _) - | Instruction::BranchLt(a, b, _) - | Instruction::BranchGe(a, b, _) - | Instruction::BranchLe(a, b, _) => check(a) || check(b), - - Instruction::BranchEqZero(a, _) | Instruction::BranchNeZero(a, _) => check(a), - - Instruction::LoadReagent(_, device, _, item_hash) => check(device) || check(item_hash), - - Instruction::LoadSlot(_, dev, slot, _) => check(dev) || check(slot), - Instruction::LoadBatch(_, dev, _, mode) => check(dev) || check(mode), - Instruction::LoadBatchNamed(_, d_hash, n_hash, _, mode) => { - check(d_hash) || check(n_hash) || check(mode) - } - - Instruction::SetEq(_, a, b) - | Instruction::SetNe(_, a, b) - | Instruction::SetGt(_, a, b) - | Instruction::SetLt(_, a, b) - | Instruction::SetGe(_, a, b) - | Instruction::SetLe(_, a, b) - | Instruction::And(_, a, b) - | Instruction::Or(_, a, b) - | Instruction::Xor(_, a, b) => check(a) || check(b), - - Instruction::Push(a) => check(a), - Instruction::Get(_, a, b) => check(a) || check(b), - Instruction::Put(a, b, c) => check(a) || check(b) || check(c), - - Instruction::Select(_, a, b, c) => check(a) || check(b) || check(c), - Instruction::Sleep(a) => check(a), - - // Math single arg - Instruction::Acos(_, a) - | Instruction::Asin(_, a) - | Instruction::Atan(_, a) - | Instruction::Abs(_, a) - | Instruction::Ceil(_, a) - | Instruction::Cos(_, a) - | Instruction::Floor(_, a) - | Instruction::Log(_, a) - | Instruction::Sin(_, a) - | Instruction::Sqrt(_, a) - | Instruction::Tan(_, a) - | Instruction::Trunc(_, a) => check(a), - - // Math double arg - Instruction::Atan2(_, a, b) | Instruction::Max(_, a, b) | Instruction::Min(_, a, b) => { - check(a) || check(b) - } - - _ => false, - } -} - -/// --- Constant Propagation & Dead Code --- -fn constant_propagation<'a>(input: Vec>) -> (Vec>, bool) { - let mut output = Vec::with_capacity(input.len()); - let mut changed = false; - let mut registers: [Option; 16] = [None; 16]; - - for mut node in input { - match &node.instruction { - Instruction::LabelDef(_) | Instruction::JumpAndLink(_) => registers = [None; 16], - _ => {} - } - - let simplified = match &node.instruction { - Instruction::Move(dst, src) => resolve_value(src, ®isters) - .map(|val| Instruction::Move(dst.clone(), Operand::Number(val))), - Instruction::Add(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x + y), - Instruction::Sub(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x - y), - Instruction::Mul(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x * y), - Instruction::Div(dst, a, b) => { - try_fold_math( - dst, - a, - b, - ®isters, - |x, y| if y.is_zero() { x } else { x / y }, - ) - } - Instruction::Mod(dst, a, b) => { - try_fold_math( - dst, - a, - b, - ®isters, - |x, y| if y.is_zero() { x } else { x % y }, - ) - } - Instruction::BranchEq(a, b, l) => { - try_resolve_branch(a, b, l, ®isters, |x, y| x == y) - } - Instruction::BranchNe(a, b, l) => { - try_resolve_branch(a, b, l, ®isters, |x, y| x != y) - } - Instruction::BranchGt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x > y), - Instruction::BranchLt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x < y), - Instruction::BranchGe(a, b, l) => { - try_resolve_branch(a, b, l, ®isters, |x, y| x >= y) - } - Instruction::BranchLe(a, b, l) => { - try_resolve_branch(a, b, l, ®isters, |x, y| x <= y) - } - Instruction::BranchEqZero(a, l) => { - try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x == y) - } - Instruction::BranchNeZero(a, l) => { - try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x != y) - } - _ => None, - }; - - if let Some(new) = simplified { - node.instruction = new; - changed = true; - } - - // Update tracking - match &node.instruction { - Instruction::Move(Operand::Register(r), src) => { - registers[*r as usize] = resolve_value(src, ®isters) - } - // Invalidate if destination is register - _ => { - if let Some(r) = get_destination_reg(&node.instruction) { - registers[r as usize] = None; - } - } - } - - // Filter out NOPs (Empty LabelDefs from branch resolution) - if let Instruction::LabelDef(l) = &node.instruction - && l.is_empty() - { - changed = true; - continue; - } - - output.push(node); - } - (output, changed) -} - -fn resolve_value(op: &Operand, regs: &[Option; 16]) -> Option { - match op { - Operand::Number(n) => Some(*n), - Operand::Register(r) => regs[*r as usize], - _ => None, - } -} - -fn try_fold_math<'a, F>( - dst: &Operand<'a>, - a: &Operand<'a>, - b: &Operand<'a>, - regs: &[Option; 16], - op: F, -) -> Option> -where - F: Fn(Decimal, Decimal) -> Decimal, -{ - let val_a = resolve_value(a, regs)?; - let val_b = resolve_value(b, regs)?; - Some(Instruction::Move( - dst.clone(), - Operand::Number(op(val_a, val_b)), - )) -} - -fn try_resolve_branch<'a, F>( - a: &Operand<'a>, - b: &Operand<'a>, - label: &Operand<'a>, - regs: &[Option; 16], - check: F, -) -> Option> -where - F: Fn(Decimal, Decimal) -> bool, -{ - let val_a = resolve_value(a, regs)?; - let val_b = resolve_value(b, regs)?; - if check(val_a, val_b) { - Some(Instruction::Jump(label.clone())) - } else { - Some(Instruction::LabelDef("".into())) // NOP - } -} - -fn remove_redundant_moves<'a>(input: Vec>) -> (Vec>, bool) { - let mut output = Vec::with_capacity(input.len()); - let mut changed = false; - for node in input { - if let Instruction::Move(dst, src) = &node.instruction - && dst == src - { - changed = true; - continue; - } - output.push(node); - } - (output, changed) -} - -fn remove_unreachable_code<'a>( - input: Vec>, -) -> (Vec>, bool) { - let mut output = Vec::with_capacity(input.len()); - let mut changed = false; - let mut dead = false; - for node in input { - if let Instruction::LabelDef(_) = node.instruction { - dead = false; - } - if dead { - changed = true; - continue; - } - if let Instruction::Jump(_) = node.instruction { - dead = true - } - output.push(node); - } - (output, changed) -} diff --git a/rust_compiler/libs/optimizer/src/peephole_optimization.rs b/rust_compiler/libs/optimizer/src/peephole_optimization.rs new file mode 100644 index 0000000..42ef65f --- /dev/null +++ b/rust_compiler/libs/optimizer/src/peephole_optimization.rs @@ -0,0 +1,756 @@ +use il::{Instruction, InstructionNode, Operand}; + +/// Pass: Peephole Optimization +/// Recognizes and optimizes common instruction patterns. +pub fn peephole_optimization<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + let mut i = 0; + + while i < input.len() { + // Pattern: push sp; push ra ... pop ra; pop sp (with no jal in between) + // If we push sp and ra and later pop them, but never call a function in between, remove all four + // and adjust any stack pointer offsets in between by -2 + if i + 1 < input.len() { + if let ( + Instruction::Push(Operand::StackPointer), + Instruction::Push(Operand::ReturnAddress), + ) = (&input[i].instruction, &input[i + 1].instruction) + { + // Look for matching pop ra; pop sp pattern + if let Some((ra_pop_idx, instructions_between)) = + find_matching_ra_pop(&input[i + 1..]) + { + let absolute_ra_pop = i + 1 + ra_pop_idx; + // Check if the next instruction is pop sp + if absolute_ra_pop + 1 < input.len() { + if let Instruction::Pop(Operand::StackPointer) = + &input[absolute_ra_pop + 1].instruction + { + // Check if there's any jal between push and pop + let has_call = instructions_between.iter().any(|node| { + matches!(node.instruction, Instruction::JumpAndLink(_)) + }); + + if !has_call { + // Safe to remove all four: push sp, push ra, pop ra, pop sp + // Also need to adjust stack pointer offsets in between by -2 + let absolute_sp_pop = absolute_ra_pop + 1; + for (idx, node) in input.iter().enumerate() { + if idx == i + || idx == i + 1 + || idx == absolute_ra_pop + || idx == absolute_sp_pop + { + // Skip all four push/pop instructions + continue; + } + + // If this instruction is between the pushes and pops, adjust its stack offsets + if idx > i + 1 && idx < absolute_ra_pop { + let adjusted_instruction = + adjust_stack_offset(node.instruction.clone(), 2); + output.push(InstructionNode::new( + adjusted_instruction, + node.span, + )); + } else { + output.push(node.clone()); + } + } + changed = true; + // We've processed the entire input, so break + break; + } + } + } + } + } + } + + // Pattern: push ra ... pop ra (with no jal in between) + // Fallback for when there's only ra push/pop without sp + if let Instruction::Push(Operand::ReturnAddress) = &input[i].instruction { + if let Some((pop_idx, instructions_between)) = find_matching_ra_pop(&input[i..]) { + // Check if there's any jal between push and pop + let has_call = instructions_between + .iter() + .any(|node| matches!(node.instruction, Instruction::JumpAndLink(_))); + + if !has_call { + // Safe to remove both push and pop + // Also need to adjust stack pointer offsets in between + let absolute_pop_idx = i + pop_idx; + for (idx, node) in input.iter().enumerate() { + if idx == i || idx == absolute_pop_idx { + // Skip the push and pop + continue; + } + + // If this instruction is between push and pop, adjust its stack offsets + if idx > i && idx < absolute_pop_idx { + let adjusted_instruction = + adjust_stack_offset(node.instruction.clone(), 1); + output.push(InstructionNode::new(adjusted_instruction, node.span)); + } else { + output.push(node.clone()); + } + } + changed = true; + // We've processed the entire input, so break + break; + } + } + } + + // Pattern: Branch-Move-Jump-Label-Move-Label -> Select + // beqz r1 else_label + // move r2 val1 + // j end_label + // else_label: + // move r2 val2 + // end_label: + // Converts to: select r2 r1 val1 val2 + if i + 5 < input.len() { + let select_pattern = try_match_select_pattern(&input[i..i + 6]); + if let Some((dst, cond, true_val, false_val, skip_count)) = select_pattern { + output.push(InstructionNode::new( + Instruction::Select(dst, cond, true_val, false_val), + input[i].span, + )); + changed = true; + i += skip_count; + continue; + } + } + + // Pattern: seq + beqz -> beq + if i + 1 < input.len() { + let pattern = match (&input[i].instruction, &input[i + 1].instruction) { + ( + Instruction::SetEq(Operand::Register(temp), a, b), + Instruction::BranchEqZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Eq, true)), // invert: beqz means "if NOT equal" + + ( + Instruction::SetNe(Operand::Register(temp), a, b), + Instruction::BranchEqZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Ne, true)), + + ( + Instruction::SetGt(Operand::Register(temp), a, b), + Instruction::BranchEqZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Gt, true)), + + ( + Instruction::SetLt(Operand::Register(temp), a, b), + Instruction::BranchEqZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Lt, true)), + + ( + Instruction::SetGe(Operand::Register(temp), a, b), + Instruction::BranchEqZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Ge, true)), + + ( + Instruction::SetLe(Operand::Register(temp), a, b), + Instruction::BranchEqZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Le, true)), + + // Pattern: seq + bnez -> bne + ( + Instruction::SetEq(Operand::Register(temp), a, b), + Instruction::BranchNeZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Eq, false)), + + ( + Instruction::SetNe(Operand::Register(temp), a, b), + Instruction::BranchNeZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Ne, false)), + + ( + Instruction::SetGt(Operand::Register(temp), a, b), + Instruction::BranchNeZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Gt, false)), + + ( + Instruction::SetLt(Operand::Register(temp), a, b), + Instruction::BranchNeZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Lt, false)), + + ( + Instruction::SetGe(Operand::Register(temp), a, b), + Instruction::BranchNeZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Ge, false)), + + ( + Instruction::SetLe(Operand::Register(temp), a, b), + Instruction::BranchNeZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Le, false)), + + _ => None, + }; + + if let Some((a, b, label, branch_type, invert)) = pattern { + // Create optimized branch instruction + let new_instr = if invert { + // beqz after seq means "branch if NOT equal" -> bne + match branch_type { + BranchType::Eq => { + Instruction::BranchNe(a.clone(), b.clone(), label.clone()) + } + BranchType::Ne => { + Instruction::BranchEq(a.clone(), b.clone(), label.clone()) + } + BranchType::Gt => { + Instruction::BranchLe(a.clone(), b.clone(), label.clone()) + } + BranchType::Lt => { + Instruction::BranchGe(a.clone(), b.clone(), label.clone()) + } + BranchType::Ge => { + Instruction::BranchLt(a.clone(), b.clone(), label.clone()) + } + BranchType::Le => { + Instruction::BranchGt(a.clone(), b.clone(), label.clone()) + } + } + } else { + // bnez after seq means "branch if equal" -> beq + match branch_type { + BranchType::Eq => { + Instruction::BranchEq(a.clone(), b.clone(), label.clone()) + } + BranchType::Ne => { + Instruction::BranchNe(a.clone(), b.clone(), label.clone()) + } + BranchType::Gt => { + Instruction::BranchGt(a.clone(), b.clone(), label.clone()) + } + BranchType::Lt => { + Instruction::BranchLt(a.clone(), b.clone(), label.clone()) + } + BranchType::Ge => { + Instruction::BranchGe(a.clone(), b.clone(), label.clone()) + } + BranchType::Le => { + Instruction::BranchLe(a.clone(), b.clone(), label.clone()) + } + } + }; + + output.push(InstructionNode::new(new_instr, input[i].span)); + changed = true; + i += 2; // Skip both instructions + continue; + } + } + + output.push(input[i].clone()); + i += 1; + } + + (output, changed) +} + +/// Tries to match a select pattern in the instruction sequence. +/// Pattern (6 instructions): +/// beqz/bnez cond else_label (i+0) +/// move dst val1 (i+1) +/// j end_label (i+2) +/// else_label: (i+3) +/// move dst val2 (i+4) +/// end_label: (i+5) +/// Returns: (dst, cond, true_val, false_val, instruction_count) +fn try_match_select_pattern<'a>( + instructions: &[InstructionNode<'a>], +) -> Option<(Operand<'a>, Operand<'a>, Operand<'a>, Operand<'a>, usize)> { + if instructions.len() < 6 { + return None; + } + + // Check for beqz pattern + if let Instruction::BranchEqZero(cond, Operand::Label(else_label)) = + &instructions[0].instruction + { + if let Instruction::Move(dst1, val1) = &instructions[1].instruction { + if let Instruction::Jump(Operand::Label(end_label)) = &instructions[2].instruction { + if let Instruction::LabelDef(label3) = &instructions[3].instruction { + if label3 == else_label { + if let Instruction::Move(dst2, val2) = &instructions[4].instruction { + if dst1 == dst2 { + if let Instruction::LabelDef(label5) = &instructions[5].instruction + { + if label5 == end_label { + // beqz means: if cond==0, goto else, so val1 is for true, val2 for false + // select dst cond true_val false_val + // When cond is non-zero (true), use val1, otherwise val2 + return Some(( + dst1.clone(), + cond.clone(), + val1.clone(), + val2.clone(), + 6, + )); + } + } + } + } + } + } + } + } + } + + // Check for bnez pattern + if let Instruction::BranchNeZero(cond, Operand::Label(then_label)) = + &instructions[0].instruction + { + if let Instruction::Move(dst1, val_false) = &instructions[1].instruction { + if let Instruction::Jump(Operand::Label(end_label)) = &instructions[2].instruction { + if let Instruction::LabelDef(label3) = &instructions[3].instruction { + if label3 == then_label { + if let Instruction::Move(dst2, val_true) = &instructions[4].instruction { + if dst1 == dst2 { + if let Instruction::LabelDef(label5) = &instructions[5].instruction + { + if label5 == end_label { + // bnez means: if cond!=0, goto then, so val_true for true, val_false for false + return Some(( + dst1.clone(), + cond.clone(), + val_true.clone(), + val_false.clone(), + 6, + )); + } + } + } + } + } + } + } + } + } + + None +} + +/// Finds a matching `pop ra` for a `push ra` at the start of the slice. +/// Returns the index of the pop and the instructions in between. +fn find_matching_ra_pop<'a>( + instructions: &'a [InstructionNode<'a>], +) -> Option<(usize, &'a [InstructionNode<'a>])> { + if instructions.is_empty() { + return None; + } + + // Skip the push itself + for (idx, node) in instructions.iter().enumerate().skip(1) { + if let Instruction::Pop(Operand::ReturnAddress) = &node.instruction { + // Found matching pop + return Some((idx, &instructions[1..idx])); + } + + // Stop searching if we hit a jump (different control flow) + // Labels are OK - they're just markers + if matches!( + node.instruction, + Instruction::Jump(_) | Instruction::JumpRelative(_) + ) { + return None; + } + } + + None +} + +/// Checks if an instruction uses or modifies the stack pointer. +#[allow(dead_code)] +fn uses_stack_pointer(instruction: &Instruction) -> bool { + match instruction { + Instruction::Push(_) | Instruction::Pop(_) | Instruction::Peek(_) => true, + Instruction::Add(Operand::StackPointer, _, _) + | Instruction::Sub(Operand::StackPointer, _, _) + | Instruction::Mul(Operand::StackPointer, _, _) + | Instruction::Div(Operand::StackPointer, _, _) + | Instruction::Mod(Operand::StackPointer, _, _) => true, + Instruction::Add(_, Operand::StackPointer, _) + | Instruction::Sub(_, Operand::StackPointer, _) + | Instruction::Mul(_, Operand::StackPointer, _) + | Instruction::Div(_, Operand::StackPointer, _) + | Instruction::Mod(_, Operand::StackPointer, _) => true, + Instruction::Add(_, _, Operand::StackPointer) + | Instruction::Sub(_, _, Operand::StackPointer) + | Instruction::Mul(_, _, Operand::StackPointer) + | Instruction::Div(_, _, Operand::StackPointer) + | Instruction::Mod(_, _, Operand::StackPointer) => true, + Instruction::Move(Operand::StackPointer, _) + | Instruction::Move(_, Operand::StackPointer) => true, + _ => false, + } +} + +/// Adjusts stack pointer offsets in an instruction by decrementing them by a given amount. +/// This is necessary when removing push operations that would have increased the stack size. +fn adjust_stack_offset<'a>(instruction: Instruction<'a>, decrement: i64) -> Instruction<'a> { + use rust_decimal::prelude::*; + + match instruction { + // Adjust arithmetic operations on sp that use literal offsets + Instruction::Sub(dst, Operand::StackPointer, Operand::Number(n)) => { + let new_n = n - Decimal::from(decrement); + // If the result is 0 or negative, we may want to skip this entirely + // but for now, just adjust the value + Instruction::Sub(dst, Operand::StackPointer, Operand::Number(new_n)) + } + Instruction::Add(dst, Operand::StackPointer, Operand::Number(n)) => { + let new_n = n - Decimal::from(decrement); + Instruction::Add(dst, Operand::StackPointer, Operand::Number(new_n)) + } + // Return the instruction unchanged if it doesn't need adjustment + other => other, + } +} + +#[derive(Debug, Clone, Copy)] +enum BranchType { + Eq, + Ne, + Gt, + Lt, + Ge, + Le, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_seq_beqz_to_bne() { + let input = vec![ + InstructionNode::new( + Instruction::SetEq( + Operand::Register(1), + Operand::Register(2), + Operand::Register(3), + ), + None, + ), + InstructionNode::new( + Instruction::BranchEqZero(Operand::Register(1), Operand::Label("target".into())), + None, + ), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::BranchNe(_, _, _) + )); + } + + #[test] + fn test_sne_beqz_to_beq() { + let input = vec![ + InstructionNode::new( + Instruction::SetNe( + Operand::Register(1), + Operand::Register(2), + Operand::Register(3), + ), + None, + ), + InstructionNode::new( + Instruction::BranchEqZero(Operand::Register(1), Operand::Label("target".into())), + None, + ), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::BranchEq(_, _, _) + )); + } + + #[test] + fn test_seq_bnez_to_beq() { + let input = vec![ + InstructionNode::new( + Instruction::SetEq( + Operand::Register(1), + Operand::Register(2), + Operand::Register(3), + ), + None, + ), + InstructionNode::new( + Instruction::BranchNeZero(Operand::Register(1), Operand::Label("target".into())), + None, + ), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::BranchEq(_, _, _) + )); + } + + #[test] + fn test_sgt_beqz_to_ble() { + let input = vec![ + InstructionNode::new( + Instruction::SetGt( + Operand::Register(1), + Operand::Register(2), + Operand::Register(3), + ), + None, + ), + InstructionNode::new( + Instruction::BranchEqZero(Operand::Register(1), Operand::Label("target".into())), + None, + ), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::BranchLe(_, _, _) + )); + } + + #[test] + fn test_branch_move_jump_to_select_beqz() { + // Pattern: beqz r1 else / move r2 10 / j end / else: / move r2 20 / end: + // Should convert to: select r2 r1 10 20 + let input = vec![ + InstructionNode::new( + Instruction::BranchEqZero(Operand::Register(1), Operand::Label("else".into())), + None, + ), + InstructionNode::new( + Instruction::Move(Operand::Register(2), Operand::Number(10.into())), + None, + ), + InstructionNode::new(Instruction::Jump(Operand::Label("end".into())), None), + InstructionNode::new(Instruction::LabelDef("else".into()), None), + InstructionNode::new( + Instruction::Move(Operand::Register(2), Operand::Number(20.into())), + None, + ), + InstructionNode::new(Instruction::LabelDef("end".into()), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + if let Instruction::Select(dst, cond, true_val, false_val) = &output[0].instruction { + assert!(matches!(dst, Operand::Register(2))); + assert!(matches!(cond, Operand::Register(1))); + assert!(matches!(true_val, Operand::Number(_))); + assert!(matches!(false_val, Operand::Number(_))); + } else { + panic!("Expected Select instruction"); + } + } + + #[test] + fn test_branch_move_jump_to_select_bnez() { + // Pattern: bnez r1 then / move r2 20 / j end / then: / move r2 10 / end: + // Should convert to: select r2 r1 10 20 + let input = vec![ + InstructionNode::new( + Instruction::BranchNeZero(Operand::Register(1), Operand::Label("then".into())), + None, + ), + InstructionNode::new( + Instruction::Move(Operand::Register(2), Operand::Number(20.into())), + None, + ), + InstructionNode::new(Instruction::Jump(Operand::Label("end".into())), None), + InstructionNode::new(Instruction::LabelDef("then".into()), None), + InstructionNode::new( + Instruction::Move(Operand::Register(2), Operand::Number(10.into())), + None, + ), + InstructionNode::new(Instruction::LabelDef("end".into()), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + if let Instruction::Select(dst, cond, true_val, false_val) = &output[0].instruction { + assert!(matches!(dst, Operand::Register(2))); + assert!(matches!(cond, Operand::Register(1))); + assert!(matches!(true_val, Operand::Number(_))); + assert!(matches!(false_val, Operand::Number(_))); + } else { + panic!("Expected Select instruction"); + } + } + + #[test] + fn test_remove_useless_ra_push_pop() { + // Pattern: push ra / add r1 r2 r3 / pop ra + // Should remove both push and pop since no jal in between + let input = vec![ + InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None), + InstructionNode::new( + Instruction::Add( + Operand::Register(1), + Operand::Register(2), + Operand::Register(3), + ), + None, + ), + InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!(output[0].instruction, Instruction::Add(_, _, _))); + } + + #[test] + fn test_keep_ra_push_pop_with_jal() { + // Pattern: push ra / jal func / pop ra + // Should keep both since there's a jal in between + let input = vec![ + InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None), + InstructionNode::new( + Instruction::JumpAndLink(Operand::Label("func".into())), + None, + ), + InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(!changed); + assert_eq!(output.len(), 3); + } + + #[test] + fn test_ra_push_pop_with_stack_offset_adjustment() { + // Pattern: push ra / sub r1 sp 2 / pop ra + // Should remove push/pop AND adjust the stack offset from 2 to 1 + use rust_decimal::prelude::*; + + let input = vec![ + InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None), + InstructionNode::new( + Instruction::Sub( + Operand::Register(1), + Operand::StackPointer, + Operand::Number(Decimal::from(2)), + ), + None, + ), + InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + + if let Instruction::Sub(dst, src, Operand::Number(offset)) = &output[0].instruction { + assert!(matches!(dst, Operand::Register(1))); + assert!(matches!(src, Operand::StackPointer)); + assert_eq!(*offset, Decimal::from(1)); // Should be decremented from 2 to 1 + } else { + panic!("Expected Sub instruction with adjusted offset"); + } + } + + #[test] + fn test_remove_sp_and_ra_push_pop() { + // Pattern: push sp / push ra / move r8 10 / pop ra / pop sp + // Should remove all four push/pop instructions since no jal in between + let input = vec![ + InstructionNode::new(Instruction::Push(Operand::StackPointer), None), + InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None), + InstructionNode::new( + Instruction::Move(Operand::Register(8), Operand::Number(10.into())), + None, + ), + InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None), + InstructionNode::new(Instruction::Pop(Operand::StackPointer), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::Move(Operand::Register(8), _) + )); + } + + #[test] + fn test_keep_sp_and_ra_push_pop_with_jal() { + // Pattern: push sp / push ra / jal func / pop ra / pop sp + // Should keep all since there's a jal in between + let input = vec![ + InstructionNode::new(Instruction::Push(Operand::StackPointer), None), + InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None), + InstructionNode::new( + Instruction::JumpAndLink(Operand::Label("func".into())), + None, + ), + InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None), + InstructionNode::new(Instruction::Pop(Operand::StackPointer), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(!changed); + assert_eq!(output.len(), 5); + } + + #[test] + fn test_sp_and_ra_with_stack_offset_adjustment() { + // Pattern: push sp / push ra / sub r1 sp 3 / pop ra / pop sp + // Should remove all push/pop AND adjust the stack offset from 3 to 1 (decrement by 2) + use rust_decimal::prelude::*; + + let input = vec![ + InstructionNode::new(Instruction::Push(Operand::StackPointer), None), + InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None), + InstructionNode::new( + Instruction::Sub( + Operand::Register(1), + Operand::StackPointer, + Operand::Number(Decimal::from(3)), + ), + None, + ), + InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None), + InstructionNode::new(Instruction::Pop(Operand::StackPointer), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + + if let Instruction::Sub(dst, src, Operand::Number(offset)) = &output[0].instruction { + assert!(matches!(dst, Operand::Register(1))); + assert!(matches!(src, Operand::StackPointer)); + assert_eq!(*offset, Decimal::from(1)); // Should be decremented from 3 to 1 + } else { + panic!("Expected Sub instruction with adjusted offset"); + } + } +} diff --git a/rust_compiler/libs/optimizer/src/register_forwarding.rs b/rust_compiler/libs/optimizer/src/register_forwarding.rs new file mode 100644 index 0000000..3b9a7c1 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/register_forwarding.rs @@ -0,0 +1,108 @@ +use crate::helpers::{get_destination_reg, reg_is_read, set_destination_reg}; +use il::{Instruction, InstructionNode}; + +/// Pass: Register Forwarding +/// Eliminates intermediate moves by writing directly to the final destination. +/// Example: `l r1 d0 Temperature` + `move r9 r1` -> `l r9 d0 Temperature` +pub fn register_forwarding<'a>( + mut input: Vec>, +) -> (Vec>, bool) { + let mut changed = false; + let mut i = 0; + + while i < input.len().saturating_sub(1) { + let next_idx = i + 1; + + // Check if current instruction defines a register + // and the NEXT instruction is a move from that register. + let forward_candidate = if let Some(def_reg) = get_destination_reg(&input[i].instruction) { + if let Instruction::Move( + il::Operand::Register(dest_reg), + il::Operand::Register(src_reg), + ) = &input[next_idx].instruction + { + if *src_reg == def_reg { + Some((def_reg, *dest_reg)) + } else { + None + } + } else { + None + } + } else { + None + }; + + if let Some((temp_reg, final_reg)) = forward_candidate { + // Check liveness: Is temp_reg used after i+1? + let mut temp_is_dead = true; + for node in input.iter().skip(i + 2) { + if reg_is_read(&node.instruction, temp_reg) { + temp_is_dead = false; + break; + } + // If the temp is redefined, then the old value is dead + if let Some(redef) = get_destination_reg(&node.instruction) + && redef == temp_reg + { + break; + } + + // Conservative: assume liveness might leak at labels/jumps + if matches!( + node.instruction, + Instruction::LabelDef(_) | Instruction::Jump(_) | Instruction::JumpAndLink(_) + ) { + temp_is_dead = false; + break; + } + } + + if temp_is_dead { + // Rewrite to use final destination directly + if let Some(new_instr) = set_destination_reg(&input[i].instruction, final_reg) { + input[i].instruction = new_instr; + input.remove(next_idx); + changed = true; + continue; + } + } + } + + i += 1; + } + + (input, changed) +} + +#[cfg(test)] +mod tests { + use super::*; + use il::{Instruction, InstructionNode, Operand}; + + #[test] + fn test_forward_simple_move() { + let input = vec![ + InstructionNode::new( + Instruction::Add( + Operand::Register(1), + Operand::Register(2), + Operand::Register(3), + ), + None, + ), + InstructionNode::new( + Instruction::Move(Operand::Register(5), Operand::Register(1)), + None, + ), + ]; + + let (output, changed) = register_forwarding(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::Add(Operand::Register(5), _, _) + )); + } +} diff --git a/rust_compiler/libs/optimizer/src/strength_reduction.rs b/rust_compiler/libs/optimizer/src/strength_reduction.rs new file mode 100644 index 0000000..f5e4917 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/strength_reduction.rs @@ -0,0 +1,63 @@ +use il::{Instruction, InstructionNode, Operand}; +use rust_decimal::Decimal; + +/// Pass: Strength Reduction +/// Replaces expensive operations with cheaper equivalents. +/// Example: x * 2 -> add x x x (addition is typically faster than multiplication) +pub fn strength_reduction<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + + for mut node in input { + let reduced = match &node.instruction { + // x * 2 = x + x + Instruction::Mul(dst, a, Operand::Number(n)) if *n == Decimal::from(2) => { + Some(Instruction::Add(dst.clone(), a.clone(), a.clone())) + } + Instruction::Mul(dst, Operand::Number(n), b) if *n == Decimal::from(2) => { + Some(Instruction::Add(dst.clone(), b.clone(), b.clone())) + } + + // Future: Could add power-of-2 optimizations using bit shifts if IC10 supports them + // x * 4 = (x + x) + (x + x) or x << 2 + // x / 2 = x >> 1 + + _ => None, + }; + + if let Some(new) = reduced { + node.instruction = new; + changed = true; + } + + output.push(node); + } + + (output, changed) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mul_two_to_add() { + let input = vec![InstructionNode::new( + Instruction::Mul( + Operand::Register(1), + Operand::Register(2), + Operand::Number(Decimal::from(2)), + ), + None, + )]; + + let (output, changed) = strength_reduction(input); + assert!(changed); + assert!(matches!( + output[0].instruction, + Instruction::Add(Operand::Register(1), Operand::Register(2), Operand::Register(2)) + )); + } +}