More optimizations and snapshot integration tests
This commit is contained in:
756
rust_compiler/libs/optimizer/src/peephole_optimization.rs
Normal file
756
rust_compiler/libs/optimizer/src/peephole_optimization.rs
Normal file
@@ -0,0 +1,756 @@
|
||||
use il::{Instruction, InstructionNode, Operand};
|
||||
|
||||
/// Pass: Peephole Optimization
|
||||
/// Recognizes and optimizes common instruction patterns.
|
||||
pub fn peephole_optimization<'a>(
|
||||
input: Vec<InstructionNode<'a>>,
|
||||
) -> (Vec<InstructionNode<'a>>, bool) {
|
||||
let mut output = Vec::with_capacity(input.len());
|
||||
let mut changed = false;
|
||||
let mut i = 0;
|
||||
|
||||
while i < input.len() {
|
||||
// Pattern: push sp; push ra ... pop ra; pop sp (with no jal in between)
|
||||
// If we push sp and ra and later pop them, but never call a function in between, remove all four
|
||||
// and adjust any stack pointer offsets in between by -2
|
||||
if i + 1 < input.len() {
|
||||
if let (
|
||||
Instruction::Push(Operand::StackPointer),
|
||||
Instruction::Push(Operand::ReturnAddress),
|
||||
) = (&input[i].instruction, &input[i + 1].instruction)
|
||||
{
|
||||
// Look for matching pop ra; pop sp pattern
|
||||
if let Some((ra_pop_idx, instructions_between)) =
|
||||
find_matching_ra_pop(&input[i + 1..])
|
||||
{
|
||||
let absolute_ra_pop = i + 1 + ra_pop_idx;
|
||||
// Check if the next instruction is pop sp
|
||||
if absolute_ra_pop + 1 < input.len() {
|
||||
if let Instruction::Pop(Operand::StackPointer) =
|
||||
&input[absolute_ra_pop + 1].instruction
|
||||
{
|
||||
// Check if there's any jal between push and pop
|
||||
let has_call = instructions_between.iter().any(|node| {
|
||||
matches!(node.instruction, Instruction::JumpAndLink(_))
|
||||
});
|
||||
|
||||
if !has_call {
|
||||
// Safe to remove all four: push sp, push ra, pop ra, pop sp
|
||||
// Also need to adjust stack pointer offsets in between by -2
|
||||
let absolute_sp_pop = absolute_ra_pop + 1;
|
||||
for (idx, node) in input.iter().enumerate() {
|
||||
if idx == i
|
||||
|| idx == i + 1
|
||||
|| idx == absolute_ra_pop
|
||||
|| idx == absolute_sp_pop
|
||||
{
|
||||
// Skip all four push/pop instructions
|
||||
continue;
|
||||
}
|
||||
|
||||
// If this instruction is between the pushes and pops, adjust its stack offsets
|
||||
if idx > i + 1 && idx < absolute_ra_pop {
|
||||
let adjusted_instruction =
|
||||
adjust_stack_offset(node.instruction.clone(), 2);
|
||||
output.push(InstructionNode::new(
|
||||
adjusted_instruction,
|
||||
node.span,
|
||||
));
|
||||
} else {
|
||||
output.push(node.clone());
|
||||
}
|
||||
}
|
||||
changed = true;
|
||||
// We've processed the entire input, so break
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern: push ra ... pop ra (with no jal in between)
|
||||
// Fallback for when there's only ra push/pop without sp
|
||||
if let Instruction::Push(Operand::ReturnAddress) = &input[i].instruction {
|
||||
if let Some((pop_idx, instructions_between)) = find_matching_ra_pop(&input[i..]) {
|
||||
// Check if there's any jal between push and pop
|
||||
let has_call = instructions_between
|
||||
.iter()
|
||||
.any(|node| matches!(node.instruction, Instruction::JumpAndLink(_)));
|
||||
|
||||
if !has_call {
|
||||
// Safe to remove both push and pop
|
||||
// Also need to adjust stack pointer offsets in between
|
||||
let absolute_pop_idx = i + pop_idx;
|
||||
for (idx, node) in input.iter().enumerate() {
|
||||
if idx == i || idx == absolute_pop_idx {
|
||||
// Skip the push and pop
|
||||
continue;
|
||||
}
|
||||
|
||||
// If this instruction is between push and pop, adjust its stack offsets
|
||||
if idx > i && idx < absolute_pop_idx {
|
||||
let adjusted_instruction =
|
||||
adjust_stack_offset(node.instruction.clone(), 1);
|
||||
output.push(InstructionNode::new(adjusted_instruction, node.span));
|
||||
} else {
|
||||
output.push(node.clone());
|
||||
}
|
||||
}
|
||||
changed = true;
|
||||
// We've processed the entire input, so break
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern: Branch-Move-Jump-Label-Move-Label -> Select
|
||||
// beqz r1 else_label
|
||||
// move r2 val1
|
||||
// j end_label
|
||||
// else_label:
|
||||
// move r2 val2
|
||||
// end_label:
|
||||
// Converts to: select r2 r1 val1 val2
|
||||
if i + 5 < input.len() {
|
||||
let select_pattern = try_match_select_pattern(&input[i..i + 6]);
|
||||
if let Some((dst, cond, true_val, false_val, skip_count)) = select_pattern {
|
||||
output.push(InstructionNode::new(
|
||||
Instruction::Select(dst, cond, true_val, false_val),
|
||||
input[i].span,
|
||||
));
|
||||
changed = true;
|
||||
i += skip_count;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern: seq + beqz -> beq
|
||||
if i + 1 < input.len() {
|
||||
let pattern = match (&input[i].instruction, &input[i + 1].instruction) {
|
||||
(
|
||||
Instruction::SetEq(Operand::Register(temp), a, b),
|
||||
Instruction::BranchEqZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Eq, true)), // invert: beqz means "if NOT equal"
|
||||
|
||||
(
|
||||
Instruction::SetNe(Operand::Register(temp), a, b),
|
||||
Instruction::BranchEqZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Ne, true)),
|
||||
|
||||
(
|
||||
Instruction::SetGt(Operand::Register(temp), a, b),
|
||||
Instruction::BranchEqZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Gt, true)),
|
||||
|
||||
(
|
||||
Instruction::SetLt(Operand::Register(temp), a, b),
|
||||
Instruction::BranchEqZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Lt, true)),
|
||||
|
||||
(
|
||||
Instruction::SetGe(Operand::Register(temp), a, b),
|
||||
Instruction::BranchEqZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Ge, true)),
|
||||
|
||||
(
|
||||
Instruction::SetLe(Operand::Register(temp), a, b),
|
||||
Instruction::BranchEqZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Le, true)),
|
||||
|
||||
// Pattern: seq + bnez -> bne
|
||||
(
|
||||
Instruction::SetEq(Operand::Register(temp), a, b),
|
||||
Instruction::BranchNeZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Eq, false)),
|
||||
|
||||
(
|
||||
Instruction::SetNe(Operand::Register(temp), a, b),
|
||||
Instruction::BranchNeZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Ne, false)),
|
||||
|
||||
(
|
||||
Instruction::SetGt(Operand::Register(temp), a, b),
|
||||
Instruction::BranchNeZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Gt, false)),
|
||||
|
||||
(
|
||||
Instruction::SetLt(Operand::Register(temp), a, b),
|
||||
Instruction::BranchNeZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Lt, false)),
|
||||
|
||||
(
|
||||
Instruction::SetGe(Operand::Register(temp), a, b),
|
||||
Instruction::BranchNeZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Ge, false)),
|
||||
|
||||
(
|
||||
Instruction::SetLe(Operand::Register(temp), a, b),
|
||||
Instruction::BranchNeZero(Operand::Register(cond), label),
|
||||
) if temp == cond => Some((a, b, label, BranchType::Le, false)),
|
||||
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if let Some((a, b, label, branch_type, invert)) = pattern {
|
||||
// Create optimized branch instruction
|
||||
let new_instr = if invert {
|
||||
// beqz after seq means "branch if NOT equal" -> bne
|
||||
match branch_type {
|
||||
BranchType::Eq => {
|
||||
Instruction::BranchNe(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Ne => {
|
||||
Instruction::BranchEq(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Gt => {
|
||||
Instruction::BranchLe(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Lt => {
|
||||
Instruction::BranchGe(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Ge => {
|
||||
Instruction::BranchLt(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Le => {
|
||||
Instruction::BranchGt(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// bnez after seq means "branch if equal" -> beq
|
||||
match branch_type {
|
||||
BranchType::Eq => {
|
||||
Instruction::BranchEq(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Ne => {
|
||||
Instruction::BranchNe(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Gt => {
|
||||
Instruction::BranchGt(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Lt => {
|
||||
Instruction::BranchLt(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Ge => {
|
||||
Instruction::BranchGe(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
BranchType::Le => {
|
||||
Instruction::BranchLe(a.clone(), b.clone(), label.clone())
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
output.push(InstructionNode::new(new_instr, input[i].span));
|
||||
changed = true;
|
||||
i += 2; // Skip both instructions
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
output.push(input[i].clone());
|
||||
i += 1;
|
||||
}
|
||||
|
||||
(output, changed)
|
||||
}
|
||||
|
||||
/// Tries to match a select pattern in the instruction sequence.
|
||||
/// Pattern (6 instructions):
|
||||
/// beqz/bnez cond else_label (i+0)
|
||||
/// move dst val1 (i+1)
|
||||
/// j end_label (i+2)
|
||||
/// else_label: (i+3)
|
||||
/// move dst val2 (i+4)
|
||||
/// end_label: (i+5)
|
||||
/// Returns: (dst, cond, true_val, false_val, instruction_count)
|
||||
fn try_match_select_pattern<'a>(
|
||||
instructions: &[InstructionNode<'a>],
|
||||
) -> Option<(Operand<'a>, Operand<'a>, Operand<'a>, Operand<'a>, usize)> {
|
||||
if instructions.len() < 6 {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check for beqz pattern
|
||||
if let Instruction::BranchEqZero(cond, Operand::Label(else_label)) =
|
||||
&instructions[0].instruction
|
||||
{
|
||||
if let Instruction::Move(dst1, val1) = &instructions[1].instruction {
|
||||
if let Instruction::Jump(Operand::Label(end_label)) = &instructions[2].instruction {
|
||||
if let Instruction::LabelDef(label3) = &instructions[3].instruction {
|
||||
if label3 == else_label {
|
||||
if let Instruction::Move(dst2, val2) = &instructions[4].instruction {
|
||||
if dst1 == dst2 {
|
||||
if let Instruction::LabelDef(label5) = &instructions[5].instruction
|
||||
{
|
||||
if label5 == end_label {
|
||||
// beqz means: if cond==0, goto else, so val1 is for true, val2 for false
|
||||
// select dst cond true_val false_val
|
||||
// When cond is non-zero (true), use val1, otherwise val2
|
||||
return Some((
|
||||
dst1.clone(),
|
||||
cond.clone(),
|
||||
val1.clone(),
|
||||
val2.clone(),
|
||||
6,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for bnez pattern
|
||||
if let Instruction::BranchNeZero(cond, Operand::Label(then_label)) =
|
||||
&instructions[0].instruction
|
||||
{
|
||||
if let Instruction::Move(dst1, val_false) = &instructions[1].instruction {
|
||||
if let Instruction::Jump(Operand::Label(end_label)) = &instructions[2].instruction {
|
||||
if let Instruction::LabelDef(label3) = &instructions[3].instruction {
|
||||
if label3 == then_label {
|
||||
if let Instruction::Move(dst2, val_true) = &instructions[4].instruction {
|
||||
if dst1 == dst2 {
|
||||
if let Instruction::LabelDef(label5) = &instructions[5].instruction
|
||||
{
|
||||
if label5 == end_label {
|
||||
// bnez means: if cond!=0, goto then, so val_true for true, val_false for false
|
||||
return Some((
|
||||
dst1.clone(),
|
||||
cond.clone(),
|
||||
val_true.clone(),
|
||||
val_false.clone(),
|
||||
6,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Finds a matching `pop ra` for a `push ra` at the start of the slice.
|
||||
/// Returns the index of the pop and the instructions in between.
|
||||
fn find_matching_ra_pop<'a>(
|
||||
instructions: &'a [InstructionNode<'a>],
|
||||
) -> Option<(usize, &'a [InstructionNode<'a>])> {
|
||||
if instructions.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Skip the push itself
|
||||
for (idx, node) in instructions.iter().enumerate().skip(1) {
|
||||
if let Instruction::Pop(Operand::ReturnAddress) = &node.instruction {
|
||||
// Found matching pop
|
||||
return Some((idx, &instructions[1..idx]));
|
||||
}
|
||||
|
||||
// Stop searching if we hit a jump (different control flow)
|
||||
// Labels are OK - they're just markers
|
||||
if matches!(
|
||||
node.instruction,
|
||||
Instruction::Jump(_) | Instruction::JumpRelative(_)
|
||||
) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Checks if an instruction uses or modifies the stack pointer.
|
||||
#[allow(dead_code)]
|
||||
fn uses_stack_pointer(instruction: &Instruction) -> bool {
|
||||
match instruction {
|
||||
Instruction::Push(_) | Instruction::Pop(_) | Instruction::Peek(_) => true,
|
||||
Instruction::Add(Operand::StackPointer, _, _)
|
||||
| Instruction::Sub(Operand::StackPointer, _, _)
|
||||
| Instruction::Mul(Operand::StackPointer, _, _)
|
||||
| Instruction::Div(Operand::StackPointer, _, _)
|
||||
| Instruction::Mod(Operand::StackPointer, _, _) => true,
|
||||
Instruction::Add(_, Operand::StackPointer, _)
|
||||
| Instruction::Sub(_, Operand::StackPointer, _)
|
||||
| Instruction::Mul(_, Operand::StackPointer, _)
|
||||
| Instruction::Div(_, Operand::StackPointer, _)
|
||||
| Instruction::Mod(_, Operand::StackPointer, _) => true,
|
||||
Instruction::Add(_, _, Operand::StackPointer)
|
||||
| Instruction::Sub(_, _, Operand::StackPointer)
|
||||
| Instruction::Mul(_, _, Operand::StackPointer)
|
||||
| Instruction::Div(_, _, Operand::StackPointer)
|
||||
| Instruction::Mod(_, _, Operand::StackPointer) => true,
|
||||
Instruction::Move(Operand::StackPointer, _)
|
||||
| Instruction::Move(_, Operand::StackPointer) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Adjusts stack pointer offsets in an instruction by decrementing them by a given amount.
|
||||
/// This is necessary when removing push operations that would have increased the stack size.
|
||||
fn adjust_stack_offset<'a>(instruction: Instruction<'a>, decrement: i64) -> Instruction<'a> {
|
||||
use rust_decimal::prelude::*;
|
||||
|
||||
match instruction {
|
||||
// Adjust arithmetic operations on sp that use literal offsets
|
||||
Instruction::Sub(dst, Operand::StackPointer, Operand::Number(n)) => {
|
||||
let new_n = n - Decimal::from(decrement);
|
||||
// If the result is 0 or negative, we may want to skip this entirely
|
||||
// but for now, just adjust the value
|
||||
Instruction::Sub(dst, Operand::StackPointer, Operand::Number(new_n))
|
||||
}
|
||||
Instruction::Add(dst, Operand::StackPointer, Operand::Number(n)) => {
|
||||
let new_n = n - Decimal::from(decrement);
|
||||
Instruction::Add(dst, Operand::StackPointer, Operand::Number(new_n))
|
||||
}
|
||||
// Return the instruction unchanged if it doesn't need adjustment
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum BranchType {
|
||||
Eq,
|
||||
Ne,
|
||||
Gt,
|
||||
Lt,
|
||||
Ge,
|
||||
Le,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_seq_beqz_to_bne() {
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::SetEq(
|
||||
Operand::Register(1),
|
||||
Operand::Register(2),
|
||||
Operand::Register(3),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::BranchEqZero(Operand::Register(1), Operand::Label("target".into())),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::BranchNe(_, _, _)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sne_beqz_to_beq() {
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::SetNe(
|
||||
Operand::Register(1),
|
||||
Operand::Register(2),
|
||||
Operand::Register(3),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::BranchEqZero(Operand::Register(1), Operand::Label("target".into())),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::BranchEq(_, _, _)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_seq_bnez_to_beq() {
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::SetEq(
|
||||
Operand::Register(1),
|
||||
Operand::Register(2),
|
||||
Operand::Register(3),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::BranchNeZero(Operand::Register(1), Operand::Label("target".into())),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::BranchEq(_, _, _)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sgt_beqz_to_ble() {
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::SetGt(
|
||||
Operand::Register(1),
|
||||
Operand::Register(2),
|
||||
Operand::Register(3),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::BranchEqZero(Operand::Register(1), Operand::Label("target".into())),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::BranchLe(_, _, _)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_branch_move_jump_to_select_beqz() {
|
||||
// Pattern: beqz r1 else / move r2 10 / j end / else: / move r2 20 / end:
|
||||
// Should convert to: select r2 r1 10 20
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::BranchEqZero(Operand::Register(1), Operand::Label("else".into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(2), Operand::Number(10.into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Jump(Operand::Label("end".into())), None),
|
||||
InstructionNode::new(Instruction::LabelDef("else".into()), None),
|
||||
InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(2), Operand::Number(20.into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::LabelDef("end".into()), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
if let Instruction::Select(dst, cond, true_val, false_val) = &output[0].instruction {
|
||||
assert!(matches!(dst, Operand::Register(2)));
|
||||
assert!(matches!(cond, Operand::Register(1)));
|
||||
assert!(matches!(true_val, Operand::Number(_)));
|
||||
assert!(matches!(false_val, Operand::Number(_)));
|
||||
} else {
|
||||
panic!("Expected Select instruction");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_branch_move_jump_to_select_bnez() {
|
||||
// Pattern: bnez r1 then / move r2 20 / j end / then: / move r2 10 / end:
|
||||
// Should convert to: select r2 r1 10 20
|
||||
let input = vec![
|
||||
InstructionNode::new(
|
||||
Instruction::BranchNeZero(Operand::Register(1), Operand::Label("then".into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(2), Operand::Number(20.into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Jump(Operand::Label("end".into())), None),
|
||||
InstructionNode::new(Instruction::LabelDef("then".into()), None),
|
||||
InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(2), Operand::Number(10.into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::LabelDef("end".into()), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
if let Instruction::Select(dst, cond, true_val, false_val) = &output[0].instruction {
|
||||
assert!(matches!(dst, Operand::Register(2)));
|
||||
assert!(matches!(cond, Operand::Register(1)));
|
||||
assert!(matches!(true_val, Operand::Number(_)));
|
||||
assert!(matches!(false_val, Operand::Number(_)));
|
||||
} else {
|
||||
panic!("Expected Select instruction");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_useless_ra_push_pop() {
|
||||
// Pattern: push ra / add r1 r2 r3 / pop ra
|
||||
// Should remove both push and pop since no jal in between
|
||||
let input = vec![
|
||||
InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(
|
||||
Instruction::Add(
|
||||
Operand::Register(1),
|
||||
Operand::Register(2),
|
||||
Operand::Register(3),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(output[0].instruction, Instruction::Add(_, _, _)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keep_ra_push_pop_with_jal() {
|
||||
// Pattern: push ra / jal func / pop ra
|
||||
// Should keep both since there's a jal in between
|
||||
let input = vec![
|
||||
InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(
|
||||
Instruction::JumpAndLink(Operand::Label("func".into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(!changed);
|
||||
assert_eq!(output.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ra_push_pop_with_stack_offset_adjustment() {
|
||||
// Pattern: push ra / sub r1 sp 2 / pop ra
|
||||
// Should remove push/pop AND adjust the stack offset from 2 to 1
|
||||
use rust_decimal::prelude::*;
|
||||
|
||||
let input = vec![
|
||||
InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(
|
||||
Instruction::Sub(
|
||||
Operand::Register(1),
|
||||
Operand::StackPointer,
|
||||
Operand::Number(Decimal::from(2)),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
|
||||
if let Instruction::Sub(dst, src, Operand::Number(offset)) = &output[0].instruction {
|
||||
assert!(matches!(dst, Operand::Register(1)));
|
||||
assert!(matches!(src, Operand::StackPointer));
|
||||
assert_eq!(*offset, Decimal::from(1)); // Should be decremented from 2 to 1
|
||||
} else {
|
||||
panic!("Expected Sub instruction with adjusted offset");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_sp_and_ra_push_pop() {
|
||||
// Pattern: push sp / push ra / move r8 10 / pop ra / pop sp
|
||||
// Should remove all four push/pop instructions since no jal in between
|
||||
let input = vec![
|
||||
InstructionNode::new(Instruction::Push(Operand::StackPointer), None),
|
||||
InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(
|
||||
Instruction::Move(Operand::Register(8), Operand::Number(10.into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(Instruction::Pop(Operand::StackPointer), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
assert!(matches!(
|
||||
output[0].instruction,
|
||||
Instruction::Move(Operand::Register(8), _)
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_keep_sp_and_ra_push_pop_with_jal() {
|
||||
// Pattern: push sp / push ra / jal func / pop ra / pop sp
|
||||
// Should keep all since there's a jal in between
|
||||
let input = vec![
|
||||
InstructionNode::new(Instruction::Push(Operand::StackPointer), None),
|
||||
InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(
|
||||
Instruction::JumpAndLink(Operand::Label("func".into())),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(Instruction::Pop(Operand::StackPointer), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(!changed);
|
||||
assert_eq!(output.len(), 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sp_and_ra_with_stack_offset_adjustment() {
|
||||
// Pattern: push sp / push ra / sub r1 sp 3 / pop ra / pop sp
|
||||
// Should remove all push/pop AND adjust the stack offset from 3 to 1 (decrement by 2)
|
||||
use rust_decimal::prelude::*;
|
||||
|
||||
let input = vec![
|
||||
InstructionNode::new(Instruction::Push(Operand::StackPointer), None),
|
||||
InstructionNode::new(Instruction::Push(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(
|
||||
Instruction::Sub(
|
||||
Operand::Register(1),
|
||||
Operand::StackPointer,
|
||||
Operand::Number(Decimal::from(3)),
|
||||
),
|
||||
None,
|
||||
),
|
||||
InstructionNode::new(Instruction::Pop(Operand::ReturnAddress), None),
|
||||
InstructionNode::new(Instruction::Pop(Operand::StackPointer), None),
|
||||
];
|
||||
|
||||
let (output, changed) = peephole_optimization(input);
|
||||
assert!(changed);
|
||||
assert_eq!(output.len(), 1);
|
||||
|
||||
if let Instruction::Sub(dst, src, Operand::Number(offset)) = &output[0].instruction {
|
||||
assert!(matches!(dst, Operand::Register(1)));
|
||||
assert!(matches!(src, Operand::StackPointer));
|
||||
assert_eq!(*offset, Decimal::from(1)); // Should be decremented from 3 to 1
|
||||
} else {
|
||||
panic!("Expected Sub instruction with adjusted offset");
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user