More optimizations and snapshot integration tests
This commit is contained in:
212
rust_compiler/libs/optimizer/OPTIMIZATION_IDEAS.md
Normal file
212
rust_compiler/libs/optimizer/OPTIMIZATION_IDEAS.md
Normal file
@@ -0,0 +1,212 @@
|
||||
# Additional Optimization Opportunities for Slang IL Optimizer
|
||||
|
||||
## Currently Implemented ✓
|
||||
|
||||
1. Constant Propagation - Folds math operations with known values
|
||||
2. Register Forwarding - Eliminates intermediate moves
|
||||
3. Function Call Optimization - Removes unnecessary push/pop around calls
|
||||
4. Leaf Function Optimization - Removes RA save/restore for non-calling functions
|
||||
5. Redundant Move Elimination - Removes `move rx rx`
|
||||
6. Dead Code Elimination - Removes unreachable code after jumps
|
||||
|
||||
## Proposed Additional Optimizations
|
||||
|
||||
### 1. **Algebraic Simplification** 🔥 HIGH IMPACT
|
||||
|
||||
Simplify mathematical identities:
|
||||
|
||||
- `x + 0` → `x` (move)
|
||||
- `x - 0` → `x` (move)
|
||||
- `x * 1` → `x` (move)
|
||||
- `x * 0` → `0` (move to constant)
|
||||
- `x / 1` → `x` (move)
|
||||
- `x - x` → `0` (move to constant)
|
||||
- `x % 1` → `0` (move to constant)
|
||||
|
||||
**Example:**
|
||||
|
||||
```
|
||||
add r1 r2 0 → move r1 r2
|
||||
mul r3 r4 1 → move r3 r4
|
||||
mul r5 r6 0 → move r5 0
|
||||
```
|
||||
|
||||
### 2. **Strength Reduction** 🔥 HIGH IMPACT
|
||||
|
||||
Replace expensive operations with cheaper ones:
|
||||
|
||||
- `x * 2` → `add x x x` (addition is cheaper than multiplication)
|
||||
- `x * power_of_2` → bit shifts (if IC10 supports)
|
||||
- `x / 2` → bit shifts (if IC10 supports)
|
||||
|
||||
**Example:**
|
||||
|
||||
```
|
||||
mul r1 r2 2 → add r1 r2 r2
|
||||
```
|
||||
|
||||
### 3. **Peephole Optimization - Instruction Sequences** 🔥 MEDIUM-HIGH IMPACT
|
||||
|
||||
Recognize and optimize common instruction patterns:
|
||||
|
||||
#### Pattern: Conditional Branch Simplification
|
||||
|
||||
```
|
||||
seq r1 ra rb → beq ra rb label
|
||||
beqz r1 label (remove the seq entirely)
|
||||
|
||||
sne r1 ra rb → bne ra rb label
|
||||
beqz r1 label (remove the sne entirely)
|
||||
```
|
||||
|
||||
#### Pattern: Double Move Elimination
|
||||
|
||||
```
|
||||
move r1 r2 → move r1 r3
|
||||
move r1 r3 (remove first move if r1 not used between)
|
||||
```
|
||||
|
||||
#### Pattern: Redundant Load Elimination
|
||||
|
||||
If a register's value is already loaded and hasn't been clobbered:
|
||||
|
||||
```
|
||||
l r1 d0 Temperature
|
||||
... (no writes to r1)
|
||||
l r1 d0 Temperature → (remove second load)
|
||||
```
|
||||
|
||||
### 4. **Copy Propagation Enhancement** 🔥 MEDIUM IMPACT
|
||||
|
||||
Current register forwarding is good, but we can extend it:
|
||||
|
||||
- Track `move` chains: if `r1 = r2` and `r2 = 5`, propagate the `5` directly
|
||||
- Eliminate the intermediate register if possible
|
||||
|
||||
### 5. **Dead Store Elimination** 🔥 MEDIUM IMPACT
|
||||
|
||||
Remove writes to registers that are never read before being overwritten:
|
||||
|
||||
```
|
||||
move r1 5
|
||||
move r1 10 → move r1 10
|
||||
(first write is dead)
|
||||
```
|
||||
|
||||
### 6. **Common Subexpression Elimination (CSE)** 🔥 MEDIUM-HIGH IMPACT
|
||||
|
||||
Recognize when the same computation is done multiple times:
|
||||
|
||||
```
|
||||
add r1 r8 r9
|
||||
add r2 r8 r9 → add r1 r8 r9
|
||||
move r2 r1
|
||||
```
|
||||
|
||||
This is especially valuable for expensive operations like:
|
||||
|
||||
- Device loads (`l`)
|
||||
- Math functions (sqrt, sin, cos, etc.)
|
||||
|
||||
### 7. **Jump Threading** 🔥 LOW-MEDIUM IMPACT
|
||||
|
||||
Optimize jump-to-jump sequences:
|
||||
|
||||
```
|
||||
j label1
|
||||
...
|
||||
label1:
|
||||
j label2 → j label2 (rewrite first jump)
|
||||
```
|
||||
|
||||
### 8. **Branch Folding** 🔥 LOW-MEDIUM IMPACT
|
||||
|
||||
Merge consecutive branches to the same target:
|
||||
|
||||
```
|
||||
bgt r1 r2 label
|
||||
bgt r3 r4 label → Could potentially be optimized based on conditions
|
||||
```
|
||||
|
||||
### 9. **Loop Invariant Code Motion** 🔥 MEDIUM-HIGH IMPACT
|
||||
|
||||
Move calculations out of loops if they don't change:
|
||||
|
||||
```
|
||||
loop:
|
||||
mul r2 5 10 → mul r2 5 10 (hoisted before loop)
|
||||
add r1 r1 r2 loop:
|
||||
... add r1 r1 r2
|
||||
j loop ...
|
||||
j loop
|
||||
```
|
||||
|
||||
### 10. **Select Instruction Optimization** 🔥 LOW-MEDIUM IMPACT
|
||||
|
||||
The `select` instruction can sometimes replace branch patterns:
|
||||
|
||||
```
|
||||
beq r1 r2 else
|
||||
move r3 r4
|
||||
j end
|
||||
else:
|
||||
move r3 r5 → seq r6 r1 r2
|
||||
end: select r3 r6 r5 r4
|
||||
```
|
||||
|
||||
### 11. **Stack Access Pattern Optimization** 🔥 LOW IMPACT
|
||||
|
||||
If we see repeated `sub r0 sp N` + `get`, we might be able to optimize by:
|
||||
|
||||
- Caching the stack address in a register if used multiple times
|
||||
- Combining sequential gets from adjacent stack slots
|
||||
|
||||
### 12. **Inline Small Functions** 🔥 HIGH IMPACT (Complex to implement)
|
||||
|
||||
For very small leaf functions (1-2 instructions), inline them at the call site:
|
||||
|
||||
```
|
||||
calculateSum:
|
||||
add r15 r8 r9
|
||||
j ra
|
||||
|
||||
main:
|
||||
push 5 → main:
|
||||
push 10 add r15 5 10
|
||||
jal calculateSum
|
||||
```
|
||||
|
||||
### 13. **Branch Prediction Hints** 🔥 LOW IMPACT
|
||||
|
||||
Reorganize code to put likely branches inline (fall-through) and unlikely branches as jumps.
|
||||
|
||||
### 14. **Register Coalescing** 🔥 MEDIUM IMPACT
|
||||
|
||||
Reduce register pressure by reusing registers that have non-overlapping lifetimes.
|
||||
|
||||
## Priority Implementation Order
|
||||
|
||||
### Phase 1 (Quick Wins):
|
||||
|
||||
1. Algebraic Simplification (easy, high impact)
|
||||
2. Strength Reduction (easy, high impact)
|
||||
3. Dead Store Elimination (medium complexity, good impact)
|
||||
|
||||
### Phase 2 (Medium Effort):
|
||||
|
||||
4. Peephole Optimizations - seq/beq pattern (medium, high impact)
|
||||
5. Common Subexpression Elimination (medium, high impact)
|
||||
6. Copy Propagation Enhancement (medium, medium impact)
|
||||
|
||||
### Phase 3 (Advanced):
|
||||
|
||||
7. Loop Invariant Code Motion (complex, high impact for loop-heavy code)
|
||||
8. Function Inlining (complex, high impact)
|
||||
9. Register Coalescing (complex, medium impact)
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
- Add test cases for each optimization
|
||||
- Ensure optimization preserves semantics (run existing tests after each)
|
||||
- Measure code size reduction
|
||||
- Consider adding benchmarks to measure game performance impact
|
||||
161
rust_compiler/libs/optimizer/src/algebraic_simplification.rs
Normal file
161
rust_compiler/libs/optimizer/src/algebraic_simplification.rs
Normal file
@@ -0,0 +1,161 @@
|
||||
use il::{Instruction, InstructionNode, Operand};
|
||||
use rust_decimal::Decimal;
|
||||
|
||||
/// Pass: Algebraic Simplification
|
||||
/// Simplifies mathematical identities like x+0, x*1, x*0, etc.
|
||||
pub fn algebraic_simplification<'a>(
|
||||
input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
let mut changed = false;
|
||||
|
||||
for mut node in input {
|
||||
let simplified = match &node.instruction {
|
||||
// x + 0 = x
|
||||
Instruction::Add(dst, a, Operand::Number(n)) if n.is_zero() => {
|
||||
Some(Instruction::Move(dst.clone(), a.clone()))
|
||||
}
|
||||
Instruction::Add(dst, Operand::Number(n), b) if n.is_zero() => {
|
||||
Some(Instruction::Move(dst.clone(), b.clone()))
|
||||
}
|
||||
|
||||
// x - 0 = x
|
||||
Instruction::Sub(dst, a, Operand::Number(n)) if n.is_zero() => {
|
||||
Some(Instruction::Move(dst.clone(), a.clone()))
|
||||
}
|
||||
|
||||
// x * 1 = x
|
||||
Instruction::Mul(dst, a, Operand::Number(n)) if *n == Decimal::from(1) => {
|
||||
Some(Instruction::Move(dst.clone(), a.clone()))
|
||||
}
|
||||
Instruction::Mul(dst, Operand::Number(n), b) if *n == Decimal::from(1) => {
|
||||
Some(Instruction::Move(dst.clone(), b.clone()))
|
||||
}
|
||||
|
||||
// x * 0 = 0
|
||||
Instruction::Mul(dst, _, Operand::Number(n)) if n.is_zero() => {
|
||||
Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO)))
|
||||
}
|
||||
Instruction::Mul(dst, Operand::Number(n), _) if n.is_zero() => {
|
||||
Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO)))
|
||||
}
|
||||
|
||||
// x / 1 = x
|
||||
Instruction::Div(dst, a, Operand::Number(n)) if *n == Decimal::from(1) => {
|
||||
Some(Instruction::Move(dst.clone(), a.clone()))
|
||||
}
|
||||
|
||||
// 0 / x = 0 (if x != 0, but we can't check at compile time for non-literals)
|
||||
Instruction::Div(dst, Operand::Number(n), _) if n.is_zero() => {
|
||||
Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO)))
|
||||
}
|
||||
|
||||
// x % 1 = 0
|
||||
Instruction::Mod(dst, _, Operand::Number(n)) if *n == Decimal::from(1) => {
|
||||
Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO)))
|
||||
}
|
||||
|
||||
// 0 % x = 0
|
||||
Instruction::Mod(dst, Operand::Number(n), _) if n.is_zero() => {
|
||||
Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO)))
|
||||
}
|
||||
|
||||
// x AND 0 = 0
|
||||
Instruction::And(dst, _, Operand::Number(n)) if n.is_zero() => {
|
||||
Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO)))
|
||||
}
|
||||
Instruction::And(dst, Operand::Number(n), _) if n.is_zero() => {
|
||||
Some(Instruction::Move(dst.clone(), Operand::Number(Decimal::ZERO)))
|
||||
}
|
||||
|
||||
// x OR 0 = x
|
||||
Instruction::Or(dst, a, Operand::Number(n)) if n.is_zero() => {
|
||||
Some(Instruction::Move(dst.clone(), a.clone()))
|
||||
}
|
||||
Instruction::Or(dst, Operand::Number(n), b) if n.is_zero() => {
|
||||
Some(Instruction::Move(dst.clone(), b.clone()))
|
||||
}
|
||||
|
||||
// x XOR 0 = x
|
||||
Instruction::Xor(dst, a, Operand::Number(n)) if n.is_zero() => {
|
||||
Some(Instruction::Move(dst.clone(), a.clone()))
|
||||
}
|
||||
Instruction::Xor(dst, Operand::Number(n), b) if n.is_zero() => {
|
||||
Some(Instruction::Move(dst.clone(), b.clone()))
|
||||
}
|
||||
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(new) = simplified {
|
||||
node.instruction = new;
|
||||
changed = true;
|
||||
}
|
||||
|
||||
output.push(node);
|
||||
}
|
||||
|
||||
(output, changed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_add_zero() {
|
||||
let input = vec![InstructionNode::new(
|
||||
Instruction::Add(
|
||||
Operand::Register(1),
|
||||
Operand::Register(2),
|
||||
Operand::Number(Decimal::ZERO),
|
||||
),
|
||||
None,
|
||||
)];
|
||||
|
||||
let (output, changed) = algebraic_simplification(input);
|
||||
assert!(changed);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::Move(Operand::Register(1), Operand::Register(2))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul_one() {
|
||||
let input = vec![InstructionNode::new(
|
||||
Instruction::Mul(
|
||||
Operand::Register(3),
|
||||
Operand::Register(4),
|
||||
Operand::Number(Decimal::ONE),
|
||||
),
|
||||
None,
|
||||
)];
|
||||
|
||||
let (output, changed) = algebraic_simplification(input);
|
||||
assert!(changed);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::Move(Operand::Register(3), Operand::Register(4))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mul_zero() {
|
||||
let input = vec![InstructionNode::new(
|
||||
Instruction::Mul(
|
||||
Operand::Register(5),
|
||||
Operand::Register(6),
|
||||
Operand::Number(Decimal::ZERO),
|
||||
),
|
||||
None,
|
||||
)];
|
||||
|
||||
let (output, changed) = algebraic_simplification(input);
|
||||
assert!(changed);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::Move(Operand::Register(5), Operand::Number(_))
|
||||
));
|
||||
}
|
||||
}
|
||||
168
rust_compiler/libs/optimizer/src/constant_propagation.rs
Normal file
168
rust_compiler/libs/optimizer/src/constant_propagation.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
use crate::helpers::get_destination_reg;
|
||||
use il::{Instruction, InstructionNode, Operand};
|
||||
use rust_decimal::Decimal;
|
||||
|
||||
/// Pass: Constant Propagation
|
||||
/// Folds arithmetic operations when both operands are constant.
|
||||
/// Also tracks register values and propagates them forward.
|
||||
pub fn constant_propagation<'a>(
|
||||
input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
let mut changed = false;
|
||||
let mut registers: [Option<Decimal>; 16] = [None; 16];
|
||||
|
||||
for mut node in input {
|
||||
// Invalidate register tracking on label/call boundaries
|
||||
match &node.instruction {
|
||||
Instruction::LabelDef(_) | Instruction::JumpAndLink(_) => registers = [None; 16],
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let simplified = match &node.instruction {
|
||||
Instruction::Move(dst, src) => resolve_value(src, ®isters)
|
||||
.map(|val| Instruction::Move(dst.clone(), Operand::Number(val))),
|
||||
Instruction::Add(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x + y),
|
||||
Instruction::Sub(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x - y),
|
||||
Instruction::Mul(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x * y),
|
||||
Instruction::Div(dst, a, b) => {
|
||||
try_fold_math(
|
||||
dst,
|
||||
a,
|
||||
b,
|
||||
®isters,
|
||||
|x, y| if y.is_zero() { x } else { x / y },
|
||||
)
|
||||
}
|
||||
Instruction::Mod(dst, a, b) => {
|
||||
try_fold_math(
|
||||
dst,
|
||||
a,
|
||||
b,
|
||||
®isters,
|
||||
|x, y| if y.is_zero() { x } else { x % y },
|
||||
)
|
||||
}
|
||||
Instruction::BranchEq(a, b, l) => {
|
||||
try_resolve_branch(a, b, l, ®isters, |x, y| x == y)
|
||||
}
|
||||
Instruction::BranchNe(a, b, l) => {
|
||||
try_resolve_branch(a, b, l, ®isters, |x, y| x != y)
|
||||
}
|
||||
Instruction::BranchGt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x > y),
|
||||
Instruction::BranchLt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x < y),
|
||||
Instruction::BranchGe(a, b, l) => {
|
||||
try_resolve_branch(a, b, l, ®isters, |x, y| x >= y)
|
||||
}
|
||||
Instruction::BranchLe(a, b, l) => {
|
||||
try_resolve_branch(a, b, l, ®isters, |x, y| x <= y)
|
||||
}
|
||||
Instruction::BranchEqZero(a, l) => {
|
||||
try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x == y)
|
||||
}
|
||||
Instruction::BranchNeZero(a, l) => {
|
||||
try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x != y)
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(new) = simplified {
|
||||
node.instruction = new;
|
||||
changed = true;
|
||||
}
|
||||
|
||||
// Update register tracking
|
||||
match &node.instruction {
|
||||
Instruction::Move(Operand::Register(r), src) => {
|
||||
registers[*r as usize] = resolve_value(src, ®isters)
|
||||
}
|
||||
_ => {
|
||||
if let Some(r) = get_destination_reg(&node.instruction) {
|
||||
registers[r as usize] = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out NOPs (empty labels from branch resolution)
|
||||
if let Instruction::LabelDef(l) = &node.instruction
|
||||
&& l.is_empty()
|
||||
{
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
output.push(node);
|
||||
}
|
||||
(output, changed)
|
||||
}
|
||||
|
||||
fn resolve_value(op: &Operand, regs: &[Option<Decimal>; 16]) -> Option<Decimal> {
|
||||
match op {
|
||||
Operand::Number(n) => Some(*n),
|
||||
Operand::Register(r) => regs[*r as usize],
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn try_fold_math<'a, F>(
|
||||
dst: &Operand<'a>,
|
||||
a: &Operand<'a>,
|
||||
b: &Operand<'a>,
|
||||
regs: &[Option<Decimal>; 16],
|
||||
op: F,
|
||||
) -> Option<Instruction<'a>>
|
||||
where
|
||||
F: Fn(Decimal, Decimal) -> Decimal,
|
||||
{
|
||||
let val_a = resolve_value(a, regs)?;
|
||||
let val_b = resolve_value(b, regs)?;
|
||||
Some(Instruction::Move(
|
||||
dst.clone(),
|
||||
Operand::Number(op(val_a, val_b)),
|
||||
))
|
||||
}
|
||||
|
||||
fn try_resolve_branch<'a, F>(
|
||||
a: &Operand<'a>,
|
||||
b: &Operand<'a>,
|
||||
label: &Operand<'a>,
|
||||
regs: &[Option<Decimal>; 16],
|
||||
check: F,
|
||||
) -> Option<Instruction<'a>>
|
||||
where
|
||||
F: Fn(Decimal, Decimal) -> bool,
|
||||
{
|
||||
let val_a = resolve_value(a, regs)?;
|
||||
let val_b = resolve_value(b, regs)?;
|
||||
if check(val_a, val_b) {
|
||||
Some(Instruction::Jump(label.clone()))
|
||||
} else {
|
||||
Some(Instruction::LabelDef("".into())) // NOP
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use il::InstructionNode;
|
||||
|
||||
#[test]
|
||||
fn test_fold_add() {
|
||||
let input = vec![InstructionNode::new(
|
||||
Instruction::Add(
|
||||
Operand::Register(1),
|
||||
Operand::Number(5.into()),
|
||||
Operand::Number(3.into()),
|
||||
),
|
||||
None,
|
||||
)];
|
||||
|
||||
let (output, changed) = constant_propagation(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::Move(Operand::Register(1), Operand::Number(_))
|
||||
));
|
||||
}
|
||||
}
|
||||
82
rust_compiler/libs/optimizer/src/dead_code.rs
Normal file
82
rust_compiler/libs/optimizer/src/dead_code.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use il::{Instruction, InstructionNode};
|
||||
|
||||
/// Pass: Redundant Move Elimination
|
||||
/// Removes moves where source and destination are the same: `move rx rx`
|
||||
pub fn remove_redundant_moves<'a>(
|
||||
input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
let mut changed = false;
|
||||
for node in input {
|
||||
if let Instruction::Move(dst, src) = &node.instruction
|
||||
&& dst == src
|
||||
{
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
output.push(node);
|
||||
}
|
||||
(output, changed)
|
||||
}
|
||||
|
||||
/// Pass: Dead Code Elimination
|
||||
/// Removes unreachable code after unconditional jumps.
|
||||
pub fn remove_unreachable_code<'a>(
|
||||
input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
let mut changed = false;
|
||||
let mut dead = false;
|
||||
for node in input {
|
||||
if let Instruction::LabelDef(_) = node.instruction {
|
||||
dead = false;
|
||||
}
|
||||
if dead {
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
if let Instruction::Jump(_) = node.instruction {
|
||||
dead = true
|
||||
}
|
||||
output.push(node);
|
||||
}
|
||||
(output, changed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use il::{Instruction, InstructionNode, Operand};
|
||||
|
||||
#[test]
|
||||
fn test_remove_redundant_move() {
|
||||
let input = vec![InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(1), Operand::Register(1)),
|
||||
None,
|
||||
)];
|
||||
|
||||
let (output, changed) = remove_redundant_moves(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_unreachable() {
|
||||
let input = vec![
|
||||
InstructionNode::new(Instruction::Jump(Operand::Label("main".into())), None),
|
||||
InstructionNode::new(
|
||||
Instruction::Add(
|
||||
Operand::Register(1),
|
||||
Operand::Number(1.into()),
|
||||
Operand::Number(2.into()),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::LabelDef("main".into()), None),
|
||||
];
|
||||
|
||||
let (output, changed) = remove_unreachable_code(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 2);
|
||||
}
|
||||
}
|
||||
109
rust_compiler/libs/optimizer/src/dead_store_elimination.rs
Normal file
109
rust_compiler/libs/optimizer/src/dead_store_elimination.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
use crate::helpers::get_destination_reg;
|
||||
use il::{Instruction, InstructionNode};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Pass: Dead Store Elimination
|
||||
/// Removes writes to registers that are never read before being overwritten.
|
||||
pub fn dead_store_elimination<'a>(
|
||||
input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let mut changed = false;
|
||||
let mut last_write: HashMap<u8, usize> = HashMap::new();
|
||||
let mut to_remove = Vec::new();
|
||||
|
||||
// Scan for dead writes
|
||||
for (i, node) in input.iter().enumerate() {
|
||||
if let Some(dest_reg) = get_destination_reg(&node.instruction) {
|
||||
// If this register was written before and hasn't been read, previous write is dead
|
||||
if let Some(&prev_idx) = last_write.get(&dest_reg) {
|
||||
// Check if the value was ever used between prev_idx and current
|
||||
let was_used = input[prev_idx + 1..i]
|
||||
.iter()
|
||||
.any(|n| reg_is_read_or_affects_control(&n.instruction, dest_reg));
|
||||
|
||||
if !was_used {
|
||||
// Previous write was dead
|
||||
to_remove.push(prev_idx);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Update last write location
|
||||
last_write.insert(dest_reg, i);
|
||||
}
|
||||
|
||||
// On labels/jumps, conservatively clear tracking (value might be used elsewhere)
|
||||
if matches!(
|
||||
node.instruction,
|
||||
Instruction::LabelDef(_) | Instruction::Jump(_) | Instruction::JumpAndLink(_)
|
||||
) {
|
||||
last_write.clear();
|
||||
}
|
||||
}
|
||||
|
||||
if changed {
|
||||
let output = input
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, node)| {
|
||||
if to_remove.contains(&i) {
|
||||
None
|
||||
} else {
|
||||
Some(node)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
(output, true)
|
||||
} else {
|
||||
(input, false)
|
||||
}
|
||||
}
|
||||
|
||||
/// Simplified check: Does this instruction read the register or affect control flow?
|
||||
fn reg_is_read_or_affects_control(instr: &Instruction, reg: u8) -> bool {
|
||||
use crate::helpers::reg_is_read;
|
||||
|
||||
// If it reads the register, it's used
|
||||
if reg_is_read(instr, reg) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Conservatively assume register might be used if there's control flow
|
||||
matches!(
|
||||
instr,
|
||||
Instruction::Jump(_)
|
||||
| Instruction::JumpAndLink(_)
|
||||
| Instruction::BranchEq(_, _, _)
|
||||
| Instruction::BranchNe(_, _, _)
|
||||
| Instruction::BranchGt(_, _, _)
|
||||
| Instruction::BranchLt(_, _, _)
|
||||
| Instruction::BranchGe(_, _, _)
|
||||
| Instruction::BranchLe(_, _, _)
|
||||
| Instruction::BranchEqZero(_, _)
|
||||
| Instruction::BranchNeZero(_, _)
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use il::Operand;
|
||||
|
||||
#[test]
|
||||
fn test_dead_store() {
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(1), Operand::Number(5.into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(1), Operand::Number(10.into())),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
let (output, changed) = dead_store_elimination(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
}
|
||||
}
|
||||
160
rust_compiler/libs/optimizer/src/function_call_optimization.rs
Normal file
160
rust_compiler/libs/optimizer/src/function_call_optimization.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use crate::helpers::get_destination_reg;
|
||||
use il::{Instruction, InstructionNode, Operand};
|
||||
use rust_decimal::Decimal;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// Analyzes which registers are written to by each function label.
|
||||
fn analyze_clobbers(instructions: &[InstructionNode]) -> HashMap<String, HashSet<u8>> {
|
||||
let mut clobbers = HashMap::new();
|
||||
let mut current_label = None;
|
||||
|
||||
for node in instructions {
|
||||
if let Instruction::LabelDef(label) = &node.instruction {
|
||||
current_label = Some(label.to_string());
|
||||
clobbers.insert(label.to_string(), HashSet::new());
|
||||
}
|
||||
|
||||
if let Some(label) = ¤t_label
|
||||
&& let Some(reg) = get_destination_reg(&node.instruction)
|
||||
&& let Some(set) = clobbers.get_mut(label)
|
||||
{
|
||||
set.insert(reg);
|
||||
}
|
||||
}
|
||||
clobbers
|
||||
}
|
||||
|
||||
/// Pass: Function Call Optimization
|
||||
/// Removes Push/Restore pairs surrounding a JAL if the target function does not clobber that register.
|
||||
pub fn optimize_function_calls<'a>(
|
||||
input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let clobbers = analyze_clobbers(&input);
|
||||
let mut changed = false;
|
||||
let mut to_remove = HashSet::new();
|
||||
let mut stack_adjustments = HashMap::new();
|
||||
|
||||
let mut i = 0;
|
||||
while i < input.len() {
|
||||
if let Instruction::JumpAndLink(Operand::Label(target)) = &input[i].instruction {
|
||||
let target_key = target.to_string();
|
||||
|
||||
if let Some(func_clobbers) = clobbers.get(&target_key) {
|
||||
// 1. Identify Pushes immediately preceding the JAL
|
||||
let mut pushes = Vec::new(); // (index, register)
|
||||
let mut scan_back = i.saturating_sub(1);
|
||||
while scan_back > 0 {
|
||||
if to_remove.contains(&scan_back) {
|
||||
scan_back -= 1;
|
||||
continue;
|
||||
}
|
||||
if let Instruction::Push(Operand::Register(r)) = &input[scan_back].instruction {
|
||||
pushes.push((scan_back, *r));
|
||||
scan_back -= 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Identify Restores immediately following the JAL
|
||||
let mut restores = Vec::new(); // (index_of_get, register, index_of_sub)
|
||||
let mut scan_fwd = i + 1;
|
||||
while scan_fwd < input.len() {
|
||||
// Skip 'sub r0 sp X'
|
||||
if let Instruction::Sub(Operand::Register(0), Operand::StackPointer, _) =
|
||||
&input[scan_fwd].instruction
|
||||
{
|
||||
// Check next instruction for the Get
|
||||
if scan_fwd + 1 < input.len()
|
||||
&& let Instruction::Get(Operand::Register(r), _, Operand::Register(0)) =
|
||||
&input[scan_fwd + 1].instruction
|
||||
{
|
||||
restores.push((scan_fwd + 1, *r, scan_fwd));
|
||||
scan_fwd += 2;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// 3. Stack Cleanup
|
||||
let cleanup_idx = scan_fwd;
|
||||
let has_cleanup = if cleanup_idx < input.len() {
|
||||
matches!(
|
||||
input[cleanup_idx].instruction,
|
||||
Instruction::Sub(
|
||||
Operand::StackPointer,
|
||||
Operand::StackPointer,
|
||||
Operand::Number(_)
|
||||
)
|
||||
)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// SAFEGUARD: Check Counts!
|
||||
let mut push_counts = HashMap::new();
|
||||
for (_, r) in &pushes {
|
||||
*push_counts.entry(*r).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
let mut restore_counts = HashMap::new();
|
||||
for (_, r, _) in &restores {
|
||||
*restore_counts.entry(*r).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
let counts_match = push_counts
|
||||
.iter()
|
||||
.all(|(reg, count)| restore_counts.get(reg).unwrap_or(&0) == count);
|
||||
let counts_match_reverse = restore_counts
|
||||
.iter()
|
||||
.all(|(reg, count)| push_counts.get(reg).unwrap_or(&0) == count);
|
||||
|
||||
// Clobber Check
|
||||
let all_pushes_safe = pushes.iter().all(|(_, r)| !func_clobbers.contains(r));
|
||||
|
||||
if all_pushes_safe && has_cleanup && counts_match && counts_match_reverse {
|
||||
// Remove all pushes/restores
|
||||
for (p_idx, _) in pushes {
|
||||
to_remove.insert(p_idx);
|
||||
}
|
||||
for (g_idx, _, s_idx) in restores {
|
||||
to_remove.insert(g_idx);
|
||||
to_remove.insert(s_idx);
|
||||
}
|
||||
|
||||
// Reduce stack cleanup amount
|
||||
let num_removed = push_counts.values().sum::<i32>() as i64;
|
||||
stack_adjustments.insert(cleanup_idx, num_removed);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if changed {
|
||||
let mut clean = Vec::with_capacity(input.len());
|
||||
for (idx, mut node) in input.into_iter().enumerate() {
|
||||
if to_remove.contains(&idx) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Apply stack adjustment
|
||||
if let Some(reduction) = stack_adjustments.get(&idx)
|
||||
&& let Instruction::Sub(dst, a, Operand::Number(n)) = &node.instruction
|
||||
{
|
||||
let new_n = n - Decimal::from(*reduction);
|
||||
if new_n.is_zero() {
|
||||
continue;
|
||||
}
|
||||
node.instruction = Instruction::Sub(dst.clone(), a.clone(), Operand::Number(new_n));
|
||||
}
|
||||
|
||||
clean.push(node);
|
||||
}
|
||||
return (clean, changed);
|
||||
}
|
||||
|
||||
(input, false)
|
||||
}
|
||||
172
rust_compiler/libs/optimizer/src/helpers.rs
Normal file
172
rust_compiler/libs/optimizer/src/helpers.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
use il::{Instruction, Operand};
|
||||
|
||||
/// Returns the register number written to by an instruction, if any.
|
||||
pub fn get_destination_reg(instr: &Instruction) -> Option<u8> {
|
||||
match instr {
|
||||
Instruction::Move(Operand::Register(r), _)
|
||||
| Instruction::Add(Operand::Register(r), _, _)
|
||||
| Instruction::Sub(Operand::Register(r), _, _)
|
||||
| Instruction::Mul(Operand::Register(r), _, _)
|
||||
| Instruction::Div(Operand::Register(r), _, _)
|
||||
| Instruction::Mod(Operand::Register(r), _, _)
|
||||
| Instruction::Pow(Operand::Register(r), _, _)
|
||||
| Instruction::Load(Operand::Register(r), _, _)
|
||||
| Instruction::LoadSlot(Operand::Register(r), _, _, _)
|
||||
| Instruction::LoadBatch(Operand::Register(r), _, _, _)
|
||||
| Instruction::LoadBatchNamed(Operand::Register(r), _, _, _, _)
|
||||
| Instruction::SetEq(Operand::Register(r), _, _)
|
||||
| Instruction::SetNe(Operand::Register(r), _, _)
|
||||
| Instruction::SetGt(Operand::Register(r), _, _)
|
||||
| Instruction::SetLt(Operand::Register(r), _, _)
|
||||
| Instruction::SetGe(Operand::Register(r), _, _)
|
||||
| Instruction::SetLe(Operand::Register(r), _, _)
|
||||
| Instruction::And(Operand::Register(r), _, _)
|
||||
| Instruction::Or(Operand::Register(r), _, _)
|
||||
| Instruction::Xor(Operand::Register(r), _, _)
|
||||
| Instruction::Peek(Operand::Register(r))
|
||||
| Instruction::Get(Operand::Register(r), _, _)
|
||||
| Instruction::Select(Operand::Register(r), _, _, _)
|
||||
| Instruction::Rand(Operand::Register(r))
|
||||
| Instruction::Acos(Operand::Register(r), _)
|
||||
| Instruction::Asin(Operand::Register(r), _)
|
||||
| Instruction::Atan(Operand::Register(r), _)
|
||||
| Instruction::Atan2(Operand::Register(r), _, _)
|
||||
| Instruction::Abs(Operand::Register(r), _)
|
||||
| Instruction::Ceil(Operand::Register(r), _)
|
||||
| Instruction::Cos(Operand::Register(r), _)
|
||||
| Instruction::Floor(Operand::Register(r), _)
|
||||
| Instruction::Log(Operand::Register(r), _)
|
||||
| Instruction::Max(Operand::Register(r), _, _)
|
||||
| Instruction::Min(Operand::Register(r), _, _)
|
||||
| Instruction::Sin(Operand::Register(r), _)
|
||||
| Instruction::Sqrt(Operand::Register(r), _)
|
||||
| Instruction::Tan(Operand::Register(r), _)
|
||||
| Instruction::Trunc(Operand::Register(r), _)
|
||||
| Instruction::LoadReagent(Operand::Register(r), _, _, _)
|
||||
| Instruction::Pop(Operand::Register(r)) => Some(*r),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a new instruction with the destination register changed.
|
||||
pub fn set_destination_reg<'a>(instr: &Instruction<'a>, new_reg: u8) -> Option<Instruction<'a>> {
|
||||
let r = Operand::Register(new_reg);
|
||||
match instr {
|
||||
Instruction::Move(_, b) => Some(Instruction::Move(r, b.clone())),
|
||||
Instruction::Add(_, a, b) => Some(Instruction::Add(r, a.clone(), b.clone())),
|
||||
Instruction::Sub(_, a, b) => Some(Instruction::Sub(r, a.clone(), b.clone())),
|
||||
Instruction::Mul(_, a, b) => Some(Instruction::Mul(r, a.clone(), b.clone())),
|
||||
Instruction::Div(_, a, b) => Some(Instruction::Div(r, a.clone(), b.clone())),
|
||||
Instruction::Mod(_, a, b) => Some(Instruction::Mod(r, a.clone(), b.clone())),
|
||||
Instruction::Pow(_, a, b) => Some(Instruction::Pow(r, a.clone(), b.clone())),
|
||||
Instruction::Load(_, a, b) => Some(Instruction::Load(r, a.clone(), b.clone())),
|
||||
Instruction::LoadSlot(_, a, b, c) => {
|
||||
Some(Instruction::LoadSlot(r, a.clone(), b.clone(), c.clone()))
|
||||
}
|
||||
Instruction::LoadBatch(_, a, b, c) => {
|
||||
Some(Instruction::LoadBatch(r, a.clone(), b.clone(), c.clone()))
|
||||
}
|
||||
Instruction::LoadBatchNamed(_, a, b, c, d) => Some(Instruction::LoadBatchNamed(
|
||||
r,
|
||||
a.clone(),
|
||||
b.clone(),
|
||||
c.clone(),
|
||||
d.clone(),
|
||||
)),
|
||||
Instruction::LoadReagent(_, b, c, d) => {
|
||||
Some(Instruction::LoadReagent(r, b.clone(), c.clone(), d.clone()))
|
||||
}
|
||||
Instruction::SetEq(_, a, b) => Some(Instruction::SetEq(r, a.clone(), b.clone())),
|
||||
Instruction::SetNe(_, a, b) => Some(Instruction::SetNe(r, a.clone(), b.clone())),
|
||||
Instruction::SetGt(_, a, b) => Some(Instruction::SetGt(r, a.clone(), b.clone())),
|
||||
Instruction::SetLt(_, a, b) => Some(Instruction::SetLt(r, a.clone(), b.clone())),
|
||||
Instruction::SetGe(_, a, b) => Some(Instruction::SetGe(r, a.clone(), b.clone())),
|
||||
Instruction::SetLe(_, a, b) => Some(Instruction::SetLe(r, a.clone(), b.clone())),
|
||||
Instruction::And(_, a, b) => Some(Instruction::And(r, a.clone(), b.clone())),
|
||||
Instruction::Or(_, a, b) => Some(Instruction::Or(r, a.clone(), b.clone())),
|
||||
Instruction::Xor(_, a, b) => Some(Instruction::Xor(r, a.clone(), b.clone())),
|
||||
Instruction::Peek(_) => Some(Instruction::Peek(r)),
|
||||
Instruction::Get(_, a, b) => Some(Instruction::Get(r, a.clone(), b.clone())),
|
||||
Instruction::Select(_, a, b, c) => {
|
||||
Some(Instruction::Select(r, a.clone(), b.clone(), c.clone()))
|
||||
}
|
||||
Instruction::Rand(_) => Some(Instruction::Rand(r)),
|
||||
Instruction::Pop(_) => Some(Instruction::Pop(r)),
|
||||
Instruction::Acos(_, a) => Some(Instruction::Acos(r, a.clone())),
|
||||
Instruction::Asin(_, a) => Some(Instruction::Asin(r, a.clone())),
|
||||
Instruction::Atan(_, a) => Some(Instruction::Atan(r, a.clone())),
|
||||
Instruction::Atan2(_, a, b) => Some(Instruction::Atan2(r, a.clone(), b.clone())),
|
||||
Instruction::Abs(_, a) => Some(Instruction::Abs(r, a.clone())),
|
||||
Instruction::Ceil(_, a) => Some(Instruction::Ceil(r, a.clone())),
|
||||
Instruction::Cos(_, a) => Some(Instruction::Cos(r, a.clone())),
|
||||
Instruction::Floor(_, a) => Some(Instruction::Floor(r, a.clone())),
|
||||
Instruction::Log(_, a) => Some(Instruction::Log(r, a.clone())),
|
||||
Instruction::Max(_, a, b) => Some(Instruction::Max(r, a.clone(), b.clone())),
|
||||
Instruction::Min(_, a, b) => Some(Instruction::Min(r, a.clone(), b.clone())),
|
||||
Instruction::Sin(_, a) => Some(Instruction::Sin(r, a.clone())),
|
||||
Instruction::Sqrt(_, a) => Some(Instruction::Sqrt(r, a.clone())),
|
||||
Instruction::Tan(_, a) => Some(Instruction::Tan(r, a.clone())),
|
||||
Instruction::Trunc(_, a) => Some(Instruction::Trunc(r, a.clone())),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks if a register is read by an instruction.
|
||||
pub fn reg_is_read(instr: &Instruction, reg: u8) -> bool {
|
||||
let check = |op: &Operand| matches!(op, Operand::Register(r) if *r == reg);
|
||||
|
||||
match instr {
|
||||
Instruction::Move(_, a) => check(a),
|
||||
Instruction::Add(_, a, b)
|
||||
| Instruction::Sub(_, a, b)
|
||||
| Instruction::Mul(_, a, b)
|
||||
| Instruction::Div(_, a, b)
|
||||
| Instruction::Mod(_, a, b)
|
||||
| Instruction::Pow(_, a, b) => check(a) || check(b),
|
||||
Instruction::Load(_, a, _) => check(a),
|
||||
Instruction::Store(a, _, b) => check(a) || check(b),
|
||||
Instruction::BranchEq(a, b, _)
|
||||
| Instruction::BranchNe(a, b, _)
|
||||
| Instruction::BranchGt(a, b, _)
|
||||
| Instruction::BranchLt(a, b, _)
|
||||
| Instruction::BranchGe(a, b, _)
|
||||
| Instruction::BranchLe(a, b, _) => check(a) || check(b),
|
||||
Instruction::BranchEqZero(a, _) | Instruction::BranchNeZero(a, _) => check(a),
|
||||
Instruction::LoadReagent(_, device, _, item_hash) => check(device) || check(item_hash),
|
||||
Instruction::LoadSlot(_, dev, slot, _) => check(dev) || check(slot),
|
||||
Instruction::LoadBatch(_, dev, _, mode) => check(dev) || check(mode),
|
||||
Instruction::LoadBatchNamed(_, d_hash, n_hash, _, mode) => {
|
||||
check(d_hash) || check(n_hash) || check(mode)
|
||||
}
|
||||
Instruction::SetEq(_, a, b)
|
||||
| Instruction::SetNe(_, a, b)
|
||||
| Instruction::SetGt(_, a, b)
|
||||
| Instruction::SetLt(_, a, b)
|
||||
| Instruction::SetGe(_, a, b)
|
||||
| Instruction::SetLe(_, a, b)
|
||||
| Instruction::And(_, a, b)
|
||||
| Instruction::Or(_, a, b)
|
||||
| Instruction::Xor(_, a, b) => check(a) || check(b),
|
||||
Instruction::Push(a) => check(a),
|
||||
Instruction::Get(_, a, b) => check(a) || check(b),
|
||||
Instruction::Put(a, b, c) => check(a) || check(b) || check(c),
|
||||
Instruction::Select(_, a, b, c) => check(a) || check(b) || check(c),
|
||||
Instruction::Sleep(a) => check(a),
|
||||
Instruction::Acos(_, a)
|
||||
| Instruction::Asin(_, a)
|
||||
| Instruction::Atan(_, a)
|
||||
| Instruction::Abs(_, a)
|
||||
| Instruction::Ceil(_, a)
|
||||
| Instruction::Cos(_, a)
|
||||
| Instruction::Floor(_, a)
|
||||
| Instruction::Log(_, a)
|
||||
| Instruction::Sin(_, a)
|
||||
| Instruction::Sqrt(_, a)
|
||||
| Instruction::Tan(_, a)
|
||||
| Instruction::Trunc(_, a) => check(a),
|
||||
Instruction::Atan2(_, a, b) | Instruction::Max(_, a, b) | Instruction::Min(_, a, b) => {
|
||||
check(a) || check(b)
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
70
rust_compiler/libs/optimizer/src/label_resolution.rs
Normal file
70
rust_compiler/libs/optimizer/src/label_resolution.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
use il::{Instruction, InstructionNode, Operand};
|
||||
use rust_decimal::Decimal;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Pass: Resolve Labels
|
||||
/// Converts all Jump/Branch labels to absolute line numbers and removes LabelDefs.
|
||||
pub fn resolve_labels<'a>(input: Vec<InstructionNode<'a>>) -> Vec<InstructionNode<'a>> {
|
||||
let mut label_map: HashMap<String, usize> = HashMap::new();
|
||||
let mut line_number = 0;
|
||||
|
||||
// Build Label Map (filtering out LabelDefs from the count)
|
||||
for node in &input {
|
||||
if let Instruction::LabelDef(name) = &node.instruction {
|
||||
label_map.insert(name.to_string(), line_number);
|
||||
} else {
|
||||
line_number += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
|
||||
// Rewrite Jumps and Filter Labels
|
||||
for mut node in input {
|
||||
// Helper to get line number as Decimal operand
|
||||
let get_line = |lbl: &Operand| -> Option<Operand<'a>> {
|
||||
if let Operand::Label(name) = lbl {
|
||||
label_map
|
||||
.get(name.as_ref())
|
||||
.map(|&l| Operand::Number(Decimal::from(l)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
match &mut node.instruction {
|
||||
Instruction::LabelDef(_) => continue, // Strip labels
|
||||
|
||||
// Jumps
|
||||
Instruction::Jump(op) => {
|
||||
if let Some(num) = get_line(op) {
|
||||
*op = num;
|
||||
}
|
||||
}
|
||||
Instruction::JumpAndLink(op) => {
|
||||
if let Some(num) = get_line(op) {
|
||||
*op = num;
|
||||
}
|
||||
}
|
||||
Instruction::BranchEq(_, _, op)
|
||||
| Instruction::BranchNe(_, _, op)
|
||||
| Instruction::BranchGt(_, _, op)
|
||||
| Instruction::BranchLt(_, _, op)
|
||||
| Instruction::BranchGe(_, _, op)
|
||||
| Instruction::BranchLe(_, _, op) => {
|
||||
if let Some(num) = get_line(op) {
|
||||
*op = num;
|
||||
}
|
||||
}
|
||||
Instruction::BranchEqZero(_, op) | Instruction::BranchNeZero(_, op) => {
|
||||
if let Some(num) = get_line(op) {
|
||||
*op = num;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
output.push(node);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
150
rust_compiler/libs/optimizer/src/leaf_function_optimization.rs
Normal file
150
rust_compiler/libs/optimizer/src/leaf_function_optimization.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
use crate::leaf_function::find_leaf_functions;
|
||||
use il::{Instruction, InstructionNode, Operand};
|
||||
use rust_decimal::Decimal;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// Helper: Check if a function body contains unsafe stack manipulation.
|
||||
fn function_has_complex_stack_ops(
|
||||
instructions: &[InstructionNode],
|
||||
start_idx: usize,
|
||||
end_idx: usize,
|
||||
) -> bool {
|
||||
for instruction in instructions.iter().take(end_idx).skip(start_idx) {
|
||||
match instruction.instruction {
|
||||
Instruction::Push(_) | Instruction::Pop(_) => return true,
|
||||
Instruction::Add(Operand::StackPointer, _, _)
|
||||
| Instruction::Sub(Operand::StackPointer, _, _)
|
||||
| Instruction::Mul(Operand::StackPointer, _, _)
|
||||
| Instruction::Div(Operand::StackPointer, _, _)
|
||||
| Instruction::Move(Operand::StackPointer, _) => return true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Pass: Leaf Function Optimization
|
||||
/// If a function makes no calls (is a leaf), it doesn't need to save/restore `ra`.
|
||||
pub fn optimize_leaf_functions<'a>(
|
||||
input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let leaves = find_leaf_functions(&input);
|
||||
if leaves.is_empty() {
|
||||
return (input, false);
|
||||
}
|
||||
|
||||
let mut changed = false;
|
||||
let mut to_remove = HashSet::new();
|
||||
let mut func_restore_indices = HashMap::new();
|
||||
let mut func_ra_offsets = HashMap::new();
|
||||
let mut current_function: Option<String> = None;
|
||||
let mut function_start_indices = HashMap::new();
|
||||
|
||||
// First scan: Identify instructions to remove and capture RA offsets
|
||||
for (i, node) in input.iter().enumerate() {
|
||||
match &node.instruction {
|
||||
Instruction::LabelDef(label) if !label.starts_with("__internal_L") => {
|
||||
current_function = Some(label.to_string());
|
||||
function_start_indices.insert(label.to_string(), i);
|
||||
}
|
||||
Instruction::Push(Operand::ReturnAddress) => {
|
||||
if let Some(func) = ¤t_function
|
||||
&& leaves.contains(func)
|
||||
{
|
||||
to_remove.insert(i);
|
||||
}
|
||||
}
|
||||
Instruction::Get(Operand::ReturnAddress, _, Operand::Register(_)) => {
|
||||
if let Some(func) = ¤t_function
|
||||
&& leaves.contains(func)
|
||||
{
|
||||
to_remove.insert(i);
|
||||
func_restore_indices.insert(func.clone(), i);
|
||||
|
||||
// Look back for the address calc: `sub r0 sp OFFSET`
|
||||
if i > 0
|
||||
&& let Instruction::Sub(_, Operand::StackPointer, Operand::Number(n)) =
|
||||
&input[i - 1].instruction
|
||||
{
|
||||
func_ra_offsets.insert(func.clone(), *n);
|
||||
to_remove.insert(i - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Safety Check: Verify functions don't have complex stack ops
|
||||
let mut safe_functions = HashSet::new();
|
||||
|
||||
for (func, start_idx) in &function_start_indices {
|
||||
if let Some(restore_idx) = func_restore_indices.get(func) {
|
||||
let check_start = if to_remove.contains(&(start_idx + 1)) {
|
||||
start_idx + 2
|
||||
} else {
|
||||
start_idx + 1
|
||||
};
|
||||
|
||||
if !function_has_complex_stack_ops(&input, check_start, *restore_idx) {
|
||||
safe_functions.insert(func.clone());
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return (input, false);
|
||||
}
|
||||
|
||||
// Second scan: Rebuild with adjustments
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
let mut processing_function: Option<String> = None;
|
||||
|
||||
for (i, mut node) in input.into_iter().enumerate() {
|
||||
if to_remove.contains(&i)
|
||||
&& let Some(func) = &processing_function
|
||||
&& safe_functions.contains(func)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Instruction::LabelDef(l) = &node.instruction
|
||||
&& !l.starts_with("__internal_L")
|
||||
{
|
||||
processing_function = Some(l.to_string());
|
||||
}
|
||||
|
||||
// Apply Stack Adjustments
|
||||
if let Some(func) = &processing_function
|
||||
&& safe_functions.contains(func)
|
||||
&& let Some(ra_offset) = func_ra_offsets.get(func)
|
||||
{
|
||||
// Stack Cleanup Adjustment
|
||||
if let Instruction::Sub(
|
||||
Operand::StackPointer,
|
||||
Operand::StackPointer,
|
||||
Operand::Number(n),
|
||||
) = &mut node.instruction
|
||||
{
|
||||
let new_n = *n - Decimal::from(1);
|
||||
if new_n.is_zero() {
|
||||
continue;
|
||||
}
|
||||
*n = new_n;
|
||||
}
|
||||
|
||||
// Stack Variable Offset Adjustment
|
||||
if let Instruction::Sub(_, Operand::StackPointer, Operand::Number(n)) =
|
||||
&mut node.instruction
|
||||
&& *n > *ra_offset
|
||||
{
|
||||
*n -= Decimal::from(1);
|
||||
}
|
||||
}
|
||||
|
||||
output.push(node);
|
||||
}
|
||||
|
||||
(output, true)
|
||||
}
|
||||
@@ -1,9 +1,30 @@
|
||||
use il::{Instruction, InstructionNode, Instructions, Operand};
|
||||
use rust_decimal::Decimal;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use il::Instructions;
|
||||
|
||||
// Optimization pass modules
|
||||
mod helpers;
|
||||
mod leaf_function;
|
||||
use leaf_function::find_leaf_functions;
|
||||
|
||||
mod algebraic_simplification;
|
||||
mod constant_propagation;
|
||||
mod dead_code;
|
||||
mod dead_store_elimination;
|
||||
mod function_call_optimization;
|
||||
mod label_resolution;
|
||||
mod leaf_function_optimization;
|
||||
mod peephole_optimization;
|
||||
mod register_forwarding;
|
||||
mod strength_reduction;
|
||||
|
||||
use algebraic_simplification::algebraic_simplification;
|
||||
use constant_propagation::constant_propagation;
|
||||
use dead_code::{remove_redundant_moves, remove_unreachable_code};
|
||||
use dead_store_elimination::dead_store_elimination;
|
||||
use function_call_optimization::optimize_function_calls;
|
||||
use label_resolution::resolve_labels;
|
||||
use leaf_function_optimization::optimize_leaf_functions;
|
||||
use peephole_optimization::peephole_optimization;
|
||||
use register_forwarding::register_forwarding;
|
||||
use strength_reduction::strength_reduction;
|
||||
|
||||
/// Entry point for the optimizer.
|
||||
pub fn optimize<'a>(instructions: Instructions<'a>) -> Instructions<'a> {
|
||||
@@ -38,845 +59,37 @@ pub fn optimize<'a>(instructions: Instructions<'a>) -> Instructions<'a> {
|
||||
instructions = new_inst;
|
||||
changed |= c4;
|
||||
|
||||
// Pass 5: Redundant Move Elimination
|
||||
let (new_inst, c5) = remove_redundant_moves(instructions);
|
||||
// Pass 5: Algebraic Simplification (Identity operations)
|
||||
let (new_inst, c5) = algebraic_simplification(instructions);
|
||||
instructions = new_inst;
|
||||
changed |= c5;
|
||||
|
||||
// Pass 6: Dead Code Elimination
|
||||
let (new_inst, c6) = remove_unreachable_code(instructions);
|
||||
// Pass 6: Strength Reduction (Replace expensive ops with cheaper ones)
|
||||
let (new_inst, c6) = strength_reduction(instructions);
|
||||
instructions = new_inst;
|
||||
changed |= c6;
|
||||
|
||||
// Pass 7: Peephole Optimizations (Common patterns)
|
||||
let (new_inst, c7) = peephole_optimization(instructions);
|
||||
instructions = new_inst;
|
||||
changed |= c7;
|
||||
|
||||
// Pass 8: Dead Store Elimination
|
||||
let (new_inst, c8) = dead_store_elimination(instructions);
|
||||
instructions = new_inst;
|
||||
changed |= c8;
|
||||
|
||||
// Pass 9: Redundant Move Elimination
|
||||
let (new_inst, c9) = remove_redundant_moves(instructions);
|
||||
instructions = new_inst;
|
||||
changed |= c9;
|
||||
|
||||
// Pass 10: Dead Code Elimination
|
||||
let (new_inst, c10) = remove_unreachable_code(instructions);
|
||||
instructions = new_inst;
|
||||
changed |= c10;
|
||||
}
|
||||
|
||||
// Final Pass: Resolve Labels to Line Numbers
|
||||
Instructions::new(resolve_labels(instructions))
|
||||
}
|
||||
|
||||
/// Helper: Check if a function body contains unsafe stack manipulation.
|
||||
/// Returns true if the function modifies SP in a way that makes static RA offset analysis unsafe.
|
||||
fn function_has_complex_stack_ops(
|
||||
instructions: &[InstructionNode],
|
||||
start_idx: usize,
|
||||
end_idx: usize,
|
||||
) -> bool {
|
||||
for instruction in instructions.iter().take(end_idx).skip(start_idx) {
|
||||
match instruction.instruction {
|
||||
Instruction::Push(_) | Instruction::Pop(_) => return true,
|
||||
// Check for explicit SP modification
|
||||
Instruction::Add(Operand::StackPointer, _, _)
|
||||
| Instruction::Sub(Operand::StackPointer, _, _)
|
||||
| Instruction::Mul(Operand::StackPointer, _, _)
|
||||
| Instruction::Div(Operand::StackPointer, _, _)
|
||||
| Instruction::Move(Operand::StackPointer, _) => return true,
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Pass: Leaf Function Optimization
|
||||
/// If a function makes no calls (is a leaf), it doesn't need to save/restore `ra`.
|
||||
fn optimize_leaf_functions<'a>(
|
||||
input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let leaves = find_leaf_functions(&input);
|
||||
if leaves.is_empty() {
|
||||
return (input, false);
|
||||
}
|
||||
|
||||
let mut changed = false;
|
||||
let mut to_remove = HashSet::new();
|
||||
|
||||
// We map function names to the INDEX of the instruction that restores RA.
|
||||
// We use this to validate the function body later.
|
||||
let mut func_restore_indices = HashMap::new();
|
||||
let mut func_ra_offsets = HashMap::new();
|
||||
|
||||
let mut current_function: Option<String> = None;
|
||||
let mut function_start_indices = HashMap::new();
|
||||
|
||||
// First scan: Identify instructions to remove and capture RA offsets
|
||||
for (i, node) in input.iter().enumerate() {
|
||||
match &node.instruction {
|
||||
Instruction::LabelDef(label) if !label.starts_with("__internal_L") => {
|
||||
current_function = Some(label.to_string());
|
||||
function_start_indices.insert(label.to_string(), i);
|
||||
}
|
||||
Instruction::Push(Operand::ReturnAddress) => {
|
||||
if let Some(func) = ¤t_function
|
||||
&& leaves.contains(func)
|
||||
{
|
||||
to_remove.insert(i);
|
||||
}
|
||||
}
|
||||
Instruction::Get(Operand::ReturnAddress, _, Operand::Register(_)) => {
|
||||
// This is the restore instruction: `get ra db r0`
|
||||
if let Some(func) = ¤t_function
|
||||
&& leaves.contains(func)
|
||||
{
|
||||
to_remove.insert(i);
|
||||
func_restore_indices.insert(func.clone(), i);
|
||||
|
||||
// Look back for the address calc: `sub r0 sp OFFSET`
|
||||
if i > 0
|
||||
&& let Instruction::Sub(_, Operand::StackPointer, Operand::Number(n)) =
|
||||
&input[i - 1].instruction
|
||||
{
|
||||
func_ra_offsets.insert(func.clone(), *n);
|
||||
to_remove.insert(i - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Safety Check: Verify that functions marked for optimization don't have complex stack ops.
|
||||
// If they do, unmark them.
|
||||
let mut safe_functions = HashSet::new();
|
||||
|
||||
for (func, start_idx) in &function_start_indices {
|
||||
if let Some(restore_idx) = func_restore_indices.get(func) {
|
||||
// Check instructions between start and restore using the helper function.
|
||||
// We need to skip the `push ra` we just marked for removal, otherwise the helper
|
||||
// will flag it as a complex op (Push).
|
||||
// `start_idx` is the LabelDef. `start_idx + 1` is typically `push ra`.
|
||||
|
||||
let check_start = if to_remove.contains(&(start_idx + 1)) {
|
||||
start_idx + 2
|
||||
} else {
|
||||
start_idx + 1
|
||||
};
|
||||
|
||||
// `restore_idx` points to the `get ra` instruction. The helper scans up to `end_idx` exclusive,
|
||||
// so we don't need to worry about the restore instruction itself.
|
||||
if !function_has_complex_stack_ops(&input, check_start, *restore_idx) {
|
||||
safe_functions.insert(func.clone());
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !changed {
|
||||
return (input, false);
|
||||
}
|
||||
|
||||
// Second scan: Rebuild with adjustments, but only for SAFE functions
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
let mut processing_function: Option<String> = None;
|
||||
|
||||
for (i, mut node) in input.into_iter().enumerate() {
|
||||
if to_remove.contains(&i)
|
||||
&& let Some(func) = &processing_function
|
||||
&& safe_functions.contains(func)
|
||||
{
|
||||
continue; // SKIP (Remove)
|
||||
}
|
||||
|
||||
if let Instruction::LabelDef(l) = &node.instruction
|
||||
&& !l.starts_with("__internal_L")
|
||||
{
|
||||
processing_function = Some(l.to_string());
|
||||
}
|
||||
|
||||
// Apply Stack Adjustments
|
||||
if let Some(func) = &processing_function
|
||||
&& safe_functions.contains(func)
|
||||
&& let Some(ra_offset) = func_ra_offsets.get(func)
|
||||
{
|
||||
// 1. Stack Cleanup Adjustment
|
||||
if let Instruction::Sub(
|
||||
Operand::StackPointer,
|
||||
Operand::StackPointer,
|
||||
Operand::Number(n),
|
||||
) = &mut node.instruction
|
||||
{
|
||||
// Decrease cleanup amount by 1 (for the removed RA)
|
||||
let new_n = *n - Decimal::from(1);
|
||||
if new_n.is_zero() {
|
||||
continue;
|
||||
}
|
||||
*n = new_n;
|
||||
}
|
||||
|
||||
// 2. Stack Variable Offset Adjustment
|
||||
// Since we verified the function is "Simple" (no nested stack mods),
|
||||
// we can safely assume offsets > ra_offset need shifting.
|
||||
if let Instruction::Sub(_, Operand::StackPointer, Operand::Number(n)) =
|
||||
&mut node.instruction
|
||||
&& *n > *ra_offset
|
||||
{
|
||||
*n -= Decimal::from(1);
|
||||
}
|
||||
}
|
||||
|
||||
output.push(node);
|
||||
}
|
||||
|
||||
(output, true)
|
||||
}
|
||||
|
||||
/// Analyzes which registers are written to by each function label.
|
||||
fn analyze_clobbers(instructions: &[InstructionNode]) -> HashMap<String, HashSet<u8>> {
|
||||
let mut clobbers = HashMap::new();
|
||||
let mut current_label = None;
|
||||
|
||||
for node in instructions {
|
||||
if let Instruction::LabelDef(label) = &node.instruction {
|
||||
current_label = Some(label.to_string());
|
||||
clobbers.insert(label.to_string(), HashSet::new());
|
||||
}
|
||||
|
||||
if let Some(label) = ¤t_label
|
||||
&& let Some(reg) = get_destination_reg(&node.instruction)
|
||||
&& let Some(set) = clobbers.get_mut(label)
|
||||
{
|
||||
set.insert(reg);
|
||||
}
|
||||
}
|
||||
clobbers
|
||||
}
|
||||
|
||||
/// Pass: Function Call Optimization
|
||||
/// Removes Push/Restore pairs surrounding a JAL if the target function does not clobber that register.
|
||||
fn optimize_function_calls<'a>(
|
||||
input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let clobbers = analyze_clobbers(&input);
|
||||
let mut changed = false;
|
||||
let mut to_remove = HashSet::new();
|
||||
let mut stack_adjustments = HashMap::new();
|
||||
|
||||
let mut i = 0;
|
||||
while i < input.len() {
|
||||
if let Instruction::JumpAndLink(Operand::Label(target)) = &input[i].instruction {
|
||||
let target_key = target.to_string();
|
||||
|
||||
if let Some(func_clobbers) = clobbers.get(&target_key) {
|
||||
// 1. Identify Pushes immediately preceding the JAL
|
||||
let mut pushes = Vec::new(); // (index, register)
|
||||
let mut scan_back = i.saturating_sub(1);
|
||||
while scan_back > 0 {
|
||||
if to_remove.contains(&scan_back) {
|
||||
scan_back -= 1;
|
||||
continue;
|
||||
}
|
||||
if let Instruction::Push(Operand::Register(r)) = &input[scan_back].instruction {
|
||||
pushes.push((scan_back, *r));
|
||||
scan_back -= 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Identify Restores immediately following the JAL
|
||||
let mut restores = Vec::new(); // (index_of_get, register, index_of_sub)
|
||||
let mut scan_fwd = i + 1;
|
||||
while scan_fwd < input.len() {
|
||||
// Skip 'sub r0 sp X'
|
||||
if let Instruction::Sub(Operand::Register(0), Operand::StackPointer, _) =
|
||||
&input[scan_fwd].instruction
|
||||
{
|
||||
// Check next instruction for the Get
|
||||
if scan_fwd + 1 < input.len()
|
||||
&& let Instruction::Get(Operand::Register(r), _, Operand::Register(0)) =
|
||||
&input[scan_fwd + 1].instruction
|
||||
{
|
||||
restores.push((scan_fwd + 1, *r, scan_fwd));
|
||||
scan_fwd += 2;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
// 3. Stack Cleanup
|
||||
let cleanup_idx = scan_fwd;
|
||||
let has_cleanup = if cleanup_idx < input.len() {
|
||||
matches!(
|
||||
input[cleanup_idx].instruction,
|
||||
Instruction::Sub(
|
||||
Operand::StackPointer,
|
||||
Operand::StackPointer,
|
||||
Operand::Number(_)
|
||||
)
|
||||
)
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
// SAFEGUARD: Check Counts!
|
||||
// If we pushed r8 twice but only restored it once, we have an argument.
|
||||
// We must ensure the number of pushes for each register MATCHES the number of restores.
|
||||
let mut push_counts = HashMap::new();
|
||||
for (_, r) in &pushes {
|
||||
*push_counts.entry(*r).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
let mut restore_counts = HashMap::new();
|
||||
for (_, r, _) in &restores {
|
||||
*restore_counts.entry(*r).or_insert(0) += 1;
|
||||
}
|
||||
|
||||
let counts_match = push_counts
|
||||
.iter()
|
||||
.all(|(reg, count)| restore_counts.get(reg).unwrap_or(&0) == count);
|
||||
// Also check reverse to ensure we didn't restore something we didn't push (unlikely but possible)
|
||||
let counts_match_reverse = restore_counts
|
||||
.iter()
|
||||
.all(|(reg, count)| push_counts.get(reg).unwrap_or(&0) == count);
|
||||
|
||||
// Clobber Check
|
||||
let all_pushes_safe = pushes.iter().all(|(_, r)| !func_clobbers.contains(r));
|
||||
|
||||
if all_pushes_safe && has_cleanup && counts_match && counts_match_reverse {
|
||||
// We can remove ALL found pushes/restores safely
|
||||
for (p_idx, _) in pushes {
|
||||
to_remove.insert(p_idx);
|
||||
}
|
||||
for (g_idx, _, s_idx) in restores {
|
||||
to_remove.insert(g_idx);
|
||||
to_remove.insert(s_idx);
|
||||
}
|
||||
|
||||
// Reduce stack cleanup amount
|
||||
let num_removed = push_counts.values().sum::<i32>() as i64;
|
||||
stack_adjustments.insert(cleanup_idx, num_removed);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if changed {
|
||||
let mut clean = Vec::with_capacity(input.len());
|
||||
for (idx, mut node) in input.into_iter().enumerate() {
|
||||
if to_remove.contains(&idx) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Apply stack adjustment
|
||||
if let Some(reduction) = stack_adjustments.get(&idx)
|
||||
&& let Instruction::Sub(dst, a, Operand::Number(n)) = &node.instruction
|
||||
{
|
||||
let new_n = n - Decimal::from(*reduction);
|
||||
if new_n.is_zero() {
|
||||
continue; // Remove the sub entirely if 0
|
||||
}
|
||||
node.instruction = Instruction::Sub(dst.clone(), a.clone(), Operand::Number(new_n));
|
||||
}
|
||||
|
||||
clean.push(node);
|
||||
}
|
||||
return (clean, changed);
|
||||
}
|
||||
|
||||
(input, false)
|
||||
}
|
||||
|
||||
/// Pass: Register Forwarding
|
||||
/// Eliminates intermediate moves by writing directly to the final destination.
|
||||
/// Example: `l r1 d0 T` + `move r9 r1` -> `l r9 d0 T`
|
||||
fn register_forwarding<'a>(
|
||||
mut input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let mut changed = false;
|
||||
let mut i = 0;
|
||||
|
||||
// We use a while loop to manually control index so we can peek ahead
|
||||
while i < input.len().saturating_sub(1) {
|
||||
let next_idx = i + 1;
|
||||
|
||||
// Check if current instruction defines a register
|
||||
// and the NEXT instruction is a move from that register.
|
||||
let forward_candidate = if let Some(def_reg) = get_destination_reg(&input[i].instruction) {
|
||||
if let Instruction::Move(Operand::Register(dest_reg), Operand::Register(src_reg)) =
|
||||
&input[next_idx].instruction
|
||||
{
|
||||
if *src_reg == def_reg {
|
||||
// Candidate found: Instruction `i` defines `src_reg`, Instruction `i+1` moves `src_reg` to `dest_reg`.
|
||||
// We can optimize if `src_reg` (the temp) is NOT used after this move.
|
||||
Some((def_reg, *dest_reg))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some((temp_reg, final_reg)) = forward_candidate {
|
||||
// Check liveness: Is temp_reg used after i+1?
|
||||
// We scan from i+2 onwards.
|
||||
let mut temp_is_dead = true;
|
||||
for node in input.iter().skip(i + 2) {
|
||||
if reg_is_read(&node.instruction, temp_reg) {
|
||||
temp_is_dead = false;
|
||||
break;
|
||||
}
|
||||
// If the temp is redefined, then the old value is dead, so we are safe.
|
||||
if let Some(redef) = get_destination_reg(&node.instruction)
|
||||
&& redef == temp_reg
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
// If we hit a label/jump, we assume liveness might leak (conservative safety)
|
||||
if matches!(
|
||||
node.instruction,
|
||||
Instruction::LabelDef(_) | Instruction::Jump(_) | Instruction::JumpAndLink(_)
|
||||
) {
|
||||
temp_is_dead = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if temp_is_dead {
|
||||
// Perform the swap
|
||||
// 1. Rewrite input[i] to write to final_reg
|
||||
if let Some(new_instr) = set_destination_reg(&input[i].instruction, final_reg) {
|
||||
input[i].instruction = new_instr;
|
||||
// 2. Remove input[i+1] (The Move)
|
||||
input.remove(next_idx);
|
||||
changed = true;
|
||||
// Don't increment i, re-evaluate current index (which is now a new neighbor)
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
(input, changed)
|
||||
}
|
||||
|
||||
/// Pass: Resolve Labels
|
||||
/// Converts all Jump/Branch labels to absolute line numbers and removes LabelDefs.
|
||||
fn resolve_labels<'a>(input: Vec<InstructionNode<'a>>) -> Vec<InstructionNode<'a>> {
|
||||
let mut label_map: HashMap<String, usize> = HashMap::new();
|
||||
let mut line_number = 0;
|
||||
|
||||
// 1. Build Label Map (filtering out LabelDefs from the count)
|
||||
for node in &input {
|
||||
if let Instruction::LabelDef(name) = &node.instruction {
|
||||
label_map.insert(name.to_string(), line_number);
|
||||
} else {
|
||||
line_number += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
|
||||
// 2. Rewrite Jumps and Filter Labels
|
||||
for mut node in input {
|
||||
// Helper to get line number as Decimal operand
|
||||
let get_line = |lbl: &Operand| -> Option<Operand<'a>> {
|
||||
if let Operand::Label(name) = lbl {
|
||||
label_map
|
||||
.get(name.as_ref())
|
||||
.map(|&l| Operand::Number(Decimal::from(l)))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
match &mut node.instruction {
|
||||
Instruction::LabelDef(_) => continue, // Strip labels
|
||||
|
||||
// Jumps
|
||||
Instruction::Jump(op) => {
|
||||
if let Some(num) = get_line(op) {
|
||||
*op = num;
|
||||
}
|
||||
}
|
||||
Instruction::JumpAndLink(op) => {
|
||||
if let Some(num) = get_line(op) {
|
||||
*op = num;
|
||||
}
|
||||
}
|
||||
Instruction::BranchEq(_, _, op)
|
||||
| Instruction::BranchNe(_, _, op)
|
||||
| Instruction::BranchGt(_, _, op)
|
||||
| Instruction::BranchLt(_, _, op)
|
||||
| Instruction::BranchGe(_, _, op)
|
||||
| Instruction::BranchLe(_, _, op) => {
|
||||
if let Some(num) = get_line(op) {
|
||||
*op = num;
|
||||
}
|
||||
}
|
||||
Instruction::BranchEqZero(_, op) | Instruction::BranchNeZero(_, op) => {
|
||||
if let Some(num) = get_line(op) {
|
||||
*op = num;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
output.push(node);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
// --- Helpers for Register Analysis ---
|
||||
|
||||
fn get_destination_reg(instr: &Instruction) -> Option<u8> {
|
||||
match instr {
|
||||
Instruction::Move(Operand::Register(r), _)
|
||||
| Instruction::Add(Operand::Register(r), _, _)
|
||||
| Instruction::Sub(Operand::Register(r), _, _)
|
||||
| Instruction::Mul(Operand::Register(r), _, _)
|
||||
| Instruction::Div(Operand::Register(r), _, _)
|
||||
| Instruction::Mod(Operand::Register(r), _, _)
|
||||
| Instruction::Pow(Operand::Register(r), _, _)
|
||||
| Instruction::Load(Operand::Register(r), _, _)
|
||||
| Instruction::LoadSlot(Operand::Register(r), _, _, _)
|
||||
| Instruction::LoadBatch(Operand::Register(r), _, _, _)
|
||||
| Instruction::LoadBatchNamed(Operand::Register(r), _, _, _, _)
|
||||
| Instruction::SetEq(Operand::Register(r), _, _)
|
||||
| Instruction::SetNe(Operand::Register(r), _, _)
|
||||
| Instruction::SetGt(Operand::Register(r), _, _)
|
||||
| Instruction::SetLt(Operand::Register(r), _, _)
|
||||
| Instruction::SetGe(Operand::Register(r), _, _)
|
||||
| Instruction::SetLe(Operand::Register(r), _, _)
|
||||
| Instruction::And(Operand::Register(r), _, _)
|
||||
| Instruction::Or(Operand::Register(r), _, _)
|
||||
| Instruction::Xor(Operand::Register(r), _, _)
|
||||
| Instruction::Peek(Operand::Register(r))
|
||||
| Instruction::Get(Operand::Register(r), _, _)
|
||||
| Instruction::Select(Operand::Register(r), _, _, _)
|
||||
| Instruction::Rand(Operand::Register(r))
|
||||
| Instruction::Acos(Operand::Register(r), _)
|
||||
| Instruction::Asin(Operand::Register(r), _)
|
||||
| Instruction::Atan(Operand::Register(r), _)
|
||||
| Instruction::Atan2(Operand::Register(r), _, _)
|
||||
| Instruction::Abs(Operand::Register(r), _)
|
||||
| Instruction::Ceil(Operand::Register(r), _)
|
||||
| Instruction::Cos(Operand::Register(r), _)
|
||||
| Instruction::Floor(Operand::Register(r), _)
|
||||
| Instruction::Log(Operand::Register(r), _)
|
||||
| Instruction::Max(Operand::Register(r), _, _)
|
||||
| Instruction::Min(Operand::Register(r), _, _)
|
||||
| Instruction::Sin(Operand::Register(r), _)
|
||||
| Instruction::Sqrt(Operand::Register(r), _)
|
||||
| Instruction::Tan(Operand::Register(r), _)
|
||||
| Instruction::Trunc(Operand::Register(r), _)
|
||||
| Instruction::LoadReagent(Operand::Register(r), _, _, _)
|
||||
| Instruction::Pop(Operand::Register(r)) => Some(*r),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn set_destination_reg<'a>(instr: &Instruction<'a>, new_reg: u8) -> Option<Instruction<'a>> {
|
||||
// Helper to easily recreate instruction with new dest
|
||||
let r = Operand::Register(new_reg);
|
||||
match instr {
|
||||
Instruction::Move(_, b) => Some(Instruction::Move(r, b.clone())),
|
||||
Instruction::Add(_, a, b) => Some(Instruction::Add(r, a.clone(), b.clone())),
|
||||
Instruction::Sub(_, a, b) => Some(Instruction::Sub(r, a.clone(), b.clone())),
|
||||
Instruction::Mul(_, a, b) => Some(Instruction::Mul(r, a.clone(), b.clone())),
|
||||
Instruction::Div(_, a, b) => Some(Instruction::Div(r, a.clone(), b.clone())),
|
||||
Instruction::Mod(_, a, b) => Some(Instruction::Mod(r, a.clone(), b.clone())),
|
||||
Instruction::Pow(_, a, b) => Some(Instruction::Pow(r, a.clone(), b.clone())),
|
||||
Instruction::Load(_, a, b) => Some(Instruction::Load(r, a.clone(), b.clone())),
|
||||
Instruction::LoadSlot(_, a, b, c) => {
|
||||
Some(Instruction::LoadSlot(r, a.clone(), b.clone(), c.clone()))
|
||||
}
|
||||
Instruction::LoadBatch(_, a, b, c) => {
|
||||
Some(Instruction::LoadBatch(r, a.clone(), b.clone(), c.clone()))
|
||||
}
|
||||
Instruction::LoadBatchNamed(_, a, b, c, d) => Some(Instruction::LoadBatchNamed(
|
||||
r,
|
||||
a.clone(),
|
||||
b.clone(),
|
||||
c.clone(),
|
||||
d.clone(),
|
||||
)),
|
||||
Instruction::LoadReagent(_, b, c, d) => {
|
||||
Some(Instruction::LoadReagent(r, b.clone(), c.clone(), d.clone()))
|
||||
}
|
||||
Instruction::SetEq(_, a, b) => Some(Instruction::SetEq(r, a.clone(), b.clone())),
|
||||
Instruction::SetNe(_, a, b) => Some(Instruction::SetNe(r, a.clone(), b.clone())),
|
||||
Instruction::SetGt(_, a, b) => Some(Instruction::SetGt(r, a.clone(), b.clone())),
|
||||
Instruction::SetLt(_, a, b) => Some(Instruction::SetLt(r, a.clone(), b.clone())),
|
||||
Instruction::SetGe(_, a, b) => Some(Instruction::SetGe(r, a.clone(), b.clone())),
|
||||
Instruction::SetLe(_, a, b) => Some(Instruction::SetLe(r, a.clone(), b.clone())),
|
||||
Instruction::And(_, a, b) => Some(Instruction::And(r, a.clone(), b.clone())),
|
||||
Instruction::Or(_, a, b) => Some(Instruction::Or(r, a.clone(), b.clone())),
|
||||
Instruction::Xor(_, a, b) => Some(Instruction::Xor(r, a.clone(), b.clone())),
|
||||
Instruction::Peek(_) => Some(Instruction::Peek(r)),
|
||||
Instruction::Get(_, a, b) => Some(Instruction::Get(r, a.clone(), b.clone())),
|
||||
Instruction::Select(_, a, b, c) => {
|
||||
Some(Instruction::Select(r, a.clone(), b.clone(), c.clone()))
|
||||
}
|
||||
Instruction::Rand(_) => Some(Instruction::Rand(r)),
|
||||
Instruction::Pop(_) => Some(Instruction::Pop(r)),
|
||||
|
||||
// Math funcs
|
||||
Instruction::Acos(_, a) => Some(Instruction::Acos(r, a.clone())),
|
||||
Instruction::Asin(_, a) => Some(Instruction::Asin(r, a.clone())),
|
||||
Instruction::Atan(_, a) => Some(Instruction::Atan(r, a.clone())),
|
||||
Instruction::Atan2(_, a, b) => Some(Instruction::Atan2(r, a.clone(), b.clone())),
|
||||
Instruction::Abs(_, a) => Some(Instruction::Abs(r, a.clone())),
|
||||
Instruction::Ceil(_, a) => Some(Instruction::Ceil(r, a.clone())),
|
||||
Instruction::Cos(_, a) => Some(Instruction::Cos(r, a.clone())),
|
||||
Instruction::Floor(_, a) => Some(Instruction::Floor(r, a.clone())),
|
||||
Instruction::Log(_, a) => Some(Instruction::Log(r, a.clone())),
|
||||
Instruction::Max(_, a, b) => Some(Instruction::Max(r, a.clone(), b.clone())),
|
||||
Instruction::Min(_, a, b) => Some(Instruction::Min(r, a.clone(), b.clone())),
|
||||
Instruction::Sin(_, a) => Some(Instruction::Sin(r, a.clone())),
|
||||
Instruction::Sqrt(_, a) => Some(Instruction::Sqrt(r, a.clone())),
|
||||
Instruction::Tan(_, a) => Some(Instruction::Tan(r, a.clone())),
|
||||
Instruction::Trunc(_, a) => Some(Instruction::Trunc(r, a.clone())),
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn reg_is_read(instr: &Instruction, reg: u8) -> bool {
|
||||
let check = |op: &Operand| matches!(op, Operand::Register(r) if *r == reg);
|
||||
|
||||
match instr {
|
||||
Instruction::Move(_, a) => check(a),
|
||||
Instruction::Add(_, a, b)
|
||||
| Instruction::Sub(_, a, b)
|
||||
| Instruction::Mul(_, a, b)
|
||||
| Instruction::Div(_, a, b)
|
||||
| Instruction::Mod(_, a, b)
|
||||
| Instruction::Pow(_, a, b) => check(a) || check(b),
|
||||
|
||||
Instruction::Load(_, a, _) => check(a), // Load reads device? Device can be reg? Yes.
|
||||
Instruction::Store(a, _, b) => check(a) || check(b),
|
||||
|
||||
Instruction::BranchEq(a, b, _)
|
||||
| Instruction::BranchNe(a, b, _)
|
||||
| Instruction::BranchGt(a, b, _)
|
||||
| Instruction::BranchLt(a, b, _)
|
||||
| Instruction::BranchGe(a, b, _)
|
||||
| Instruction::BranchLe(a, b, _) => check(a) || check(b),
|
||||
|
||||
Instruction::BranchEqZero(a, _) | Instruction::BranchNeZero(a, _) => check(a),
|
||||
|
||||
Instruction::LoadReagent(_, device, _, item_hash) => check(device) || check(item_hash),
|
||||
|
||||
Instruction::LoadSlot(_, dev, slot, _) => check(dev) || check(slot),
|
||||
Instruction::LoadBatch(_, dev, _, mode) => check(dev) || check(mode),
|
||||
Instruction::LoadBatchNamed(_, d_hash, n_hash, _, mode) => {
|
||||
check(d_hash) || check(n_hash) || check(mode)
|
||||
}
|
||||
|
||||
Instruction::SetEq(_, a, b)
|
||||
| Instruction::SetNe(_, a, b)
|
||||
| Instruction::SetGt(_, a, b)
|
||||
| Instruction::SetLt(_, a, b)
|
||||
| Instruction::SetGe(_, a, b)
|
||||
| Instruction::SetLe(_, a, b)
|
||||
| Instruction::And(_, a, b)
|
||||
| Instruction::Or(_, a, b)
|
||||
| Instruction::Xor(_, a, b) => check(a) || check(b),
|
||||
|
||||
Instruction::Push(a) => check(a),
|
||||
Instruction::Get(_, a, b) => check(a) || check(b),
|
||||
Instruction::Put(a, b, c) => check(a) || check(b) || check(c),
|
||||
|
||||
Instruction::Select(_, a, b, c) => check(a) || check(b) || check(c),
|
||||
Instruction::Sleep(a) => check(a),
|
||||
|
||||
// Math single arg
|
||||
Instruction::Acos(_, a)
|
||||
| Instruction::Asin(_, a)
|
||||
| Instruction::Atan(_, a)
|
||||
| Instruction::Abs(_, a)
|
||||
| Instruction::Ceil(_, a)
|
||||
| Instruction::Cos(_, a)
|
||||
| Instruction::Floor(_, a)
|
||||
| Instruction::Log(_, a)
|
||||
| Instruction::Sin(_, a)
|
||||
| Instruction::Sqrt(_, a)
|
||||
| Instruction::Tan(_, a)
|
||||
| Instruction::Trunc(_, a) => check(a),
|
||||
|
||||
// Math double arg
|
||||
Instruction::Atan2(_, a, b) | Instruction::Max(_, a, b) | Instruction::Min(_, a, b) => {
|
||||
check(a) || check(b)
|
||||
}
|
||||
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// --- Constant Propagation & Dead Code ---
|
||||
fn constant_propagation<'a>(input: Vec<InstructionNode<'a>>) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
let mut changed = false;
|
||||
let mut registers: [Option<Decimal>; 16] = [None; 16];
|
||||
|
||||
for mut node in input {
|
||||
match &node.instruction {
|
||||
Instruction::LabelDef(_) | Instruction::JumpAndLink(_) => registers = [None; 16],
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let simplified = match &node.instruction {
|
||||
Instruction::Move(dst, src) => resolve_value(src, ®isters)
|
||||
.map(|val| Instruction::Move(dst.clone(), Operand::Number(val))),
|
||||
Instruction::Add(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x + y),
|
||||
Instruction::Sub(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x - y),
|
||||
Instruction::Mul(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x * y),
|
||||
Instruction::Div(dst, a, b) => {
|
||||
try_fold_math(
|
||||
dst,
|
||||
a,
|
||||
b,
|
||||
®isters,
|
||||
|x, y| if y.is_zero() { x } else { x / y },
|
||||
)
|
||||
}
|
||||
Instruction::Mod(dst, a, b) => {
|
||||
try_fold_math(
|
||||
dst,
|
||||
a,
|
||||
b,
|
||||
®isters,
|
||||
|x, y| if y.is_zero() { x } else { x % y },
|
||||
)
|
||||
}
|
||||
Instruction::BranchEq(a, b, l) => {
|
||||
try_resolve_branch(a, b, l, ®isters, |x, y| x == y)
|
||||
}
|
||||
Instruction::BranchNe(a, b, l) => {
|
||||
try_resolve_branch(a, b, l, ®isters, |x, y| x != y)
|
||||
}
|
||||
Instruction::BranchGt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x > y),
|
||||
Instruction::BranchLt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x < y),
|
||||
Instruction::BranchGe(a, b, l) => {
|
||||
try_resolve_branch(a, b, l, ®isters, |x, y| x >= y)
|
||||
}
|
||||
Instruction::BranchLe(a, b, l) => {
|
||||
try_resolve_branch(a, b, l, ®isters, |x, y| x <= y)
|
||||
}
|
||||
Instruction::BranchEqZero(a, l) => {
|
||||
try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x == y)
|
||||
}
|
||||
Instruction::BranchNeZero(a, l) => {
|
||||
try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x != y)
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(new) = simplified {
|
||||
node.instruction = new;
|
||||
changed = true;
|
||||
}
|
||||
|
||||
// Update tracking
|
||||
match &node.instruction {
|
||||
Instruction::Move(Operand::Register(r), src) => {
|
||||
registers[*r as usize] = resolve_value(src, ®isters)
|
||||
}
|
||||
// Invalidate if destination is register
|
||||
_ => {
|
||||
if let Some(r) = get_destination_reg(&node.instruction) {
|
||||
registers[r as usize] = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Filter out NOPs (Empty LabelDefs from branch resolution)
|
||||
if let Instruction::LabelDef(l) = &node.instruction
|
||||
&& l.is_empty()
|
||||
{
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
output.push(node);
|
||||
}
|
||||
(output, changed)
|
||||
}
|
||||
|
||||
fn resolve_value(op: &Operand, regs: &[Option<Decimal>; 16]) -> Option<Decimal> {
|
||||
match op {
|
||||
Operand::Number(n) => Some(*n),
|
||||
Operand::Register(r) => regs[*r as usize],
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn try_fold_math<'a, F>(
|
||||
dst: &Operand<'a>,
|
||||
a: &Operand<'a>,
|
||||
b: &Operand<'a>,
|
||||
regs: &[Option<Decimal>; 16],
|
||||
op: F,
|
||||
) -> Option<Instruction<'a>>
|
||||
where
|
||||
F: Fn(Decimal, Decimal) -> Decimal,
|
||||
{
|
||||
let val_a = resolve_value(a, regs)?;
|
||||
let val_b = resolve_value(b, regs)?;
|
||||
Some(Instruction::Move(
|
||||
dst.clone(),
|
||||
Operand::Number(op(val_a, val_b)),
|
||||
))
|
||||
}
|
||||
|
||||
fn try_resolve_branch<'a, F>(
|
||||
a: &Operand<'a>,
|
||||
b: &Operand<'a>,
|
||||
label: &Operand<'a>,
|
||||
regs: &[Option<Decimal>; 16],
|
||||
check: F,
|
||||
) -> Option<Instruction<'a>>
|
||||
where
|
||||
F: Fn(Decimal, Decimal) -> bool,
|
||||
{
|
||||
let val_a = resolve_value(a, regs)?;
|
||||
let val_b = resolve_value(b, regs)?;
|
||||
if check(val_a, val_b) {
|
||||
Some(Instruction::Jump(label.clone()))
|
||||
} else {
|
||||
Some(Instruction::LabelDef("".into())) // NOP
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_redundant_moves<'a>(input: Vec<InstructionNode<'a>>) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
let mut changed = false;
|
||||
for node in input {
|
||||
if let Instruction::Move(dst, src) = &node.instruction
|
||||
&& dst == src
|
||||
{
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
output.push(node);
|
||||
}
|
||||
(output, changed)
|
||||
}
|
||||
|
||||
fn remove_unreachable_code<'a>(
|
||||
input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
let mut changed = false;
|
||||
let mut dead = false;
|
||||
for node in input {
|
||||
if let Instruction::LabelDef(_) = node.instruction {
|
||||
dead = false;
|
||||
}
|
||||
if dead {
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
if let Instruction::Jump(_) = node.instruction {
|
||||
dead = true
|
||||
}
|
||||
output.push(node);
|
||||
}
|
||||
(output, changed)
|
||||
}
|
||||
|
||||
756
rust_compiler/libs/optimizer/src/peephole_optimization.rs
Normal file
756
rust_compiler/libs/optimizer/src/peephole_optimization.rs
Normal file
@@ -0,0 +1,756 @@
|
||||
use il::{Instruction, InstructionNode, Operand};
|
||||
|
||||
/// Pass: Peephole Optimization
|
||||
/// Recognizes and optimizes common instruction patterns.
|
||||
pub fn peephole_optimization<'a>(
|
||||
input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
let mut changed = false;
|
||||
let mut i = 0;
|
||||
|
||||
while i < input.len() {
|
||||
// Pattern: push sp; push ra ... pop ra; pop sp (with no jal in between)
|
||||
// If we push sp and ra and later pop them, but never call a function in between, remove all four
|
||||
// and adjust any stack pointer offsets in between by -2
|
||||
if i + 1 < input.len() {
|
||||
if let (
|
||||
Instruction::Push(Operand::StackPointer),
|
||||
Instruction::Push(Operand::ReturnAddress),
|
||||
) = (&input[i].instruction, &input[i + 1].instruction)
|
||||
{
|
||||
// Look for matching pop ra; pop sp pattern
|
||||
if let Some((ra_pop_idx, instructions_between)) =
|
||||
find_matching_ra_pop(&input[i + 1..])
|
||||
{
|
||||
let absolute_ra_pop = i + 1 + ra_pop_idx;
|
||||
// Check if the next instruction is pop sp
|
||||
if absolute_ra_pop + 1 < input.len() {
|
||||
if let Instruction::Pop(Operand::StackPointer) =
|
||||
&input[absolute_ra_pop + 1].instruction
|
||||
{
|
||||
// Check if there's any jal between push and pop
|
||||
let has_call = instructions_between.iter().any(|node| {
|
||||
matches!(node.instruction, Instruction::JumpAndLink(_))
|
||||
});
|
||||
|
||||
if !has_call {
|
||||
// Safe to remove all four: push sp, push ra, pop ra, pop sp
|
||||
// Also need to adjust stack pointer offsets in between by -2
|
||||
let absolute_sp_pop = absolute_ra_pop + 1;
|
||||
for (idx, node) in input.iter().enumerate() {
|
||||
if idx == i
|
||||
|| idx == i + 1
|
||||
|| idx == absolute_ra_pop
|
||||
|| idx == absolute_sp_pop
|
||||
{
|
||||
// Skip all four push/pop instructions
|
||||
continue;
|
||||
}
|
||||
|
||||
// If this instruction is between the pushes and pops, adjust its stack offsets
|
||||
if idx > i + 1 && idx < absolute_ra_pop {
|
||||
let adjusted_instruction =
|
||||
adjust_stack_offset(node.instruction.clone(), 2);
|
||||
output.push(InstructionNode::new(
|
||||
adjusted_instruction,
|
||||
node.span,
|
||||
));
|
||||
} else {
|
||||
output.push(node.clone());
|
||||
}
|
||||
}
|
||||
changed = true;
|
||||
// We've processed the entire input, so break
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern: push ra ... pop ra (with no jal in between)
|
||||
// Fallback for when there's only ra push/pop without sp
|
||||
if let Instruction::Push(Operand::ReturnAddress) = &input[i].instruction {
|
||||
if let Some((pop_idx, instructions_between)) = find_matching_ra_pop(&input[i..]) {
|
||||
// Check if there's any jal between push and pop
|
||||
let has_call = instructions_between
|
||||
.iter()
|
||||
.any(|node| matches!(node.instruction, Instruction::JumpAndLink(_)));
|
||||
|
||||
if !has_call {
|
||||
// Safe to remove both push and pop
|
||||
// Also need to adjust stack pointer offsets in between
|
||||
let absolute_pop_idx = i + pop_idx;
|
||||
for (idx, node) in input.iter().enumerate() {
|
||||
if idx == i || idx == absolute_pop_idx {
|
||||
// Skip the push and pop
|
||||
continue;
|
||||
}
|
||||
|
||||
// If this instruction is between push and pop, adjust its stack offsets
|
||||
if idx > i && idx < absolute_pop_idx {
|
||||
let adjusted_instruction =
|
||||
adjust_stack_offset(node.instruction.clone(), 1);
|
||||
output.push(InstructionNode::new(adjusted_instruction, node.span));
|
||||
} else {
|
||||
output.push(node.clone());
|
||||
}
|
||||
}
|
||||
changed = true;
|
||||
// We've processed the entire input, so break
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern: Branch-Move-Jump-Label-Move-Label -> Select
|
||||
// beqz r1 else_label
|
||||
// move r2 val1
|
||||
// j end_label
|
||||
// else_label:
|
||||
// move r2 val2
|
||||
// end_label:
|
||||
// Converts to: select r2 r1 val1 val2
|
||||
if i + 5 < input.len() {
|
||||
let select_pattern = try_match_select_pattern(&input[i..i + 6]);
|
||||
if let Some((dst, cond, true_val, false_val, skip_count)) = select_pattern {
|
||||
output.push(InstructionNode::new(
|
||||
Instruction::Select(dst, cond, true_val, false_val),
|
||||
input[i].span,
|
||||
));
|
||||
changed = true;
|
||||
i += skip_count;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern: seq + beqz -> beq
|
||||
if i + 1 < input.len() {
|
||||
let pattern = match (&input[i].instruction, &input[i + 1].instruction) {
|
||||
(
|
||||
Instruction::SetEq(Operand::Register(temp), a, b),
|
||||
Instruction::BranchEqZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Eq, true)), // invert: beqz means "if NOT equal"
|
||||
|
||||
(
|
||||
Instruction::SetNe(Operand::Register(temp), a, b),
|
||||
Instruction::BranchEqZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Ne, true)),
|
||||
|
||||
(
|
||||
Instruction::SetGt(Operand::Register(temp), a, b),
|
||||
Instruction::BranchEqZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Gt, true)),
|
||||
|
||||
(
|
||||
Instruction::SetLt(Operand::Register(temp), a, b),
|
||||
Instruction::BranchEqZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Lt, true)),
|
||||
|
||||
(
|
||||
Instruction::SetGe(Operand::Register(temp), a, b),
|
||||
Instruction::BranchEqZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Ge, true)),
|
||||
|
||||
(
|
||||
Instruction::SetLe(Operand::Register(temp), a, b),
|
||||
Instruction::BranchEqZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Le, true)),
|
||||
|
||||
// Pattern: seq + bnez -> bne
|
||||
(
|
||||
Instruction::SetEq(Operand::Register(temp), a, b),
|
||||
Instruction::BranchNeZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Eq, false)),
|
||||
|
||||
(
|
||||
Instruction::SetNe(Operand::Register(temp), a, b),
|
||||
Instruction::BranchNeZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Ne, false)),
|
||||
|
||||
(
|
||||
Instruction::SetGt(Operand::Register(temp), a, b),
|
||||
Instruction::BranchNeZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Gt, false)),
|
||||
|
||||
(
|
||||
Instruction::SetLt(Operand::Register(temp), a, b),
|
||||
Instruction::BranchNeZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Lt, false)),
|
||||
|
||||
(
|
||||
Instruction::SetGe(Operand::Register(temp), a, b),
|
||||
Instruction::BranchNeZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Ge, false)),
|
||||
|
||||
(
|
||||
Instruction::SetLe(Operand::Register(temp), a, b),
|
||||
Instruction::BranchNeZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Le, false)),
|
||||
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some((a, b, label, branch_type, invert)) = pattern {
|
||||
// Create optimized branch instruction
|
||||
let new_instr = if invert {
|
||||
// beqz after seq means "branch if NOT equal" -> bne
|
||||
match branch_type {
|
||||
BranchType::Eq => {
|
||||
Instruction::BranchNe(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Ne => {
|
||||
Instruction::BranchEq(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Gt => {
|
||||
Instruction::BranchLe(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Lt => {
|
||||
Instruction::BranchGe(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Ge => {
|
||||
Instruction::BranchLt(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Le => {
|
||||
Instruction::BranchGt(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// bnez after seq means "branch if equal" -> beq
|
||||
match branch_type {
|
||||
BranchType::Eq => {
|
||||
Instruction::BranchEq(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Ne => {
|
||||
Instruction::BranchNe(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Gt => {
|
||||
Instruction::BranchGt(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Lt => {
|
||||
Instruction::BranchLt(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Ge => {
|
||||
Instruction::BranchGe(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Le => {
|
||||
Instruction::BranchLe(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
output.push(InstructionNode::new(new_instr, input[i].span));
|
||||
changed = true;
|
||||
i += 2; // Skip both instructions
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
output.push(input[i].clone());
|
||||
i += 1;
|
||||
}
|
||||
|
||||
(output, changed)
|
||||
}
|
||||
|
||||
/// Tries to match a select pattern in the instruction sequence.
|
||||
/// Pattern (6 instructions):
|
||||
/// beqz/bnez cond else_label (i+0)
|
||||
/// move dst val1 (i+1)
|
||||
/// j end_label (i+2)
|
||||
/// else_label: (i+3)
|
||||
/// move dst val2 (i+4)
|
||||
/// end_label: (i+5)
|
||||
/// Returns: (dst, cond, true_val, false_val, instruction_count)
|
||||
fn try_match_select_pattern<'a>(
|
||||
instructions: &[InstructionNode<'a>],
|
||||
) -> Option<(Operand<'a>, Operand<'a>, Operand<'a>, Operand<'a>, usize)> {
|
||||
if instructions.len() < 6 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check for beqz pattern
|
||||
if let Instruction::BranchEqZero(cond, Operand::Label(else_label)) =
|
||||
&instructions[0].instruction
|
||||
{
|
||||
if let Instruction::Move(dst1, val1) = &instructions[1].instruction {
|
||||
if let Instruction::Jump(Operand::Label(end_label)) = &instructions[2].instruction {
|
||||
if let Instruction::LabelDef(label3) = &instructions[3].instruction {
|
||||
if label3 == else_label {
|
||||
if let Instruction::Move(dst2, val2) = &instructions[4].instruction {
|
||||
if dst1 == dst2 {
|
||||
if let Instruction::LabelDef(label5) = &instructions[5].instruction
|
||||
{
|
||||
if label5 == end_label {
|
||||
// beqz means: if cond==0, goto else, so val1 is for true, val2 for false
|
||||
// select dst cond true_val false_val
|
||||
// When cond is non-zero (true), use val1, otherwise val2
|
||||
return Some((
|
||||
dst1.clone(),
|
||||
cond.clone(),
|
||||
val1.clone(),
|
||||
val2.clone(),
|
||||
6,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for bnez pattern
|
||||
if let Instruction::BranchNeZero(cond, Operand::Label(then_label)) =
|
||||
&instructions[0].instruction
|
||||
{
|
||||
if let Instruction::Move(dst1, val_false) = &instructions[1].instruction {
|
||||
if let Instruction::Jump(Operand::Label(end_label)) = &instructions[2].instruction {
|
||||
if let Instruction::LabelDef(label3) = &instructions[3].instruction {
|
||||
if label3 == then_label {
|
||||
if let Instruction::Move(dst2, val_true) = &instructions[4].instruction {
|
||||
if dst1 == dst2 {
|
||||
if let Instruction::LabelDef(label5) = &instructions[5].instruction
|
||||
{
|
||||
if label5 == end_label {
|
||||
// bnez means: if cond!=0, goto then, so val_true for true, val_false for false
|
||||
return Some((
|
||||
dst1.clone(),
|
||||
cond.clone(),
|
||||
val_true.clone(),
|
||||
val_false.clone(),
|
||||
6,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Finds a matching `pop ra` for a `push ra` at the start of the slice.
|
||||
/// Returns the index of the pop and the instructions in between.
|
||||
fn find_matching_ra_pop<'a>(
|
||||
instructions: &'a [InstructionNode<'a>],
|
||||
) -> Option<(usize, &'a [InstructionNode<'a>])> {
|
||||
if instructions.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Skip the push itself
|
||||
for (idx, node) in instructions.iter().enumerate().skip(1) {
|
||||
if let Instruction::Pop(Operand::ReturnAddress) = &node.instruction {
|
||||
// Found matching pop
|
||||
return Some((idx, &instructions[1..idx]));
|
||||
}
|
||||
|
||||
// Stop searching if we hit a jump (different control flow)
|
||||
// Labels are OK - they're just markers
|
||||
if matches!(
|
||||
node.instruction,
|
||||
Instruction::Jump(_) | Instruction::JumpRelative(_)
|
||||
) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Checks if an instruction uses or modifies the stack pointer.
|
||||
#[allow(dead_code)]
|
||||
fn uses_stack_pointer(instruction: &Instruction) -> bool {
|
||||
match instruction {
|
||||
Instruction::Push(_) | Instruction::Pop(_) | Instruction::Peek(_) => true,
|
||||
Instruction::Add(Operand::StackPointer, _, _)
|
||||
| Instruction::Sub(Operand::StackPointer, _, _)
|
||||
| Instruction::Mul(Operand::StackPointer, _, _)
|
||||
| Instruction::Div(Operand::StackPointer, _, _)
|
||||
| Instruction::Mod(Operand::StackPointer, _, _) => true,
|
||||
Instruction::Add(_, Operand::StackPointer, _)
|
||||
| Instruction::Sub(_, Operand::StackPointer, _)
|
||||
| Instruction::Mul(_, Operand::StackPointer, _)
|
||||
| Instruction::Div(_, Operand::StackPointer, _)
|
||||
| Instruction::Mod(_, Operand::StackPointer, _) => true,
|
||||
Instruction::Add(_, _, Operand::StackPointer)
|
||||
| Instruction::Sub(_, _, Operand::StackPointer)
|
||||
| Instruction::Mul(_, _, Operand::StackPointer)
|
||||
| Instruction::Div(_, _, Operand::StackPointer)
|
||||
| Instruction::Mod(_, _, Operand::StackPointer) => true,
|
||||
Instruction::Move(Operand::StackPointer, _)
|
||||
| Instruction::Move(_, Operand::StackPointer) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adjusts stack pointer offsets in an instruction by decrementing them by a given amount.
|
||||
/// This is necessary when removing push operations that would have increased the stack size.
|
||||
fn adjust_stack_offset<'a>(instruction: Instruction<'a>, decrement: i64) -> Instruction<'a> {
|
||||
use rust_decimal::prelude::*;
|
||||
|
||||
match instruction {
|
||||
// Adjust arithmetic operations on sp that use literal offsets
|
||||
Instruction::Sub(dst, Operand::StackPointer, Operand::Number(n)) => {
|
||||
let new_n = n - Decimal::from(decrement);
|
||||
// If the result is 0 or negative, we may want to skip this entirely
|
||||
// but for now, just adjust the value
|
||||
Instruction::Sub(dst, Operand::StackPointer, Operand::Number(new_n))
|
||||
}
|
||||
Instruction::Add(dst, Operand::StackPointer, Operand::Number(n)) => {
|
||||
let new_n = n - Decimal::from(decrement);
|
||||
Instruction::Add(dst, Operand::StackPointer, Operand::Number(new_n))
|
||||
}
|
||||
// Return the instruction unchanged if it doesn't need adjustment
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum BranchType {
|
||||
Eq,
|
||||
Ne,
|
||||
Gt,
|
||||
Lt,
|
||||
Ge,
|
||||
Le,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_seq_beqz_to_bne() {
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::SetEq(
|
||||
Operand::Register(1),
|
||||
Operand::Register(2),
|
||||
Operand::Register(3),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::BranchEqZero(Operand::Register(1), Operand::Label("target".into())),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::BranchNe(_, _, _)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sne_beqz_to_beq() {
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::SetNe(
|
||||
Operand::Register(1),
|
||||
Operand::Register(2),
|
||||
Operand::Register(3),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::BranchEqZero(Operand::Register(1), Operand::Label("target".into())),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::BranchEq(_, _, _)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_seq_bnez_to_beq() {
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::SetEq(
|
||||
Operand::Register(1),
|
||||
Operand::Register(2),
|
||||
Operand::Register(3),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::BranchNeZero(Operand::Register(1), Operand::Label("target".into())),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::BranchEq(_, _, _)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sgt_beqz_to_ble() {
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::SetGt(
|
||||
Operand::Register(1),
|
||||
Operand::Register(2),
|
||||
Operand::Register(3),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::BranchEqZero(Operand::Register(1), Operand::Label("target".into())),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::BranchLe(_, _, _)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_branch_move_jump_to_select_beqz() {
|
||||
// Pattern: beqz r1 else / move r2 10 / j end / else: / move r2 20 / end:
|
||||
// Should convert to: select r2 r1 10 20
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::BranchEqZero(Operand::Register(1), Operand::Label("else".into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(2), Operand::Number(10.into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Jump(Operand::Label("end".into())), None),
|
||||
InstructionNode::new(Instruction::LabelDef("else".into()), None),
|
||||
InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(2), Operand::Number(20.into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::LabelDef("end".into()), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
if let Instruction::Select(dst, cond, true_val, false_val) = &output[0].instruction {
|
||||
assert!(matches!(dst, Operand::Register(2)));
|
||||
assert!(matches!(cond, Operand::Register(1)));
|
||||
assert!(matches!(true_val, Operand::Number(_)));
|
||||
assert!(matches!(false_val, Operand::Number(_)));
|
||||
} else {
|
||||
panic!("Expected Select instruction");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_branch_move_jump_to_select_bnez() {
|
||||
// Pattern: bnez r1 then / move r2 20 / j end / then: / move r2 10 / end:
|
||||
// Should convert to: select r2 r1 10 20
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::BranchNeZero(Operand::Register(1), Operand::Label("then".into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(2), Operand::Number(20.into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Jump(Operand::Label("end".into())), None),
|
||||
InstructionNode::new(Instruction::LabelDef("then".into()), None),
|
||||
InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(2), Operand::Number(10.into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::LabelDef("end".into()), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
if let Instruction::Select(dst, cond, true_val, false_val) = &output[0].instruction {
|
||||
assert!(matches!(dst, Operand::Register(2)));
|
||||
assert!(matches!(cond, Operand::Register(1)));
|
||||
assert!(matches!(true_val, Operand::Number(_)));
|
||||
assert!(matches!(false_val, Operand::Number(_)));
|
||||
} else {
|
||||
panic!("Expected Select instruction");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_useless_ra_push_pop() {
|
||||
// Pattern: push ra / add r1 r2 r3 / pop ra
|
||||
// Should remove both push and pop since no jal in between
|
||||
let input = vec![
|
||||
InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(
|
||||
Instruction::Add(
|
||||
Operand::Register(1),
|
||||
Operand::Register(2),
|
||||
Operand::Register(3),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(output[0].instruction, Instruction::Add(_, _, _)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keep_ra_push_pop_with_jal() {
|
||||
// Pattern: push ra / jal func / pop ra
|
||||
// Should keep both since there's a jal in between
|
||||
let input = vec![
|
||||
InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(
|
||||
Instruction::JumpAndLink(Operand::Label("func".into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(!changed);
|
||||
assert_eq!(output.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ra_push_pop_with_stack_offset_adjustment() {
|
||||
// Pattern: push ra / sub r1 sp 2 / pop ra
|
||||
// Should remove push/pop AND adjust the stack offset from 2 to 1
|
||||
use rust_decimal::prelude::*;
|
||||
|
||||
let input = vec![
|
||||
InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(
|
||||
Instruction::Sub(
|
||||
Operand::Register(1),
|
||||
Operand::StackPointer,
|
||||
Operand::Number(Decimal::from(2)),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
|
||||
if let Instruction::Sub(dst, src, Operand::Number(offset)) = &output[0].instruction {
|
||||
assert!(matches!(dst, Operand::Register(1)));
|
||||
assert!(matches!(src, Operand::StackPointer));
|
||||
assert_eq!(*offset, Decimal::from(1)); // Should be decremented from 2 to 1
|
||||
} else {
|
||||
panic!("Expected Sub instruction with adjusted offset");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_sp_and_ra_push_pop() {
|
||||
// Pattern: push sp / push ra / move r8 10 / pop ra / pop sp
|
||||
// Should remove all four push/pop instructions since no jal in between
|
||||
let input = vec![
|
||||
InstructionNode::new(Instruction::Push(Operand::StackPointer), None),
|
||||
InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(8), Operand::Number(10.into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(Instruction::Pop(Operand::StackPointer), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::Move(Operand::Register(8), _)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keep_sp_and_ra_push_pop_with_jal() {
|
||||
// Pattern: push sp / push ra / jal func / pop ra / pop sp
|
||||
// Should keep all since there's a jal in between
|
||||
let input = vec![
|
||||
InstructionNode::new(Instruction::Push(Operand::StackPointer), None),
|
||||
InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(
|
||||
Instruction::JumpAndLink(Operand::Label("func".into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(Instruction::Pop(Operand::StackPointer), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(!changed);
|
||||
assert_eq!(output.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sp_and_ra_with_stack_offset_adjustment() {
|
||||
// Pattern: push sp / push ra / sub r1 sp 3 / pop ra / pop sp
|
||||
// Should remove all push/pop AND adjust the stack offset from 3 to 1 (decrement by 2)
|
||||
use rust_decimal::prelude::*;
|
||||
|
||||
let input = vec![
|
||||
InstructionNode::new(Instruction::Push(Operand::StackPointer), None),
|
||||
InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(
|
||||
Instruction::Sub(
|
||||
Operand::Register(1),
|
||||
Operand::StackPointer,
|
||||
Operand::Number(Decimal::from(3)),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(Instruction::Pop(Operand::StackPointer), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
|
||||
if let Instruction::Sub(dst, src, Operand::Number(offset)) = &output[0].instruction {
|
||||
assert!(matches!(dst, Operand::Register(1)));
|
||||
assert!(matches!(src, Operand::StackPointer));
|
||||
assert_eq!(*offset, Decimal::from(1)); // Should be decremented from 3 to 1
|
||||
} else {
|
||||
panic!("Expected Sub instruction with adjusted offset");
|
||||
}
|
||||
}
|
||||
}
|
||||
108
rust_compiler/libs/optimizer/src/register_forwarding.rs
Normal file
108
rust_compiler/libs/optimizer/src/register_forwarding.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use crate::helpers::{get_destination_reg, reg_is_read, set_destination_reg};
|
||||
use il::{Instruction, InstructionNode};
|
||||
|
||||
/// Pass: Register Forwarding
|
||||
/// Eliminates intermediate moves by writing directly to the final destination.
|
||||
/// Example: `l r1 d0 Temperature` + `move r9 r1` -> `l r9 d0 Temperature`
|
||||
pub fn register_forwarding<'a>(
|
||||
mut input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let mut changed = false;
|
||||
let mut i = 0;
|
||||
|
||||
while i < input.len().saturating_sub(1) {
|
||||
let next_idx = i + 1;
|
||||
|
||||
// Check if current instruction defines a register
|
||||
// and the NEXT instruction is a move from that register.
|
||||
let forward_candidate = if let Some(def_reg) = get_destination_reg(&input[i].instruction) {
|
||||
if let Instruction::Move(
|
||||
il::Operand::Register(dest_reg),
|
||||
il::Operand::Register(src_reg),
|
||||
) = &input[next_idx].instruction
|
||||
{
|
||||
if *src_reg == def_reg {
|
||||
Some((def_reg, *dest_reg))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some((temp_reg, final_reg)) = forward_candidate {
|
||||
// Check liveness: Is temp_reg used after i+1?
|
||||
let mut temp_is_dead = true;
|
||||
for node in input.iter().skip(i + 2) {
|
||||
if reg_is_read(&node.instruction, temp_reg) {
|
||||
temp_is_dead = false;
|
||||
break;
|
||||
}
|
||||
// If the temp is redefined, then the old value is dead
|
||||
if let Some(redef) = get_destination_reg(&node.instruction)
|
||||
&& redef == temp_reg
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
// Conservative: assume liveness might leak at labels/jumps
|
||||
if matches!(
|
||||
node.instruction,
|
||||
Instruction::LabelDef(_) | Instruction::Jump(_) | Instruction::JumpAndLink(_)
|
||||
) {
|
||||
temp_is_dead = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if temp_is_dead {
|
||||
// Rewrite to use final destination directly
|
||||
if let Some(new_instr) = set_destination_reg(&input[i].instruction, final_reg) {
|
||||
input[i].instruction = new_instr;
|
||||
input.remove(next_idx);
|
||||
changed = true;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
(input, changed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use il::{Instruction, InstructionNode, Operand};
|
||||
|
||||
#[test]
|
||||
fn test_forward_simple_move() {
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::Add(
|
||||
Operand::Register(1),
|
||||
Operand::Register(2),
|
||||
Operand::Register(3),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(5), Operand::Register(1)),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
let (output, changed) = register_forwarding(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::Add(Operand::Register(5), _, _)
|
||||
));
|
||||
}
|
||||
}
|
||||
63
rust_compiler/libs/optimizer/src/strength_reduction.rs
Normal file
63
rust_compiler/libs/optimizer/src/strength_reduction.rs
Normal file
@@ -0,0 +1,63 @@
|
||||
use il::{Instruction, InstructionNode, Operand};
|
||||
use rust_decimal::Decimal;
|
||||
|
||||
/// Pass: Strength Reduction
|
||||
/// Replaces expensive operations with cheaper equivalents.
|
||||
/// Example: x * 2 -> add x x x (addition is typically faster than multiplication)
|
||||
pub fn strength_reduction<'a>(
|
||||
input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
let mut changed = false;
|
||||
|
||||
for mut node in input {
|
||||
let reduced = match &node.instruction {
|
||||
// x * 2 = x + x
|
||||
Instruction::Mul(dst, a, Operand::Number(n)) if *n == Decimal::from(2) => {
|
||||
Some(Instruction::Add(dst.clone(), a.clone(), a.clone()))
|
||||
}
|
||||
Instruction::Mul(dst, Operand::Number(n), b) if *n == Decimal::from(2) => {
|
||||
Some(Instruction::Add(dst.clone(), b.clone(), b.clone()))
|
||||
}
|
||||
|
||||
// Future: Could add power-of-2 optimizations using bit shifts if IC10 supports them
|
||||
// x * 4 = (x + x) + (x + x) or x << 2
|
||||
// x / 2 = x >> 1
|
||||
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some(new) = reduced {
|
||||
node.instruction = new;
|
||||
changed = true;
|
||||
}
|
||||
|
||||
output.push(node);
|
||||
}
|
||||
|
||||
(output, changed)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mul_two_to_add() {
|
||||
let input = vec![InstructionNode::new(
|
||||
Instruction::Mul(
|
||||
Operand::Register(1),
|
||||
Operand::Register(2),
|
||||
Operand::Number(Decimal::from(2)),
|
||||
),
|
||||
None,
|
||||
)];
|
||||
|
||||
let (output, changed) = strength_reduction(input);
|
||||
assert!(changed);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::Add(Operand::Register(1), Operand::Register(2), Operand::Register(2))
|
||||
));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user