diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000..dab5683 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,239 @@ +# Slang Language Compiler - AI Agent Instructions + +## Project Overview + +**Slang** is a high-level programming language that compiles to IC10 assembly for the game Stationeers. The compiler is a multi-stage Rust system with a C# BepInEx mod integration layer. + +**Key Goal:** Reduce manual IC10 assembly writing by providing C-like syntax with automatic register allocation and device abstraction. + +## Architecture Overview + +### Compilation Pipeline + +The compiler follows a strict 4-stage pipeline (in [rust_compiler/libs/compiler/src/v1.rs](rust_compiler/libs/compiler/src/v1.rs)): + +1. **Tokenizer** (libs/tokenizer/src/lib.rs) - Lexical analysis using `logos` crate + + - Converts source text into tokens + - Tracks line/span information for error reporting + - Supports temperature literals (c/f/k suffixes) + +2. **Parser** (libs/parser/src/lib.rs) - AST construction + + - Recursive descent parser producing `Expression` tree + - Validates syntax, handles device declarations, function definitions + - Output: `Expression` enum containing tree nodes + +3. **Compiler (v1)** (libs/compiler/src/v1.rs) - Semantic analysis & code generation + + - Variable scope management and register allocation via `VariableManager` + - Emits IL instructions to `il::Instructions` + - Error types use `lsp_types::Diagnostic` for editor integration + +4. **Optimizer** (libs/optimizer/src/lib.rs) - Post-generation optimization + - Currently optimizes leaf functions + - Optional pass before final output + +### Cross-Language Integration + +- **Rust Library** (`slang.dll`/`.so`): Core compiler logic via `safer-ffi` C FFI bindings +- **C# Mod** (`StationeersSlang.dll`): BepInEx plugin integrating with game UI +- **Generated Headers** (via `generate-headers` binary): Auto-generated C# bindings from Rust + +### Key Types & Data Flow + +- `Expression` tree (parser) → `v1::Compiler` processes → `il::Instructions` output +- `InstructionNode` wraps IC10 assembly with optional source span for debugging +- `VariableManager` tracks scopes, tracks const/device/let distinctions +- `Operand` enum represents register/literal/device-property values + +## Critical Workflows + +### Building + +```bash +cd rust_compiler +# Build for both Linux and Windows targets +cargo build --release --target=x86_64-unknown-linux-gnu +cargo build --release --target=x86_64-pc-windows-gnu + +# Generate C# FFI headers (requires "headers" feature) +cargo run --features headers --bin generate-headers + +# Full build (run from root) +./build.sh +``` + +### Testing + +```bash +cd rust_compiler +# Run all tests +cargo test --package compiler --lib + +# Run specific test file +cargo test --package compiler --lib tuple_literals + +# Run single test +cargo test --package compiler --lib -- test::tuple_literals::test::test_tuple_literal_size_mismatch --exact --nocapture +``` + +### Quick Compilation + +!IMPORTANT: make sure you use these commands instead of creating temporary files. + +```bash +cd rust_compiler +# Compile Slang code to IC10 using current compiler changes +echo 'let x = 5;' | cargo run --bin slang +# Compile Slang code to IC10 with optimization +echo 'let x = 5;' | cargo run --bin slang -z +# Or from file +cargo run --bin slang -- input.slang -o output.ic10 +# Optimize the output with -z flag +cargo run --bin slang -- input.slang -o output.ic10 -z +``` + +## Codebase Patterns + +### Test Structure + +Tests follow a macro pattern in [libs/compiler/src/test/mod.rs](rust_compiler/libs/compiler/src/test/mod.rs): + +```rust +#[test] +fn test_name() -> Result<()> { + let output = compile!(debug "slang code here"); + assert_eq!( + output, + indoc! { + "Expected IC10 output here" + } + ); + Ok(()) +} +``` + +- `compile!()` macro: full pipeline from source to IC10 +- `compile!(result ...)` for error checking +- `compile!(debug ...)` for intermediate IR inspection +- Test files organize by feature: `binary_expression.rs`, `syscall.rs`, `tuple_literals.rs`, etc. + +### Error Handling + +All stages return custom Error types implementing `From`: + +- `tokenizer::Error` - Lexical errors +- `parser::Error<'a>` - Syntax errors +- `compiler::Error<'a>` - Semantic errors (unknown identifier, type mismatch) +- Device assignment prevention: `DeviceAssignment` error if reassigning device const + +### Variable Scope Management + +[variable_manager.rs](rust_compiler/libs/compiler/src/variable_manager.rs) handles: + +- Tracking const vs mutable (let) distinction +- Device declarations as special scope items +- Function-local scopes with parameter handling +- Register allocation via `VariableLocation` + +### LSP Integration + +Error types implement conversion to `lsp_types::Diagnostic` for IDE feedback: + +```rust +impl<'a> From> for lsp_types::Diagnostic { ... } +``` + +This enables real-time error reporting in the Stationeers IC10 Editor mod. + +## Project-Specific Conventions + +### Tuple Destructuring + +The compiler supports tuple returns and multi-assignment: + +```rust +let (x, y) = func(); // TupleDeclarationExpression +(x, y) = another_func(); // TupleAssignmentExpression +``` + +Compiler validates size matching with `TupleSizeMismatch` error. + +### Device Property Access + +Devices are first-class with property access: + +```rust +device ac = "d0"; +ac.On = true; +ac.Temperature > 20c; +``` + +Parsed as `MemberAccessExpression`, compiled to device I/O syscalls. + +### Temperature Literals + +Unique language feature - automatic unit conversion at compile time: + +```rust +20c → 293.15k // Celsius to Kelvin +68f → 293.15k // Fahrenheit to Kelvin +``` + +Tokenizer produces `Literal::Number(Number(decimal, Some(Unit::Celsius)))`. + +### Constants are Immutable + +Once declared with `const`, reassignment is a compile error. Device assignment prevention is critical (prevents game logic bugs). + +## Integration Points + +### C# FFI (`csharp_mod/FfiGlue.cs`) + +- Calls Rust compiler via marshaled FFI +- Passes source code, receives IC10 output +- Marshals errors as `Diagnostic` objects + +### BepInEx Plugin Lifecycle + +[csharp_mod/Plugin.cs](csharp_mod/Plugin.cs): + +- Harmony patches for IC10 Editor integration +- Cleanup code for live-reload support (mod destruction) +- Logger integration for debug output + +### CI/Build Target Matrix + +- Linux: `x86_64-unknown-linux-gnu` +- Windows: `x86_64-pc-windows-gnu` (cross-compile from Linux) +- Both produce dynamic libraries + CLI binary + +## Debugging Tips + +1. **Print source spans:** `Span` type tracks line/column for error reporting +2. **IL inspection:** Use `compile!(debug source)` to view intermediate instructions +3. **Register allocation:** `VariableManager` logs scope changes; check for conflicts +4. **Syscall validation:** [parser/src/sys_call.rs](rust_compiler/libs/parser/src/sys_call.rs) lists all valid syscalls +5. **Tokenizer issues:** Check [tokenizer/src/token.rs](rust_compiler/libs/tokenizer/src/token.rs) for supported keywords/symbols + +## Key Files for Common Tasks + +| Task | File | +| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------- | +| Add language feature | [libs/parser/src/lib.rs](rust_compiler/libs/parser/src/lib.rs) + test in [libs/compiler/src/test/](rust_compiler/libs/compiler/src/test/) | +| Fix codegen bug | [libs/compiler/src/v1.rs](rust_compiler/libs/compiler/src/v1.rs) (~3500 lines) | +| Add syscall | [libs/parser/src/sys_call.rs](rust_compiler/libs/parser/src/sys_call.rs) | +| Optimize output | [libs/optimizer/src/lib.rs](rust_compiler/libs/optimizer/src/lib.rs) | +| Mod integration | [csharp_mod/](csharp_mod/) | +| Language docs | [docs/language-reference.md](docs/language-reference.md) | + +## Dependencies to Know + +- `logos` - Tokenizer with derive macros +- `rust_decimal` - Precise decimal arithmetic for temperature conversion +- `safer-ffi` - Safe C FFI between Rust and C# +- `lsp-types` - Standard for editor diagnostics +- `thiserror` - Error type derivation +- `clap` - CLI argument parsing +- `anyhow` - Error handling in main binary diff --git a/.gitignore b/.gitignore index 78b5f00..4566789 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ target *.ic10 +*.snap.new release csharp_mod/bin obj diff --git a/Changelog.md b/Changelog.md index 681bb50..290ce09 100644 --- a/Changelog.md +++ b/Changelog.md @@ -1,5 +1,12 @@ # Changelog +[0.5.0] + +- Added full tuple support: declarations, assignments, and returns +- Refactored optimizer into modular passes with improved code generation +- Enhanced peephole optimizations and pattern recognition +- Comprehensive test coverage for edge cases and error handling + [0.4.7] - Added support for Windows CRLF endings diff --git a/ModData/About/About.xml b/ModData/About/About.xml index ad84f68..b6f8562 100644 --- a/ModData/About/About.xml +++ b/ModData/About/About.xml @@ -2,7 +2,7 @@ Slang JoeDiertay - 0.4.7 + 0.5.0 [h1]Slang: High-Level Programming for Stationeers[/h1] diff --git a/csharp_mod/Plugin.cs b/csharp_mod/Plugin.cs index 17ac98d..15b0053 100644 --- a/csharp_mod/Plugin.cs +++ b/csharp_mod/Plugin.cs @@ -39,7 +39,7 @@ namespace Slang { public const string PluginGuid = "com.biddydev.slang"; public const string PluginName = "Slang"; - public const string PluginVersion = "0.4.7"; + public const string PluginVersion = "0.5.0"; private static Harmony? _harmony; diff --git a/csharp_mod/stationeersSlang.csproj b/csharp_mod/stationeersSlang.csproj index f62a17e..7bd7e07 100644 --- a/csharp_mod/stationeersSlang.csproj +++ b/csharp_mod/stationeersSlang.csproj @@ -5,7 +5,7 @@ enable StationeersSlang Slang Compiler Bridge - 0.4.2 + 0.5.0 true latest diff --git a/rust_compiler/Cargo.lock b/rust_compiler/Cargo.lock index bc28d3e..aed6d37 100644 --- a/rust_compiler/Cargo.lock +++ b/rust_compiler/Cargo.lock @@ -23,7 +23,7 @@ version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "891477e0c6a8957309ee5c45a6368af3ae14bb510732d2684ffa19af310920f9" dependencies = [ - "getrandom", + "getrandom 0.2.16", "once_cell", "version_check", ] @@ -73,7 +73,7 @@ version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" dependencies = [ - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -84,7 +84,7 @@ checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" dependencies = [ "anstyle", "once_cell_polyfill", - "windows-sys", + "windows-sys 0.61.2", ] [[package]] @@ -135,6 +135,12 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "bitflags" +version = "2.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" + [[package]] name = "bitvec" version = "1.0.1" @@ -278,6 +284,18 @@ dependencies = [ "tokenizer", ] +[[package]] +name = "console" +version = "0.15.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8" +dependencies = [ + "encode_unicode", + "libc", + "once_cell", + "windows-sys 0.59.0", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -293,12 +311,28 @@ version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" +[[package]] +name = "encode_unicode" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34aa73646ffb006b8f5147f3dc182bd4bcb190227ce861fc4a4844bf8e3cb2c0" + [[package]] name = "equivalent" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + [[package]] name = "ext-trait" version = "1.0.1" @@ -334,13 +368,19 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320bea982e85d42441eb25c49b41218e7eaa2657e8f90bc4eca7437376751e23" +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "fluent-uri" version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "17c704e9dbe1ddd863da1e6ff3567795087b1eb201ce80d8fa81162e1516500d" dependencies = [ - "bitflags", + "bitflags 1.3.2", ] [[package]] @@ -366,6 +406,18 @@ dependencies = [ "wasi", ] +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + [[package]] name = "gimli" version = "0.32.3" @@ -428,6 +480,32 @@ dependencies = [ "rustversion", ] +[[package]] +name = "insta" +version = "1.45.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "983e3b24350c84ab8a65151f537d67afbbf7153bb9f1110e03e9fa9b07f67a5c" +dependencies = [ + "console", + "once_cell", + "similar", + "tempfile", +] + +[[package]] +name = "integration_tests" +version = "0.1.0" +dependencies = [ + "anyhow", + "compiler", + "il", + "indoc", + "insta", + "optimizer", + "parser", + "tokenizer", +] + [[package]] name = "inventory" version = "0.3.21" @@ -465,6 +543,12 @@ version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "logos" version = "0.16.0" @@ -505,7 +589,7 @@ version = "0.97.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53353550a17c04ac46c585feb189c2db82154fc84b79c7a66c96c2c644f66071" dependencies = [ - "bitflags", + "bitflags 1.3.2", "fluent-uri", "serde", "serde_json", @@ -678,6 +762,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "radium" version = "0.7.0" @@ -711,7 +801,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom", + "getrandom 0.2.16", ] [[package]] @@ -800,6 +890,19 @@ dependencies = [ "semver", ] +[[package]] +name = "rustix" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +dependencies = [ + "bitflags 2.10.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -928,9 +1031,15 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + [[package]] name = "slang" -version = "0.4.7" +version = "0.5.0" dependencies = [ "anyhow", "clap", @@ -1014,6 +1123,19 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" +[[package]] +name = "tempfile" +version = "3.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655da9c7eb6305c55742045d5a8d2037996d61d8de95806335c7c86ce0f82e9c" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "thiserror" version = "2.0.17" @@ -1140,6 +1262,15 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "wasm-bindgen" version = "0.2.106" @@ -1191,6 +1322,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.61.2" @@ -1200,6 +1340,70 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "winnow" version = "0.7.14" @@ -1209,6 +1413,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + [[package]] name = "with_builtin_macros" version = "0.0.3" diff --git a/rust_compiler/Cargo.toml b/rust_compiler/Cargo.toml index e880ddf..550171e 100644 --- a/rust_compiler/Cargo.toml +++ b/rust_compiler/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "slang" -version = "0.4.7" +version = "0.5.0" edition = "2021" [workspace] diff --git a/rust_compiler/libs/compiler/src/test/binary_expression.rs b/rust_compiler/libs/compiler/src/test/binary_expression.rs index e4aebef..8acf677 100644 --- a/rust_compiler/libs/compiler/src/test/binary_expression.rs +++ b/rust_compiler/libs/compiler/src/test/binary_expression.rs @@ -4,15 +4,21 @@ use pretty_assertions::assert_eq; #[test] fn simple_binary_expression() -> Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let i = 1 + 2; " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -27,8 +33,8 @@ fn simple_binary_expression() -> Result<()> { #[test] fn nested_binary_expressions() -> Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " fn calculateArgs(arg1, arg2, arg3) { return (arg1 + arg2) * arg3; @@ -38,8 +44,14 @@ fn nested_binary_expressions() -> Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -47,6 +59,7 @@ fn nested_binary_expressions() -> Result<()> { pop r8 pop r9 pop r10 + push sp push ra add r1 r10 r9 mul r2 r1 r8 @@ -54,6 +67,7 @@ fn nested_binary_expressions() -> Result<()> { j __internal_L1 __internal_L1: pop ra + pop sp j ra main: push 10 @@ -72,15 +86,21 @@ fn nested_binary_expressions() -> Result<()> { #[test] fn stress_test_constant_folding() -> Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let negationHell = (-1 + -2) * (-3 + (-4 * (-5 + -6))); " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -95,16 +115,22 @@ fn stress_test_constant_folding() -> Result<()> { #[test] fn test_constant_folding_with_variables_mixed_in() -> Result<()> { - let compiled = compile! { - debug + let result = compile! { + check r#" device self = "db"; let i = 1 - 3 * (1 + 123.4) * self.Setting + 245c; "# }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -123,15 +149,21 @@ fn test_constant_folding_with_variables_mixed_in() -> Result<()> { #[test] fn test_ternary_expression() -> Result<()> { - let compiled = compile! { - debug + let result = compile! { + check r#" let i = 1 > 2 ? 15 : 20; "# }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -148,16 +180,22 @@ fn test_ternary_expression() -> Result<()> { #[test] fn test_ternary_expression_assignment() -> Result<()> { - let compiled = compile! { - debug + let result = compile! { + check r#" let i = 0; i = 1 > 2 ? 15 : 20; "# }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -175,15 +213,21 @@ fn test_ternary_expression_assignment() -> Result<()> { #[test] fn test_negative_literals() -> Result<()> { - let compiled = compile!( - debug + let result = compile!( + check r#" let item = -10c - 20c; "# ); + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -198,16 +242,22 @@ fn test_negative_literals() -> Result<()> { #[test] fn test_mismatched_temperature_literals() -> Result<()> { - let compiled = compile!( - debug + let result = compile!( + check r#" let item = -10c - 100k; let item2 = item + 500c; "# ); + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main diff --git a/rust_compiler/libs/compiler/src/test/branching.rs b/rust_compiler/libs/compiler/src/test/branching.rs index 9addbe7..c2561a0 100644 --- a/rust_compiler/libs/compiler/src/test/branching.rs +++ b/rust_compiler/libs/compiler/src/test/branching.rs @@ -3,8 +3,8 @@ use pretty_assertions::assert_eq; #[test] fn test_if_statement() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let a = 10; if (a > 5) { @@ -13,8 +13,14 @@ fn test_if_statement() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -33,8 +39,8 @@ fn test_if_statement() -> anyhow::Result<()> { #[test] fn test_if_else_statement() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let a = 0; if (10 > 5) { @@ -45,8 +51,14 @@ fn test_if_else_statement() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -68,8 +80,8 @@ fn test_if_else_statement() -> anyhow::Result<()> { #[test] fn test_if_else_if_statement() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let a = 0; if (a == 1) { @@ -82,8 +94,14 @@ fn test_if_else_if_statement() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -111,8 +129,8 @@ fn test_if_else_if_statement() -> anyhow::Result<()> { #[test] fn test_spilled_variable_update_in_branch() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let a = 1; let b = 2; @@ -129,8 +147,14 @@ fn test_spilled_variable_update_in_branch() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main diff --git a/rust_compiler/libs/compiler/src/test/declaration_function_invocation.rs b/rust_compiler/libs/compiler/src/test/declaration_function_invocation.rs index adfcdfd..c21aad4 100644 --- a/rust_compiler/libs/compiler/src/test/declaration_function_invocation.rs +++ b/rust_compiler/libs/compiler/src/test/declaration_function_invocation.rs @@ -3,21 +3,29 @@ use pretty_assertions::assert_eq; #[test] fn no_arguments() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " fn doSomething() {}; let i = doSomething(); " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + let to_test = indoc! { " j main doSomething: + push sp push ra __internal_L1: pop ra + pop sp j ra main: jal doSomething @@ -25,15 +33,15 @@ fn no_arguments() -> anyhow::Result<()> { " }; - assert_eq!(compiled, to_test); + assert_eq!(result.output, to_test); Ok(()) } #[test] fn let_var_args() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " fn mul2(arg1) { return arg1 * 2; @@ -46,19 +54,27 @@ fn let_var_args() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main mul2: pop r8 + push sp push ra mul r1 r8 2 move r15 r1 j __internal_L1 __internal_L1: pop ra + pop sp j ra main: __internal_L2: @@ -99,8 +115,8 @@ fn incorrect_args_count() -> anyhow::Result<()> { #[test] fn inline_literal_args() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " fn doSomething(arg1, arg2) { return 5; @@ -110,19 +126,27 @@ fn inline_literal_args() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main doSomething: pop r8 pop r9 + push sp push ra move r15 5 j __internal_L1 __internal_L1: pop ra + pop sp j ra main: move r8 123 @@ -141,8 +165,8 @@ fn inline_literal_args() -> anyhow::Result<()> { #[test] fn mixed_args() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let arg1 = 123; let returnValue = doSomething(arg1, 456); @@ -150,17 +174,25 @@ fn mixed_args() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main doSomething: pop r8 pop r9 + push sp push ra __internal_L1: pop ra + pop sp j ra main: move r8 123 @@ -179,8 +211,8 @@ fn mixed_args() -> anyhow::Result<()> { #[test] fn with_return_statement() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " fn doSomething(arg1) { return 456; @@ -190,18 +222,26 @@ fn with_return_statement() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main doSomething: pop r8 + push sp push ra move r15 456 j __internal_L1 __internal_L1: pop ra + pop sp j ra main: push 123 @@ -216,8 +256,8 @@ fn with_return_statement() -> anyhow::Result<()> { #[test] fn with_negative_return_literal() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " fn doSomething() { return -1; @@ -226,16 +266,24 @@ fn with_negative_return_literal() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main doSomething: + push sp push ra move r15 -1 __internal_L1: pop ra + pop sp j ra main: jal doSomething diff --git a/rust_compiler/libs/compiler/src/test/declaration_literal.rs b/rust_compiler/libs/compiler/src/test/declaration_literal.rs index 20f01f5..10ea391 100644 --- a/rust_compiler/libs/compiler/src/test/declaration_literal.rs +++ b/rust_compiler/libs/compiler/src/test/declaration_literal.rs @@ -4,13 +4,19 @@ use pretty_assertions::assert_eq; #[test] fn variable_declaration_numeric_literal() -> anyhow::Result<()> { let compiled = crate::compile! { - debug r#" + check r#" let i = 20c; "# }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -26,7 +32,7 @@ fn variable_declaration_numeric_literal() -> anyhow::Result<()> { #[test] fn variable_declaration_numeric_literal_stack_spillover() -> anyhow::Result<()> { let compiled = compile! { - debug + check r#" let a = 0; let b = 1; @@ -40,8 +46,14 @@ fn variable_declaration_numeric_literal_stack_spillover() -> anyhow::Result<()> let j = 9; "#}; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -67,14 +79,20 @@ fn variable_declaration_numeric_literal_stack_spillover() -> anyhow::Result<()> #[test] fn variable_declaration_negative() -> anyhow::Result<()> { let compiled = compile! { - debug + check " let i = -1; " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -90,15 +108,21 @@ fn variable_declaration_negative() -> anyhow::Result<()> { #[test] fn test_boolean_declaration() -> anyhow::Result<()> { let compiled = compile! { - debug + check " let t = true; let f = false; " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -115,7 +139,7 @@ fn test_boolean_declaration() -> anyhow::Result<()> { #[test] fn test_boolean_return() -> anyhow::Result<()> { let compiled = compile! { - debug + check " fn getTrue() { return true; @@ -125,17 +149,25 @@ fn test_boolean_return() -> anyhow::Result<()> { " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main getTrue: + push sp push ra move r15 1 j __internal_L1 __internal_L1: pop ra + pop sp j ra main: jal getTrue @@ -149,15 +181,21 @@ fn test_boolean_return() -> anyhow::Result<()> { #[test] fn test_const_hash_expr() -> anyhow::Result<()> { - let compiled = compile!(debug r#" + let compiled = compile!(check r#" const nameHash = hash("AccessCard"); device self = "db"; self.Setting = nameHash; "#); + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -172,7 +210,7 @@ fn test_const_hash_expr() -> anyhow::Result<()> { #[test] fn test_declaration_is_const() -> anyhow::Result<()> { let compiled = compile! { - debug + check r#" const MAX = 100; @@ -180,8 +218,14 @@ fn test_declaration_is_const() -> anyhow::Result<()> { "# }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main diff --git a/rust_compiler/libs/compiler/src/test/device_access.rs b/rust_compiler/libs/compiler/src/test/device_access.rs new file mode 100644 index 0000000..e4dee6b --- /dev/null +++ b/rust_compiler/libs/compiler/src/test/device_access.rs @@ -0,0 +1,274 @@ +use indoc::indoc; +use pretty_assertions::assert_eq; + +#[test] +fn device_declaration() -> anyhow::Result<()> { + let compiled = compile! { + check " + device d0 = \"d0\"; + " + }; + + // Declaration only emits the jump label header + assert_eq!(compiled.output, "j main\n"); + + Ok(()) +} + +#[test] +fn device_property_read() -> anyhow::Result<()> { + let compiled = compile! { + check " + device ac = \"d0\"; + let temp = ac.Temperature; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + l r1 d0 Temperature + move r8 r1 + " + } + ); + + Ok(()) +} + +#[test] +fn device_property_write() -> anyhow::Result<()> { + let compiled = compile! { + check " + device ac = \"d0\"; + ac.On = 1; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + s d0 On 1 + " + } + ); + + Ok(()) +} + +#[test] +fn multiple_device_declarations() -> anyhow::Result<()> { + let compiled = compile! { + check " + device d0 = \"d0\"; + device d1 = \"d1\"; + device d2 = \"d2\"; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + // Declarations only emit the header when unused + assert_eq!(compiled.output, "j main\n"); + + Ok(()) +} + +#[test] +fn device_with_variable_interaction() -> anyhow::Result<()> { + let compiled = compile! { + check " + device sensor = \"d0\"; + let reading = sensor.Temperature; + let threshold = 373.15; + let alert = reading > threshold; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + l r1 d0 Temperature + move r8 r1 + move r9 373.15 + sgt r2 r8 r9 + move r10 r2 + " + } + ); + + Ok(()) +} + +#[test] +fn device_property_in_arithmetic() -> anyhow::Result<()> { + let compiled = compile! { + check " + device d0 = \"d0\"; + let result = d0.Temperature + 100; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + // Verify that we load property, add 100, and move to result + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + l r1 d0 Temperature + add r2 r1 100 + move r8 r2 + " + } + ); + + Ok(()) +} + +#[test] +fn device_used_in_function() -> anyhow::Result<()> { + let compiled = compile! { + check " + device d0 = \"d0\"; + + fn check_power() { + return d0.On; + }; + + let powered = check_power(); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + check_power: + push sp + push ra + l r1 d0 On + move r15 r1 + j __internal_L1 + __internal_L1: + pop ra + pop sp + j ra + main: + jal check_power + move r8 r15 + " + } + ); + + Ok(()) +} + +#[test] +fn device_in_conditional() -> anyhow::Result<()> { + let compiled = compile! { + check " + device d0 = \"d0\"; + + if (d0.On) { + let x = 1; + } + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + l r1 d0 On + beqz r1 __internal_L1 + move r8 1 + __internal_L1: + " + } + ); + + Ok(()) +} + +#[test] +fn device_property_with_underscore_name() -> anyhow::Result<()> { + let compiled = compile! { + check " + device cool_device = \"d0\"; + let value = cool_device.SomeProperty; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + l r1 d0 SomeProperty + move r8 r1 + " + } + ); + + Ok(()) +} diff --git a/rust_compiler/libs/compiler/src/test/edge_cases.rs b/rust_compiler/libs/compiler/src/test/edge_cases.rs new file mode 100644 index 0000000..62019cf --- /dev/null +++ b/rust_compiler/libs/compiler/src/test/edge_cases.rs @@ -0,0 +1,737 @@ +use indoc::indoc; +use pretty_assertions::assert_eq; + +#[test] +fn zero_value_handling() -> anyhow::Result<()> { + let result = compile! { + check " + let x = 0; + let y = x + 0; + let z = x * 100; + " + }; + + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + + assert_eq!( + result.output, + indoc! { + " + j main + main: + move r8 0 + add r1 r8 0 + move r9 r1 + mul r2 r8 100 + move r10 r2 + " + } + ); + + Ok(()) +} + +#[test] +fn negative_number_handling() -> anyhow::Result<()> { + let result = compile! { + check " + let x = -100; + let y = -x; + let z = -(-50); + " + }; + + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + + assert_eq!( + result.output, + indoc! { + " + j main + main: + move r8 -100 + sub r1 0 r8 + move r9 r1 + move r10 50 + " + } + ); + + Ok(()) +} + +#[test] +fn large_number_constants() -> anyhow::Result<()> { + let result = compile! { + check " + let x = 999999999; + let y = x + 1; + " + }; + + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + + assert_eq!( + result.output, + indoc! { + " + j main + main: + move r8 999999999 + add r1 r8 1 + move r9 r1 + " + } + ); + + Ok(()) +} + +#[test] +fn floating_point_precision() -> anyhow::Result<()> { + let result = compile! { + check " + let pi = 3.14159265; + let e = 2.71828182; + let sum = pi + e; + " + }; + + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + + assert_eq!( + result.output, + indoc! { + " + j main + main: + move r8 3.14159265 + move r9 2.71828182 + add r1 r8 r9 + move r10 r1 + " + } + ); + + Ok(()) +} + +#[test] +fn temperature_unit_conversion() -> anyhow::Result<()> { + let result = compile! { + check " + let celsius = 20c; + let fahrenheit = 68f; + let kelvin = 293.15k; + " + }; + + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + + assert_eq!( + result.output, + indoc! { + " + j main + main: + move r8 293.15 + move r9 293.15 + move r10 293.15 + " + } + ); + + Ok(()) +} + +#[test] +fn mixed_temperature_units() -> anyhow::Result<()> { + let compiled = compile! { + check " + let c = 0c; + let f = 32f; + let k = 273.15k; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 273.15 + move r9 273.15 + move r10 273.15 + " + } + ); + + Ok(()) +} + +#[test] +fn boolean_constant_folding() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = true; + let y = false; + let z = true && true; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 1 + move r9 0 + and r1 1 1 + move r10 r1 + " + } + ); + + Ok(()) +} + +#[test] +fn empty_block() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 5; + { + } + let y = x; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 5 + move r9 r8 + " + } + ); + + Ok(()) +} + +#[test] +fn multiple_statements_same_line() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 1; let y = 2; let z = 3; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 1 + move r9 2 + move r10 3 + " + } + ); + + Ok(()) +} + +#[test] +fn function_with_no_return() -> anyhow::Result<()> { + let compiled = compile! { + check " + fn no_return() { + let x = 5; + }; + + no_return(); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + no_return: + push sp + push ra + move r8 5 + __internal_L1: + pop ra + pop sp + j ra + main: + jal no_return + move r1 r15 + " + } + ); + + Ok(()) +} + +#[test] +fn deeply_nested_expressions() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = ((((((((1 + 2) + 3) + 4) + 5) + 6) + 7) + 8) + 9); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 45 + " + } + ); + + Ok(()) +} + +#[test] +fn constant_folding_with_operations() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 10 * 5 + 3 - 2; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 51 + " + } + ); + + Ok(()) +} + +#[test] +fn constant_folding_with_division() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 100 / 2 / 5; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 10 + " + } + ); + + Ok(()) +} + +#[test] +fn modulo_operation() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 17 % 5; + let y = 10 % 3; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 2 + move r9 1 + " + } + ); + + Ok(()) +} + +#[test] +fn exponentiation() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 2 ** 8; + let y = 3 ** 3; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + pow r1 2 8 + move r8 r1 + pow r2 3 3 + move r9 r2 + " + } + ); + + Ok(()) +} + +#[test] +fn comparison_with_zero() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 0 == 0; + let y = 0 < 1; + let z = 0 > -1; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + seq r1 0 0 + move r8 r1 + slt r2 0 1 + move r9 r2 + sgt r3 0 -1 + move r10 r3 + " + } + ); + + Ok(()) +} + +#[test] +fn boolean_negation_edge_cases() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = !0; + let y = !1; + let z = !100; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + seq r1 0 0 + move r8 r1 + seq r2 1 0 + move r9 r2 + seq r3 100 0 + move r10 r3 + " + } + ); + + Ok(()) +} + +#[test] +fn function_with_many_parameters() -> anyhow::Result<()> { + let compiled = compile! { + check " + fn many_params(a, b, c, d, e, f, g, h) { + return a + b + c + d + e + f + g + h; + }; + + let result = many_params(1, 2, 3, 4, 5, 6, 7, 8); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + many_params: + pop r8 + pop r9 + pop r10 + pop r11 + pop r12 + pop r13 + pop r14 + push sp + push ra + sub r0 sp 3 + get r1 db r0 + add r2 r1 r14 + add r3 r2 r13 + add r4 r3 r12 + add r5 r4 r11 + add r6 r5 r10 + add r7 r6 r9 + add r1 r7 r8 + move r15 r1 + j __internal_L1 + __internal_L1: + pop ra + pop sp + j ra + main: + push 1 + push 2 + push 3 + push 4 + push 5 + push 6 + push 7 + push 8 + jal many_params + move r8 r15 + " + } + ); + + Ok(()) +} + +#[test] +fn tuple_declaration_with_functions() -> anyhow::Result<()> { + let compiled = compile! { + check + r#" + device self = "db"; + fn doSomething() { + return (self.Setting, self.Temperature); + } + + let (setting, temperature) = doSomething(); + "# + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! {" + j main + doSomething: + push sp + push ra + l r1 db Setting + push r1 + l r2 db Temperature + push r2 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L1: + sub r0 sp 3 + get ra db r0 + j ra + main: + jal doSomething + pop r9 + pop r8 + move sp r15 + "} + ); + + Ok(()) +} + +#[test] +fn tuple_from_simple_function() -> anyhow::Result<()> { + let compiled = compile! { + check " + fn get_pair() { + return (1, 2); + } + + let (a, b) = get_pair(); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! {" + j main + get_pair: + push sp + push ra + push 1 + push 2 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L1: + sub r0 sp 3 + get ra db r0 + j ra + main: + jal get_pair + pop r9 + pop r8 + move sp r15 + "} + ); + + Ok(()) +} + +#[test] +fn tuple_from_expression_not_function() -> anyhow::Result<()> { + let compiled = compile! { + check " + let (a, b) = (5 + 3, 10 * 2); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! {" + j main + main: + move r8 8 + move r9 20 + "} + ); + + Ok(()) +} diff --git a/rust_compiler/libs/compiler/src/test/error_handling.rs b/rust_compiler/libs/compiler/src/test/error_handling.rs new file mode 100644 index 0000000..116c2a3 --- /dev/null +++ b/rust_compiler/libs/compiler/src/test/error_handling.rs @@ -0,0 +1,197 @@ +use crate::Error; +use crate::variable_manager::Error as ScopeError; + +#[test] +fn unknown_identifier_error() { + let errors = compile! { + result "let x = unknown_var;" + }; + + assert_eq!(errors.len(), 1); + match &errors[0] { + Error::UnknownIdentifier(name, _) => { + assert_eq!(name.as_ref(), "unknown_var"); + } + _ => panic!("Expected UnknownIdentifier error, got {:?}", errors[0]), + } +} + +#[test] +fn duplicate_identifier_error() { + let errors = compile! { + result " + let x = 5; + let x = 10; + " + }; + + assert_eq!(errors.len(), 1); + match &errors[0] { + Error::Scope(ScopeError::DuplicateVariable(name, _)) => { + assert_eq!(name.as_ref(), "x"); + } + _ => panic!("Expected DuplicateIdentifier error, got {:?}", errors[0]), + } +} + +#[test] +fn const_reassignment_error() { + let errors = compile! { + result " + const PI = 3.14; + PI = 2.71; + " + }; + + assert_eq!(errors.len(), 1); + match &errors[0] { + Error::ConstAssignment(name, _) => { + assert_eq!(name.as_ref(), "PI"); + } + _ => panic!("Expected ConstAssignment error, got {:?}", errors[0]), + } +} + +#[test] +fn unknown_function_call_error() { + let errors = compile! { + result " + let result = unknown_function(); + " + }; + + assert_eq!(errors.len(), 1); + match &errors[0] { + Error::UnknownIdentifier(name, _) => { + assert_eq!(name.as_ref(), "unknown_function"); + } + _ => panic!("Expected UnknownIdentifier error, got {:?}", errors[0]), + } +} + +#[test] +fn argument_mismatch_error() { + let errors = compile! { + result " + fn add(a, b) { + return a + b; + }; + + let result = add(1); + " + }; + + // The error should be an AgrumentMismatch + assert!( + errors + .iter() + .any(|e| matches!(e, Error::AgrumentMismatch(_, _))) + ); +} + +#[test] +fn tuple_size_mismatch_error() { + let errors = compile! { + result " + fn pair() { + return (1, 2); + }; + + let (x, y, z) = pair(); + " + }; + + assert!( + errors + .iter() + .any(|e| matches!(e, Error::TupleSizeMismatch(2, 3, _))) + ); +} + +#[test] +fn multiple_errors_reported() { + let errors = compile! { + result " + let x = unknown1; + let x = 5; + let y = unknown2; + " + }; + + // Should have at least 3 errors + assert!( + errors.len() >= 2, + "Expected at least 2 errors, got {}", + errors.len() + ); +} + +#[test] +fn return_outside_function_error() { + let errors = compile! { + result " + let x = 5; + return x; + " + }; + + // Should have an error about return outside function + assert!( + !errors.is_empty(), + "Expected error for return outside function" + ); +} + +#[test] +fn break_outside_loop_error() { + let errors = compile! { + result " + break; + " + }; + + assert!(!errors.is_empty(), "Expected error for break outside loop"); +} + +#[test] +fn continue_outside_loop_error() { + let errors = compile! { + result " + continue; + " + }; + + assert!( + !errors.is_empty(), + "Expected error for continue outside loop" + ); +} + +#[test] +fn device_reassignment_error() { + let errors = compile! { + result " + device d0 = \"d0\"; + device d0 = \"d1\"; + " + }; + + assert!( + errors + .iter() + .any(|e| matches!(e, Error::DuplicateIdentifier(_, _))) + ); +} + +#[test] +fn invalid_device_error() { + let errors = compile! { + result " + device d0 = \"d0\"; + d0 = \"d1\"; + " + }; + + // Device reassignment should fail + assert!(!errors.is_empty(), "Expected error for device reassignment"); +} diff --git a/rust_compiler/libs/compiler/src/test/function_declaration.rs b/rust_compiler/libs/compiler/src/test/function_declaration.rs index 8e24a54..f9d6f17 100644 --- a/rust_compiler/libs/compiler/src/test/function_declaration.rs +++ b/rust_compiler/libs/compiler/src/test/function_declaration.rs @@ -3,7 +3,7 @@ use pretty_assertions::assert_eq; #[test] fn test_function_declaration_with_spillover_params() -> anyhow::Result<()> { - let compiled = compile!(debug r#" + let compiled = compile!(check r#" // we need more than 4 params to 'spill' into a stack var fn doSomething(arg1, arg2, arg3, arg4, arg5, arg6, arg7, arg8, arg9) { return arg1 + arg2 + arg3 + arg4 + arg5 + arg6 + arg7 + arg8 + arg9; @@ -13,8 +13,14 @@ fn test_function_declaration_with_spillover_params() -> anyhow::Result<()> { let returned = doSomething(item1, 2, 3, 4, 5, 6, 7, 8, 9); "#); + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! {" j main doSomething: @@ -25,10 +31,11 @@ fn test_function_declaration_with_spillover_params() -> anyhow::Result<()> { pop r12 pop r13 pop r14 + push sp push ra - sub r0 sp 3 + sub r0 sp 4 get r1 db r0 - sub r0 sp 2 + sub r0 sp 3 get r2 db r0 add r3 r1 r2 add r4 r3 r14 @@ -42,7 +49,7 @@ fn test_function_declaration_with_spillover_params() -> anyhow::Result<()> { j __internal_L1 __internal_L1: pop ra - sub sp sp 2 + pop sp j ra main: move r8 1 @@ -67,7 +74,7 @@ fn test_function_declaration_with_spillover_params() -> anyhow::Result<()> { #[test] fn test_early_return() -> anyhow::Result<()> { - let compiled = compile!(debug r#" + let compiled = compile!(check r#" // This is a test function declaration with no body fn doSomething() { if (1 == 1) { @@ -79,12 +86,19 @@ fn test_early_return() -> anyhow::Result<()> { doSomething(); "#); + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main doSomething: + push sp push ra seq r1 1 1 beqz r1 __internal_L2 @@ -94,6 +108,7 @@ fn test_early_return() -> anyhow::Result<()> { j __internal_L1 __internal_L1: pop ra + pop sp j ra main: jal doSomething @@ -107,22 +122,30 @@ fn test_early_return() -> anyhow::Result<()> { #[test] fn test_function_declaration_with_register_params() -> anyhow::Result<()> { - let compiled = compile!(debug r#" + let compiled = compile!(check r#" // This is a test function declaration with no body fn doSomething(arg1, arg2) { }; "#); + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! {" j main doSomething: pop r8 pop r9 + push sp push ra __internal_L1: pop ra + pop sp j ra "} ); diff --git a/rust_compiler/libs/compiler/src/test/logic_expression.rs b/rust_compiler/libs/compiler/src/test/logic_expression.rs index b7699ea..4dce125 100644 --- a/rust_compiler/libs/compiler/src/test/logic_expression.rs +++ b/rust_compiler/libs/compiler/src/test/logic_expression.rs @@ -3,8 +3,8 @@ use pretty_assertions::assert_eq; #[test] fn test_comparison_expressions() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let isGreater = 10 > 5; let isLess = 5 < 10; @@ -15,8 +15,14 @@ fn test_comparison_expressions() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -42,8 +48,8 @@ fn test_comparison_expressions() -> anyhow::Result<()> { #[test] fn test_logical_and_or_not() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let logic1 = 1 && 1; let logic2 = 1 || 0; @@ -51,8 +57,14 @@ fn test_logical_and_or_not() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -72,15 +84,21 @@ fn test_logical_and_or_not() -> anyhow::Result<()> { #[test] fn test_complex_logic() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let logic = (10 > 5) && (5 < 10); " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -98,15 +116,21 @@ fn test_complex_logic() -> anyhow::Result<()> { #[test] fn test_math_with_logic() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let logic = (1 + 2) > 1; " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -122,15 +146,21 @@ fn test_math_with_logic() -> anyhow::Result<()> { #[test] fn test_boolean_in_logic() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let res = true && false; " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main @@ -146,8 +176,8 @@ fn test_boolean_in_logic() -> anyhow::Result<()> { #[test] fn test_invert_a_boolean() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let i = true; let y = !i; @@ -156,8 +186,14 @@ fn test_invert_a_boolean() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + assert_eq!( - compiled, + result.output, indoc! { " j main diff --git a/rust_compiler/libs/compiler/src/test/loops.rs b/rust_compiler/libs/compiler/src/test/loops.rs index 63ce5d9..40335fe 100644 --- a/rust_compiler/libs/compiler/src/test/loops.rs +++ b/rust_compiler/libs/compiler/src/test/loops.rs @@ -3,8 +3,8 @@ use pretty_assertions::assert_eq; #[test] fn test_infinite_loop() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let a = 0; loop { @@ -13,9 +13,15 @@ fn test_infinite_loop() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + // __internal_Labels: L1 (start), L2 (end) assert_eq!( - compiled, + result.output, indoc! { " j main @@ -35,8 +41,8 @@ fn test_infinite_loop() -> anyhow::Result<()> { #[test] fn test_loop_break() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let a = 0; loop { @@ -48,9 +54,15 @@ fn test_loop_break() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + // __internal_Labels: L1 (start), L2 (end), L3 (if end - implicit else label) assert_eq!( - compiled, + result.output, indoc! { " j main @@ -74,8 +86,8 @@ fn test_loop_break() -> anyhow::Result<()> { #[test] fn test_while_loop() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check " let a = 0; while (a < 10) { @@ -84,9 +96,15 @@ fn test_while_loop() -> anyhow::Result<()> { " }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + // __internal_Labels: L1 (start), L2 (end) assert_eq!( - compiled, + result.output, indoc! { " j main @@ -108,8 +126,8 @@ fn test_while_loop() -> anyhow::Result<()> { #[test] fn test_loop_continue() -> anyhow::Result<()> { - let compiled = compile! { - debug + let result = compile! { + check r#" let a = 0; loop { @@ -122,9 +140,15 @@ fn test_loop_continue() -> anyhow::Result<()> { "# }; + assert!( + result.errors.is_empty(), + "Expected no errors, got: {:?}", + result.errors + ); + // __internal_Labels: L1 (start), L2 (end), L3 (if end) assert_eq!( - compiled, + result.output, indoc! { " j main diff --git a/rust_compiler/libs/compiler/src/test/math_syscall.rs b/rust_compiler/libs/compiler/src/test/math_syscall.rs index db9bcef..f43545b 100644 --- a/rust_compiler/libs/compiler/src/test/math_syscall.rs +++ b/rust_compiler/libs/compiler/src/test/math_syscall.rs @@ -5,14 +5,20 @@ use pretty_assertions::assert_eq; #[test] fn test_acos() -> Result<()> { let compiled = compile! { - debug + check " let i = acos(123); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -29,14 +35,20 @@ fn test_acos() -> Result<()> { #[test] fn test_asin() -> Result<()> { let compiled = compile! { - debug + check " let i = asin(123); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -53,14 +65,20 @@ fn test_asin() -> Result<()> { #[test] fn test_atan() -> Result<()> { let compiled = compile! { - debug + check " let i = atan(123); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -77,14 +95,20 @@ fn test_atan() -> Result<()> { #[test] fn test_atan2() -> Result<()> { let compiled = compile! { - debug + check " let i = atan2(123, 456); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -101,14 +125,20 @@ fn test_atan2() -> Result<()> { #[test] fn test_abs() -> Result<()> { let compiled = compile! { - debug + check " let i = abs(-123); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -125,14 +155,20 @@ fn test_abs() -> Result<()> { #[test] fn test_ceil() -> Result<()> { let compiled = compile! { - debug + check " let i = ceil(123.90); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -149,14 +185,20 @@ fn test_ceil() -> Result<()> { #[test] fn test_cos() -> Result<()> { let compiled = compile! { - debug + check " let i = cos(123); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -173,14 +215,20 @@ fn test_cos() -> Result<()> { #[test] fn test_floor() -> Result<()> { let compiled = compile! { - debug + check " let i = floor(123); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -197,14 +245,20 @@ fn test_floor() -> Result<()> { #[test] fn test_log() -> Result<()> { let compiled = compile! { - debug + check " let i = log(123); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -221,14 +275,20 @@ fn test_log() -> Result<()> { #[test] fn test_max() -> Result<()> { let compiled = compile! { - debug + check " let i = max(123, 456); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -245,15 +305,21 @@ fn test_max() -> Result<()> { #[test] fn test_max_from_game() -> Result<()> { let compiled = compile! { - debug + check r#" let item = 0; item = max(1 + 2, 2); "# }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -271,14 +337,20 @@ fn test_max_from_game() -> Result<()> { #[test] fn test_min() -> Result<()> { let compiled = compile! { - debug + check " let i = min(123, 456); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -295,14 +367,20 @@ fn test_min() -> Result<()> { #[test] fn test_rand() -> Result<()> { let compiled = compile! { - debug + check " let i = rand(); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -319,14 +397,20 @@ fn test_rand() -> Result<()> { #[test] fn test_sin() -> Result<()> { let compiled = compile! { - debug + check " let i = sin(3); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -343,14 +427,20 @@ fn test_sin() -> Result<()> { #[test] fn test_sqrt() -> Result<()> { let compiled = compile! { - debug + check " let i = sqrt(3); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -367,14 +457,20 @@ fn test_sqrt() -> Result<()> { #[test] fn test_tan() -> Result<()> { let compiled = compile! { - debug + check " let i = tan(3); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -391,14 +487,20 @@ fn test_tan() -> Result<()> { #[test] fn test_trunc() -> Result<()> { let compiled = compile! { - debug + check " let i = trunc(3.234); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main diff --git a/rust_compiler/libs/compiler/src/test/mod.rs b/rust_compiler/libs/compiler/src/test/mod.rs index 68e35e9..c732971 100644 --- a/rust_compiler/libs/compiler/src/test/mod.rs +++ b/rust_compiler/libs/compiler/src/test/mod.rs @@ -6,6 +6,12 @@ macro_rules! output { }; } +/// Represents both compilation errors and compiled output +pub struct CompilationCheckResult { + pub errors: Vec>, + pub output: String, +} + #[cfg_attr(test, macro_export)] macro_rules! compile { ($source:expr) => {{ @@ -27,7 +33,7 @@ macro_rules! compile { compiler.compile().errors }}; - (debug $source:expr) => {{ + (check $source:expr) => {{ let mut writer = std::io::BufWriter::new(Vec::new()); let compiler = crate::Compiler::new( parser::Parser::new(tokenizer::Tokenizer::from($source)), @@ -35,15 +41,25 @@ macro_rules! compile { ); let res = compiler.compile(); res.instructions.write(&mut writer)?; - output!(writer) + let output = output!(writer); + crate::test::CompilationCheckResult { + errors: res.errors, + output, + } }}; } mod binary_expression; mod branching; mod declaration_function_invocation; mod declaration_literal; +mod device_access; +mod edge_cases; +mod error_handling; mod function_declaration; mod logic_expression; mod loops; mod math_syscall; +mod negation_priority; +mod scoping; mod syscall; +mod tuple_literals; diff --git a/rust_compiler/libs/compiler/src/test/negation_priority.rs b/rust_compiler/libs/compiler/src/test/negation_priority.rs new file mode 100644 index 0000000..4b4f5db --- /dev/null +++ b/rust_compiler/libs/compiler/src/test/negation_priority.rs @@ -0,0 +1,388 @@ +use indoc::indoc; +use pretty_assertions::assert_eq; + +#[test] +fn simple_negation() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = -5; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 -5 + " + } + ); + + Ok(()) +} + +#[test] +fn negation_of_variable() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 10; + let y = -x; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 10 + sub r1 0 r8 + move r9 r1 + " + } + ); + + Ok(()) +} + +#[test] +fn double_negation() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = -(-5); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 5 + " + } + ); + + Ok(()) +} + +#[test] +fn negation_in_expression() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 10 + (-5); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 5 + " + } + ); + + Ok(()) +} + +#[test] +fn negation_with_multiplication() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = -3 * 4; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 -12 + " + } + ); + + Ok(()) +} + +#[test] +fn parentheses_priority() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = (2 + 3) * 4; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 20 + " + } + ); + + Ok(()) +} + +#[test] +fn nested_parentheses() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = ((2 + 3) * (4 - 1)); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 15 + " + } + ); + + Ok(()) +} + +#[test] +fn parentheses_with_variables() -> anyhow::Result<()> { + let compiled = compile! { + check " + let a = 5; + let b = 10; + let c = (a + b) * 2; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + // Should calculate (5 + 10) * 2 = 30 + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 5 + move r9 10 + add r1 r8 r9 + mul r2 r1 2 + move r10 r2 + " + } + ); + + Ok(()) +} + +#[test] +fn priority_affects_result() -> anyhow::Result<()> { + let compiled = compile! { + check " + let with_priority = (2 + 3) * 4; + let without_priority = 2 + 3 * 4; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + // with_priority should be 20, without_priority should be 14 + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 20 + move r9 14 + " + } + ); + + Ok(()) +} + +#[test] +fn negation_of_expression() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = -(2 + 3); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + // Should be -5 (constant folded) + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + sub r1 0 5 + move r8 r1 + " + } + ); + + Ok(()) +} + +#[test] +fn complex_negation_and_priority() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = -((10 - 5) * 2); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + // Should be -(5 * 2) = -10 (folded to constant) + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + sub r1 0 10 + move r8 r1 + " + } + ); + + Ok(()) +} + +#[test] +fn negation_in_logical_expression() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = !(-5); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + // -5 is truthy, so !(-5) should be 0 + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + sub r1 0 5 + seq r2 r1 0 + move r8 r2 + " + } + ); + + Ok(()) +} + +#[test] +fn parentheses_in_comparison() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = (10 + 5) > (3 * 4); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + // (10 + 5) = 15 > (3 * 4) = 12, so true (1) + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + sgt r1 15 12 + move r8 r1 + " + } + ); + + Ok(()) +} diff --git a/rust_compiler/libs/compiler/src/test/scoping.rs b/rust_compiler/libs/compiler/src/test/scoping.rs new file mode 100644 index 0000000..7225cec --- /dev/null +++ b/rust_compiler/libs/compiler/src/test/scoping.rs @@ -0,0 +1,462 @@ +use indoc::indoc; +use pretty_assertions::assert_eq; + +#[test] +fn block_scope() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 10; + { + let y = 20; + let z = x + y; + } + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 10 + move r9 20 + add r1 r8 r9 + move r10 r1 + " + } + ); + + Ok(()) +} + +#[test] +fn variable_scope_isolation() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 10; + { + let x = 20; + let y = x; + } + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 10 + move r9 20 + move r10 r9 + " + } + ); + + Ok(()) +} + +#[test] +fn function_parameter_scope() -> anyhow::Result<()> { + let compiled = compile! { + check " + fn double(x) { + return x * 2; + }; + + let result = double(5); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + double: + pop r8 + push sp + push ra + mul r1 r8 2 + move r15 r1 + j __internal_L1 + __internal_L1: + pop ra + pop sp + j ra + main: + push 5 + jal double + move r8 r15 + " + } + ); + + Ok(()) +} + +#[test] +fn nested_block_scopes() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 1; + { + let x = 2; + { + let x = 3; + let y = x; + } + let z = x; + } + let w = x; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 1 + move r9 2 + move r10 3 + move r11 r10 + move r10 r9 + move r9 r8 + " + } + ); + + Ok(()) +} + +#[test] +fn variable_shadowing_in_conditional() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 10; + + if (true) { + let x = 20; + } + + let y = x; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 10 + beqz 1 __internal_L1 + move r9 20 + __internal_L1: + move r9 r8 + " + } + ); + + Ok(()) +} + +#[test] +fn variable_shadowing_in_loop() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 0; + + loop { + let x = x + 1; + if (x > 5) { + break; + } + } + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 0 + __internal_L1: + add r1 r8 1 + move r9 r1 + sgt r2 r9 5 + beqz r2 __internal_L3 + j __internal_L2 + __internal_L3: + j __internal_L1 + __internal_L2: + " + } + ); + + Ok(()) +} + +#[test] +fn const_scope() -> anyhow::Result<()> { + let compiled = compile! { + check " + const PI = 3.14; + + { + const PI = 2.71; + let x = PI; + } + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 2.71 + " + } + ); + + Ok(()) +} + +#[test] +fn device_in_scope() -> anyhow::Result<()> { + let compiled = compile! { + check " + device d0 = \"d0\"; + + { + let value = d0.Temperature; + } + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + l r1 d0 Temperature + move r8 r1 + " + } + ); + + Ok(()) +} + +#[test] +fn function_scope_isolation() -> anyhow::Result<()> { + let compiled = compile! { + check " + fn func1() { + let x = 10; + return x; + }; + + fn func2() { + let x = 20; + return x; + }; + + let a = func1(); + let b = func2(); + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + func1: + push sp + push ra + move r8 10 + move r15 r8 + j __internal_L1 + __internal_L1: + pop ra + pop sp + j ra + func2: + push sp + push ra + move r8 20 + move r15 r8 + j __internal_L2 + __internal_L2: + pop ra + pop sp + j ra + main: + jal func1 + move r8 r15 + push r8 + jal func2 + pop r8 + move r9 r15 + " + } + ); + + Ok(()) +} + +#[test] +fn tuple_unpacking_scope() -> anyhow::Result<()> { + let compiled = compile! { + check " + fn pair() { + return (1, 2); + }; + + { + let (x, y) = pair(); + let z = x + y; + } + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + pair: + push sp + push ra + push 1 + push 2 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L1: + sub r0 sp 3 + get ra db r0 + j ra + main: + jal pair + pop r9 + pop r8 + move sp r15 + add r1 r8 r9 + move r10 r1 + " + } + ); + + Ok(()) +} + +#[test] +fn shadowing_doesnt_affect_outer() -> anyhow::Result<()> { + let compiled = compile! { + check " + let x = 5; + let y = x; + { + let x = 10; + let z = x; + } + let w = x + y; + " + }; + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 5 + move r9 r8 + move r10 10 + move r11 r10 + add r1 r8 r9 + move r10 r1 + " + } + ); + + Ok(()) +} diff --git a/rust_compiler/libs/compiler/src/test/syscall.rs b/rust_compiler/libs/compiler/src/test/syscall.rs index f70da30..22ce165 100644 --- a/rust_compiler/libs/compiler/src/test/syscall.rs +++ b/rust_compiler/libs/compiler/src/test/syscall.rs @@ -4,14 +4,20 @@ use pretty_assertions::assert_eq; #[test] fn test_yield() -> anyhow::Result<()> { let compiled = compile! { - debug + check " yield(); " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -27,7 +33,7 @@ fn test_yield() -> anyhow::Result<()> { #[test] fn test_sleep() -> anyhow::Result<()> { let compiled = compile! { - debug + check " sleep(3); let sleepAmount = 15; @@ -36,8 +42,14 @@ fn test_sleep() -> anyhow::Result<()> { " }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -57,7 +69,7 @@ fn test_sleep() -> anyhow::Result<()> { #[test] fn test_set_on_device() -> anyhow::Result<()> { let compiled = compile! { - debug + check r#" device airConditioner = "d0"; let internalTemp = 20c; @@ -66,8 +78,14 @@ fn test_set_on_device() -> anyhow::Result<()> { "# }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -85,15 +103,21 @@ fn test_set_on_device() -> anyhow::Result<()> { #[test] fn test_set_on_device_batched() -> anyhow::Result<()> { let compiled = compile! { - debug + check r#" const doorHash = hash("Door"); setBatched(doorHash, "Lock", true); "# }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { r#" j main @@ -108,7 +132,7 @@ fn test_set_on_device_batched() -> anyhow::Result<()> { #[test] fn test_set_on_device_batched_named() -> anyhow::Result<()> { let compiled = compile! { - debug + check r#" device dev = "d0"; const devName = hash("test"); @@ -117,8 +141,14 @@ fn test_set_on_device_batched_named() -> anyhow::Result<()> { "# }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -134,7 +164,7 @@ fn test_set_on_device_batched_named() -> anyhow::Result<()> { #[test] fn test_load_from_device() -> anyhow::Result<()> { let compiled = compile! { - debug + check r#" device airCon = "d0"; @@ -142,8 +172,14 @@ fn test_load_from_device() -> anyhow::Result<()> { "# }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -160,7 +196,7 @@ fn test_load_from_device() -> anyhow::Result<()> { #[test] fn test_load_from_slot() -> anyhow::Result<()> { let compiled = compile! { - debug + check r#" device airCon = "d0"; @@ -168,8 +204,14 @@ fn test_load_from_slot() -> anyhow::Result<()> { "# }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -186,7 +228,7 @@ fn test_load_from_slot() -> anyhow::Result<()> { #[test] fn test_set_slot() -> anyhow::Result<()> { let compiled = compile! { - debug + check r#" device airCon = "d0"; @@ -194,8 +236,14 @@ fn test_set_slot() -> anyhow::Result<()> { "# }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main @@ -211,7 +259,7 @@ fn test_set_slot() -> anyhow::Result<()> { #[test] fn test_load_reagent() -> anyhow::Result<()> { let compiled = compile! { - debug + check r#" device thingy = "d0"; @@ -219,8 +267,14 @@ fn test_load_reagent() -> anyhow::Result<()> { "# }; + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + assert_eq!( - compiled, + compiled.output, indoc! { " j main diff --git a/rust_compiler/libs/compiler/src/test/tuple_literals.rs b/rust_compiler/libs/compiler/src/test/tuple_literals.rs new file mode 100644 index 0000000..eaa75bd --- /dev/null +++ b/rust_compiler/libs/compiler/src/test/tuple_literals.rs @@ -0,0 +1,1354 @@ +#[cfg(test)] +mod test { + use indoc::indoc; + use pretty_assertions::assert_eq; + + #[test] + fn test_tuple_literal_declaration() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + let (x, y) = (1, 2); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 1 + move r9 2 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_literal_declaration_with_underscore() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + let (x, _) = (1, 2); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 1 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_literal_assignment() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + let x = 0; + let y = 0; + (x, y) = (5, 10); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 0 + move r9 0 + move r8 5 + move r9 10 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_literal_with_variables() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + let a = 42; + let b = 99; + let (x, y) = (a, b); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 42 + move r9 99 + move r10 r8 + move r11 r9 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_literal_three_elements() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + let (x, y, z) = (1, 2, 3); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 1 + move r9 2 + move r10 3 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_literal_assignment_with_underscore() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + let i = 0; + let x = 123; + (i, _) = (456, 789); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 0 + move r9 123 + move r8 456 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_return_simple() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + fn getPair() { + return (10, 20); + }; + let (x, y) = getPair(); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + getPair: + push sp + push ra + push 10 + push 20 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L1: + sub r0 sp 3 + get ra db r0 + j ra + main: + jal getPair + pop r9 + pop r8 + move sp r15 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_return_with_underscore() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + fn getPair() { + return (5, 15); + }; + let (x, _) = getPair(); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + getPair: + push sp + push ra + push 5 + push 15 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L1: + sub r0 sp 3 + get ra db r0 + j ra + main: + jal getPair + pop r0 + pop r8 + move sp r15 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_return_three_elements() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + fn getTriple() { + return (1, 2, 3); + }; + let (a, b, c) = getTriple(); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + getTriple: + push sp + push ra + push 1 + push 2 + push 3 + sub r0 sp 5 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L1: + sub r0 sp 4 + get ra db r0 + j ra + main: + jal getTriple + pop r10 + pop r9 + pop r8 + move sp r15 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_return_assignment() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + fn getPair() { + return (42, 84); + }; + let i = 1; + let j = 2; + (i, j) = getPair(); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + getPair: + push sp + push ra + push 42 + push 84 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L1: + sub r0 sp 3 + get ra db r0 + j ra + main: + move r8 1 + move r9 2 + jal getPair + pop r9 + pop r8 + move sp r15 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_return_mismatch() -> anyhow::Result<()> { + let errors = compile!( + result + r#" + fn doSomething() { + return (1, 2, 3); + }; + let (x, y) = doSomething(); + "# + ); + + // Should have exactly one error about tuple size mismatch + assert_eq!(errors.len(), 1); + + // Check for the specific TupleSizeMismatch error + match &errors[0] { + crate::Error::TupleSizeMismatch(expected_size, actual_count, _) => { + assert_eq!(*expected_size, 3); + assert_eq!(*actual_count, 2); + } + e => panic!("Expected TupleSizeMismatch error, got: {:?}", e), + } + + Ok(()) + } + + #[test] + fn test_tuple_return_called_by_non_tuple_return() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + fn doSomething() { + return (1, 2); + }; + + fn doSomethingElse() { + let (x, y) = doSomething(); + return y; + }; + + let returnedValue = doSomethingElse(); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + doSomething: + push sp + push ra + push 1 + push 2 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L1: + sub r0 sp 3 + get ra db r0 + j ra + doSomethingElse: + push sp + push ra + jal doSomething + pop r9 + pop r8 + move sp r15 + move r15 r9 + j __internal_L2 + __internal_L2: + pop ra + pop sp + j ra + main: + jal doSomethingElse + move r8 r15 + " + } + ); + + Ok(()) + } + + #[test] + fn test_non_tuple_return_called_by_tuple_return() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + fn getValue() { + return 42; + }; + + fn getTuple() { + let x = getValue(); + return (x, x); + }; + + let (a, b) = getTuple(); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + getValue: + push sp + push ra + move r15 42 + j __internal_L1 + __internal_L1: + pop ra + pop sp + j ra + getTuple: + push sp + push ra + jal getValue + move r8 r15 + push r8 + push r8 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L2 + __internal_L2: + sub r0 sp 3 + get ra db r0 + j ra + main: + jal getTuple + pop r9 + pop r8 + move sp r15 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_literal_size_mismatch() -> anyhow::Result<()> { + let errors = compile!( + result + r#" + let (x, y) = (1, 2, 3); + "# + ); + + // Should have exactly one error about tuple size mismatch + assert_eq!(errors.len(), 1); + assert!(matches!( + errors[0], + crate::Error::TupleSizeMismatch(_, _, _) + )); + + Ok(()) + } + + #[test] + fn test_multiple_tuple_returns_in_function() -> anyhow::Result<()> { + // Test multiple return paths in tuple-returning function + let compiled = compile!( + check + r#" + fn getValue(x) { + if (x) { + return (1, 2); + } else { + return (3, 4); + } + }; + + let (a, b) = getValue(1); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + getValue: + pop r8 + push sp + push ra + beqz r8 __internal_L3 + push 1 + push 2 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L1 + j __internal_L2 + __internal_L3: + push 3 + push 4 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L2: + __internal_L1: + sub r0 sp 3 + get ra db r0 + j ra + main: + push 1 + jal getValue + pop r9 + pop r8 + move sp r15 + " + }, + ); + + Ok(()) + } + + #[test] + fn test_tuple_return_with_expression() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + fn add(x, y) { + return (x, y); + }; + + let (a, b) = add(5, 10); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + add: + pop r8 + pop r9 + push sp + push ra + push r9 + push r8 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L1: + sub r0 sp 3 + get ra db r0 + j ra + main: + push 5 + push 10 + jal add + pop r9 + pop r8 + move sp r15 + " + } + ); + + Ok(()) + } + + #[test] + fn test_nested_function_tuple_calls() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + fn inner() { + return (1, 2); + }; + + fn outer() { + let (x, y) = inner(); + return (y, x); + }; + + let (a, b) = outer(); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + inner: + push sp + push ra + push 1 + push 2 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L1: + sub r0 sp 3 + get ra db r0 + j ra + outer: + push sp + push ra + jal inner + pop r9 + pop r8 + move sp r15 + push r9 + push r8 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L2 + __internal_L2: + sub r0 sp 3 + get ra db r0 + j ra + main: + jal outer + pop r9 + pop r8 + move sp r15 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_literal_with_constant_expressions() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + let (a, b) = (1 + 2, 3 * 4); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 3 + move r9 12 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_literal_with_variable_expressions() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + let x = 5; + let y = 10; + let (a, b) = (x + 1, y * 2); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 5 + move r9 10 + add r1 r8 1 + move r10 r1 + mul r2 r9 2 + move r11 r2 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_assignment_with_expressions() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + let a = 0; + let b = 0; + let x = 5; + (a, b) = (x + 1, x * 2); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 0 + move r9 0 + move r10 5 + add r1 r10 1 + move r8 r1 + mul r2 r10 2 + move r9 r2 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_literal_with_function_calls() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + fn getValue() { return 42; }; + fn getOther() { return 99; }; + + let (a, b) = (getValue(), getOther()); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + getValue: + push sp + push ra + move r15 42 + j __internal_L1 + __internal_L1: + pop ra + pop sp + j ra + getOther: + push sp + push ra + move r15 99 + j __internal_L2 + __internal_L2: + pop ra + pop sp + j ra + main: + push r8 + jal getValue + pop r8 + move r1 r15 + move r8 r1 + push r8 + push r9 + jal getOther + pop r9 + pop r8 + move r2 r15 + move r9 r2 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_with_logical_expressions() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + let x = 1; + let y = 0; + let (a, b) = (x && y, x || y); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 1 + move r9 0 + and r1 r8 r9 + move r10 r1 + or r2 r8 r9 + move r11 r2 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_with_comparison_expressions() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + let x = 5; + let y = 10; + let (a, b) = (x > y, x < y); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + move r8 5 + move r9 10 + sgt r1 r8 r9 + move r10 r1 + slt r2 r8 r9 + move r11 r2 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_with_device_property_access() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + device sensor = "d0"; + device display = "d1"; + + let (temp, pressure) = (sensor.Temperature, sensor.Pressure); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + main: + l r1 d0 Temperature + move r8 r1 + l r2 d0 Pressure + move r9 r2 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_with_device_property_and_function_call() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + device self = "db"; + + fn getY() { + return 42; + } + + let (x, y) = (self.Setting, getY()); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + getY: + push sp + push ra + move r15 42 + j __internal_L1 + __internal_L1: + pop ra + pop sp + j ra + main: + l r1 db Setting + move r8 r1 + push r8 + push r9 + jal getY + pop r9 + pop r8 + move r2 r15 + move r9 r2 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_with_function_call_expressions() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + fn getValue() { return 10; } + fn getOther() { return 20; } + + let (a, b) = (getValue() + 5, getOther() * 2); + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + getValue: + push sp + push ra + move r15 10 + j __internal_L1 + __internal_L1: + pop ra + pop sp + j ra + getOther: + push sp + push ra + move r15 20 + j __internal_L2 + __internal_L2: + pop ra + pop sp + j ra + main: + push r8 + jal getValue + pop r8 + move r1 r15 + add r2 r1 5 + move r8 r2 + push r8 + push r9 + jal getOther + pop r9 + pop r8 + move r3 r15 + mul r4 r3 2 + move r9 r4 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_with_stack_spillover() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + fn get8() { + return (1, 2, 3, 4, 5, 6, 7, 8); + } + + let (a, b, c, d, e, f, g, h) = get8(); + let sum = a + h; + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + get8: + push sp + push ra + push 1 + push 2 + push 3 + push 4 + push 5 + push 6 + push 7 + push 8 + sub r0 sp 10 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L1: + sub r0 sp 9 + get ra db r0 + j ra + main: + jal get8 + pop r0 + sub r0 sp 0 + put db r0 r0 + pop r14 + pop r13 + pop r12 + pop r11 + pop r10 + pop r9 + pop r8 + move sp r15 + sub r0 sp 1 + get r1 db r0 + add r2 r8 r1 + push r2 + sub sp sp 2 + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_return_in_loop() -> anyhow::Result<()> { + let compiled = compile!( + check + r#" + fn getValues(i) { + return (i, i * 2); + }; + + let sum = 0; + let i = 0; + loop { + let (a, b) = getValues(i); + sum = sum + a + b; + i = i + 1; + if (i > 3) { + break; + } + } + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + getValues: + pop r8 + push sp + push ra + push r8 + mul r1 r8 2 + push r1 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L1: + sub r0 sp 3 + get ra db r0 + j ra + main: + move r8 0 + move r9 0 + __internal_L2: + push r9 + jal getValues + pop r11 + pop r10 + move sp r15 + add r1 r8 r10 + add r2 r1 r11 + move r8 r2 + add r3 r9 1 + move r9 r3 + sgt r4 r9 3 + beqz r4 __internal_L4 + j __internal_L3 + __internal_L4: + j __internal_L2 + __internal_L3: + " + } + ); + + Ok(()) + } + + #[test] + fn test_tuple_all_forms_combined() -> anyhow::Result<()> { + // Test all three tuple forms in one test: + // 1. Tuple literal declaration: let (x, y) = (1, 2); + // 2. Tuple literal assignment: (x, y) = (3, 4); + // 3. Function return tuple: let (a, b) = func(); + let compiled = compile!( + check + r#" + fn swap(x, y) { + return (y, x); + }; + + let (a, b) = (10, 20); // Literal declaration + (a, b) = (30, 40); // Literal assignment + let (c, d) = swap(a, b); // Function return + "# + ); + + assert!( + compiled.errors.is_empty(), + "Expected no errors, got: {:?}", + compiled.errors + ); + + assert_eq!( + compiled.output, + indoc! { + " + j main + swap: + pop r8 + pop r9 + push sp + push ra + push r8 + push r9 + sub r0 sp 4 + get r0 db r0 + move r15 r0 + j __internal_L1 + __internal_L1: + sub r0 sp 3 + get ra db r0 + j ra + main: + move r8 10 + move r9 20 + move r8 30 + move r9 40 + push r8 + push r9 + jal swap + pop r11 + pop r10 + move sp r15 + " + } + ); + + Ok(()) + } +} diff --git a/rust_compiler/libs/compiler/src/v1.rs b/rust_compiler/libs/compiler/src/v1.rs index ea049f2..d5daaa5 100644 --- a/rust_compiler/libs/compiler/src/v1.rs +++ b/rust_compiler/libs/compiler/src/v1.rs @@ -9,7 +9,8 @@ use parser::{ AssignmentExpression, BinaryExpression, BlockExpression, ConstDeclarationExpression, DeviceDeclarationExpression, Expression, FunctionExpression, IfExpression, InvocationExpression, Literal, LiteralOr, LiteralOrVariable, LogicalExpression, - LoopExpression, MemberAccessExpression, Spanned, TernaryExpression, WhileExpression, + LoopExpression, MemberAccessExpression, Spanned, TernaryExpression, + TupleAssignmentExpression, TupleDeclarationExpression, WhileExpression, }, }; use rust_decimal::Decimal; @@ -63,6 +64,9 @@ pub enum Error<'a> { #[error("Attempted to re-assign a value to a device const `{0}`")] DeviceAssignment(Cow<'a, str>, Span), + #[error("Expected a {0}-tuple, but you're trying to destructure into {1} variables")] + TupleSizeMismatch(usize, usize, Span), + #[error("{0}")] Unknown(String, Option), } @@ -84,7 +88,8 @@ impl<'a> From> for lsp_types::Diagnostic { | InvalidDevice(_, span) | ConstAssignment(_, span) | DeviceAssignment(_, span) - | AgrumentMismatch(_, span) => Diagnostic { + | AgrumentMismatch(_, span) + | TupleSizeMismatch(_, _, span) => Diagnostic { range: span.into(), message: value.to_string(), severity: Some(DiagnosticSeverity::ERROR), @@ -138,10 +143,45 @@ pub struct CompilationResult<'a> { pub instructions: Instructions<'a>, } +/// Metadata for the currently compiling function +#[derive(Debug)] +struct FunctionMetadata<'a> { + /// Maps function name to its instruction location + locations: HashMap, usize>, + /// Maps function name to list of parameter names + params: HashMap, Vec>>, + /// Maps function name to tuple return size (if it returns a tuple) + tuple_return_sizes: HashMap, usize>, + /// Name of the function currently being compiled + current_name: Option>, + /// Return label for the current function + return_label: Option>, + /// Size of tuple return for the current function (0 if not returning tuple) + tuple_return_size: u16, + /// Whether the SP (stack pointer) has been saved for the current function + sp_saved: bool, + /// Variable name for the saved SP at function entry (for stack unwinding) + sp_backup_var: Option>, +} + +impl<'a> Default for FunctionMetadata<'a> { + fn default() -> Self { + Self { + locations: HashMap::new(), + params: HashMap::new(), + tuple_return_sizes: HashMap::new(), + current_name: None, + return_label: None, + tuple_return_size: 0, + sp_saved: false, + sp_backup_var: None, + } + } +} + pub struct Compiler<'a> { pub parser: ASTParser<'a>, - function_locations: HashMap, usize>, - function_metadata: HashMap, Vec>>, + function_meta: FunctionMetadata<'a>, devices: HashMap, Cow<'a, str>>, // This holds the IL code which will be used in the @@ -154,7 +194,6 @@ pub struct Compiler<'a> { temp_counter: usize, label_counter: usize, loop_stack: Vec<(Cow<'a, str>, Cow<'a, str>)>, // Stores (start_label, end_label) - current_return_label: Option>, /// stores (IC10 `line_num`, `Vec`) pub source_map: HashMap>, /// Accumulative errors from the compilation process @@ -165,8 +204,7 @@ impl<'a> Compiler<'a> { pub fn new(parser: ASTParser<'a>, config: Option) -> Self { Self { parser, - function_locations: HashMap::new(), - function_metadata: HashMap::new(), + function_meta: FunctionMetadata::default(), devices: HashMap::new(), instructions: Instructions::default(), current_line: 1, @@ -175,7 +213,6 @@ impl<'a> Compiler<'a> { temp_counter: 0, label_counter: 0, loop_stack: Vec::new(), - current_return_label: None, source_map: HashMap::new(), errors: Vec::new(), } @@ -267,6 +304,29 @@ impl<'a> Compiler<'a> { Cow::from(format!("__internal_L{}", self.label_counter)) } + /// Merges two spans into a single span covering both + fn merge_spans(start: Span, end: Span) -> Span { + Span { + start_line: start.start_line, + start_col: start.start_col, + end_line: end.end_line, + end_col: end.end_col, + } + } + + /// Cleans up temporary variables, ignoring errors + fn cleanup_temps( + scope: &mut VariableScope<'a, '_>, + temps: &[Option>], + ) -> Result<(), Error<'a>> { + for temp in temps { + if let Some(name) = temp { + scope.free_temp(name.clone(), None)?; + } + } + Ok(()) + } + fn expression( &mut self, expr: Spanned>, @@ -465,6 +525,14 @@ impl<'a> Compiler<'a> { temp_name: Some(result_name), })) } + Expression::TupleDeclaration(tuple_decl) => { + self.expression_tuple_declaration(tuple_decl.node, scope)?; + Ok(None) + } + Expression::TupleAssignment(tuple_assign) => { + self.expression_tuple_assignment(tuple_assign.node, scope)?; + Ok(None) + } _ => Err(Error::Unknown( format!( "Expression type not yet supported in general expression context: {:?}", @@ -538,15 +606,13 @@ impl<'a> Compiler<'a> { let name_str = var_name.node; let name_span = var_name.span; - // optimization. Check for a negated numeric literal - if let Expression::Negation(box_expr) = &expr.node - && let Expression::Literal(spanned_lit) = &box_expr.node - && let Literal::Number(neg_num) = &spanned_lit.node - { + // optimization. Check for a negated numeric literal (including nested negations) + // e.g., -5, -(-5), -(-(5)), etc. + if let Some(num) = self.try_fold_negation(&expr.node) { let loc = scope.add_variable(name_str.clone(), LocationRequest::Persist, Some(name_span))?; - self.emit_variable_assignment(&loc, Operand::Number((-*neg_num).into()))?; + self.emit_variable_assignment(&loc, Operand::Number(num.into()))?; return Ok(Some(CompileLocation { location: loc, temp_name: None, @@ -783,6 +849,46 @@ impl<'a> Compiler<'a> { } (var_loc, None) } + Expression::Negation(_) => { + // Use try_fold_negation to see if this is a constant folded negation + if let Some(num) = self.try_fold_negation(&expr.node) { + let loc = scope.add_variable( + name_str.clone(), + LocationRequest::Persist, + Some(name_span), + )?; + self.emit_variable_assignment(&loc, Operand::Number(num.into()))?; + return Ok(Some(CompileLocation { + location: loc, + temp_name: None, + })); + } + + // Otherwise, compile the negation expression + let result = self.expression(expr, scope)?; + let var_loc = scope.add_variable( + name_str.clone(), + LocationRequest::Persist, + Some(name_span), + )?; + + if let Some(res) = result { + // Move result from temp to new persistent variable + let result_reg = self.resolve_register(&res.location)?; + self.emit_variable_assignment(&var_loc, Operand::Register(result_reg))?; + + // Free the temp result + if let Some(name) = res.temp_name { + scope.free_temp(name, None)?; + } + } else { + return Err(Error::Unknown( + format!("`{name_str}` negation expression did not produce a value"), + Some(name_span), + )); + } + (var_loc, None) + } _ => { return Err(Error::Unknown( format!("`{name_str}` declaration of this type is not supported/implemented."), @@ -932,6 +1038,492 @@ impl<'a> Compiler<'a> { Ok(()) } + fn expression_function_invocation_with_invocation( + &mut self, + invoke_expr: &InvocationExpression<'a>, + parent_scope: &mut VariableScope<'a, '_>, + backup_registers: bool, + ) -> Result<(), Error<'a>> { + let InvocationExpression { name, arguments } = invoke_expr; + + if !self + .function_meta + .locations + .contains_key(name.node.as_ref()) + { + self.errors + .push(Error::UnknownIdentifier(name.node.clone(), name.span)); + return Ok(()); + } + + let Some(args) = self.function_meta.params.get(name.node.as_ref()) else { + return Err(Error::UnknownIdentifier(name.node.clone(), name.span)); + }; + + if args.len() != arguments.len() { + self.errors + .push(Error::AgrumentMismatch(name.node.clone(), name.span)); + return Ok(()); + } + let mut stack = VariableScope::scoped(parent_scope); + + // Get the list of active registers (may or may not backup) + let active_registers = stack.registers(); + + // backup all used registers to the stack (unless this is for tuple return handling) + if backup_registers { + for register in &active_registers { + stack.add_variable( + Cow::from(format!("temp_{register}")), + LocationRequest::Stack, + None, + )?; + self.write_instruction( + Instruction::Push(Operand::Register(*register)), + Some(name.span), + )?; + } + } + for arg in arguments { + match &arg.node { + Expression::Literal(spanned_lit) => match &spanned_lit.node { + Literal::Number(num) => { + self.write_instruction( + Instruction::Push(Operand::Number((*num).into())), + Some(spanned_lit.span), + )?; + } + Literal::Boolean(b) => { + self.write_instruction( + Instruction::Push(Operand::Number(Number::from(*b).into())), + Some(spanned_lit.span), + )?; + } + _ => {} + }, + Expression::Variable(var_name) => { + let loc = match stack.get_location_of(&var_name.node, Some(var_name.span)) { + Ok(l) => l, + Err(_) => { + self.errors.push(Error::UnknownIdentifier( + var_name.node.clone(), + var_name.span, + )); + VariableLocation::Temporary(0) + } + }; + + match loc { + VariableLocation::Persistant(reg) | VariableLocation::Temporary(reg) => { + self.write_instruction( + Instruction::Push(Operand::Register(reg)), + Some(var_name.span), + )?; + } + VariableLocation::Constant(lit) => { + self.write_instruction( + Instruction::Push(extract_literal(lit, false)?), + Some(var_name.span), + )?; + } + VariableLocation::Stack(stack_offset) => { + self.write_instruction( + Instruction::Sub( + Operand::Register(VariableScope::TEMP_STACK_REGISTER), + Operand::StackPointer, + Operand::Number(stack_offset.into()), + ), + Some(var_name.span), + )?; + + self.write_instruction( + Instruction::Get( + Operand::Register(VariableScope::TEMP_STACK_REGISTER), + Operand::Device(Cow::from("db")), + Operand::Register(VariableScope::TEMP_STACK_REGISTER), + ), + Some(var_name.span), + )?; + + self.write_instruction( + Instruction::Push(Operand::Register( + VariableScope::TEMP_STACK_REGISTER, + )), + Some(var_name.span), + )?; + } + VariableLocation::Device(_) => { + self.errors.push(Error::Unknown( + "Device references not supported in function arguments".into(), + Some(var_name.span), + )); + } + } + } + _ => { + self.errors.push(Error::Unknown( + "Only literals and variables supported in function arguments".into(), + Some(arg.span), + )); + } + } + } + + let Some(_location) = self.function_meta.locations.get(&name.node) else { + self.errors + .push(Error::UnknownIdentifier(name.node.clone(), name.span)); + return Ok(()); + }; + + self.write_instruction( + Instruction::JumpAndLink(Operand::Label(name.node.clone())), + Some(name.span), + )?; + + // Pop the arguments off the stack (caller cleanup convention) + // BUT: If the function returns a tuple, it saves SP in r15 and the caller + // will restore SP with "move sp r15", which automatically cleans up everything. + // So we only pop arguments for non-tuple-returning functions. + let returns_tuple = self + .function_meta + .tuple_return_sizes + .get(&name.node) + .copied() + .unwrap_or(0) + > 0; + + if !returns_tuple { + for _ in 0..arguments.len() { + self.write_instruction( + Instruction::Pop(Operand::Register(VariableScope::TEMP_STACK_REGISTER)), + Some(name.span), + )?; + } + } + + // pop all registers back (if they were backed up) + if backup_registers { + for register in active_registers.iter().rev() { + self.write_instruction( + Instruction::Pop(Operand::Register(*register)), + Some(name.span), + )?; + } + } + + Ok(()) + } + + /// Helper: Validate tuple size from function return + fn validate_tuple_function_size( + &mut self, + func_name: &Cow<'a, str>, + expected_count: usize, + span: Span, + ) { + if let Some(&actual_size) = self.function_meta.tuple_return_sizes.get(func_name) { + if actual_size != expected_count { + self.errors + .push(Error::TupleSizeMismatch(actual_size, expected_count, span)); + } + } + } + + /// Helper: Pop tuple values from stack into variables (for function returns) + /// Variables are popped in reverse order (LIFO) + fn pop_tuple_values( + &mut self, + var_locations: Vec<(Option, Span)>, + ) -> Result<(), Error<'a>> { + for (var_loc_opt, span) in var_locations.into_iter().rev() { + if let Some(var_location) = var_loc_opt { + match var_location { + VariableLocation::Temporary(reg) | VariableLocation::Persistant(reg) => { + self.write_instruction( + Instruction::Pop(Operand::Register(reg)), + Some(span), + )?; + } + VariableLocation::Stack(offset) => { + // Pop into temp register, then write to stack + self.write_instruction( + Instruction::Pop(Operand::Register(VariableScope::TEMP_STACK_REGISTER)), + Some(span), + )?; + + self.write_instruction( + Instruction::Sub( + Operand::Register(0), + Operand::StackPointer, + Operand::Number(offset.into()), + ), + Some(span), + )?; + + self.write_instruction( + Instruction::Put( + Operand::Device(Cow::from("db")), + Operand::Register(0), + Operand::Register(VariableScope::TEMP_STACK_REGISTER), + ), + Some(span), + )?; + } + VariableLocation::Constant(_) => { + return Err(Error::ConstAssignment(Cow::from("tuple element"), span)); + } + VariableLocation::Device(_) => { + return Err(Error::DeviceAssignment(Cow::from("tuple element"), span)); + } + } + } else { + // Underscore: pop into temp register to discard + self.write_instruction( + Instruction::Pop(Operand::Register(VariableScope::TEMP_STACK_REGISTER)), + Some(span), + )?; + } + } + + // Restore stack pointer from r15 to clean up remaining tuple values + // (r15 contains the caller's SP from before the function was called) + self.write_instruction( + Instruction::Move( + Operand::StackPointer, + Operand::Register(VariableScope::RETURN_REGISTER), + ), + None, + )?; + + Ok(()) + } + + fn expression_tuple_declaration( + &mut self, + tuple_decl: TupleDeclarationExpression<'a>, + scope: &mut VariableScope<'a, '_>, + ) -> Result<(), Error<'a>> { + let TupleDeclarationExpression { names, value } = tuple_decl; + + match value.node { + Expression::Invocation(invoke_expr) => { + // Execute the function call - tuple values will be on the stack + self.expression_function_invocation_with_invocation(&invoke_expr, scope, false)?; + + // Validate tuple return size matches the declaration + self.validate_tuple_function_size( + &invoke_expr.node.name.node, + names.len(), + value.span, + ); + + // Allocate variables and collect their locations + let var_locations: Vec<_> = names + .iter() + .map(|name_spanned| { + if name_spanned.node.as_ref() == "_" { + Ok((None, name_spanned.span)) + } else { + let var_location = scope.add_variable( + name_spanned.node.clone(), + LocationRequest::Persist, + Some(name_spanned.span), + )?; + Ok((Some(var_location), name_spanned.span)) + } + }) + .collect::>>()?; + + // Pop tuple values from stack into variables + self.pop_tuple_values(var_locations)?; + } + Expression::Tuple(tuple_expr) => { + // Direct tuple literal: (value1, value2, ...) + let tuple_elements = tuple_expr.node; + + // Validate tuple size matches names + if tuple_elements.len() != names.len() { + return Err(Error::TupleSizeMismatch( + names.len(), + tuple_elements.len(), + value.span, + )); + } + + // Compile each element and assign to corresponding variable + for (name_spanned, element) in names.into_iter().zip(tuple_elements.into_iter()) { + // Skip underscores + if name_spanned.node.as_ref() == "_" { + continue; + } + + // Add variable to scope + let var_location = scope.add_variable( + name_spanned.node.clone(), + LocationRequest::Persist, + Some(name_spanned.span), + )?; + + // Compile the element expression - use compile_operand to handle all expression types + let (value_operand, cleanup) = self.compile_operand(element, scope)?; + self.emit_variable_assignment(&var_location, value_operand)?; + + // Clean up any temporary registers used for complex expressions + if let Some(temp_name) = cleanup { + scope.free_temp(temp_name, None)?; + } + } + } + _ => { + return Err(Error::Unknown( + "Tuple declaration only supports function invocations or tuple literals as RHS" + .into(), + Some(value.span), + )); + } + } + + Ok(()) + } + + fn expression_tuple_assignment( + &mut self, + tuple_assign: TupleAssignmentExpression<'a>, + scope: &mut VariableScope<'a, '_>, + ) -> Result<(), Error<'a>> { + let TupleAssignmentExpression { names, value } = tuple_assign; + + match value.node { + Expression::Invocation(invoke_expr) => { + // Execute the function call - tuple values will be on the stack + self.expression_function_invocation_with_invocation(&invoke_expr, scope, false)?; + + // Validate tuple return size matches the assignment + self.validate_tuple_function_size( + &invoke_expr.node.name.node, + names.len(), + value.span, + ); + + // Look up existing variable locations + let var_locations: Vec<_> = names + .iter() + .map(|name_spanned| { + if name_spanned.node.as_ref() == "_" { + Ok((None, name_spanned.span)) + } else { + let var_location = scope + .get_location_of(&name_spanned.node, Some(name_spanned.span)) + .unwrap_or_else(|_| { + self.errors.push(Error::UnknownIdentifier( + name_spanned.node.clone(), + name_spanned.span, + )); + VariableLocation::Temporary(0) + }); + Ok((Some(var_location), name_spanned.span)) + } + }) + .collect::>>()?; + + // Pop tuple values from stack into variables + self.pop_tuple_values(var_locations)?; + } + Expression::Tuple(tuple_expr) => { + // Direct tuple literal: (value1, value2, ...) + let tuple_elements = tuple_expr.node; + + // Validate tuple size matches names + if tuple_elements.len() != names.len() { + return Err(Error::TupleSizeMismatch( + tuple_elements.len(), + names.len(), + value.span, + )); + } + + // Compile each element and assign to corresponding variable + for (name_spanned, element) in names.into_iter().zip(tuple_elements.into_iter()) { + // Skip underscores + if name_spanned.node.as_ref() == "_" { + continue; + } + + // Get the existing variable location + let var_location = + match scope.get_location_of(&name_spanned.node, Some(name_spanned.span)) { + Ok(l) => l, + Err(_) => { + self.errors.push(Error::UnknownIdentifier( + name_spanned.node.clone(), + name_spanned.span, + )); + VariableLocation::Temporary(0) + } + }; + + // Compile the element expression - use compile_operand to handle all expression types + let (value_operand, cleanup) = self.compile_operand(element, scope)?; + + // Assign the compiled value to the target variable location + match &var_location { + VariableLocation::Temporary(reg) | VariableLocation::Persistant(reg) => { + self.write_instruction( + Instruction::Move(Operand::Register(*reg), value_operand), + Some(name_spanned.span), + )?; + } + VariableLocation::Stack(offset) => { + self.write_instruction( + Instruction::Sub( + Operand::Register(VariableScope::TEMP_STACK_REGISTER), + Operand::StackPointer, + Operand::Number((*offset).into()), + ), + Some(name_spanned.span), + )?; + + self.write_instruction( + Instruction::Put( + Operand::Device(Cow::from("db")), + Operand::Register(VariableScope::TEMP_STACK_REGISTER), + value_operand, + ), + Some(name_spanned.span), + )?; + } + VariableLocation::Constant(_) => { + return Err(Error::ConstAssignment( + name_spanned.node.clone(), + name_spanned.span, + )); + } + VariableLocation::Device(_) => { + return Err(Error::DeviceAssignment( + name_spanned.node.clone(), + name_spanned.span, + )); + } + } + + // Clean up any temporary registers used for complex expressions + if let Some(temp_name) = cleanup { + scope.free_temp(temp_name, None)?; + } + } + } + _ => { + return Err(Error::Unknown( + "Tuple assignment only supports function invocations or tuple literals as RHS" + .into(), + Some(value.span), + )); + } + } + + Ok(()) + } + fn expression_function_invocation( &mut self, invoke_expr: Spanned>, @@ -939,7 +1531,7 @@ impl<'a> Compiler<'a> { ) -> Result<(), Error<'a>> { let InvocationExpression { name, arguments } = invoke_expr.node; - if !self.function_locations.contains_key(&name.node) { + if !self.function_meta.locations.contains_key(&name.node) { self.errors .push(Error::UnknownIdentifier(name.node.clone(), name.span)); // Don't emit call, just pretend we did? @@ -949,7 +1541,7 @@ impl<'a> Compiler<'a> { return Ok(()); } - let Some(args) = self.function_metadata.get(&name.node) else { + let Some(args) = self.function_meta.params.get(&name.node) else { // Should be covered by check above return Err(Error::UnknownIdentifier(name.node, name.span)); }; @@ -1497,6 +2089,25 @@ impl<'a> Compiler<'a> { ) } + /// Recursively fold negations of numeric literals, e.g., -5 => 5, -(-5) => 5 + fn try_fold_negation(&self, expr: &Expression) -> Option { + match expr { + // Base case: plain number literal + Expression::Literal(lit) => { + if let Literal::Number(n) = lit.node { + Some(n) + } else { + None + } + } + // Recursive case: negation of something foldable + Expression::Negation(inner) => self.try_fold_negation(&inner.node).map(|n| -n), + // Parentheses just pass through + Expression::Priority(inner) => self.try_fold_negation(&inner.node), + _ => None, + } + } + fn expression_binary( &mut self, expr: Spanned>, @@ -1580,12 +2191,7 @@ impl<'a> Compiler<'a> { } }; - let span = Span { - start_line: left_expr.span.start_line, - start_col: left_expr.span.start_col, - end_line: right_expr.span.end_line, - end_col: right_expr.span.end_col, - }; + let span = Self::merge_spans(left_expr.span, right_expr.span); // Compile LHS let (lhs, lhs_cleanup) = self.compile_operand(*left_expr, scope)?; @@ -1604,12 +2210,7 @@ impl<'a> Compiler<'a> { )?; // Clean up operand temps - if let Some(name) = lhs_cleanup { - scope.free_temp(name, None)?; - } - if let Some(name) = rhs_cleanup { - scope.free_temp(name, None)?; - } + Self::cleanup_temps(scope, &[lhs_cleanup, rhs_cleanup])?; Ok(CompileLocation { location: result_loc, @@ -1685,12 +2286,7 @@ impl<'a> Compiler<'a> { LogicalExpression::Not(_) => unreachable!(), }; - let span = Span { - start_line: left_expr.span.start_line, - start_col: left_expr.span.start_col, - end_line: right_expr.span.end_line, - end_col: right_expr.span.end_col, - }; + let span = Self::merge_spans(left_expr.span, right_expr.span); // Compile LHS let (lhs, lhs_cleanup) = self.compile_operand(*left_expr, scope)?; @@ -1710,12 +2306,7 @@ impl<'a> Compiler<'a> { )?; // Clean up operand temps - if let Some(name) = lhs_cleanup { - scope.free_temp(name, None)?; - } - if let Some(name) = rhs_cleanup { - scope.free_temp(name, None)?; - } + Self::cleanup_temps(scope, &[lhs_cleanup, rhs_cleanup])?; Ok(CompileLocation { location: result_loc, @@ -1946,6 +2537,73 @@ impl<'a> Compiler<'a> { } } } + Expression::Tuple(tuple_expr) => { + let span = expr.span; + let tuple_elements = tuple_expr.node; + let tuple_size = tuple_elements.len(); + + // Push each tuple element onto the stack using compile_operand + for element in tuple_elements.into_iter() { + let (push_operand, cleanup) = self.compile_operand(element, scope)?; + + self.write_instruction(Instruction::Push(push_operand), Some(span))?; + + // Don't track the push in the scope's stack offset because these values + // are being returned to the caller, not allocated in this block's scope. + // They will be left on the stack when we return. + + if let Some(temp_name) = cleanup { + scope.free_temp(temp_name, Some(span))?; + } + } + + // Load the saved SP from stack and move to r15 for caller's stack unwinding + if let Some(sp_var_name) = &self.function_meta.sp_backup_var { + let sp_var_loc = scope.get_location_of(sp_var_name, Some(span))?; + + if let VariableLocation::Stack(offset) = sp_var_loc { + // Calculate address of saved SP, accounting for tuple values just pushed + let adjusted_offset = offset + tuple_size as u16; + self.write_instruction( + Instruction::Sub( + Operand::Register(VariableScope::TEMP_STACK_REGISTER), + Operand::StackPointer, + Operand::Number(adjusted_offset.into()), + ), + Some(span), + )?; + + // Load saved SP value + self.write_instruction( + Instruction::Get( + Operand::Register(VariableScope::TEMP_STACK_REGISTER), + Operand::Device(Cow::from("db")), + Operand::Register(VariableScope::TEMP_STACK_REGISTER), + ), + Some(span), + )?; + + // Move to r15 for caller + self.write_instruction( + Instruction::Move( + Operand::Register(VariableScope::RETURN_REGISTER), + Operand::Register(VariableScope::TEMP_STACK_REGISTER), + ), + Some(span), + )?; + } + } + + // Record the tuple return size for validation at call sites + if let Some(func_name) = &self.function_meta.current_name { + self.function_meta + .tuple_return_sizes + .insert(func_name.clone(), tuple_size); + } + + // Track tuple size for epilogue cleanup + self.function_meta.tuple_return_size = tuple_size as u16; + } _ => { return Err(Error::Unknown( format!("Unsupported `return` statement: {:?}", expr), @@ -1955,7 +2613,7 @@ impl<'a> Compiler<'a> { } } - if let Some(label) = &self.current_return_label { + if let Some(label) = &self.function_meta.return_label { self.write_instruction(Instruction::Jump(Operand::Label(label.clone())), None)?; } else { return Err(Error::Unknown( @@ -2589,22 +3247,26 @@ impl<'a> Compiler<'a> { let span = expr.span; - if self.function_locations.contains_key(&name.node) { + if self.function_meta.locations.contains_key(&name.node) { self.errors .push(Error::DuplicateIdentifier(name.node.clone(), name.span)); // Fallthrough to allow compiling the body anyway? // It might be useful to check body for errors. } - self.function_metadata.insert( + self.function_meta.params.insert( name.node.clone(), arguments.iter().map(|a| a.node.clone()).collect(), ); + // Set the current function being compiled + self.function_meta.current_name = Some(name.node.clone()); + // Declare the function as a line identifier self.write_instruction(Instruction::LabelDef(name.node.clone()), Some(span))?; - self.function_locations + self.function_meta + .locations .insert(name.node.clone(), self.current_line); // Create a new block scope for the function body @@ -2662,11 +3324,24 @@ impl<'a> Compiler<'a> { )?; } - self.write_instruction(Instruction::Push(Operand::ReturnAddress), Some(span))?; + // Save the caller's stack pointer FIRST (before any pushes modify it) + // This is crucial for proper stack unwinding in tuple returns + let sp_backup_name = self.next_temp_name(); + block_scope.add_variable( + sp_backup_name.clone(), + LocationRequest::Stack, + Some(name.span), + )?; + self.write_instruction(Instruction::Push(Operand::StackPointer), Some(span))?; + self.function_meta.sp_backup_var = Some(sp_backup_name); + self.function_meta.sp_saved = true; + // Generate return label name and track it before pushing ra let return_label = self.next_label_name(); - - let prev_return_label = self.current_return_label.replace(return_label.clone()); + let prev_return_label = self + .function_meta + .return_label + .replace(return_label.clone()); block_scope.add_variable( return_label.clone(), @@ -2674,6 +3349,8 @@ impl<'a> Compiler<'a> { Some(name.span), )?; + self.write_instruction(Instruction::Push(Operand::ReturnAddress), Some(span))?; + for expr in body.0 { match expr.node { Expression::Return(ret_expr) => { @@ -2713,30 +3390,35 @@ impl<'a> Compiler<'a> { } }; - self.current_return_label = prev_return_label; + self.function_meta.return_label = prev_return_label; + // Write the return label and epilogue self.write_instruction(Instruction::LabelDef(return_label.clone()), Some(span))?; - if ra_stack_offset == 1 { - self.write_instruction(Instruction::Pop(Operand::ReturnAddress), Some(span))?; + // Handle stack cleanup based on whether this is a tuple-returning function + let is_tuple_return = self.function_meta.tuple_return_size > 0; - let remaining_cleanup = block_scope.stack_offset() - 1; - if remaining_cleanup > 0 { - self.write_instruction( - Instruction::Sub( - Operand::StackPointer, - Operand::StackPointer, - Operand::Number(remaining_cleanup.into()), - ), - Some(span), - )?; - } + // For tuple returns, account for tuple values pushed onto the stack + let adjusted_ra_offset = if is_tuple_return { + ra_stack_offset + self.function_meta.tuple_return_size as u16 } else { + ra_stack_offset + }; + + // Load return address from stack + if adjusted_ra_offset == 1 && !is_tuple_return { + // Simple case: RA is at top, and we're not returning a tuple + // Just pop ra, then pop sp to restore + self.write_instruction(Instruction::Pop(Operand::ReturnAddress), Some(span))?; + self.write_instruction(Instruction::Pop(Operand::StackPointer), Some(span))?; + } else { + // RA is deeper in stack, or we're returning a tuple + // Load ra from offset self.write_instruction( Instruction::Sub( Operand::Register(VariableScope::TEMP_STACK_REGISTER), Operand::StackPointer, - Operand::Number(ra_stack_offset.into()), + Operand::Number(adjusted_ra_offset.into()), ), Some(span), )?; @@ -2750,19 +3432,45 @@ impl<'a> Compiler<'a> { Some(span), )?; - if block_scope.stack_offset() > 0 { + if !is_tuple_return { + // Non-tuple return: restore SP from saved value to clean up + let sp_offset = adjusted_ra_offset - 1; self.write_instruction( Instruction::Sub( + Operand::Register(VariableScope::TEMP_STACK_REGISTER), Operand::StackPointer, + Operand::Number(sp_offset.into()), + ), + Some(span), + )?; + + self.write_instruction( + Instruction::Get( + Operand::Register(VariableScope::TEMP_STACK_REGISTER), + Operand::Device(Cow::from("db")), + Operand::Register(VariableScope::TEMP_STACK_REGISTER), + ), + Some(span), + )?; + + self.write_instruction( + Instruction::Move( Operand::StackPointer, - Operand::Number(block_scope.stack_offset().into()), + Operand::Register(VariableScope::TEMP_STACK_REGISTER), ), Some(span), )?; } + // else: Tuple return - leave tuple values on stack for caller to pop } self.write_instruction(Instruction::Jump(Operand::ReturnAddress), Some(span))?; + + // Reset the flags for the next function + self.function_meta.tuple_return_size = 0; + self.function_meta.sp_saved = false; + self.function_meta.sp_backup_var = None; + self.function_meta.current_name = None; Ok(()) } } diff --git a/rust_compiler/libs/compiler/test_files/deviceIo.slang b/rust_compiler/libs/compiler/test_files/deviceIo.stlg similarity index 100% rename from rust_compiler/libs/compiler/test_files/deviceIo.slang rename to rust_compiler/libs/compiler/test_files/deviceIo.stlg diff --git a/rust_compiler/libs/compiler/test_files/script.slang b/rust_compiler/libs/compiler/test_files/script.stlg similarity index 100% rename from rust_compiler/libs/compiler/test_files/script.slang rename to rust_compiler/libs/compiler/test_files/script.stlg diff --git a/rust_compiler/libs/il/src/lib.rs b/rust_compiler/libs/il/src/lib.rs index 1a3fdc3..54aadfb 100644 --- a/rust_compiler/libs/il/src/lib.rs +++ b/rust_compiler/libs/il/src/lib.rs @@ -61,6 +61,7 @@ impl<'a> std::fmt::Display for Instructions<'a> { } } +#[derive(Clone)] pub struct InstructionNode<'a> { pub instruction: Instruction<'a>, pub span: Option, diff --git a/rust_compiler/libs/integration_tests/.gitattributes b/rust_compiler/libs/integration_tests/.gitattributes new file mode 100644 index 0000000..d5f2352 --- /dev/null +++ b/rust_compiler/libs/integration_tests/.gitattributes @@ -0,0 +1,2 @@ +# Treat snapshot files as text +*.snap text diff --git a/rust_compiler/libs/integration_tests/Cargo.toml b/rust_compiler/libs/integration_tests/Cargo.toml new file mode 100644 index 0000000..bc05679 --- /dev/null +++ b/rust_compiler/libs/integration_tests/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "integration_tests" +version = "0.1.0" +edition = "2024" +publish = false + +[dependencies] +compiler = { path = "../compiler" } +parser = { path = "../parser" } +tokenizer = { path = "../tokenizer" } +optimizer = { path = "../optimizer" } +il = { path = "../il" } +anyhow = { workspace = true } +indoc = "2" +insta = "1.40" + +[lib] +# This is a test-only crate +path = "src/lib.rs" diff --git a/rust_compiler/libs/integration_tests/README.md b/rust_compiler/libs/integration_tests/README.md new file mode 100644 index 0000000..00dce8a --- /dev/null +++ b/rust_compiler/libs/integration_tests/README.md @@ -0,0 +1,92 @@ +# Integration Tests for Slang Compiler with Optimizer + +This crate contains end-to-end integration tests for the Slang compiler that verify the complete compilation pipeline including all optimization passes. + +## Snapshot Testing with Insta + +These tests use [insta](https://insta.rs/) for snapshot testing, which captures the entire compiled output and stores it in snapshot files for comparison. + +### Running Tests + +```bash +# Run all integration tests +cargo test --package integration_tests + +# Run a specific test +cargo test --package integration_tests test_simple_leaf_function +``` + +### Updating Snapshots + +When you make changes to the compiler or optimizer that affect the output: + +```bash +# Update all snapshots automatically +INSTA_UPDATE=always cargo test --package integration_tests + +# Or use cargo-insta for interactive review (install first: cargo install cargo-insta) +cargo insta test --package integration_tests +cargo insta review --package integration_tests +``` + +### Understanding Snapshots + +Snapshot files are stored in `src/snapshots/` and contain: + +- The full IC10 assembly output from compiling Slang source code +- Metadata about which test generated them +- The expression that produced the output + +Example snapshot structure: + +``` +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +j main +move r8 10 +j ra +``` + +### What We Test + +1. **Leaf Function Optimization** - Removal of unnecessary `push sp/ra` and `pop ra/sp` +2. **Function Calls** - Preservation of stack frame when calling functions +3. **Constant Folding** - Compile-time evaluation of constant expressions +4. **Algebraic Simplification** - Identity operations like `x * 1` → `x` +5. **Strength Reduction** - Converting expensive operations like `x * 2` → `x + x` +6. **Dead Code Elimination** - Removal of unused variables +7. **Peephole Comparison Fusion** - Combining comparison + branch instructions +8. **Select Optimization** - Converting if/else to single `select` instruction +9. **Complex Arithmetic** - Multiple optimizations working together +10. **Nested Function Calls** - Full program optimization + +### Adding New Tests + +To add a new integration test: + +1. Add a new `#[test]` function in `src/lib.rs` +2. Call `compile_optimized()` with your Slang source code +3. Use `insta::assert_snapshot!(output)` to capture the output +4. Run with `INSTA_UPDATE=always` to create the initial snapshot +5. Review the snapshot file to ensure it looks correct + +Example: + +```rust +#[test] +fn test_my_optimization() { + let source = "fn foo(x) { return x + 1; }"; + let output = compile_optimized(source); + insta::assert_snapshot!(output); +} +``` + +### Benefits of Snapshot Testing + +- **Full Output Verification**: Tests the entire compiled output, not just snippets +- **Easy to Review**: Visual diffs show exactly what changed in the output +- **Regression Detection**: Any change to output is immediately visible +- **Living Documentation**: Snapshots serve as examples of compiler output +- **Less Brittle**: No need to manually update expected strings when making intentional changes diff --git a/rust_compiler/libs/integration_tests/src/lib.rs b/rust_compiler/libs/integration_tests/src/lib.rs new file mode 100644 index 0000000..7295e46 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/lib.rs @@ -0,0 +1,222 @@ +//! Integration tests for the Slang compiler with optimizer +//! +//! These tests compile Slang source code and verify both the compilation +//! and optimization passes work correctly together using snapshot testing. + +#[cfg(test)] +mod tests { + use compiler::Compiler; + use indoc::indoc; + use parser::Parser; + use tokenizer::Tokenizer; + + /// Compile Slang source code and return both unoptimized and optimized output + fn compile_with_and_without_optimization(source: &str) -> String { + // Compile for unoptimized output + let tokenizer = Tokenizer::from(source); + let parser = Parser::new(tokenizer); + let compiler = Compiler::new(parser, None); + let result = compiler.compile(); + + // Get unoptimized output + let mut unoptimized_writer = std::io::BufWriter::new(Vec::new()); + result + .instructions + .write(&mut unoptimized_writer) + .expect("Failed to write unoptimized output"); + let unoptimized_bytes = unoptimized_writer + .into_inner() + .expect("Failed to get bytes"); + let unoptimized = String::from_utf8(unoptimized_bytes).expect("Invalid UTF-8"); + + // Compile again for optimized output + let tokenizer2 = Tokenizer::from(source); + let parser2 = Parser::new(tokenizer2); + let compiler2 = Compiler::new(parser2, None); + let result2 = compiler2.compile(); + + // Apply optimizations + let optimized_instructions = optimizer::optimize(result2.instructions); + + // Get optimized output + let mut optimized_writer = std::io::BufWriter::new(Vec::new()); + optimized_instructions + .write(&mut optimized_writer) + .expect("Failed to write optimized output"); + let optimized_bytes = optimized_writer.into_inner().expect("Failed to get bytes"); + let optimized = String::from_utf8(optimized_bytes).expect("Invalid UTF-8"); + + // Combine both outputs with clear separators + format!( + "## Unoptimized Output\n\n{}\n## Optimized Output\n\n{}", + unoptimized, optimized + ) + } + + #[test] + fn test_simple_leaf_function() { + let source = "fn test() { let x = 10; }"; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_function_with_call() { + let source = indoc! {" + fn add(a, b) { return a + b; } + fn main() { let x = add(5, 10); } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_constant_folding() { + let source = "let x = 5 + 10;"; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_algebraic_simplification() { + let source = "let x = 5; let y = x * 1;"; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_strength_reduction() { + let source = "fn double(x) { return x * 2; }"; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_dead_code_elimination() { + let source = indoc! {" + fn compute(x) { + let unused = 20; + return x + 1; + } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_peephole_comparison_fusion() { + let source = indoc! {" + fn compare(x, y) { + if (x > y) { + let z = 1; + } + } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_select_optimization() { + let source = indoc! {" + fn ternary(cond) { + let result = 0; + if (cond) { + result = 10; + } else { + result = 20; + } + return result; + } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_leaf_function_no_stack_frame() { + let source = indoc! {" + fn increment(x) { + x = x + 1; + } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_complex_arithmetic() { + let source = indoc! {" + fn compute(a, b, c) { + let x = a * 2; + let y = b + 0; + let z = c * 1; + return x + y + z; + } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_nested_function_calls() { + let source = indoc! {" + fn add(a, b) { return a + b; } + fn multiply(x, y) { return x * 2; } + fn complex(a, b) { + let sum = add(a, b); + let doubled = multiply(sum, 2); + return doubled; + } + "}; + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_tuples() { + let source = indoc! {r#" + device self = "db"; + device day = "d0"; + + fn getSomethingElse(input) { + return input + 1; + } + + fn getSensorData() { + return ( + day.Vertical, + day.Horizontal, + getSomethingElse(3) + ); + } + + loop { + yield(); + + let (vertical, horizontal, _) = getSensorData(); + + (horizontal, _, _) = getSensorData(); + + self.Setting = horizontal; + } + "#}; + + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_larre_script() { + let source = include_str!("./test_files/test_larre_script.stlg"); + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } + + #[test] + fn test_reagent_processing() { + let source = include_str!("./test_files/reagent_processing.stlg"); + let output = compile_with_and_without_optimization(source); + insta::assert_snapshot!(output); + } +} diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__algebraic_simplification.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__algebraic_simplification.snap new file mode 100644 index 0000000..ce3aa70 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__algebraic_simplification.snap @@ -0,0 +1,17 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +main: +move r8 5 +mul r1 r8 1 +move r9 r1 + +## Optimized Output + +move r8 5 +move r1 5 +move r9 5 diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__complex_arithmetic.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__complex_arithmetic.snap new file mode 100644 index 0000000..f18794c --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__complex_arithmetic.snap @@ -0,0 +1,45 @@ +--- +source: libs/integration_tests/src/lib.rs +assertion_line: 158 +expression: output +--- +## Unoptimized Output + +j main +compute: +pop r8 +pop r9 +pop r10 +push sp +push ra +mul r1 r10 2 +move r11 r1 +add r2 r9 0 +move r12 r2 +mul r3 r8 1 +move r13 r3 +add r4 r11 r12 +add r5 r4 r13 +move r15 r5 +j __internal_L1 +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +pop r9 +pop r10 +push sp +push ra +add r11 r10 r10 +move r12 r9 +move r13 r8 +add r4 r11 r12 +add r15 r4 r13 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__constant_folding.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__constant_folding.snap new file mode 100644 index 0000000..075159a --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__constant_folding.snap @@ -0,0 +1,13 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +main: +move r8 15 + +## Optimized Output + +move r8 15 diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__dead_code_elimination.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__dead_code_elimination.snap new file mode 100644 index 0000000..9404104 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__dead_code_elimination.snap @@ -0,0 +1,32 @@ +--- +source: libs/integration_tests/src/lib.rs +assertion_line: 103 +expression: output +--- +## Unoptimized Output + +j main +compute: +pop r8 +push sp +push ra +move r9 20 +add r1 r8 1 +move r15 r1 +j __internal_L1 +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +push sp +push ra +move r9 20 +add r15 r8 1 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__function_with_call.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__function_with_call.snap new file mode 100644 index 0000000..264b371 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__function_with_call.snap @@ -0,0 +1,52 @@ +--- +source: libs/integration_tests/src/lib.rs +assertion_line: 70 +expression: output +--- +## Unoptimized Output + +j main +add: +pop r8 +pop r9 +push sp +push ra +add r1 r9 r8 +move r15 r1 +j __internal_L1 +__internal_L1: +pop ra +pop sp +j ra +main: +push sp +push ra +push 5 +push 10 +jal add +move r8 r15 +__internal_L2: +pop ra +pop sp +j ra + +## Optimized Output + +j 9 +pop r8 +pop r9 +push sp +push ra +add r15 r9 r8 +pop ra +pop sp +j ra +push sp +push ra +push 5 +push 10 +jal 1 +move r8 r15 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__larre_script.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__larre_script.snap new file mode 100644 index 0000000..4a7c529 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__larre_script.snap @@ -0,0 +1,223 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +waitForIdle: +push sp +push ra +yield +__internal_L2: +l r1 d0 Idle +seq r2 r1 0 +beqz r2 __internal_L3 +yield +j __internal_L2 +__internal_L3: +__internal_L1: +pop ra +pop sp +j ra +deposit: +push sp +push ra +s d0 Setting 1 +jal waitForIdle +move r1 r15 +s d0 Activate 1 +jal waitForIdle +move r2 r15 +s d1 Open 0 +__internal_L4: +pop ra +pop sp +j ra +checkAndHarvest: +pop r8 +push sp +push ra +sle r1 r8 1 +ls r15 d0 255 Seeding +slt r2 r15 1 +or r3 r1 r2 +beqz r3 __internal_L6 +j __internal_L5 +__internal_L6: +__internal_L7: +ls r15 d0 255 Mature +beqz r15 __internal_L8 +yield +s d0 Activate 1 +j __internal_L7 +__internal_L8: +ls r15 d0 255 Occupied +move r9 r15 +s d0 Setting 1 +push r8 +push r9 +jal waitForIdle +pop r9 +pop r8 +move r4 r15 +push r8 +push r9 +jal deposit +pop r9 +pop r8 +move r5 r15 +beqz r9 __internal_L9 +push r8 +push r9 +jal deposit +pop r9 +pop r8 +move r6 r15 +__internal_L9: +s d0 Setting r8 +push r8 +push r9 +jal waitForIdle +pop r9 +pop r8 +move r6 r15 +ls r15 d0 0 Occupied +beqz r15 __internal_L10 +s d0 Activate 1 +__internal_L10: +push r8 +push r9 +jal waitForIdle +pop r9 +pop r8 +move r7 r15 +__internal_L5: +pop ra +pop sp +j ra +main: +move r8 0 +__internal_L11: +yield +l r1 d0 Idle +seq r2 r1 0 +beqz r2 __internal_L13 +j __internal_L11 +__internal_L13: +add r3 r8 1 +sgt r4 r3 19 +add r5 r8 1 +select r6 r4 2 r5 +move r9 r6 +push r8 +push r9 +push r8 +jal checkAndHarvest +pop r9 +pop r8 +move r7 r15 +s d0 Setting r9 +move r8 r9 +j __internal_L11 +__internal_L12: + +## Optimized Output + +j 77 +push sp +push ra +yield +l r1 d0 Idle +bne r1 0 8 +yield +j 4 +pop ra +pop sp +j ra +push sp +push ra +s d0 Setting 1 +jal 1 +move r1 r15 +s d0 Activate 1 +jal 1 +move r2 r15 +s d1 Open 0 +pop ra +pop sp +j ra +pop r8 +push sp +push ra +sle r1 r8 1 +ls r15 d0 255 Seeding +slt r2 r15 1 +or r3 r1 r2 +beqz r3 32 +j 74 +ls r15 d0 255 Mature +beqz r15 37 +yield +s d0 Activate 1 +j 32 +ls r9 d0 255 Occupied +s d0 Setting 1 +push r8 +push r9 +jal 1 +pop r9 +pop r8 +move r4 r15 +push r8 +push r9 +jal 11 +pop r9 +pop r8 +move r5 r15 +beqz r9 58 +push r8 +push r9 +jal 11 +pop r9 +pop r8 +move r6 r15 +s d0 Setting r8 +push r8 +push r9 +jal 1 +pop r9 +pop r8 +move r6 r15 +ls r15 d0 0 Occupied +beqz r15 68 +s d0 Activate 1 +push r8 +push r9 +jal 1 +pop r9 +pop r8 +move r7 r15 +pop ra +pop sp +j ra +move r8 0 +yield +l r1 d0 Idle +bne r1 0 82 +j 78 +add r3 r8 1 +sgt r4 r3 19 +add r5 r8 1 +select r6 r4 2 r5 +move r9 r6 +push r8 +push r9 +push r8 +jal 23 +pop r9 +pop r8 +move r7 r15 +s d0 Setting r9 +move r8 r9 +j 78 diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__leaf_function_no_stack_frame.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__leaf_function_no_stack_frame.snap new file mode 100644 index 0000000..3362dcd --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__leaf_function_no_stack_frame.snap @@ -0,0 +1,30 @@ +--- +source: libs/integration_tests/src/lib.rs +assertion_line: 144 +expression: output +--- +## Unoptimized Output + +j main +increment: +pop r8 +push sp +push ra +add r1 r8 1 +move r8 r1 +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +push sp +push ra +add r1 r8 1 +move r8 r1 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__nested_function_calls.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__nested_function_calls.snap new file mode 100644 index 0000000..a9fbdd7 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__nested_function_calls.snap @@ -0,0 +1,107 @@ +--- +source: libs/integration_tests/src/lib.rs +assertion_line: 173 +expression: output +--- +## Unoptimized Output + +j main +add: +pop r8 +pop r9 +push sp +push ra +add r1 r9 r8 +move r15 r1 +j __internal_L1 +__internal_L1: +pop ra +pop sp +j ra +multiply: +pop r8 +pop r9 +push sp +push ra +mul r1 r9 2 +move r15 r1 +j __internal_L2 +__internal_L2: +pop ra +pop sp +j ra +complex: +pop r8 +pop r9 +push sp +push ra +push r8 +push r9 +push r9 +push r8 +jal add +pop r9 +pop r8 +move r10 r15 +push r8 +push r9 +push r10 +push r10 +push 2 +jal multiply +pop r10 +pop r9 +pop r8 +move r11 r15 +move r15 r11 +j __internal_L3 +__internal_L3: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +pop r9 +push sp +push ra +add r15 r9 r8 +pop ra +pop sp +j ra +pop r8 +pop r9 +push sp +push ra +add r15 r9 r9 +pop ra +pop sp +j ra +pop r8 +pop r9 +push sp +push ra +push r8 +push r9 +push r9 +push r8 +jal 1 +pop r9 +pop r8 +move r10 r15 +push r8 +push r9 +push r10 +push r10 +push 2 +jal 9 +pop r10 +pop r9 +pop r8 +move r11 r15 +move r15 r11 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__peephole_comparison_fusion.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__peephole_comparison_fusion.snap new file mode 100644 index 0000000..880034c --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__peephole_comparison_fusion.snap @@ -0,0 +1,34 @@ +--- +source: libs/integration_tests/src/lib.rs +assertion_line: 116 +expression: output +--- +## Unoptimized Output + +j main +compare: +pop r8 +pop r9 +push sp +push ra +sgt r1 r9 r8 +beqz r1 __internal_L2 +move r10 1 +__internal_L2: +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +pop r9 +push sp +push ra +ble r9 r8 7 +move r10 1 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__reagent_processing.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__reagent_processing.snap new file mode 100644 index 0000000..e58b6e7 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__reagent_processing.snap @@ -0,0 +1,111 @@ +--- +source: libs/integration_tests/src/lib.rs +expression: output +--- +## Unoptimized Output + +j main +main: +s d2 Mode 1 +s d2 On 0 +move r8 0 +move r9 0 +__internal_L1: +yield +l r1 d0 Reagents +move r10 r1 +sge r2 r10 100 +sge r3 r9 2 +or r4 r2 r3 +beqz r4 __internal_L3 +move r8 1 +__internal_L3: +slt r5 r10 100 +ls r15 d0 0 Occupied +seq r6 r15 0 +and r7 r5 r6 +add r1 r9 1 +select r2 r7 r1 0 +move r9 r2 +l r3 d0 Rpm +slt r4 r3 1 +and r5 r8 r4 +beqz r5 __internal_L4 +s d0 Open 1 +seq r6 r10 0 +ls r15 d0 0 Occupied +and r7 r6 r15 +seq r1 r7 0 +move r8 r1 +__internal_L4: +seq r6 r8 0 +s d0 On r6 +ls r15 d1 0 Quantity +move r11 r15 +l r7 d3 Pressure +sgt r1 r7 200 +beqz r1 __internal_L5 +j __internal_L1 +__internal_L5: +sgt r2 r11 0 +s d1 On r2 +sgt r3 r11 0 +s d1 Activate r3 +l r4 d3 Pressure +sgt r5 r4 0 +l r6 d1 Activate +or r7 r5 r6 +s d2 On r7 +l r1 d1 Activate +s db Setting r1 +j __internal_L1 +__internal_L2: + +## Optimized Output + +s d2 Mode 1 +s d2 On 0 +move r8 0 +move r9 0 +yield +l r10 d0 Reagents +sge r2 r10 100 +sge r3 r9 2 +or r4 r2 r3 +beqz r4 11 +move r8 1 +slt r5 r10 100 +ls r15 d0 0 Occupied +seq r6 r15 0 +and r7 r5 r6 +add r1 r9 1 +select r2 r7 r1 0 +move r9 r2 +l r3 d0 Rpm +slt r4 r3 1 +and r5 r8 r4 +beqz r5 27 +s d0 Open 1 +seq r6 r10 0 +ls r15 d0 0 Occupied +and r7 r6 r15 +seq r8 r7 0 +seq r6 r8 0 +s d0 On r6 +ls r15 d1 0 Quantity +move r11 r15 +l r7 d3 Pressure +ble r7 200 34 +j 4 +sgt r2 r11 0 +s d1 On r2 +sgt r3 r11 0 +s d1 Activate r3 +l r4 d3 Pressure +sgt r5 r4 0 +l r6 d1 Activate +or r7 r5 r6 +s d2 On r7 +l r1 d1 Activate +s db Setting r1 +j 4 diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__select_optimization.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__select_optimization.snap new file mode 100644 index 0000000..20172da --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__select_optimization.snap @@ -0,0 +1,36 @@ +--- +source: libs/integration_tests/src/lib.rs +assertion_line: 133 +expression: output +--- +## Unoptimized Output + +j main +ternary: +pop r8 +push sp +push ra +move r9 0 +beqz r8 __internal_L3 +move r9 10 +j __internal_L2 +__internal_L3: +move r9 20 +__internal_L2: +move r15 r9 +j __internal_L1 +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +push sp +push ra +select r15 r8 10 20 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__simple_leaf_function.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__simple_leaf_function.snap new file mode 100644 index 0000000..f093e06 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__simple_leaf_function.snap @@ -0,0 +1,26 @@ +--- +source: libs/integration_tests/src/lib.rs +assertion_line: 60 +expression: output +--- +## Unoptimized Output + +j main +test: +push sp +push ra +move r8 10 +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +push sp +push ra +move r8 10 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__strength_reduction.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__strength_reduction.snap new file mode 100644 index 0000000..a2615e0 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__strength_reduction.snap @@ -0,0 +1,30 @@ +--- +source: libs/integration_tests/src/lib.rs +assertion_line: 91 +expression: output +--- +## Unoptimized Output + +j main +double: +pop r8 +push sp +push ra +mul r1 r8 2 +move r15 r1 +j __internal_L1 +__internal_L1: +pop ra +pop sp +j ra + +## Optimized Output + +j main +pop r8 +push sp +push ra +add r15 r8 r8 +pop ra +pop sp +j ra diff --git a/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__tuples.snap b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__tuples.snap new file mode 100644 index 0000000..a525293 --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/snapshots/integration_tests__tests__tuples.snap @@ -0,0 +1,93 @@ +--- +source: libs/integration_tests/src/lib.rs +assertion_line: 206 +expression: output +--- +## Unoptimized Output + +j main +getSomethingElse: +pop r8 +push sp +push ra +add r1 r8 1 +move r15 r1 +j __internal_L1 +__internal_L1: +pop ra +pop sp +j ra +getSensorData: +push sp +push ra +l r1 d0 Vertical +push r1 +l r2 d0 Horizontal +push r2 +push 3 +jal getSomethingElse +move r3 r15 +push r3 +sub r0 sp 5 +get r0 db r0 +move r15 r0 +j __internal_L2 +__internal_L2: +sub r0 sp 4 +get ra db r0 +j ra +main: +__internal_L3: +yield +jal getSensorData +pop r0 +pop r9 +pop r8 +move sp r15 +jal getSensorData +pop r0 +pop r0 +pop r9 +move sp r15 +s db Setting r9 +j __internal_L3 +__internal_L4: + +## Optimized Output + +j 23 +pop r8 +push sp +push ra +add r15 r8 1 +pop ra +pop sp +j ra +push sp +push ra +l r1 d0 Vertical +push r1 +l r2 d0 Horizontal +push r2 +push 3 +jal 1 +move r3 r15 +push r3 +sub r0 sp 5 +get r15 db r0 +sub r0 sp 4 +get ra db r0 +j ra +yield +jal 8 +pop r0 +pop r9 +pop r8 +move sp r15 +jal 8 +pop r0 +pop r0 +pop r9 +move sp r15 +s db Setting r9 +j 23 diff --git a/rust_compiler/libs/integration_tests/src/test_files/reagent_processing.stlg b/rust_compiler/libs/integration_tests/src/test_files/reagent_processing.stlg new file mode 100644 index 0000000..dfe7a7f --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/test_files/reagent_processing.stlg @@ -0,0 +1,49 @@ +device combustion = "d0"; +device furnace = "d1"; +device vent = "d2"; +device gasSensor = "d3"; +device self = "db"; + +const MAX_WAIT_ITER = 2; +const STACK_SIZE = 100; + +vent.Mode = 1; // Vent inward into pipes +vent.On = false; +let ejecting = false; +let combustionWaitIter = 0; + +loop { + yield(); + + let reagentCount = combustion.Reagents; + + if (reagentCount >= STACK_SIZE || combustionWaitIter >= MAX_WAIT_ITER) { + ejecting = true; + } + + combustionWaitIter = (reagentCount < STACK_SIZE && !ls(combustion, 0, "Occupied")) + ? combustionWaitIter + 1 + : 0; + + if (ejecting && combustion.Rpm < 1) { + combustion.Open = true; + ejecting = !(reagentCount == 0 && ls(combustion, 0, "Occupied")); + } + + combustion.On = !ejecting; + + let furnaceAmt = ls(furnace, 0, "Quantity"); + + if (gasSensor.Pressure > 200) { + // safety: don't turn this on if we have gas still to process + // This should prevent pipes from blowing. This will NOT hault + // The in-progress burn job, but it'll prevent new jobs from + // blowing the walls or pipes. + continue; + } + + furnace.On = furnaceAmt > 0; + furnace.Activate = furnaceAmt > 0; + vent.On = gasSensor.Pressure > 0 || furnace.Activate; + self.Setting = furnace.Activate; +} \ No newline at end of file diff --git a/rust_compiler/libs/integration_tests/src/test_files/test_larre_script.stlg b/rust_compiler/libs/integration_tests/src/test_files/test_larre_script.stlg new file mode 100644 index 0000000..c17ea5c --- /dev/null +++ b/rust_compiler/libs/integration_tests/src/test_files/test_larre_script.stlg @@ -0,0 +1,72 @@ +/// Laree script V1 + +device self = "db"; +device larre = "d0"; +device exportChute = "d1"; + +const TOTAL_SLOTS = 19; +const EXPORT_CHUTE = 1; +const START_STATION = 2; + +let currentIndex = 0; + +/// Waits for the larre to be idle before continuing +fn waitForIdle() { + yield(); + while (!larre.Idle) { + yield(); + } +} + +/// Instructs the Larre to go to the chute and deposit +/// what is currently in its arm +fn deposit() { + larre.Setting = EXPORT_CHUTE; + waitForIdle(); + larre.Activate = true; + waitForIdle(); + exportChute.Open = false; +} + +/// This function is responsible for checking the plant under +/// the larre at this index, and harvesting if applicable +fn checkAndHarvest(currentIndex) { + if (currentIndex <= EXPORT_CHUTE || ls(larre, 255, "Seeding") < 1) { + return; + } + + // harvest from this device + while (ls(larre, 255, "Mature")) { + yield(); + larre.Activate = true; + } + let hasRemainingPlant = ls(larre, 255, "Occupied"); + + // move to the export chute + larre.Setting = EXPORT_CHUTE; + waitForIdle(); + deposit(); + if (hasRemainingPlant) { + deposit(); + } + + larre.Setting = currentIndex; + waitForIdle(); + + if (ls(larre, 0, "Occupied")) { + larre.Activate = true; + } + waitForIdle(); +} + +loop { + yield(); + if (!larre.Idle) { + continue; + } + let newIndex = currentIndex + 1 > TOTAL_SLOTS ? START_STATION : currentIndex + 1; + + checkAndHarvest(currentIndex); + larre.Setting = newIndex; + currentIndex = newIndex; +} \ No newline at end of file diff --git a/rust_compiler/libs/optimizer/OPTIMIZATION_IDEAS.md b/rust_compiler/libs/optimizer/OPTIMIZATION_IDEAS.md new file mode 100644 index 0000000..8d970ab --- /dev/null +++ b/rust_compiler/libs/optimizer/OPTIMIZATION_IDEAS.md @@ -0,0 +1,212 @@ +# Additional Optimization Opportunities for Slang IL Optimizer + +## Currently Implemented ✓ + +1. Constant Propagation - Folds math operations with known values +2. Register Forwarding - Eliminates intermediate moves +3. Function Call Optimization - Removes unnecessary push/pop around calls +4. Leaf Function Optimization - Removes RA save/restore for non-calling functions +5. Redundant Move Elimination - Removes `move rx rx` +6. Dead Code Elimination - Removes unreachable code after jumps + +## Proposed Additional Optimizations + +### 1. **Algebraic Simplification** 🔥 HIGH IMPACT + +Simplify mathematical identities: + +- `x + 0` → `x` (move) +- `x - 0` → `x` (move) +- `x * 1` → `x` (move) +- `x * 0` → `0` (move to constant) +- `x / 1` → `x` (move) +- `x - x` → `0` (move to constant) +- `x % 1` → `0` (move to constant) + +**Example:** + +``` +add r1 r2 0 → move r1 r2 +mul r3 r4 1 → move r3 r4 +mul r5 r6 0 → move r5 0 +``` + +### 2. **Strength Reduction** 🔥 HIGH IMPACT + +Replace expensive operations with cheaper ones: + +- `x * 2` → `add x x x` (addition is cheaper than multiplication) +- `x * power_of_2` → bit shifts (if IC10 supports) +- `x / 2` → bit shifts (if IC10 supports) + +**Example:** + +``` +mul r1 r2 2 → add r1 r2 r2 +``` + +### 3. **Peephole Optimization - Instruction Sequences** 🔥 MEDIUM-HIGH IMPACT + +Recognize and optimize common instruction patterns: + +#### Pattern: Conditional Branch Simplification + +``` +seq r1 ra rb → beq ra rb label +beqz r1 label (remove the seq entirely) + +sne r1 ra rb → bne ra rb label +beqz r1 label (remove the sne entirely) +``` + +#### Pattern: Double Move Elimination + +``` +move r1 r2 → move r1 r3 +move r1 r3 (remove first move if r1 not used between) +``` + +#### Pattern: Redundant Load Elimination + +If a register's value is already loaded and hasn't been clobbered: + +``` +l r1 d0 Temperature +... (no writes to r1) +l r1 d0 Temperature → (remove second load) +``` + +### 4. **Copy Propagation Enhancement** 🔥 MEDIUM IMPACT + +Current register forwarding is good, but we can extend it: + +- Track `move` chains: if `r1 = r2` and `r2 = 5`, propagate the `5` directly +- Eliminate the intermediate register if possible + +### 5. **Dead Store Elimination** 🔥 MEDIUM IMPACT + +Remove writes to registers that are never read before being overwritten: + +``` +move r1 5 +move r1 10 → move r1 10 + (first write is dead) +``` + +### 6. **Common Subexpression Elimination (CSE)** 🔥 MEDIUM-HIGH IMPACT + +Recognize when the same computation is done multiple times: + +``` +add r1 r8 r9 +add r2 r8 r9 → add r1 r8 r9 + move r2 r1 +``` + +This is especially valuable for expensive operations like: + +- Device loads (`l`) +- Math functions (sqrt, sin, cos, etc.) + +### 7. **Jump Threading** 🔥 LOW-MEDIUM IMPACT + +Optimize jump-to-jump sequences: + +``` +j label1 +... +label1: +j label2 → j label2 (rewrite first jump) +``` + +### 8. **Branch Folding** 🔥 LOW-MEDIUM IMPACT + +Merge consecutive branches to the same target: + +``` +bgt r1 r2 label +bgt r3 r4 label → Could potentially be optimized based on conditions +``` + +### 9. **Loop Invariant Code Motion** 🔥 MEDIUM-HIGH IMPACT + +Move calculations out of loops if they don't change: + +``` +loop: + mul r2 5 10 → mul r2 5 10 (hoisted before loop) + add r1 r1 r2 loop: + ... add r1 r1 r2 + j loop ... + j loop +``` + +### 10. **Select Instruction Optimization** 🔥 LOW-MEDIUM IMPACT + +The `select` instruction can sometimes replace branch patterns: + +``` +beq r1 r2 else +move r3 r4 +j end +else: +move r3 r5 → seq r6 r1 r2 +end: select r3 r6 r5 r4 +``` + +### 11. **Stack Access Pattern Optimization** 🔥 LOW IMPACT + +If we see repeated `sub r0 sp N` + `get`, we might be able to optimize by: + +- Caching the stack address in a register if used multiple times +- Combining sequential gets from adjacent stack slots + +### 12. **Inline Small Functions** 🔥 HIGH IMPACT (Complex to implement) + +For very small leaf functions (1-2 instructions), inline them at the call site: + +``` +calculateSum: + add r15 r8 r9 + j ra + +main: + push 5 → main: + push 10 add r15 5 10 + jal calculateSum +``` + +### 13. **Branch Prediction Hints** 🔥 LOW IMPACT + +Reorganize code to put likely branches inline (fall-through) and unlikely branches as jumps. + +### 14. **Register Coalescing** 🔥 MEDIUM IMPACT + +Reduce register pressure by reusing registers that have non-overlapping lifetimes. + +## Priority Implementation Order + +### Phase 1 (Quick Wins): + +1. Algebraic Simplification (easy, high impact) +2. Strength Reduction (easy, high impact) +3. Dead Store Elimination (medium complexity, good impact) + +### Phase 2 (Medium Effort): + +4. Peephole Optimizations - seq/beq pattern (medium, high impact) +5. Common Subexpression Elimination (medium, high impact) +6. Copy Propagation Enhancement (medium, medium impact) + +### Phase 3 (Advanced): + +7. Loop Invariant Code Motion (complex, high impact for loop-heavy code) +8. Function Inlining (complex, high impact) +9. Register Coalescing (complex, medium impact) + +## Testing Strategy + +- Add test cases for each optimization +- Ensure optimization preserves semantics (run existing tests after each) +- Measure code size reduction +- Consider adding benchmarks to measure game performance impact diff --git a/rust_compiler/libs/optimizer/src/algebraic_simplification.rs b/rust_compiler/libs/optimizer/src/algebraic_simplification.rs new file mode 100644 index 0000000..82bd39d --- /dev/null +++ b/rust_compiler/libs/optimizer/src/algebraic_simplification.rs @@ -0,0 +1,161 @@ +use il::{Instruction, InstructionNode, Operand}; +use rust_decimal::Decimal; + +/// Pass: Algebraic Simplification +/// Simplifies mathematical identities like x+0, x*1, x*0, etc. +pub fn algebraic_simplification<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + + for mut node in input { + let simplified = match &node.instruction { + // x + 0 = x + Instruction::Add(dst, a, Operand::Number(n)) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), a.clone())) + } + Instruction::Add(dst, Operand::Number(n), b) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), b.clone())) + } + + // x - 0 = x + Instruction::Sub(dst, a, Operand::Number(n)) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), a.clone())) + } + + // x * 1 = x + Instruction::Mul(dst, a, Operand::Number(n)) if *n == Decimal::from(1) => { + Some(Instruction::Move(dst.clone(), a.clone())) + } + Instruction::Mul(dst, Operand::Number(n), b) if *n == Decimal::from(1) => { + Some(Instruction::Move(dst.clone(), b.clone())) + } + + // x * 0 = 0 + Instruction::Mul(dst, _, Operand::Number(n)) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + Instruction::Mul(dst, Operand::Number(n), _) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + + // x / 1 = x + Instruction::Div(dst, a, Operand::Number(n)) if *n == Decimal::from(1) => { + Some(Instruction::Move(dst.clone(), a.clone())) + } + + // 0 / x = 0 (if x != 0, but we can't check at compile time for non-literals) + Instruction::Div(dst, Operand::Number(n), _) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + + // x % 1 = 0 + Instruction::Mod(dst, _, Operand::Number(n)) if *n == Decimal::from(1) => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + + // 0 % x = 0 + Instruction::Mod(dst, Operand::Number(n), _) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + + // x AND 0 = 0 + Instruction::And(dst, _, Operand::Number(n)) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + Instruction::And(dst, Operand::Number(n), _) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO))) + } + + // x OR 0 = x + Instruction::Or(dst, a, Operand::Number(n)) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), a.clone())) + } + Instruction::Or(dst, Operand::Number(n), b) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), b.clone())) + } + + // x XOR 0 = x + Instruction::Xor(dst, a, Operand::Number(n)) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), a.clone())) + } + Instruction::Xor(dst, Operand::Number(n), b) if n.is_zero() => { + Some(Instruction::Move(dst.clone(), b.clone())) + } + + _ => None, + }; + + if let Some(new) = simplified { + node.instruction = new; + changed = true; + } + + output.push(node); + } + + (output, changed) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_add_zero() { + let input = vec![InstructionNode::new( + Instruction::Add( + Operand::Register(1), + Operand::Register(2), + Operand::Number(Decimal::ZERO), + ), + None, + )]; + + let (output, changed) = algebraic_simplification(input); + assert!(changed); + assert!(matches!( + output[0].instruction, + Instruction::Move(Operand::Register(1), Operand::Register(2)) + )); + } + + #[test] + fn test_mul_one() { + let input = vec![InstructionNode::new( + Instruction::Mul( + Operand::Register(3), + Operand::Register(4), + Operand::Number(Decimal::ONE), + ), + None, + )]; + + let (output, changed) = algebraic_simplification(input); + assert!(changed); + assert!(matches!( + output[0].instruction, + Instruction::Move(Operand::Register(3), Operand::Register(4)) + )); + } + + #[test] + fn test_mul_zero() { + let input = vec![InstructionNode::new( + Instruction::Mul( + Operand::Register(5), + Operand::Register(6), + Operand::Number(Decimal::ZERO), + ), + None, + )]; + + let (output, changed) = algebraic_simplification(input); + assert!(changed); + assert!(matches!( + output[0].instruction, + Instruction::Move(Operand::Register(5), Operand::Number(_)) + )); + } +} diff --git a/rust_compiler/libs/optimizer/src/constant_propagation.rs b/rust_compiler/libs/optimizer/src/constant_propagation.rs new file mode 100644 index 0000000..c9fb9cd --- /dev/null +++ b/rust_compiler/libs/optimizer/src/constant_propagation.rs @@ -0,0 +1,156 @@ +use crate::helpers::get_destination_reg; +use il::{Instruction, InstructionNode, Operand}; +use rust_decimal::Decimal; + +/// Pass: Constant Propagation +/// Folds arithmetic operations when both operands are constant. +/// Also tracks register values and propagates them forward. +pub fn constant_propagation<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + let mut registers: [Option; 16] = [None; 16]; + + for mut node in input { + // Invalidate register tracking on label/call boundaries + match &node.instruction { + Instruction::LabelDef(_) | Instruction::JumpAndLink(_) => registers = [None; 16], + _ => {} + } + + let simplified = match &node.instruction { + Instruction::Move(dst, src) => resolve_value(src, ®isters) + .map(|val| Instruction::Move(dst.clone(), Operand::Number(val))), + Instruction::Add(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x + y), + Instruction::Sub(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x - y), + Instruction::Mul(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x * y), + Instruction::Div(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| { + if y.is_zero() { Decimal::ZERO } else { x / y } + }), + Instruction::Mod(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| { + if y.is_zero() { Decimal::ZERO } else { x % y } + }), + Instruction::BranchEq(a, b, l) => { + try_resolve_branch(a, b, l, ®isters, |x, y| x == y) + } + Instruction::BranchNe(a, b, l) => { + try_resolve_branch(a, b, l, ®isters, |x, y| x != y) + } + Instruction::BranchGt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x > y), + Instruction::BranchLt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x < y), + Instruction::BranchGe(a, b, l) => { + try_resolve_branch(a, b, l, ®isters, |x, y| x >= y) + } + Instruction::BranchLe(a, b, l) => { + try_resolve_branch(a, b, l, ®isters, |x, y| x <= y) + } + Instruction::BranchEqZero(a, l) => { + try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x == y) + } + Instruction::BranchNeZero(a, l) => { + try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x != y) + } + _ => None, + }; + + if let Some(new) = simplified { + node.instruction = new; + changed = true; + } + + // Update register tracking + match &node.instruction { + Instruction::Move(Operand::Register(r), src) => { + registers[*r as usize] = resolve_value(src, ®isters) + } + _ => { + if let Some(r) = get_destination_reg(&node.instruction) { + registers[r as usize] = None; + } + } + } + + // Filter out NOPs (empty labels from branch resolution) + if let Instruction::LabelDef(l) = &node.instruction + && l.is_empty() + { + changed = true; + continue; + } + + output.push(node); + } + (output, changed) +} + +fn resolve_value(op: &Operand, regs: &[Option; 16]) -> Option { + match op { + Operand::Number(n) => Some(*n), + Operand::Register(r) => regs[*r as usize], + _ => None, + } +} + +fn try_fold_math<'a, F>( + dst: &Operand<'a>, + a: &Operand<'a>, + b: &Operand<'a>, + regs: &[Option; 16], + op: F, +) -> Option> +where + F: Fn(Decimal, Decimal) -> Decimal, +{ + let val_a = resolve_value(a, regs)?; + let val_b = resolve_value(b, regs)?; + Some(Instruction::Move( + dst.clone(), + Operand::Number(op(val_a, val_b)), + )) +} + +fn try_resolve_branch<'a, F>( + a: &Operand<'a>, + b: &Operand<'a>, + label: &Operand<'a>, + regs: &[Option; 16], + check: F, +) -> Option> +where + F: Fn(Decimal, Decimal) -> bool, +{ + let val_a = resolve_value(a, regs)?; + let val_b = resolve_value(b, regs)?; + if check(val_a, val_b) { + Some(Instruction::Jump(label.clone())) + } else { + Some(Instruction::LabelDef("".into())) // NOP + } +} + +#[cfg(test)] +mod tests { + use super::*; + use il::InstructionNode; + + #[test] + fn test_fold_add() { + let input = vec![InstructionNode::new( + Instruction::Add( + Operand::Register(1), + Operand::Number(5.into()), + Operand::Number(3.into()), + ), + None, + )]; + + let (output, changed) = constant_propagation(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::Move(Operand::Register(1), Operand::Number(_)) + )); + } +} diff --git a/rust_compiler/libs/optimizer/src/dead_code.rs b/rust_compiler/libs/optimizer/src/dead_code.rs new file mode 100644 index 0000000..b7dae22 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/dead_code.rs @@ -0,0 +1,148 @@ +use il::{Instruction, InstructionNode, Operand}; + +/// Pass: Redundant Move Elimination +/// Removes moves where source and destination are the same: `move rx rx` +pub fn remove_redundant_moves<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + for node in input { + if let Instruction::Move(dst, src) = &node.instruction + && dst == src + { + changed = true; + continue; + } + output.push(node); + } + (output, changed) +} + +/// Pass: Dead Code Elimination +/// Removes unreachable code after unconditional jumps. +pub fn remove_unreachable_code<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + let mut dead = false; + for node in input { + if let Instruction::LabelDef(_) = node.instruction { + dead = false; + } + if dead { + changed = true; + continue; + } + if let Instruction::Jump(_) = node.instruction { + dead = true + } + output.push(node); + } + (output, changed) +} + +/// Pass: Remove Redundant Jumps +/// Removes jumps to the next instruction (after label resolution). +/// Must run AFTER label resolution since it needs line numbers. +/// This pass also adjusts all jump targets to account for removed instructions. +pub fn remove_redundant_jumps<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + let mut removed_lines = Vec::new(); + + // First pass: identify redundant jumps + for (i, node) in input.iter().enumerate() { + // Check if this is a jump to the next line number + if let Instruction::Jump(Operand::Number(target)) = &node.instruction { + // Current line number is i, next line number is i+1 + // If jump target equals the next line, it's redundant + if target.to_string().parse::().ok() == Some(i + 1) { + changed = true; + removed_lines.push(i); + continue; // Skip this redundant jump + } + } + output.push(node.clone()); + } + + // Second pass: adjust all jump/branch targets to account for removed lines + if changed { + for node in &mut output { + // Helper to adjust a target line number + let adjust_target = |target_line: usize| -> usize { + // Count how many removed lines are before the target + let offset = removed_lines + .iter() + .filter(|&&removed| removed < target_line) + .count(); + target_line.saturating_sub(offset) + }; + + match &mut node.instruction { + Instruction::Jump(Operand::Number(target)) + | Instruction::JumpAndLink(Operand::Number(target)) => { + if let Ok(target_line) = target.to_string().parse::() { + *target = rust_decimal::Decimal::from(adjust_target(target_line)); + } + } + Instruction::BranchEq(_, _, Operand::Number(target)) + | Instruction::BranchNe(_, _, Operand::Number(target)) + | Instruction::BranchGt(_, _, Operand::Number(target)) + | Instruction::BranchLt(_, _, Operand::Number(target)) + | Instruction::BranchGe(_, _, Operand::Number(target)) + | Instruction::BranchLe(_, _, Operand::Number(target)) + | Instruction::BranchEqZero(_, Operand::Number(target)) + | Instruction::BranchNeZero(_, Operand::Number(target)) => { + if let Ok(target_line) = target.to_string().parse::() { + *target = rust_decimal::Decimal::from(adjust_target(target_line)); + } + } + _ => {} + } + } + } + + (output, changed) +} + +#[cfg(test)] +mod tests { + use super::*; + use il::{Instruction, InstructionNode, Operand}; + + #[test] + fn test_remove_redundant_move() { + let input = vec![InstructionNode::new( + Instruction::Move(Operand::Register(1), Operand::Register(1)), + None, + )]; + + let (output, changed) = remove_redundant_moves(input); + assert!(changed); + assert_eq!(output.len(), 0); + } + + #[test] + fn test_remove_unreachable() { + let input = vec![ + InstructionNode::new(Instruction::Jump(Operand::Label("main".into())), None), + InstructionNode::new( + Instruction::Add( + Operand::Register(1), + Operand::Number(1.into()), + Operand::Number(2.into()), + ), + None, + ), + InstructionNode::new(Instruction::LabelDef("main".into()), None), + ]; + + let (output, changed) = remove_unreachable_code(input); + assert!(changed); + assert_eq!(output.len(), 2); + } +} diff --git a/rust_compiler/libs/optimizer/src/dead_store_elimination.rs b/rust_compiler/libs/optimizer/src/dead_store_elimination.rs new file mode 100644 index 0000000..f16a034 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/dead_store_elimination.rs @@ -0,0 +1,126 @@ +use crate::helpers::get_destination_reg; +use il::{Instruction, InstructionNode}; +use std::collections::HashMap; + +/// Pass: Dead Store Elimination +/// Removes writes to registers that are never read before being overwritten. +pub fn dead_store_elimination<'a>( + input: Vec>, +) -> (Vec>, bool) { + // Forward pass: Remove writes that are immediately overwritten + let (input, forward_changed) = eliminate_overwritten_stores(input); + + // Note: Backward pass disabled for now - it needs more work to handle all cases correctly + // The forward pass is sufficient for most common patterns + // (e.g., move r6 r15 immediately followed by move r6 r15 again) + + (input, forward_changed) +} + +/// Forward pass: Remove stores that are overwritten before being read +fn eliminate_overwritten_stores<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut last_write: HashMap = HashMap::new(); + let mut to_remove = Vec::new(); + + // Scan for dead writes + for (i, node) in input.iter().enumerate() { + // Never remove Pop instructions - they have critical side effects on the stack pointer + if matches!(node.instruction, Instruction::Pop(_)) { + continue; + } + + if let Some(dest_reg) = get_destination_reg(&node.instruction) { + // If this register was written before and hasn't been read, previous write is dead + if let Some(&prev_idx) = last_write.get(&dest_reg) { + // Check if the value was ever used between prev_idx and current + let was_used = input[prev_idx + 1..i] + .iter() + .any(|n| reg_is_read_or_affects_control(&n.instruction, dest_reg)) + // Also check if current instruction reads the register before overwriting it + || reg_is_read_or_affects_control(&node.instruction, dest_reg); + + if !was_used { + // Previous write was dead + to_remove.push(prev_idx); + } + } + + // Update last write location + last_write.insert(dest_reg, i); + } + + // Handle control flow instructions + match &node.instruction { + // JumpAndLink (function calls) only clobbers the return register (r15) + // We can continue tracking other registers across function calls + Instruction::JumpAndLink(_) => { + last_write.remove(&15); + } + // Other control flow instructions create complexity - clear all tracking + Instruction::Jump(_) + | Instruction::LabelDef(_) + | Instruction::BranchEq(_, _, _) + | Instruction::BranchNe(_, _, _) + | Instruction::BranchGt(_, _, _) + | Instruction::BranchLt(_, _, _) + | Instruction::BranchGe(_, _, _) + | Instruction::BranchLe(_, _, _) + | Instruction::BranchEqZero(_, _) + | Instruction::BranchNeZero(_, _) => { + last_write.clear(); + } + _ => {} + } + } + + if !to_remove.is_empty() { + let output = input + .into_iter() + .enumerate() + .filter_map(|(i, node)| { + if to_remove.contains(&i) { + None + } else { + Some(node) + } + }) + .collect(); + (output, true) + } else { + (input, false) + } +} + +/// Simplified check: Does this instruction read the register? +fn reg_is_read_or_affects_control(instr: &Instruction, reg: u8) -> bool { + use crate::helpers::reg_is_read; + + // If it reads the register, it's used + reg_is_read(instr, reg) +} + +#[cfg(test)] +mod tests { + use super::*; + use il::Operand; + + #[test] + fn test_dead_store() { + let input = vec![ + InstructionNode::new( + Instruction::Move(Operand::Register(1), Operand::Number(5.into())), + None, + ), + InstructionNode::new( + Instruction::Move(Operand::Register(1), Operand::Number(10.into())), + None, + ), + ]; + + let (output, changed) = dead_store_elimination(input); + assert!(changed); + assert_eq!(output.len(), 1); + } +} diff --git a/rust_compiler/libs/optimizer/src/function_call_optimization.rs b/rust_compiler/libs/optimizer/src/function_call_optimization.rs new file mode 100644 index 0000000..3233ad4 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/function_call_optimization.rs @@ -0,0 +1,160 @@ +use crate::helpers::get_destination_reg; +use il::{Instruction, InstructionNode, Operand}; +use rust_decimal::Decimal; +use std::collections::{HashMap, HashSet}; + +/// Analyzes which registers are written to by each function label. +fn analyze_clobbers(instructions: &[InstructionNode]) -> HashMap> { + let mut clobbers = HashMap::new(); + let mut current_label = None; + + for node in instructions { + if let Instruction::LabelDef(label) = &node.instruction { + current_label = Some(label.to_string()); + clobbers.insert(label.to_string(), HashSet::new()); + } + + if let Some(label) = ¤t_label + && let Some(reg) = get_destination_reg(&node.instruction) + && let Some(set) = clobbers.get_mut(label) + { + set.insert(reg); + } + } + clobbers +} + +/// Pass: Function Call Optimization +/// Removes Push/Restore pairs surrounding a JAL if the target function does not clobber that register. +pub fn optimize_function_calls<'a>( + input: Vec>, +) -> (Vec>, bool) { + let clobbers = analyze_clobbers(&input); + let mut changed = false; + let mut to_remove = HashSet::new(); + let mut stack_adjustments = HashMap::new(); + + let mut i = 0; + while i < input.len() { + if let Instruction::JumpAndLink(Operand::Label(target)) = &input[i].instruction { + let target_key = target.to_string(); + + if let Some(func_clobbers) = clobbers.get(&target_key) { + // 1. Identify Pushes immediately preceding the JAL + let mut pushes = Vec::new(); // (index, register) + let mut scan_back = i.saturating_sub(1); + while scan_back > 0 { + if to_remove.contains(&scan_back) { + scan_back -= 1; + continue; + } + if let Instruction::Push(Operand::Register(r)) = &input[scan_back].instruction { + pushes.push((scan_back, *r)); + scan_back -= 1; + } else { + break; + } + } + + // 2. Identify Restores immediately following the JAL + let mut restores = Vec::new(); // (index_of_get, register, index_of_sub) + let mut scan_fwd = i + 1; + while scan_fwd < input.len() { + // Skip 'sub r0 sp X' + if let Instruction::Sub(Operand::Register(0), Operand::StackPointer, _) = + &input[scan_fwd].instruction + { + // Check next instruction for the Get + if scan_fwd + 1 < input.len() + && let Instruction::Get(Operand::Register(r), _, Operand::Register(0)) = + &input[scan_fwd + 1].instruction + { + restores.push((scan_fwd + 1, *r, scan_fwd)); + scan_fwd += 2; + continue; + } + } + break; + } + + // 3. Stack Cleanup + let cleanup_idx = scan_fwd; + let has_cleanup = if cleanup_idx < input.len() { + matches!( + input[cleanup_idx].instruction, + Instruction::Sub( + Operand::StackPointer, + Operand::StackPointer, + Operand::Number(_) + ) + ) + } else { + false + }; + + // SAFEGUARD: Check Counts! + let mut push_counts = HashMap::new(); + for (_, r) in &pushes { + *push_counts.entry(*r).or_insert(0) += 1; + } + + let mut restore_counts = HashMap::new(); + for (_, r, _) in &restores { + *restore_counts.entry(*r).or_insert(0) += 1; + } + + let counts_match = push_counts + .iter() + .all(|(reg, count)| restore_counts.get(reg).unwrap_or(&0) == count); + let counts_match_reverse = restore_counts + .iter() + .all(|(reg, count)| push_counts.get(reg).unwrap_or(&0) == count); + + // Clobber Check + let all_pushes_safe = pushes.iter().all(|(_, r)| !func_clobbers.contains(r)); + + if all_pushes_safe && has_cleanup && counts_match && counts_match_reverse { + // Remove all pushes/restores + for (p_idx, _) in pushes { + to_remove.insert(p_idx); + } + for (g_idx, _, s_idx) in restores { + to_remove.insert(g_idx); + to_remove.insert(s_idx); + } + + // Reduce stack cleanup amount + let num_removed = push_counts.values().sum::() as i64; + stack_adjustments.insert(cleanup_idx, num_removed); + changed = true; + } + } + } + i += 1; + } + + if changed { + let mut clean = Vec::with_capacity(input.len()); + for (idx, mut node) in input.into_iter().enumerate() { + if to_remove.contains(&idx) { + continue; + } + + // Apply stack adjustment + if let Some(reduction) = stack_adjustments.get(&idx) + && let Instruction::Sub(dst, a, Operand::Number(n)) = &node.instruction + { + let new_n = n - Decimal::from(*reduction); + if new_n.is_zero() { + continue; + } + node.instruction = Instruction::Sub(dst.clone(), a.clone(), Operand::Number(new_n)); + } + + clean.push(node); + } + return (clean, changed); + } + + (input, false) +} diff --git a/rust_compiler/libs/optimizer/src/helpers.rs b/rust_compiler/libs/optimizer/src/helpers.rs new file mode 100644 index 0000000..f396969 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/helpers.rs @@ -0,0 +1,172 @@ +use il::{Instruction, Operand}; + +/// Returns the register number written to by an instruction, if any. +pub fn get_destination_reg(instr: &Instruction) -> Option { + match instr { + Instruction::Move(Operand::Register(r), _) + | Instruction::Add(Operand::Register(r), _, _) + | Instruction::Sub(Operand::Register(r), _, _) + | Instruction::Mul(Operand::Register(r), _, _) + | Instruction::Div(Operand::Register(r), _, _) + | Instruction::Mod(Operand::Register(r), _, _) + | Instruction::Pow(Operand::Register(r), _, _) + | Instruction::Load(Operand::Register(r), _, _) + | Instruction::LoadSlot(Operand::Register(r), _, _, _) + | Instruction::LoadBatch(Operand::Register(r), _, _, _) + | Instruction::LoadBatchNamed(Operand::Register(r), _, _, _, _) + | Instruction::SetEq(Operand::Register(r), _, _) + | Instruction::SetNe(Operand::Register(r), _, _) + | Instruction::SetGt(Operand::Register(r), _, _) + | Instruction::SetLt(Operand::Register(r), _, _) + | Instruction::SetGe(Operand::Register(r), _, _) + | Instruction::SetLe(Operand::Register(r), _, _) + | Instruction::And(Operand::Register(r), _, _) + | Instruction::Or(Operand::Register(r), _, _) + | Instruction::Xor(Operand::Register(r), _, _) + | Instruction::Peek(Operand::Register(r)) + | Instruction::Get(Operand::Register(r), _, _) + | Instruction::Select(Operand::Register(r), _, _, _) + | Instruction::Rand(Operand::Register(r)) + | Instruction::Acos(Operand::Register(r), _) + | Instruction::Asin(Operand::Register(r), _) + | Instruction::Atan(Operand::Register(r), _) + | Instruction::Atan2(Operand::Register(r), _, _) + | Instruction::Abs(Operand::Register(r), _) + | Instruction::Ceil(Operand::Register(r), _) + | Instruction::Cos(Operand::Register(r), _) + | Instruction::Floor(Operand::Register(r), _) + | Instruction::Log(Operand::Register(r), _) + | Instruction::Max(Operand::Register(r), _, _) + | Instruction::Min(Operand::Register(r), _, _) + | Instruction::Sin(Operand::Register(r), _) + | Instruction::Sqrt(Operand::Register(r), _) + | Instruction::Tan(Operand::Register(r), _) + | Instruction::Trunc(Operand::Register(r), _) + | Instruction::LoadReagent(Operand::Register(r), _, _, _) + | Instruction::Pop(Operand::Register(r)) => Some(*r), + _ => None, + } +} + +/// Creates a new instruction with the destination register changed. +pub fn set_destination_reg<'a>(instr: &Instruction<'a>, new_reg: u8) -> Option> { + let r = Operand::Register(new_reg); + match instr { + Instruction::Move(_, b) => Some(Instruction::Move(r, b.clone())), + Instruction::Add(_, a, b) => Some(Instruction::Add(r, a.clone(), b.clone())), + Instruction::Sub(_, a, b) => Some(Instruction::Sub(r, a.clone(), b.clone())), + Instruction::Mul(_, a, b) => Some(Instruction::Mul(r, a.clone(), b.clone())), + Instruction::Div(_, a, b) => Some(Instruction::Div(r, a.clone(), b.clone())), + Instruction::Mod(_, a, b) => Some(Instruction::Mod(r, a.clone(), b.clone())), + Instruction::Pow(_, a, b) => Some(Instruction::Pow(r, a.clone(), b.clone())), + Instruction::Load(_, a, b) => Some(Instruction::Load(r, a.clone(), b.clone())), + Instruction::LoadSlot(_, a, b, c) => { + Some(Instruction::LoadSlot(r, a.clone(), b.clone(), c.clone())) + } + Instruction::LoadBatch(_, a, b, c) => { + Some(Instruction::LoadBatch(r, a.clone(), b.clone(), c.clone())) + } + Instruction::LoadBatchNamed(_, a, b, c, d) => Some(Instruction::LoadBatchNamed( + r, + a.clone(), + b.clone(), + c.clone(), + d.clone(), + )), + Instruction::LoadReagent(_, b, c, d) => { + Some(Instruction::LoadReagent(r, b.clone(), c.clone(), d.clone())) + } + Instruction::SetEq(_, a, b) => Some(Instruction::SetEq(r, a.clone(), b.clone())), + Instruction::SetNe(_, a, b) => Some(Instruction::SetNe(r, a.clone(), b.clone())), + Instruction::SetGt(_, a, b) => Some(Instruction::SetGt(r, a.clone(), b.clone())), + Instruction::SetLt(_, a, b) => Some(Instruction::SetLt(r, a.clone(), b.clone())), + Instruction::SetGe(_, a, b) => Some(Instruction::SetGe(r, a.clone(), b.clone())), + Instruction::SetLe(_, a, b) => Some(Instruction::SetLe(r, a.clone(), b.clone())), + Instruction::And(_, a, b) => Some(Instruction::And(r, a.clone(), b.clone())), + Instruction::Or(_, a, b) => Some(Instruction::Or(r, a.clone(), b.clone())), + Instruction::Xor(_, a, b) => Some(Instruction::Xor(r, a.clone(), b.clone())), + Instruction::Peek(_) => Some(Instruction::Peek(r)), + Instruction::Get(_, a, b) => Some(Instruction::Get(r, a.clone(), b.clone())), + Instruction::Select(_, a, b, c) => { + Some(Instruction::Select(r, a.clone(), b.clone(), c.clone())) + } + Instruction::Rand(_) => Some(Instruction::Rand(r)), + Instruction::Pop(_) => Some(Instruction::Pop(r)), + Instruction::Acos(_, a) => Some(Instruction::Acos(r, a.clone())), + Instruction::Asin(_, a) => Some(Instruction::Asin(r, a.clone())), + Instruction::Atan(_, a) => Some(Instruction::Atan(r, a.clone())), + Instruction::Atan2(_, a, b) => Some(Instruction::Atan2(r, a.clone(), b.clone())), + Instruction::Abs(_, a) => Some(Instruction::Abs(r, a.clone())), + Instruction::Ceil(_, a) => Some(Instruction::Ceil(r, a.clone())), + Instruction::Cos(_, a) => Some(Instruction::Cos(r, a.clone())), + Instruction::Floor(_, a) => Some(Instruction::Floor(r, a.clone())), + Instruction::Log(_, a) => Some(Instruction::Log(r, a.clone())), + Instruction::Max(_, a, b) => Some(Instruction::Max(r, a.clone(), b.clone())), + Instruction::Min(_, a, b) => Some(Instruction::Min(r, a.clone(), b.clone())), + Instruction::Sin(_, a) => Some(Instruction::Sin(r, a.clone())), + Instruction::Sqrt(_, a) => Some(Instruction::Sqrt(r, a.clone())), + Instruction::Tan(_, a) => Some(Instruction::Tan(r, a.clone())), + Instruction::Trunc(_, a) => Some(Instruction::Trunc(r, a.clone())), + _ => None, + } +} + +/// Checks if a register is read by an instruction. +pub fn reg_is_read(instr: &Instruction, reg: u8) -> bool { + let check = |op: &Operand| matches!(op, Operand::Register(r) if *r == reg); + + match instr { + Instruction::Move(_, a) => check(a), + Instruction::Add(_, a, b) + | Instruction::Sub(_, a, b) + | Instruction::Mul(_, a, b) + | Instruction::Div(_, a, b) + | Instruction::Mod(_, a, b) + | Instruction::Pow(_, a, b) => check(a) || check(b), + Instruction::Load(_, a, _) => check(a), + Instruction::Store(a, _, b) => check(a) || check(b), + Instruction::BranchEq(a, b, _) + | Instruction::BranchNe(a, b, _) + | Instruction::BranchGt(a, b, _) + | Instruction::BranchLt(a, b, _) + | Instruction::BranchGe(a, b, _) + | Instruction::BranchLe(a, b, _) => check(a) || check(b), + Instruction::BranchEqZero(a, _) | Instruction::BranchNeZero(a, _) => check(a), + Instruction::LoadReagent(_, device, _, item_hash) => check(device) || check(item_hash), + Instruction::LoadSlot(_, dev, slot, _) => check(dev) || check(slot), + Instruction::LoadBatch(_, dev, _, mode) => check(dev) || check(mode), + Instruction::LoadBatchNamed(_, d_hash, n_hash, _, mode) => { + check(d_hash) || check(n_hash) || check(mode) + } + Instruction::SetEq(_, a, b) + | Instruction::SetNe(_, a, b) + | Instruction::SetGt(_, a, b) + | Instruction::SetLt(_, a, b) + | Instruction::SetGe(_, a, b) + | Instruction::SetLe(_, a, b) + | Instruction::And(_, a, b) + | Instruction::Or(_, a, b) + | Instruction::Xor(_, a, b) => check(a) || check(b), + Instruction::Push(a) => check(a), + Instruction::Get(_, a, b) => check(a) || check(b), + Instruction::Put(a, b, c) => check(a) || check(b) || check(c), + Instruction::Select(_, a, b, c) => check(a) || check(b) || check(c), + Instruction::Sleep(a) => check(a), + Instruction::Acos(_, a) + | Instruction::Asin(_, a) + | Instruction::Atan(_, a) + | Instruction::Abs(_, a) + | Instruction::Ceil(_, a) + | Instruction::Cos(_, a) + | Instruction::Floor(_, a) + | Instruction::Log(_, a) + | Instruction::Sin(_, a) + | Instruction::Sqrt(_, a) + | Instruction::Tan(_, a) + | Instruction::Trunc(_, a) => check(a), + Instruction::Atan2(_, a, b) | Instruction::Max(_, a, b) | Instruction::Min(_, a, b) => { + check(a) || check(b) + } + _ => false, + } +} diff --git a/rust_compiler/libs/optimizer/src/label_resolution.rs b/rust_compiler/libs/optimizer/src/label_resolution.rs new file mode 100644 index 0000000..a801166 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/label_resolution.rs @@ -0,0 +1,70 @@ +use il::{Instruction, InstructionNode, Operand}; +use rust_decimal::Decimal; +use std::collections::HashMap; + +/// Pass: Resolve Labels +/// Converts all Jump/Branch labels to absolute line numbers and removes LabelDefs. +pub fn resolve_labels<'a>(input: Vec>) -> Vec> { + let mut label_map: HashMap = HashMap::new(); + let mut line_number = 0; + + // Build Label Map (filtering out LabelDefs from the count) + for node in &input { + if let Instruction::LabelDef(name) = &node.instruction { + label_map.insert(name.to_string(), line_number); + } else { + line_number += 1; + } + } + + let mut output = Vec::with_capacity(input.len()); + + // Rewrite Jumps and Filter Labels + for mut node in input { + // Helper to get line number as Decimal operand + let get_line = |lbl: &Operand| -> Option> { + if let Operand::Label(name) = lbl { + label_map + .get(name.as_ref()) + .map(|&l| Operand::Number(Decimal::from(l))) + } else { + None + } + }; + + match &mut node.instruction { + Instruction::LabelDef(_) => continue, // Strip labels + + // Jumps + Instruction::Jump(op) => { + if let Some(num) = get_line(op) { + *op = num; + } + } + Instruction::JumpAndLink(op) => { + if let Some(num) = get_line(op) { + *op = num; + } + } + Instruction::BranchEq(_, _, op) + | Instruction::BranchNe(_, _, op) + | Instruction::BranchGt(_, _, op) + | Instruction::BranchLt(_, _, op) + | Instruction::BranchGe(_, _, op) + | Instruction::BranchLe(_, _, op) => { + if let Some(num) = get_line(op) { + *op = num; + } + } + Instruction::BranchEqZero(_, op) | Instruction::BranchNeZero(_, op) => { + if let Some(num) = get_line(op) { + *op = num; + } + } + _ => {} + } + output.push(node); + } + + output +} diff --git a/rust_compiler/libs/optimizer/src/leaf_function_optimization.rs b/rust_compiler/libs/optimizer/src/leaf_function_optimization.rs new file mode 100644 index 0000000..d65f10a --- /dev/null +++ b/rust_compiler/libs/optimizer/src/leaf_function_optimization.rs @@ -0,0 +1,41 @@ +use crate::leaf_function::find_leaf_functions; +use il::InstructionNode; + +/// Pass: Leaf Function Optimization +/// If a function makes no calls (is a leaf), it doesn't need to save/restore `ra`. +/// +/// NOTE: This optimization is DISABLED due to correctness issues. +/// The optimization was designed for a specific calling convention (GET/PUT for RA) +/// but the compiler generates POP ra for return address restoration. Without proper +/// tracking of both conventions and validation of balanced push/pop pairs, this +/// optimization corrupts the stack frame by: +/// +/// 1. Removing `push ra` but not `pop ra`, leaving unbalanced push/pop pairs +/// 2. Not accounting for parameter pops that occur before `push sp` +/// 3. Assuming all RA restoration uses GET instruction, but code uses POP +/// +/// Example of broken optimization: +/// ``` +/// Unoptimized: Optimized (BROKEN): +/// compare: pop r8 +/// pop r8 pop r9 +/// pop r9 ble r9 r8 5 +/// push sp move r10 1 +/// push ra j ra +/// sgt r1 r9 r8 ^ Missing stack frame! +/// ... +/// pop ra +/// pop sp +/// j ra +/// ``` +/// +/// Future work: Fix by handling both POP and GET calling conventions, validating +/// balanced push/pop pairs, and accounting for parameter pops. +pub fn optimize_leaf_functions<'a>( + input: Vec>, +) -> (Vec>, bool) { + // Optimization disabled - returns input unchanged + #[allow(unused)] + let _leaves = find_leaf_functions(&input); + (input, false) +} diff --git a/rust_compiler/libs/optimizer/src/lib.rs b/rust_compiler/libs/optimizer/src/lib.rs index 855f913..fbb57fb 100644 --- a/rust_compiler/libs/optimizer/src/lib.rs +++ b/rust_compiler/libs/optimizer/src/lib.rs @@ -1,9 +1,30 @@ -use il::{Instruction, InstructionNode, Instructions, Operand}; -use rust_decimal::Decimal; -use std::collections::{HashMap, HashSet}; +use il::Instructions; +// Optimization pass modules +mod helpers; mod leaf_function; -use leaf_function::find_leaf_functions; + +mod algebraic_simplification; +mod constant_propagation; +mod dead_code; +mod dead_store_elimination; +mod function_call_optimization; +mod label_resolution; +mod leaf_function_optimization; +mod peephole_optimization; +mod register_forwarding; +mod strength_reduction; + +use algebraic_simplification::algebraic_simplification; +use constant_propagation::constant_propagation; +use dead_code::{remove_redundant_jumps, remove_redundant_moves, remove_unreachable_code}; +use dead_store_elimination::dead_store_elimination; +use function_call_optimization::optimize_function_calls; +use label_resolution::resolve_labels; +use leaf_function_optimization::optimize_leaf_functions; +use peephole_optimization::peephole_optimization; +use register_forwarding::register_forwarding; +use strength_reduction::strength_reduction; /// Entry point for the optimizer. pub fn optimize<'a>(instructions: Instructions<'a>) -> Instructions<'a> { @@ -38,845 +59,42 @@ pub fn optimize<'a>(instructions: Instructions<'a>) -> Instructions<'a> { instructions = new_inst; changed |= c4; - // Pass 5: Redundant Move Elimination - let (new_inst, c5) = remove_redundant_moves(instructions); + // Pass 5: Algebraic Simplification (Identity operations) + let (new_inst, c5) = algebraic_simplification(instructions); instructions = new_inst; changed |= c5; - // Pass 6: Dead Code Elimination - let (new_inst, c6) = remove_unreachable_code(instructions); + // Pass 6: Strength Reduction (Replace expensive ops with cheaper ones) + let (new_inst, c6) = strength_reduction(instructions); instructions = new_inst; changed |= c6; + + // Pass 7: Peephole Optimizations (Common patterns) + let (new_inst, c7) = peephole_optimization(instructions); + instructions = new_inst; + changed |= c7; + + // Pass 8: Dead Store Elimination + let (new_inst, c8) = dead_store_elimination(instructions); + instructions = new_inst; + changed |= c8; + + // Pass 9: Redundant Move Elimination + let (new_inst, c9) = remove_redundant_moves(instructions); + instructions = new_inst; + changed |= c9; + + // Pass 10: Dead Code Elimination + let (new_inst, c10) = remove_unreachable_code(instructions); + instructions = new_inst; + changed |= c10; } // Final Pass: Resolve Labels to Line Numbers - Instructions::new(resolve_labels(instructions)) -} - -/// Helper: Check if a function body contains unsafe stack manipulation. -/// Returns true if the function modifies SP in a way that makes static RA offset analysis unsafe. -fn function_has_complex_stack_ops( - instructions: &[InstructionNode], - start_idx: usize, - end_idx: usize, -) -> bool { - for instruction in instructions.iter().take(end_idx).skip(start_idx) { - match instruction.instruction { - Instruction::Push(_) | Instruction::Pop(_) => return true, - // Check for explicit SP modification - Instruction::Add(Operand::StackPointer, _, _) - | Instruction::Sub(Operand::StackPointer, _, _) - | Instruction::Mul(Operand::StackPointer, _, _) - | Instruction::Div(Operand::StackPointer, _, _) - | Instruction::Move(Operand::StackPointer, _) => return true, - _ => {} - } - } - false -} - -/// Pass: Leaf Function Optimization -/// If a function makes no calls (is a leaf), it doesn't need to save/restore `ra`. -fn optimize_leaf_functions<'a>( - input: Vec>, -) -> (Vec>, bool) { - let leaves = find_leaf_functions(&input); - if leaves.is_empty() { - return (input, false); - } - - let mut changed = false; - let mut to_remove = HashSet::new(); - - // We map function names to the INDEX of the instruction that restores RA. - // We use this to validate the function body later. - let mut func_restore_indices = HashMap::new(); - let mut func_ra_offsets = HashMap::new(); - - let mut current_function: Option = None; - let mut function_start_indices = HashMap::new(); - - // First scan: Identify instructions to remove and capture RA offsets - for (i, node) in input.iter().enumerate() { - match &node.instruction { - Instruction::LabelDef(label) if !label.starts_with("__internal_L") => { - current_function = Some(label.to_string()); - function_start_indices.insert(label.to_string(), i); - } - Instruction::Push(Operand::ReturnAddress) => { - if let Some(func) = ¤t_function - && leaves.contains(func) - { - to_remove.insert(i); - } - } - Instruction::Get(Operand::ReturnAddress, _, Operand::Register(_)) => { - // This is the restore instruction: `get ra db r0` - if let Some(func) = ¤t_function - && leaves.contains(func) - { - to_remove.insert(i); - func_restore_indices.insert(func.clone(), i); - - // Look back for the address calc: `sub r0 sp OFFSET` - if i > 0 - && let Instruction::Sub(_, Operand::StackPointer, Operand::Number(n)) = - &input[i - 1].instruction - { - func_ra_offsets.insert(func.clone(), *n); - to_remove.insert(i - 1); - } - } - } - _ => {} - } - } - - // Safety Check: Verify that functions marked for optimization don't have complex stack ops. - // If they do, unmark them. - let mut safe_functions = HashSet::new(); - - for (func, start_idx) in &function_start_indices { - if let Some(restore_idx) = func_restore_indices.get(func) { - // Check instructions between start and restore using the helper function. - // We need to skip the `push ra` we just marked for removal, otherwise the helper - // will flag it as a complex op (Push). - // `start_idx` is the LabelDef. `start_idx + 1` is typically `push ra`. - - let check_start = if to_remove.contains(&(start_idx + 1)) { - start_idx + 2 - } else { - start_idx + 1 - }; - - // `restore_idx` points to the `get ra` instruction. The helper scans up to `end_idx` exclusive, - // so we don't need to worry about the restore instruction itself. - if !function_has_complex_stack_ops(&input, check_start, *restore_idx) { - safe_functions.insert(func.clone()); - changed = true; - } - } - } - - if !changed { - return (input, false); - } - - // Second scan: Rebuild with adjustments, but only for SAFE functions - let mut output = Vec::with_capacity(input.len()); - let mut processing_function: Option = None; - - for (i, mut node) in input.into_iter().enumerate() { - if to_remove.contains(&i) - && let Some(func) = &processing_function - && safe_functions.contains(func) - { - continue; // SKIP (Remove) - } - - if let Instruction::LabelDef(l) = &node.instruction - && !l.starts_with("__internal_L") - { - processing_function = Some(l.to_string()); - } - - // Apply Stack Adjustments - if let Some(func) = &processing_function - && safe_functions.contains(func) - && let Some(ra_offset) = func_ra_offsets.get(func) - { - // 1. Stack Cleanup Adjustment - if let Instruction::Sub( - Operand::StackPointer, - Operand::StackPointer, - Operand::Number(n), - ) = &mut node.instruction - { - // Decrease cleanup amount by 1 (for the removed RA) - let new_n = *n - Decimal::from(1); - if new_n.is_zero() { - continue; - } - *n = new_n; - } - - // 2. Stack Variable Offset Adjustment - // Since we verified the function is "Simple" (no nested stack mods), - // we can safely assume offsets > ra_offset need shifting. - if let Instruction::Sub(_, Operand::StackPointer, Operand::Number(n)) = - &mut node.instruction - && *n > *ra_offset - { - *n -= Decimal::from(1); - } - } - - output.push(node); - } - - (output, true) -} - -/// Analyzes which registers are written to by each function label. -fn analyze_clobbers(instructions: &[InstructionNode]) -> HashMap> { - let mut clobbers = HashMap::new(); - let mut current_label = None; - - for node in instructions { - if let Instruction::LabelDef(label) = &node.instruction { - current_label = Some(label.to_string()); - clobbers.insert(label.to_string(), HashSet::new()); - } - - if let Some(label) = ¤t_label - && let Some(reg) = get_destination_reg(&node.instruction) - && let Some(set) = clobbers.get_mut(label) - { - set.insert(reg); - } - } - clobbers -} - -/// Pass: Function Call Optimization -/// Removes Push/Restore pairs surrounding a JAL if the target function does not clobber that register. -fn optimize_function_calls<'a>( - input: Vec>, -) -> (Vec>, bool) { - let clobbers = analyze_clobbers(&input); - let mut changed = false; - let mut to_remove = HashSet::new(); - let mut stack_adjustments = HashMap::new(); - - let mut i = 0; - while i < input.len() { - if let Instruction::JumpAndLink(Operand::Label(target)) = &input[i].instruction { - let target_key = target.to_string(); - - if let Some(func_clobbers) = clobbers.get(&target_key) { - // 1. Identify Pushes immediately preceding the JAL - let mut pushes = Vec::new(); // (index, register) - let mut scan_back = i.saturating_sub(1); - while scan_back > 0 { - if to_remove.contains(&scan_back) { - scan_back -= 1; - continue; - } - if let Instruction::Push(Operand::Register(r)) = &input[scan_back].instruction { - pushes.push((scan_back, *r)); - scan_back -= 1; - } else { - break; - } - } - - // 2. Identify Restores immediately following the JAL - let mut restores = Vec::new(); // (index_of_get, register, index_of_sub) - let mut scan_fwd = i + 1; - while scan_fwd < input.len() { - // Skip 'sub r0 sp X' - if let Instruction::Sub(Operand::Register(0), Operand::StackPointer, _) = - &input[scan_fwd].instruction - { - // Check next instruction for the Get - if scan_fwd + 1 < input.len() - && let Instruction::Get(Operand::Register(r), _, Operand::Register(0)) = - &input[scan_fwd + 1].instruction - { - restores.push((scan_fwd + 1, *r, scan_fwd)); - scan_fwd += 2; - continue; - } - } - break; - } - - // 3. Stack Cleanup - let cleanup_idx = scan_fwd; - let has_cleanup = if cleanup_idx < input.len() { - matches!( - input[cleanup_idx].instruction, - Instruction::Sub( - Operand::StackPointer, - Operand::StackPointer, - Operand::Number(_) - ) - ) - } else { - false - }; - - // SAFEGUARD: Check Counts! - // If we pushed r8 twice but only restored it once, we have an argument. - // We must ensure the number of pushes for each register MATCHES the number of restores. - let mut push_counts = HashMap::new(); - for (_, r) in &pushes { - *push_counts.entry(*r).or_insert(0) += 1; - } - - let mut restore_counts = HashMap::new(); - for (_, r, _) in &restores { - *restore_counts.entry(*r).or_insert(0) += 1; - } - - let counts_match = push_counts - .iter() - .all(|(reg, count)| restore_counts.get(reg).unwrap_or(&0) == count); - // Also check reverse to ensure we didn't restore something we didn't push (unlikely but possible) - let counts_match_reverse = restore_counts - .iter() - .all(|(reg, count)| push_counts.get(reg).unwrap_or(&0) == count); - - // Clobber Check - let all_pushes_safe = pushes.iter().all(|(_, r)| !func_clobbers.contains(r)); - - if all_pushes_safe && has_cleanup && counts_match && counts_match_reverse { - // We can remove ALL found pushes/restores safely - for (p_idx, _) in pushes { - to_remove.insert(p_idx); - } - for (g_idx, _, s_idx) in restores { - to_remove.insert(g_idx); - to_remove.insert(s_idx); - } - - // Reduce stack cleanup amount - let num_removed = push_counts.values().sum::() as i64; - stack_adjustments.insert(cleanup_idx, num_removed); - changed = true; - } - } - } - i += 1; - } - - if changed { - let mut clean = Vec::with_capacity(input.len()); - for (idx, mut node) in input.into_iter().enumerate() { - if to_remove.contains(&idx) { - continue; - } - - // Apply stack adjustment - if let Some(reduction) = stack_adjustments.get(&idx) - && let Instruction::Sub(dst, a, Operand::Number(n)) = &node.instruction - { - let new_n = n - Decimal::from(*reduction); - if new_n.is_zero() { - continue; // Remove the sub entirely if 0 - } - node.instruction = Instruction::Sub(dst.clone(), a.clone(), Operand::Number(new_n)); - } - - clean.push(node); - } - return (clean, changed); - } - - (input, false) -} - -/// Pass: Register Forwarding -/// Eliminates intermediate moves by writing directly to the final destination. -/// Example: `l r1 d0 T` + `move r9 r1` -> `l r9 d0 T` -fn register_forwarding<'a>( - mut input: Vec>, -) -> (Vec>, bool) { - let mut changed = false; - let mut i = 0; - - // We use a while loop to manually control index so we can peek ahead - while i < input.len().saturating_sub(1) { - let next_idx = i + 1; - - // Check if current instruction defines a register - // and the NEXT instruction is a move from that register. - let forward_candidate = if let Some(def_reg) = get_destination_reg(&input[i].instruction) { - if let Instruction::Move(Operand::Register(dest_reg), Operand::Register(src_reg)) = - &input[next_idx].instruction - { - if *src_reg == def_reg { - // Candidate found: Instruction `i` defines `src_reg`, Instruction `i+1` moves `src_reg` to `dest_reg`. - // We can optimize if `src_reg` (the temp) is NOT used after this move. - Some((def_reg, *dest_reg)) - } else { - None - } - } else { - None - } - } else { - None - }; - - if let Some((temp_reg, final_reg)) = forward_candidate { - // Check liveness: Is temp_reg used after i+1? - // We scan from i+2 onwards. - let mut temp_is_dead = true; - for node in input.iter().skip(i + 2) { - if reg_is_read(&node.instruction, temp_reg) { - temp_is_dead = false; - break; - } - // If the temp is redefined, then the old value is dead, so we are safe. - if let Some(redef) = get_destination_reg(&node.instruction) - && redef == temp_reg - { - break; - } - - // If we hit a label/jump, we assume liveness might leak (conservative safety) - if matches!( - node.instruction, - Instruction::LabelDef(_) | Instruction::Jump(_) | Instruction::JumpAndLink(_) - ) { - temp_is_dead = false; - break; - } - } - - if temp_is_dead { - // Perform the swap - // 1. Rewrite input[i] to write to final_reg - if let Some(new_instr) = set_destination_reg(&input[i].instruction, final_reg) { - input[i].instruction = new_instr; - // 2. Remove input[i+1] (The Move) - input.remove(next_idx); - changed = true; - // Don't increment i, re-evaluate current index (which is now a new neighbor) - continue; - } - } - } - - i += 1; - } - - (input, changed) -} - -/// Pass: Resolve Labels -/// Converts all Jump/Branch labels to absolute line numbers and removes LabelDefs. -fn resolve_labels<'a>(input: Vec>) -> Vec> { - let mut label_map: HashMap = HashMap::new(); - let mut line_number = 0; - - // 1. Build Label Map (filtering out LabelDefs from the count) - for node in &input { - if let Instruction::LabelDef(name) = &node.instruction { - label_map.insert(name.to_string(), line_number); - } else { - line_number += 1; - } - } - - let mut output = Vec::with_capacity(input.len()); - - // 2. Rewrite Jumps and Filter Labels - for mut node in input { - // Helper to get line number as Decimal operand - let get_line = |lbl: &Operand| -> Option> { - if let Operand::Label(name) = lbl { - label_map - .get(name.as_ref()) - .map(|&l| Operand::Number(Decimal::from(l))) - } else { - None - } - }; - - match &mut node.instruction { - Instruction::LabelDef(_) => continue, // Strip labels - - // Jumps - Instruction::Jump(op) => { - if let Some(num) = get_line(op) { - *op = num; - } - } - Instruction::JumpAndLink(op) => { - if let Some(num) = get_line(op) { - *op = num; - } - } - Instruction::BranchEq(_, _, op) - | Instruction::BranchNe(_, _, op) - | Instruction::BranchGt(_, _, op) - | Instruction::BranchLt(_, _, op) - | Instruction::BranchGe(_, _, op) - | Instruction::BranchLe(_, _, op) => { - if let Some(num) = get_line(op) { - *op = num; - } - } - Instruction::BranchEqZero(_, op) | Instruction::BranchNeZero(_, op) => { - if let Some(num) = get_line(op) { - *op = num; - } - } - _ => {} - } - output.push(node); - } - - output -} - -// --- Helpers for Register Analysis --- - -fn get_destination_reg(instr: &Instruction) -> Option { - match instr { - Instruction::Move(Operand::Register(r), _) - | Instruction::Add(Operand::Register(r), _, _) - | Instruction::Sub(Operand::Register(r), _, _) - | Instruction::Mul(Operand::Register(r), _, _) - | Instruction::Div(Operand::Register(r), _, _) - | Instruction::Mod(Operand::Register(r), _, _) - | Instruction::Pow(Operand::Register(r), _, _) - | Instruction::Load(Operand::Register(r), _, _) - | Instruction::LoadSlot(Operand::Register(r), _, _, _) - | Instruction::LoadBatch(Operand::Register(r), _, _, _) - | Instruction::LoadBatchNamed(Operand::Register(r), _, _, _, _) - | Instruction::SetEq(Operand::Register(r), _, _) - | Instruction::SetNe(Operand::Register(r), _, _) - | Instruction::SetGt(Operand::Register(r), _, _) - | Instruction::SetLt(Operand::Register(r), _, _) - | Instruction::SetGe(Operand::Register(r), _, _) - | Instruction::SetLe(Operand::Register(r), _, _) - | Instruction::And(Operand::Register(r), _, _) - | Instruction::Or(Operand::Register(r), _, _) - | Instruction::Xor(Operand::Register(r), _, _) - | Instruction::Peek(Operand::Register(r)) - | Instruction::Get(Operand::Register(r), _, _) - | Instruction::Select(Operand::Register(r), _, _, _) - | Instruction::Rand(Operand::Register(r)) - | Instruction::Acos(Operand::Register(r), _) - | Instruction::Asin(Operand::Register(r), _) - | Instruction::Atan(Operand::Register(r), _) - | Instruction::Atan2(Operand::Register(r), _, _) - | Instruction::Abs(Operand::Register(r), _) - | Instruction::Ceil(Operand::Register(r), _) - | Instruction::Cos(Operand::Register(r), _) - | Instruction::Floor(Operand::Register(r), _) - | Instruction::Log(Operand::Register(r), _) - | Instruction::Max(Operand::Register(r), _, _) - | Instruction::Min(Operand::Register(r), _, _) - | Instruction::Sin(Operand::Register(r), _) - | Instruction::Sqrt(Operand::Register(r), _) - | Instruction::Tan(Operand::Register(r), _) - | Instruction::Trunc(Operand::Register(r), _) - | Instruction::LoadReagent(Operand::Register(r), _, _, _) - | Instruction::Pop(Operand::Register(r)) => Some(*r), - _ => None, - } -} - -fn set_destination_reg<'a>(instr: &Instruction<'a>, new_reg: u8) -> Option> { - // Helper to easily recreate instruction with new dest - let r = Operand::Register(new_reg); - match instr { - Instruction::Move(_, b) => Some(Instruction::Move(r, b.clone())), - Instruction::Add(_, a, b) => Some(Instruction::Add(r, a.clone(), b.clone())), - Instruction::Sub(_, a, b) => Some(Instruction::Sub(r, a.clone(), b.clone())), - Instruction::Mul(_, a, b) => Some(Instruction::Mul(r, a.clone(), b.clone())), - Instruction::Div(_, a, b) => Some(Instruction::Div(r, a.clone(), b.clone())), - Instruction::Mod(_, a, b) => Some(Instruction::Mod(r, a.clone(), b.clone())), - Instruction::Pow(_, a, b) => Some(Instruction::Pow(r, a.clone(), b.clone())), - Instruction::Load(_, a, b) => Some(Instruction::Load(r, a.clone(), b.clone())), - Instruction::LoadSlot(_, a, b, c) => { - Some(Instruction::LoadSlot(r, a.clone(), b.clone(), c.clone())) - } - Instruction::LoadBatch(_, a, b, c) => { - Some(Instruction::LoadBatch(r, a.clone(), b.clone(), c.clone())) - } - Instruction::LoadBatchNamed(_, a, b, c, d) => Some(Instruction::LoadBatchNamed( - r, - a.clone(), - b.clone(), - c.clone(), - d.clone(), - )), - Instruction::LoadReagent(_, b, c, d) => { - Some(Instruction::LoadReagent(r, b.clone(), c.clone(), d.clone())) - } - Instruction::SetEq(_, a, b) => Some(Instruction::SetEq(r, a.clone(), b.clone())), - Instruction::SetNe(_, a, b) => Some(Instruction::SetNe(r, a.clone(), b.clone())), - Instruction::SetGt(_, a, b) => Some(Instruction::SetGt(r, a.clone(), b.clone())), - Instruction::SetLt(_, a, b) => Some(Instruction::SetLt(r, a.clone(), b.clone())), - Instruction::SetGe(_, a, b) => Some(Instruction::SetGe(r, a.clone(), b.clone())), - Instruction::SetLe(_, a, b) => Some(Instruction::SetLe(r, a.clone(), b.clone())), - Instruction::And(_, a, b) => Some(Instruction::And(r, a.clone(), b.clone())), - Instruction::Or(_, a, b) => Some(Instruction::Or(r, a.clone(), b.clone())), - Instruction::Xor(_, a, b) => Some(Instruction::Xor(r, a.clone(), b.clone())), - Instruction::Peek(_) => Some(Instruction::Peek(r)), - Instruction::Get(_, a, b) => Some(Instruction::Get(r, a.clone(), b.clone())), - Instruction::Select(_, a, b, c) => { - Some(Instruction::Select(r, a.clone(), b.clone(), c.clone())) - } - Instruction::Rand(_) => Some(Instruction::Rand(r)), - Instruction::Pop(_) => Some(Instruction::Pop(r)), - - // Math funcs - Instruction::Acos(_, a) => Some(Instruction::Acos(r, a.clone())), - Instruction::Asin(_, a) => Some(Instruction::Asin(r, a.clone())), - Instruction::Atan(_, a) => Some(Instruction::Atan(r, a.clone())), - Instruction::Atan2(_, a, b) => Some(Instruction::Atan2(r, a.clone(), b.clone())), - Instruction::Abs(_, a) => Some(Instruction::Abs(r, a.clone())), - Instruction::Ceil(_, a) => Some(Instruction::Ceil(r, a.clone())), - Instruction::Cos(_, a) => Some(Instruction::Cos(r, a.clone())), - Instruction::Floor(_, a) => Some(Instruction::Floor(r, a.clone())), - Instruction::Log(_, a) => Some(Instruction::Log(r, a.clone())), - Instruction::Max(_, a, b) => Some(Instruction::Max(r, a.clone(), b.clone())), - Instruction::Min(_, a, b) => Some(Instruction::Min(r, a.clone(), b.clone())), - Instruction::Sin(_, a) => Some(Instruction::Sin(r, a.clone())), - Instruction::Sqrt(_, a) => Some(Instruction::Sqrt(r, a.clone())), - Instruction::Tan(_, a) => Some(Instruction::Tan(r, a.clone())), - Instruction::Trunc(_, a) => Some(Instruction::Trunc(r, a.clone())), - - _ => None, - } -} - -fn reg_is_read(instr: &Instruction, reg: u8) -> bool { - let check = |op: &Operand| matches!(op, Operand::Register(r) if *r == reg); - - match instr { - Instruction::Move(_, a) => check(a), - Instruction::Add(_, a, b) - | Instruction::Sub(_, a, b) - | Instruction::Mul(_, a, b) - | Instruction::Div(_, a, b) - | Instruction::Mod(_, a, b) - | Instruction::Pow(_, a, b) => check(a) || check(b), - - Instruction::Load(_, a, _) => check(a), // Load reads device? Device can be reg? Yes. - Instruction::Store(a, _, b) => check(a) || check(b), - - Instruction::BranchEq(a, b, _) - | Instruction::BranchNe(a, b, _) - | Instruction::BranchGt(a, b, _) - | Instruction::BranchLt(a, b, _) - | Instruction::BranchGe(a, b, _) - | Instruction::BranchLe(a, b, _) => check(a) || check(b), - - Instruction::BranchEqZero(a, _) | Instruction::BranchNeZero(a, _) => check(a), - - Instruction::LoadReagent(_, device, _, item_hash) => check(device) || check(item_hash), - - Instruction::LoadSlot(_, dev, slot, _) => check(dev) || check(slot), - Instruction::LoadBatch(_, dev, _, mode) => check(dev) || check(mode), - Instruction::LoadBatchNamed(_, d_hash, n_hash, _, mode) => { - check(d_hash) || check(n_hash) || check(mode) - } - - Instruction::SetEq(_, a, b) - | Instruction::SetNe(_, a, b) - | Instruction::SetGt(_, a, b) - | Instruction::SetLt(_, a, b) - | Instruction::SetGe(_, a, b) - | Instruction::SetLe(_, a, b) - | Instruction::And(_, a, b) - | Instruction::Or(_, a, b) - | Instruction::Xor(_, a, b) => check(a) || check(b), - - Instruction::Push(a) => check(a), - Instruction::Get(_, a, b) => check(a) || check(b), - Instruction::Put(a, b, c) => check(a) || check(b) || check(c), - - Instruction::Select(_, a, b, c) => check(a) || check(b) || check(c), - Instruction::Sleep(a) => check(a), - - // Math single arg - Instruction::Acos(_, a) - | Instruction::Asin(_, a) - | Instruction::Atan(_, a) - | Instruction::Abs(_, a) - | Instruction::Ceil(_, a) - | Instruction::Cos(_, a) - | Instruction::Floor(_, a) - | Instruction::Log(_, a) - | Instruction::Sin(_, a) - | Instruction::Sqrt(_, a) - | Instruction::Tan(_, a) - | Instruction::Trunc(_, a) => check(a), - - // Math double arg - Instruction::Atan2(_, a, b) | Instruction::Max(_, a, b) | Instruction::Min(_, a, b) => { - check(a) || check(b) - } - - _ => false, - } -} - -/// --- Constant Propagation & Dead Code --- -fn constant_propagation<'a>(input: Vec>) -> (Vec>, bool) { - let mut output = Vec::with_capacity(input.len()); - let mut changed = false; - let mut registers: [Option; 16] = [None; 16]; - - for mut node in input { - match &node.instruction { - Instruction::LabelDef(_) | Instruction::JumpAndLink(_) => registers = [None; 16], - _ => {} - } - - let simplified = match &node.instruction { - Instruction::Move(dst, src) => resolve_value(src, ®isters) - .map(|val| Instruction::Move(dst.clone(), Operand::Number(val))), - Instruction::Add(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x + y), - Instruction::Sub(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x - y), - Instruction::Mul(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x * y), - Instruction::Div(dst, a, b) => { - try_fold_math( - dst, - a, - b, - ®isters, - |x, y| if y.is_zero() { x } else { x / y }, - ) - } - Instruction::Mod(dst, a, b) => { - try_fold_math( - dst, - a, - b, - ®isters, - |x, y| if y.is_zero() { x } else { x % y }, - ) - } - Instruction::BranchEq(a, b, l) => { - try_resolve_branch(a, b, l, ®isters, |x, y| x == y) - } - Instruction::BranchNe(a, b, l) => { - try_resolve_branch(a, b, l, ®isters, |x, y| x != y) - } - Instruction::BranchGt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x > y), - Instruction::BranchLt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x < y), - Instruction::BranchGe(a, b, l) => { - try_resolve_branch(a, b, l, ®isters, |x, y| x >= y) - } - Instruction::BranchLe(a, b, l) => { - try_resolve_branch(a, b, l, ®isters, |x, y| x <= y) - } - Instruction::BranchEqZero(a, l) => { - try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x == y) - } - Instruction::BranchNeZero(a, l) => { - try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x != y) - } - _ => None, - }; - - if let Some(new) = simplified { - node.instruction = new; - changed = true; - } - - // Update tracking - match &node.instruction { - Instruction::Move(Operand::Register(r), src) => { - registers[*r as usize] = resolve_value(src, ®isters) - } - // Invalidate if destination is register - _ => { - if let Some(r) = get_destination_reg(&node.instruction) { - registers[r as usize] = None; - } - } - } - - // Filter out NOPs (Empty LabelDefs from branch resolution) - if let Instruction::LabelDef(l) = &node.instruction - && l.is_empty() - { - changed = true; - continue; - } - - output.push(node); - } - (output, changed) -} - -fn resolve_value(op: &Operand, regs: &[Option; 16]) -> Option { - match op { - Operand::Number(n) => Some(*n), - Operand::Register(r) => regs[*r as usize], - _ => None, - } -} - -fn try_fold_math<'a, F>( - dst: &Operand<'a>, - a: &Operand<'a>, - b: &Operand<'a>, - regs: &[Option; 16], - op: F, -) -> Option> -where - F: Fn(Decimal, Decimal) -> Decimal, -{ - let val_a = resolve_value(a, regs)?; - let val_b = resolve_value(b, regs)?; - Some(Instruction::Move( - dst.clone(), - Operand::Number(op(val_a, val_b)), - )) -} - -fn try_resolve_branch<'a, F>( - a: &Operand<'a>, - b: &Operand<'a>, - label: &Operand<'a>, - regs: &[Option; 16], - check: F, -) -> Option> -where - F: Fn(Decimal, Decimal) -> bool, -{ - let val_a = resolve_value(a, regs)?; - let val_b = resolve_value(b, regs)?; - if check(val_a, val_b) { - Some(Instruction::Jump(label.clone())) - } else { - Some(Instruction::LabelDef("".into())) // NOP - } -} - -fn remove_redundant_moves<'a>(input: Vec>) -> (Vec>, bool) { - let mut output = Vec::with_capacity(input.len()); - let mut changed = false; - for node in input { - if let Instruction::Move(dst, src) = &node.instruction - && dst == src - { - changed = true; - continue; - } - output.push(node); - } - (output, changed) -} - -fn remove_unreachable_code<'a>( - input: Vec>, -) -> (Vec>, bool) { - let mut output = Vec::with_capacity(input.len()); - let mut changed = false; - let mut dead = false; - for node in input { - if let Instruction::LabelDef(_) = node.instruction { - dead = false; - } - if dead { - changed = true; - continue; - } - if let Instruction::Jump(_) = node.instruction { - dead = true - } - output.push(node); - } - (output, changed) + let instructions = resolve_labels(instructions); + + // Post-resolution Pass: Remove redundant jumps (must run after label resolution) + let (instructions, _) = remove_redundant_jumps(instructions); + + Instructions::new(instructions) } diff --git a/rust_compiler/libs/optimizer/src/peephole_optimization.rs b/rust_compiler/libs/optimizer/src/peephole_optimization.rs new file mode 100644 index 0000000..eb78676 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/peephole_optimization.rs @@ -0,0 +1,761 @@ +use il::{Instruction, InstructionNode, Operand}; + +/// Pass: Peephole Optimization +/// Recognizes and optimizes common instruction patterns. +pub fn peephole_optimization<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + let mut i = 0; + + while i < input.len() { + // Pattern: push sp; push ra ... pop ra; pop sp (with no jal in between) + // If we push sp and ra and later pop them, but never call a function in between, remove all four + // and adjust any stack pointer offsets in between by -2 + if i + 1 < input.len() { + if let ( + Instruction::Push(Operand::StackPointer), + Instruction::Push(Operand::ReturnAddress), + ) = (&input[i].instruction, &input[i + 1].instruction) + { + // Look for matching pop ra; pop sp pattern + if let Some((ra_pop_idx, instructions_between)) = + find_matching_ra_pop(&input[i + 1..]) + { + let absolute_ra_pop = i + 1 + ra_pop_idx; + // Check if the next instruction is pop sp + if absolute_ra_pop + 1 < input.len() { + if let Instruction::Pop(Operand::StackPointer) = + &input[absolute_ra_pop + 1].instruction + { + // Check if there's any jal between push and pop + let has_call = instructions_between.iter().any(|node| { + matches!(node.instruction, Instruction::JumpAndLink(_)) + }); + + if !has_call { + // Safe to remove all four: push sp, push ra, pop ra, pop sp + // Also need to adjust stack pointer offsets in between by -2 + let absolute_sp_pop = absolute_ra_pop + 1; + // Clear output since we're going to reprocess the entire input + output.clear(); + for (idx, node) in input.iter().enumerate() { + if idx == i + || idx == i + 1 + || idx == absolute_ra_pop + || idx == absolute_sp_pop + { + // Skip all four push/pop instructions + continue; + } + + // If this instruction is between the pushes and pops, adjust its stack offsets + if idx > i + 1 && idx < absolute_ra_pop { + let adjusted_instruction = + adjust_stack_offset(node.instruction.clone(), 2); + output.push(InstructionNode::new( + adjusted_instruction, + node.span, + )); + } else { + output.push(node.clone()); + } + } + changed = true; + // We've processed the entire input, so break + break; + } + } + } + } + } + } + + // Pattern: push ra ... pop ra (with no jal in between) + // Fallback for when there's only ra push/pop without sp + if let Instruction::Push(Operand::ReturnAddress) = &input[i].instruction { + if let Some((pop_idx, instructions_between)) = find_matching_ra_pop(&input[i..]) { + // Check if there's any jal between push and pop + let has_call = instructions_between + .iter() + .any(|node| matches!(node.instruction, Instruction::JumpAndLink(_))); + + if !has_call { + // Safe to remove both push and pop + // Also need to adjust stack pointer offsets in between + let absolute_pop_idx = i + pop_idx; + // Clear output since we're going to reprocess the entire input + output.clear(); + for (idx, node) in input.iter().enumerate() { + if idx == i || idx == absolute_pop_idx { + // Skip the push and pop + continue; + } + + // If this instruction is between push and pop, adjust its stack offsets + if idx > i && idx < absolute_pop_idx { + let adjusted_instruction = + adjust_stack_offset(node.instruction.clone(), 1); + output.push(InstructionNode::new(adjusted_instruction, node.span)); + } else { + output.push(node.clone()); + } + } + changed = true; + // We've processed the entire input, so break + break; + } + } + } + + // Pattern: Branch-Move-Jump-Label-Move-Label -> Select + // beqz r1 else_label + // move r2 val1 + // j end_label + // else_label: + // move r2 val2 + // end_label: + // Converts to: select r2 r1 val1 val2 + if i + 5 < input.len() { + let select_pattern = try_match_select_pattern(&input[i..i + 6]); + if let Some((dst, cond, true_val, false_val, skip_count)) = select_pattern { + output.push(InstructionNode::new( + Instruction::Select(dst, cond, true_val, false_val), + input[i].span, + )); + changed = true; + i += skip_count; + continue; + } + } + + // Pattern: seq + beqz -> beq + if i + 1 < input.len() { + let pattern = match (&input[i].instruction, &input[i + 1].instruction) { + ( + Instruction::SetEq(Operand::Register(temp), a, b), + Instruction::BranchEqZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Eq, true)), // invert: beqz means "if NOT equal" + + ( + Instruction::SetNe(Operand::Register(temp), a, b), + Instruction::BranchEqZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Ne, true)), + + ( + Instruction::SetGt(Operand::Register(temp), a, b), + Instruction::BranchEqZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Gt, true)), + + ( + Instruction::SetLt(Operand::Register(temp), a, b), + Instruction::BranchEqZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Lt, true)), + + ( + Instruction::SetGe(Operand::Register(temp), a, b), + Instruction::BranchEqZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Ge, true)), + + ( + Instruction::SetLe(Operand::Register(temp), a, b), + Instruction::BranchEqZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Le, true)), + + // Pattern: seq + bnez -> bne + ( + Instruction::SetEq(Operand::Register(temp), a, b), + Instruction::BranchNeZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Eq, false)), + + ( + Instruction::SetNe(Operand::Register(temp), a, b), + Instruction::BranchNeZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Ne, false)), + + ( + Instruction::SetGt(Operand::Register(temp), a, b), + Instruction::BranchNeZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Gt, false)), + + ( + Instruction::SetLt(Operand::Register(temp), a, b), + Instruction::BranchNeZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Lt, false)), + + ( + Instruction::SetGe(Operand::Register(temp), a, b), + Instruction::BranchNeZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Ge, false)), + + ( + Instruction::SetLe(Operand::Register(temp), a, b), + Instruction::BranchNeZero(Operand::Register(cond), label), + ) if temp == cond => Some((a, b, label, BranchType::Le, false)), + + _ => None, + }; + + if let Some((a, b, label, branch_type, invert)) = pattern { + // Create optimized branch instruction + let new_instr = if invert { + // beqz after seq means "branch if NOT equal" -> bne + match branch_type { + BranchType::Eq => { + Instruction::BranchNe(a.clone(), b.clone(), label.clone()) + } + BranchType::Ne => { + Instruction::BranchEq(a.clone(), b.clone(), label.clone()) + } + BranchType::Gt => { + Instruction::BranchLe(a.clone(), b.clone(), label.clone()) + } + BranchType::Lt => { + Instruction::BranchGe(a.clone(), b.clone(), label.clone()) + } + BranchType::Ge => { + Instruction::BranchLt(a.clone(), b.clone(), label.clone()) + } + BranchType::Le => { + Instruction::BranchGt(a.clone(), b.clone(), label.clone()) + } + } + } else { + // bnez after seq means "branch if equal" -> beq + match branch_type { + BranchType::Eq => { + Instruction::BranchEq(a.clone(), b.clone(), label.clone()) + } + BranchType::Ne => { + Instruction::BranchNe(a.clone(), b.clone(), label.clone()) + } + BranchType::Gt => { + Instruction::BranchGt(a.clone(), b.clone(), label.clone()) + } + BranchType::Lt => { + Instruction::BranchLt(a.clone(), b.clone(), label.clone()) + } + BranchType::Ge => { + Instruction::BranchGe(a.clone(), b.clone(), label.clone()) + } + BranchType::Le => { + Instruction::BranchLe(a.clone(), b.clone(), label.clone()) + } + } + }; + + output.push(InstructionNode::new(new_instr, input[i].span)); + changed = true; + i += 2; // Skip both instructions + continue; + } + } + + output.push(input[i].clone()); + i += 1; + } + + (output, changed) +} + +/// Tries to match a select pattern in the instruction sequence. +/// Pattern (6 instructions): +/// beqz/bnez cond else_label (i+0) +/// move dst val1 (i+1) +/// j end_label (i+2) +/// else_label: (i+3) +/// move dst val2 (i+4) +/// end_label: (i+5) +/// Returns: (dst, cond, true_val, false_val, instruction_count) +fn try_match_select_pattern<'a>( + instructions: &[InstructionNode<'a>], +) -> Option<(Operand<'a>, Operand<'a>, Operand<'a>, Operand<'a>, usize)> { + if instructions.len() < 6 { + return None; + } + + // Check for beqz pattern + if let Instruction::BranchEqZero(cond, Operand::Label(else_label)) = + &instructions[0].instruction + { + if let Instruction::Move(dst1, val1) = &instructions[1].instruction { + if let Instruction::Jump(Operand::Label(end_label)) = &instructions[2].instruction { + if let Instruction::LabelDef(label3) = &instructions[3].instruction { + if label3 == else_label { + if let Instruction::Move(dst2, val2) = &instructions[4].instruction { + if dst1 == dst2 { + if let Instruction::LabelDef(label5) = &instructions[5].instruction + { + if label5 == end_label { + // beqz means: if cond==0, goto else, so val1 is for true, val2 for false + // select dst cond true_val false_val + // When cond is non-zero (true), use val1, otherwise val2 + return Some(( + dst1.clone(), + cond.clone(), + val1.clone(), + val2.clone(), + 6, + )); + } + } + } + } + } + } + } + } + } + + // Check for bnez pattern + if let Instruction::BranchNeZero(cond, Operand::Label(then_label)) = + &instructions[0].instruction + { + if let Instruction::Move(dst1, val_false) = &instructions[1].instruction { + if let Instruction::Jump(Operand::Label(end_label)) = &instructions[2].instruction { + if let Instruction::LabelDef(label3) = &instructions[3].instruction { + if label3 == then_label { + if let Instruction::Move(dst2, val_true) = &instructions[4].instruction { + if dst1 == dst2 { + if let Instruction::LabelDef(label5) = &instructions[5].instruction + { + if label5 == end_label { + // bnez means: if cond!=0, goto then, so val_true for true, val_false for false + return Some(( + dst1.clone(), + cond.clone(), + val_true.clone(), + val_false.clone(), + 6, + )); + } + } + } + } + } + } + } + } + } + + None +} + +/// Finds a matching `pop ra` for a `push ra` at the start of the slice. +/// Returns the index of the pop and the instructions in between. +fn find_matching_ra_pop<'a>( + instructions: &'a [InstructionNode<'a>], +) -> Option<(usize, &'a [InstructionNode<'a>])> { + if instructions.is_empty() { + return None; + } + + // Skip the push itself + for (idx, node) in instructions.iter().enumerate().skip(1) { + if let Instruction::Pop(Operand::ReturnAddress) = &node.instruction { + // Found matching pop + return Some((idx, &instructions[1..idx])); + } + + // Stop searching if we hit a jump (different control flow) or a function label + // Labels are OK - they're just markers EXCEPT for user-defined function labels + // which indicate a function boundary + if matches!( + node.instruction, + Instruction::Jump(_) | Instruction::JumpRelative(_) | Instruction::LabelDef(_) + ) { + return None; + } + } + + None +} + +/// Checks if an instruction uses or modifies the stack pointer. +#[allow(dead_code)] +fn uses_stack_pointer(instruction: &Instruction) -> bool { + match instruction { + Instruction::Push(_) | Instruction::Pop(_) | Instruction::Peek(_) => true, + Instruction::Add(Operand::StackPointer, _, _) + | Instruction::Sub(Operand::StackPointer, _, _) + | Instruction::Mul(Operand::StackPointer, _, _) + | Instruction::Div(Operand::StackPointer, _, _) + | Instruction::Mod(Operand::StackPointer, _, _) => true, + Instruction::Add(_, Operand::StackPointer, _) + | Instruction::Sub(_, Operand::StackPointer, _) + | Instruction::Mul(_, Operand::StackPointer, _) + | Instruction::Div(_, Operand::StackPointer, _) + | Instruction::Mod(_, Operand::StackPointer, _) => true, + Instruction::Add(_, _, Operand::StackPointer) + | Instruction::Sub(_, _, Operand::StackPointer) + | Instruction::Mul(_, _, Operand::StackPointer) + | Instruction::Div(_, _, Operand::StackPointer) + | Instruction::Mod(_, _, Operand::StackPointer) => true, + Instruction::Move(Operand::StackPointer, _) + | Instruction::Move(_, Operand::StackPointer) => true, + _ => false, + } +} + +/// Adjusts stack pointer offsets in an instruction by decrementing them by a given amount. +/// This is necessary when removing push operations that would have increased the stack size. +fn adjust_stack_offset<'a>(instruction: Instruction<'a>, decrement: i64) -> Instruction<'a> { + use rust_decimal::prelude::*; + + match instruction { + // Adjust arithmetic operations on sp that use literal offsets + Instruction::Sub(dst, Operand::StackPointer, Operand::Number(n)) => { + let new_n = n - Decimal::from(decrement); + // If the result is 0 or negative, we may want to skip this entirely + // but for now, just adjust the value + Instruction::Sub(dst, Operand::StackPointer, Operand::Number(new_n)) + } + Instruction::Add(dst, Operand::StackPointer, Operand::Number(n)) => { + let new_n = n - Decimal::from(decrement); + Instruction::Add(dst, Operand::StackPointer, Operand::Number(new_n)) + } + // Return the instruction unchanged if it doesn't need adjustment + other => other, + } +} + +#[derive(Debug, Clone, Copy)] +enum BranchType { + Eq, + Ne, + Gt, + Lt, + Ge, + Le, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_seq_beqz_to_bne() { + let input = vec![ + InstructionNode::new( + Instruction::SetEq( + Operand::Register(1), + Operand::Register(2), + Operand::Register(3), + ), + None, + ), + InstructionNode::new( + Instruction::BranchEqZero(Operand::Register(1), Operand::Label("target".into())), + None, + ), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::BranchNe(_, _, _) + )); + } + + #[test] + fn test_sne_beqz_to_beq() { + let input = vec![ + InstructionNode::new( + Instruction::SetNe( + Operand::Register(1), + Operand::Register(2), + Operand::Register(3), + ), + None, + ), + InstructionNode::new( + Instruction::BranchEqZero(Operand::Register(1), Operand::Label("target".into())), + None, + ), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::BranchEq(_, _, _) + )); + } + + #[test] + fn test_seq_bnez_to_beq() { + let input = vec![ + InstructionNode::new( + Instruction::SetEq( + Operand::Register(1), + Operand::Register(2), + Operand::Register(3), + ), + None, + ), + InstructionNode::new( + Instruction::BranchNeZero(Operand::Register(1), Operand::Label("target".into())), + None, + ), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::BranchEq(_, _, _) + )); + } + + #[test] + fn test_sgt_beqz_to_ble() { + let input = vec![ + InstructionNode::new( + Instruction::SetGt( + Operand::Register(1), + Operand::Register(2), + Operand::Register(3), + ), + None, + ), + InstructionNode::new( + Instruction::BranchEqZero(Operand::Register(1), Operand::Label("target".into())), + None, + ), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::BranchLe(_, _, _) + )); + } + + #[test] + fn test_branch_move_jump_to_select_beqz() { + // Pattern: beqz r1 else / move r2 10 / j end / else: / move r2 20 / end: + // Should convert to: select r2 r1 10 20 + let input = vec![ + InstructionNode::new( + Instruction::BranchEqZero(Operand::Register(1), Operand::Label("else".into())), + None, + ), + InstructionNode::new( + Instruction::Move(Operand::Register(2), Operand::Number(10.into())), + None, + ), + InstructionNode::new(Instruction::Jump(Operand::Label("end".into())), None), + InstructionNode::new(Instruction::LabelDef("else".into()), None), + InstructionNode::new( + Instruction::Move(Operand::Register(2), Operand::Number(20.into())), + None, + ), + InstructionNode::new(Instruction::LabelDef("end".into()), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + if let Instruction::Select(dst, cond, true_val, false_val) = &output[0].instruction { + assert!(matches!(dst, Operand::Register(2))); + assert!(matches!(cond, Operand::Register(1))); + assert!(matches!(true_val, Operand::Number(_))); + assert!(matches!(false_val, Operand::Number(_))); + } else { + panic!("Expected Select instruction"); + } + } + + #[test] + fn test_branch_move_jump_to_select_bnez() { + // Pattern: bnez r1 then / move r2 20 / j end / then: / move r2 10 / end: + // Should convert to: select r2 r1 10 20 + let input = vec![ + InstructionNode::new( + Instruction::BranchNeZero(Operand::Register(1), Operand::Label("then".into())), + None, + ), + InstructionNode::new( + Instruction::Move(Operand::Register(2), Operand::Number(20.into())), + None, + ), + InstructionNode::new(Instruction::Jump(Operand::Label("end".into())), None), + InstructionNode::new(Instruction::LabelDef("then".into()), None), + InstructionNode::new( + Instruction::Move(Operand::Register(2), Operand::Number(10.into())), + None, + ), + InstructionNode::new(Instruction::LabelDef("end".into()), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + if let Instruction::Select(dst, cond, true_val, false_val) = &output[0].instruction { + assert!(matches!(dst, Operand::Register(2))); + assert!(matches!(cond, Operand::Register(1))); + assert!(matches!(true_val, Operand::Number(_))); + assert!(matches!(false_val, Operand::Number(_))); + } else { + panic!("Expected Select instruction"); + } + } + + #[test] + fn test_remove_useless_ra_push_pop() { + // Pattern: push ra / add r1 r2 r3 / pop ra + // Should remove both push and pop since no jal in between + let input = vec![ + InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None), + InstructionNode::new( + Instruction::Add( + Operand::Register(1), + Operand::Register(2), + Operand::Register(3), + ), + None, + ), + InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!(output[0].instruction, Instruction::Add(_, _, _))); + } + + #[test] + fn test_keep_ra_push_pop_with_jal() { + // Pattern: push ra / jal func / pop ra + // Should keep both since there's a jal in between + let input = vec![ + InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None), + InstructionNode::new( + Instruction::JumpAndLink(Operand::Label("func".into())), + None, + ), + InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(!changed); + assert_eq!(output.len(), 3); + } + + #[test] + fn test_ra_push_pop_with_stack_offset_adjustment() { + // Pattern: push ra / sub r1 sp 2 / pop ra + // Should remove push/pop AND adjust the stack offset from 2 to 1 + use rust_decimal::prelude::*; + + let input = vec![ + InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None), + InstructionNode::new( + Instruction::Sub( + Operand::Register(1), + Operand::StackPointer, + Operand::Number(Decimal::from(2)), + ), + None, + ), + InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + + if let Instruction::Sub(dst, src, Operand::Number(offset)) = &output[0].instruction { + assert!(matches!(dst, Operand::Register(1))); + assert!(matches!(src, Operand::StackPointer)); + assert_eq!(*offset, Decimal::from(1)); // Should be decremented from 2 to 1 + } else { + panic!("Expected Sub instruction with adjusted offset"); + } + } + + #[test] + fn test_remove_sp_and_ra_push_pop() { + // Pattern: push sp / push ra / move r8 10 / pop ra / pop sp + // Should remove all four push/pop instructions since no jal in between + let input = vec![ + InstructionNode::new(Instruction::Push(Operand::StackPointer), None), + InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None), + InstructionNode::new( + Instruction::Move(Operand::Register(8), Operand::Number(10.into())), + None, + ), + InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None), + InstructionNode::new(Instruction::Pop(Operand::StackPointer), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::Move(Operand::Register(8), _) + )); + } + + #[test] + fn test_keep_sp_and_ra_push_pop_with_jal() { + // Pattern: push sp / push ra / jal func / pop ra / pop sp + // Should keep all since there's a jal in between + let input = vec![ + InstructionNode::new(Instruction::Push(Operand::StackPointer), None), + InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None), + InstructionNode::new( + Instruction::JumpAndLink(Operand::Label("func".into())), + None, + ), + InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None), + InstructionNode::new(Instruction::Pop(Operand::StackPointer), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(!changed); + assert_eq!(output.len(), 5); + } + + #[test] + fn test_sp_and_ra_with_stack_offset_adjustment() { + // Pattern: push sp / push ra / sub r1 sp 3 / pop ra / pop sp + // Should remove all push/pop AND adjust the stack offset from 3 to 1 (decrement by 2) + use rust_decimal::prelude::*; + + let input = vec![ + InstructionNode::new(Instruction::Push(Operand::StackPointer), None), + InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None), + InstructionNode::new( + Instruction::Sub( + Operand::Register(1), + Operand::StackPointer, + Operand::Number(Decimal::from(3)), + ), + None, + ), + InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None), + InstructionNode::new(Instruction::Pop(Operand::StackPointer), None), + ]; + + let (output, changed) = peephole_optimization(input); + assert!(changed); + assert_eq!(output.len(), 1); + + if let Instruction::Sub(dst, src, Operand::Number(offset)) = &output[0].instruction { + assert!(matches!(dst, Operand::Register(1))); + assert!(matches!(src, Operand::StackPointer)); + assert_eq!(*offset, Decimal::from(1)); // Should be decremented from 3 to 1 + } else { + panic!("Expected Sub instruction with adjusted offset"); + } + } +} diff --git a/rust_compiler/libs/optimizer/src/register_forwarding.rs b/rust_compiler/libs/optimizer/src/register_forwarding.rs new file mode 100644 index 0000000..0eb8022 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/register_forwarding.rs @@ -0,0 +1,151 @@ +use crate::helpers::{get_destination_reg, reg_is_read, set_destination_reg}; +use il::{Instruction, InstructionNode, Operand}; +use std::collections::HashMap; + +/// Pass: Register Forwarding +/// Eliminates intermediate moves by writing directly to the final destination. +/// Example: `l r1 d0 Temperature` + `move r9 r1` -> `l r9 d0 Temperature` +pub fn register_forwarding<'a>( + mut input: Vec>, +) -> (Vec>, bool) { + let mut changed = false; + let mut i = 0; + + // Build a map of label positions to detect backward jumps + // Use String keys to avoid lifetime issues with references into input + let label_positions: HashMap = input + .iter() + .enumerate() + .filter_map(|(idx, node)| { + if let Instruction::LabelDef(label) = &node.instruction { + Some((label.to_string(), idx)) + } else { + None + } + }) + .collect(); + + while i < input.len().saturating_sub(1) { + let next_idx = i + 1; + + // Check if current instruction defines a register + // and the NEXT instruction is a move from that register. + let forward_candidate = if let Some(def_reg) = get_destination_reg(&input[i].instruction) { + if let Instruction::Move( + il::Operand::Register(dest_reg), + il::Operand::Register(src_reg), + ) = &input[next_idx].instruction + { + if *src_reg == def_reg { + Some((def_reg, *dest_reg)) + } else { + None + } + } else { + None + } + } else { + None + }; + + if let Some((temp_reg, final_reg)) = forward_candidate { + // Check liveness: Is temp_reg used after i+1? + let mut temp_is_dead = true; + for node in input.iter().skip(i + 2) { + if reg_is_read(&node.instruction, temp_reg) { + temp_is_dead = false; + break; + } + // If the temp is redefined, then the old value is dead + if let Some(redef) = get_destination_reg(&node.instruction) + && redef == temp_reg + { + break; + } + + // Function calls (jal) clobber the return register (r15) + // So if we're tracking r15 and hit a function call, the old value is dead + if matches!(node.instruction, Instruction::JumpAndLink(_)) && temp_reg == 15 { + break; + } + + // Labels are just markers - they don't affect register liveness + // But backward jumps create loops we need to analyze carefully + let jump_target = match &node.instruction { + Instruction::Jump(Operand::Label(target)) => Some(target.as_ref()), + Instruction::BranchEq(_, _, Operand::Label(target)) + | Instruction::BranchNe(_, _, Operand::Label(target)) + | Instruction::BranchGt(_, _, Operand::Label(target)) + | Instruction::BranchLt(_, _, Operand::Label(target)) + | Instruction::BranchGe(_, _, Operand::Label(target)) + | Instruction::BranchLe(_, _, Operand::Label(target)) + | Instruction::BranchEqZero(_, Operand::Label(target)) + | Instruction::BranchNeZero(_, Operand::Label(target)) => Some(target.as_ref()), + _ => None, + }; + + if let Some(target) = jump_target { + // Check if this is a backward jump (target appears before current position) + if let Some(&target_pos) = label_positions.get(target) { + if target_pos < i { + // Backward jump - could loop back, register might be live + temp_is_dead = false; + break; + } + // Forward jump is OK - doesn't affect liveness before it + } + } + } + + if temp_is_dead { + // Safety check: ensure final_reg is not used as an operand in the current instruction. + // This prevents generating invalid instructions like `sub r5 r0 r5` (read and write same register). + if !reg_is_read(&input[i].instruction, final_reg) { + // Rewrite to use final destination directly + if let Some(new_instr) = set_destination_reg(&input[i].instruction, final_reg) { + input[i].instruction = new_instr; + input.remove(next_idx); + changed = true; + continue; + } + } + } + } + + i += 1; + } + + (input, changed) +} + +#[cfg(test)] +mod tests { + use super::*; + use il::{Instruction, InstructionNode, Operand}; + + #[test] + fn test_forward_simple_move() { + let input = vec![ + InstructionNode::new( + Instruction::Add( + Operand::Register(1), + Operand::Register(2), + Operand::Register(3), + ), + None, + ), + InstructionNode::new( + Instruction::Move(Operand::Register(5), Operand::Register(1)), + None, + ), + ]; + + let (output, changed) = register_forwarding(input); + assert!(changed); + assert_eq!(output.len(), 1); + assert!(matches!( + output[0].instruction, + Instruction::Add(Operand::Register(5), _, _) + )); + } +} diff --git a/rust_compiler/libs/optimizer/src/strength_reduction.rs b/rust_compiler/libs/optimizer/src/strength_reduction.rs new file mode 100644 index 0000000..f5e4917 --- /dev/null +++ b/rust_compiler/libs/optimizer/src/strength_reduction.rs @@ -0,0 +1,63 @@ +use il::{Instruction, InstructionNode, Operand}; +use rust_decimal::Decimal; + +/// Pass: Strength Reduction +/// Replaces expensive operations with cheaper equivalents. +/// Example: x * 2 -> add x x x (addition is typically faster than multiplication) +pub fn strength_reduction<'a>( + input: Vec>, +) -> (Vec>, bool) { + let mut output = Vec::with_capacity(input.len()); + let mut changed = false; + + for mut node in input { + let reduced = match &node.instruction { + // x * 2 = x + x + Instruction::Mul(dst, a, Operand::Number(n)) if *n == Decimal::from(2) => { + Some(Instruction::Add(dst.clone(), a.clone(), a.clone())) + } + Instruction::Mul(dst, Operand::Number(n), b) if *n == Decimal::from(2) => { + Some(Instruction::Add(dst.clone(), b.clone(), b.clone())) + } + + // Future: Could add power-of-2 optimizations using bit shifts if IC10 supports them + // x * 4 = (x + x) + (x + x) or x << 2 + // x / 2 = x >> 1 + + _ => None, + }; + + if let Some(new) = reduced { + node.instruction = new; + changed = true; + } + + output.push(node); + } + + (output, changed) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_mul_two_to_add() { + let input = vec![InstructionNode::new( + Instruction::Mul( + Operand::Register(1), + Operand::Register(2), + Operand::Number(Decimal::from(2)), + ), + None, + )]; + + let (output, changed) = strength_reduction(input); + assert!(changed); + assert!(matches!( + output[0].instruction, + Instruction::Add(Operand::Register(1), Operand::Register(2), Operand::Register(2)) + )); + } +} diff --git a/rust_compiler/libs/parser/src/lib.rs b/rust_compiler/libs/parser/src/lib.rs index 63e8ec2..2c45cdc 100644 --- a/rust_compiler/libs/parser/src/lib.rs +++ b/rust_compiler/libs/parser/src/lib.rs @@ -441,7 +441,13 @@ impl<'a> Parser<'a> { )); } - TokenType::Keyword(Keyword::Let) => Some(self.spanned(|p| p.declaration())?), + TokenType::Keyword(Keyword::Let) => { + if self_matches_peek!(self, TokenType::Symbol(Symbol::LParen)) { + Some(self.spanned(|p| p.tuple_declaration())?) + } else { + Some(self.spanned(|p| p.declaration())?) + } + } TokenType::Keyword(Keyword::Device) => { let spanned_dev = self.spanned(|p| p.device())?; @@ -561,9 +567,7 @@ impl<'a> Parser<'a> { }) } - TokenType::Symbol(Symbol::LParen) => { - self.spanned(|p| p.priority())?.node.map(|node| *node) - } + TokenType::Symbol(Symbol::LParen) => self.parenthesized_or_tuple()?, TokenType::Symbol(Symbol::Minus) => { let start_span = self.current_span(); @@ -642,8 +646,8 @@ impl<'a> Parser<'a> { } } TokenType::Symbol(Symbol::LParen) => *self - .spanned(|p| p.priority())? - .node + .parenthesized_or_tuple()? + .map(Box::new) .ok_or(Error::UnexpectedEOF)?, TokenType::Identifier(ref id) if SysCall::is_syscall(id) => { @@ -774,7 +778,8 @@ impl<'a> Parser<'a> { | Expression::Ternary(_) | Expression::Negation(_) | Expression::MemberAccess(_) - | Expression::MethodCall(_) => {} + | Expression::MethodCall(_) + | Expression::Tuple(_) => {} _ => { return Err(Error::InvalidSyntax( self.current_span(), @@ -1081,19 +1086,39 @@ impl<'a> Parser<'a> { end_col: right.span.end_col, }; - expressions.insert( - i, - Spanned { + // Check if the left side is a tuple, and if so, create a TupleAssignment + let node = if let Expression::Tuple(tuple_expr) = &left.node { + // Extract variable names from the tuple, handling underscores + let mut names = Vec::new(); + for item in &tuple_expr.node { + if let Expression::Variable(var) = &item.node { + names.push(var.clone()); + } else { + return Err(Error::InvalidSyntax( + item.span, + String::from("Tuple assignment can only contain variable names"), + )); + } + } + + Expression::TupleAssignment(Spanned { span, - node: Expression::Assignment(Spanned { - span, - node: AssignmentExpression { - assignee: boxed!(left), - expression: boxed!(right), - }, - }), - }, - ); + node: TupleAssignmentExpression { + names, + value: boxed!(right), + }, + }) + } else { + Expression::Assignment(Spanned { + span, + node: AssignmentExpression { + assignee: boxed!(left), + expression: boxed!(right), + }, + }) + }; + + expressions.insert(i, Spanned { span, node }); } } operators.retain(|symbol| !matches!(symbol, Symbol::Assign)); @@ -1117,8 +1142,12 @@ impl<'a> Parser<'a> { expressions.pop().ok_or(Error::UnexpectedEOF) } - fn priority(&mut self) -> Result>>>, Error<'a>> { + fn parenthesized_or_tuple( + &mut self, + ) -> Result>>, Error<'a>> { + let start_span = self.current_span(); let current_token = self.current_token.as_ref().ok_or(Error::UnexpectedEOF)?; + if !token_matches!(current_token, TokenType::Symbol(Symbol::LParen)) { return Err(Error::UnexpectedToken( self.current_span(), @@ -1127,17 +1156,112 @@ impl<'a> Parser<'a> { } self.assign_next()?; - let expression = self.expression()?.ok_or(Error::UnexpectedEOF)?; - let current_token = self.get_next()?.ok_or(Error::UnexpectedEOF)?; - if !token_matches!(current_token, TokenType::Symbol(Symbol::RParen)) { - return Err(Error::UnexpectedToken( - Self::token_to_span(¤t_token), - current_token, - )); + // Handle empty tuple '()' + if self_matches_peek!(self, TokenType::Symbol(Symbol::RParen)) { + self.assign_next()?; + let end_span = self.current_span(); + let span = Span { + start_line: start_span.start_line, + start_col: start_span.start_col, + end_line: end_span.end_line, + end_col: end_span.end_col, + }; + return Ok(Some(Spanned { + span, + node: Expression::Tuple(Spanned { span, node: vec![] }), + })); } - Ok(Some(boxed!(expression))) + let first_expression = self.expression()?.ok_or(Error::UnexpectedEOF)?; + + if self_matches_peek!(self, TokenType::Symbol(Symbol::Comma)) { + // It is a tuple + let mut items = vec![first_expression]; + while self_matches_peek!(self, TokenType::Symbol(Symbol::Comma)) { + // Next toekn is a comma, we need to consume it and advance 1 more time. + self.assign_next()?; + self.assign_next()?; + items.push(self.expression()?.ok_or(Error::UnexpectedEOF)?); + } + + let next = self.get_next()?.ok_or(Error::UnexpectedEOF)?; + if !token_matches!(next, TokenType::Symbol(Symbol::RParen)) { + return Err(Error::UnexpectedToken(Self::token_to_span(&next), next)); + } + + let end_span = Self::token_to_span(&next); + let span = Span { + start_line: start_span.start_line, + start_col: start_span.start_col, + end_line: end_span.end_line, + end_col: end_span.end_col, + }; + + Ok(Some(Spanned { + span, + node: Expression::Tuple(Spanned { span, node: items }), + })) + } else { + // It is just priority + let next = self.get_next()?.ok_or(Error::UnexpectedEOF)?; + if !token_matches!(next, TokenType::Symbol(Symbol::RParen)) { + return Err(Error::UnexpectedToken(Self::token_to_span(&next), next)); + } + + Ok(Some(Spanned { + span: first_expression.span, + node: Expression::Priority(boxed!(first_expression)), + })) + } + } + + fn tuple_declaration(&mut self) -> Result, Error<'a>> { + // 'let' is consumed before this call + // expect '(' + let next = self.get_next()?.ok_or(Error::UnexpectedEOF)?; + if !token_matches!(next, TokenType::Symbol(Symbol::LParen)) { + return Err(Error::UnexpectedToken(Self::token_to_span(&next), next)); + } + + let mut names = Vec::new(); + while !self_matches_peek!(self, TokenType::Symbol(Symbol::RParen)) { + let token = self.get_next()?.ok_or(Error::UnexpectedEOF)?; + let span = Self::token_to_span(&token); + if let TokenType::Identifier(id) = token.token_type { + names.push(Spanned { span, node: id }); + } else { + return Err(Error::UnexpectedToken(span, token)); + } + + if self_matches_peek!(self, TokenType::Symbol(Symbol::Comma)) { + self.assign_next()?; + } + } + self.assign_next()?; // consume ')' + + let assign = self.get_next()?.ok_or(Error::UnexpectedEOF)?; + + if !token_matches!(assign, TokenType::Symbol(Symbol::Assign)) { + return Err(Error::UnexpectedToken(Self::token_to_span(&assign), assign)); + } + + self.assign_next()?; // Consume the `=` + + let value = self.expression()?.ok_or(Error::UnexpectedEOF)?; + + let semi = self.get_next()?.ok_or(Error::UnexpectedEOF)?; + if !token_matches!(semi, TokenType::Symbol(Symbol::Semicolon)) { + return Err(Error::UnexpectedToken(Self::token_to_span(&semi), semi)); + } + + Ok(Expression::TupleDeclaration(Spanned { + span: names.first().map(|n| n.span).unwrap_or(value.span), + node: TupleDeclarationExpression { + names, + value: boxed!(value), + }, + })) } fn invocation(&mut self) -> Result, Error<'a>> { diff --git a/rust_compiler/libs/parser/src/test/mod.rs b/rust_compiler/libs/parser/src/test/mod.rs index fe27c43..a08f6ce 100644 --- a/rust_compiler/libs/parser/src/test/mod.rs +++ b/rust_compiler/libs/parser/src/test/mod.rs @@ -112,7 +112,7 @@ fn test_function_invocation() -> Result<()> { #[test] fn test_priority_expression() -> Result<()> { let input = r#" - let x = (4); + let x = (4 + 3); "#; let tokenizer = Tokenizer::from(input); @@ -120,7 +120,7 @@ fn test_priority_expression() -> Result<()> { let expression = parser.parse()?.unwrap(); - assert_eq!("(let x = 4)", expression.to_string()); + assert_eq!("(let x = ((4 + 3)))", expression.to_string()); Ok(()) } @@ -137,7 +137,7 @@ fn test_binary_expression() -> Result<()> { assert_eq!("(((45 * 2) - (15 / 5)) + (5 ** 2))", expr.to_string()); let expr = parser!("(5 - 2) * 10;").parse()?.unwrap(); - assert_eq!("((5 - 2) * 10)", expr.to_string()); + assert_eq!("(((5 - 2)) * 10)", expr.to_string()); Ok(()) } @@ -170,7 +170,7 @@ fn test_ternary_expression() -> Result<()> { fn test_complex_binary_with_ternary() -> Result<()> { let expr = parser!("let i = (x ? 1 : 3) * 2;").parse()?.unwrap(); - assert_eq!("(let i = ((x ? 1 : 3) * 2))", expr.to_string()); + assert_eq!("(let i = (((x ? 1 : 3)) * 2))", expr.to_string()); Ok(()) } @@ -191,3 +191,99 @@ fn test_nested_ternary_right_associativity() -> Result<()> { assert_eq!("(let i = (a ? b : (c ? d : e)))", expr.to_string()); Ok(()) } + +#[test] +fn test_tuple_declaration() -> Result<()> { + let expr = parser!("let (x, _) = (1, 2);").parse()?.unwrap(); + + assert_eq!("(let (x, _) = (1, 2))", expr.to_string()); + + Ok(()) +} +#[test] +fn test_tuple_assignment() -> Result<()> { + let expr = parser!("(x, y) = (1, 2);").parse()?.unwrap(); + + assert_eq!("((x, y) = (1, 2))", expr.to_string()); + + Ok(()) +} + +#[test] +fn test_tuple_assignment_with_underscore() -> Result<()> { + let expr = parser!("(x, _) = (1, 2);").parse()?.unwrap(); + + assert_eq!("((x, _) = (1, 2))", expr.to_string()); + + Ok(()) +} + +#[test] +fn test_tuple_declaration_with_function_call() -> Result<()> { + let expr = parser!("let (x, y) = doSomething();").parse()?.unwrap(); + + assert_eq!("(let (x, y) = doSomething())", expr.to_string()); + + Ok(()) +} + +#[test] +fn test_tuple_declaration_with_function_call_with_underscore() -> Result<()> { + let expr = parser!("let (x, _) = doSomething();").parse()?.unwrap(); + + assert_eq!("(let (x, _) = doSomething())", expr.to_string()); + + Ok(()) +} + +#[test] +fn test_tuple_assignment_with_function_call() -> Result<()> { + let expr = parser!("(x, y) = doSomething();").parse()?.unwrap(); + + assert_eq!("((x, y) = doSomething())", expr.to_string()); + + Ok(()) +} + +#[test] +fn test_tuple_assignment_with_function_call_with_underscore() -> Result<()> { + let expr = parser!("(x, _) = doSomething();").parse()?.unwrap(); + + assert_eq!("((x, _) = doSomething())", expr.to_string()); + + Ok(()) +} + +#[test] +fn test_tuple_declaration_with_complex_expressions() -> Result<()> { + let expr = parser!("let (x, y) = (1 + 1, doSomething());") + .parse()? + .unwrap(); + + assert_eq!("(let (x, y) = ((1 + 1), doSomething()))", expr.to_string()); + + Ok(()) +} + +#[test] +fn test_tuple_assignment_with_complex_expressions() -> Result<()> { + let expr = parser!("(x, y) = (doSomething(), 123 / someValue.Setting);") + .parse()? + .unwrap(); + + assert_eq!( + "((x, y) = (doSomething(), (123 / someValue.Setting)))", + expr.to_string() + ); + + Ok(()) +} + +#[test] +fn test_tuple_declaration_all_complex_expressions() -> Result<()> { + let expr = parser!("let (x, y) = (a + b, c * d);").parse()?.unwrap(); + + assert_eq!("(let (x, y) = ((a + b), (c * d)))", expr.to_string()); + + Ok(()) +} diff --git a/rust_compiler/libs/parser/src/tree_node.rs b/rust_compiler/libs/parser/src/tree_node.rs index 2f21aef..3da1305 100644 --- a/rust_compiler/libs/parser/src/tree_node.rs +++ b/rust_compiler/libs/parser/src/tree_node.rs @@ -245,6 +245,42 @@ impl<'a> std::fmt::Display for DeviceDeclarationExpression<'a> { } } +#[derive(Debug, PartialEq, Eq)] +pub struct TupleDeclarationExpression<'a> { + pub names: Vec>>, + pub value: Box>>, +} + +impl<'a> std::fmt::Display for TupleDeclarationExpression<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let names = self + .names + .iter() + .map(|n| n.node.to_string()) + .collect::>() + .join(", "); + write!(f, "(let ({}) = {})", names, self.value) + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct TupleAssignmentExpression<'a> { + pub names: Vec>>, + pub value: Box>>, +} + +impl<'a> std::fmt::Display for TupleAssignmentExpression<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let names = self + .names + .iter() + .map(|n| n.node.to_string()) + .collect::>() + .join(", "); + write!(f, "(({}) = {})", names, self.value) + } +} + #[derive(Debug, PartialEq, Eq)] pub struct IfExpression<'a> { pub condition: Box>>, @@ -348,6 +384,9 @@ pub enum Expression<'a> { Return(Option>>>), Syscall(Spanned>), Ternary(Spanned>), + Tuple(Spanned>>>), + TupleAssignment(Spanned>), + TupleDeclaration(Spanned>), Variable(Spanned>), While(Spanned>), } @@ -384,8 +423,20 @@ impl<'a> std::fmt::Display for Expression<'a> { ), Expression::Syscall(e) => write!(f, "{}", e), Expression::Ternary(e) => write!(f, "{}", e), + Expression::Tuple(e) => { + let items = e + .node + .iter() + .map(|x| x.to_string()) + .collect::>() + .join(", "); + write!(f, "({})", items) + } + Expression::TupleAssignment(e) => write!(f, "{}", e), + Expression::TupleDeclaration(e) => write!(f, "{}", e), Expression::Variable(id) => write!(f, "{}", id), Expression::While(e) => write!(f, "{}", e), } } } +