214 lines
6.9 KiB
Rust
214 lines
6.9 KiB
Rust
use crate::helpers::get_destination_reg;
|
|
use il::{Instruction, InstructionNode, Operand};
|
|
use rust_decimal::Decimal;
|
|
|
|
/// Pass: Constant Propagation
|
|
/// Folds arithmetic operations when both operands are constant.
|
|
/// Also tracks register values and propagates them forward.
|
|
pub fn constant_propagation<'a>(
|
|
input: Vec<InstructionNode<'a>>,
|
|
) -> (Vec<InstructionNode<'a>>, bool) {
|
|
let mut output = Vec::with_capacity(input.len());
|
|
let mut changed = false;
|
|
let mut registers: [Option<Decimal>; 16] = [None; 16];
|
|
|
|
for mut node in input {
|
|
// Invalidate register tracking on label/call boundaries
|
|
match &node.instruction {
|
|
Instruction::LabelDef(_) | Instruction::JumpAndLink(_) => registers = [None; 16],
|
|
_ => {}
|
|
}
|
|
|
|
let simplified = match &node.instruction {
|
|
Instruction::Move(dst, src) => resolve_value(src, ®isters)
|
|
.map(|val| Instruction::Move(dst.clone(), Operand::Number(val))),
|
|
Instruction::Add(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x + y),
|
|
Instruction::Sub(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x - y),
|
|
Instruction::Mul(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| x * y),
|
|
Instruction::Div(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| {
|
|
if y.is_zero() { Decimal::ZERO } else { x / y }
|
|
}),
|
|
Instruction::Mod(dst, a, b) => try_fold_math(dst, a, b, ®isters, |x, y| {
|
|
if y.is_zero() { Decimal::ZERO } else { x % y }
|
|
}),
|
|
Instruction::And(dst, a, b) => try_fold_bitwise(dst, a, b, ®isters, |x, y| x & y),
|
|
Instruction::Or(dst, a, b) => try_fold_bitwise(dst, a, b, ®isters, |x, y| x | y),
|
|
Instruction::Xor(dst, a, b) => try_fold_bitwise(dst, a, b, ®isters, |x, y| x ^ y),
|
|
Instruction::Sll(dst, a, b) => try_fold_bitwise(dst, a, b, ®isters, |x, y| {
|
|
if y >= 64 { 0 } else { x << y as u32 }
|
|
}),
|
|
Instruction::Sra(dst, a, b) => try_fold_bitwise(dst, a, b, ®isters, |x, y| {
|
|
if y >= 64 {
|
|
if x < 0 { -1 } else { 0 }
|
|
} else {
|
|
x >> y as u32
|
|
}
|
|
}),
|
|
Instruction::Srl(dst, a, b) => try_fold_bitwise(dst, a, b, ®isters, |x, y| {
|
|
if y >= 64 {
|
|
0
|
|
} else {
|
|
(x as u64 >> y as u32) as i64
|
|
}
|
|
}),
|
|
Instruction::BranchEq(a, b, l) => {
|
|
try_resolve_branch(a, b, l, ®isters, |x, y| x == y)
|
|
}
|
|
Instruction::BranchNe(a, b, l) => {
|
|
try_resolve_branch(a, b, l, ®isters, |x, y| x != y)
|
|
}
|
|
Instruction::BranchGt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x > y),
|
|
Instruction::BranchLt(a, b, l) => try_resolve_branch(a, b, l, ®isters, |x, y| x < y),
|
|
Instruction::BranchGe(a, b, l) => {
|
|
try_resolve_branch(a, b, l, ®isters, |x, y| x >= y)
|
|
}
|
|
Instruction::BranchLe(a, b, l) => {
|
|
try_resolve_branch(a, b, l, ®isters, |x, y| x <= y)
|
|
}
|
|
Instruction::BranchEqZero(a, l) => {
|
|
try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x == y)
|
|
}
|
|
Instruction::BranchNeZero(a, l) => {
|
|
try_resolve_branch(a, &Operand::Number(0.into()), l, ®isters, |x, y| x != y)
|
|
}
|
|
_ => None,
|
|
};
|
|
|
|
if let Some(new) = simplified {
|
|
node.instruction = new;
|
|
changed = true;
|
|
}
|
|
|
|
// Update register tracking
|
|
match &node.instruction {
|
|
Instruction::Move(Operand::Register(r), src) => {
|
|
registers[*r as usize] = resolve_value(src, ®isters)
|
|
}
|
|
_ => {
|
|
if let Some(r) = get_destination_reg(&node.instruction) {
|
|
registers[r as usize] = None;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Filter out NOPs (empty labels from branch resolution)
|
|
if let Instruction::LabelDef(l) = &node.instruction
|
|
&& l.is_empty()
|
|
{
|
|
changed = true;
|
|
continue;
|
|
}
|
|
|
|
output.push(node);
|
|
}
|
|
(output, changed)
|
|
}
|
|
|
|
fn resolve_value(op: &Operand, regs: &[Option<Decimal>; 16]) -> Option<Decimal> {
|
|
match op {
|
|
Operand::Number(n) => Some(*n),
|
|
Operand::Register(r) => regs[*r as usize],
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn try_fold_math<'a, F>(
|
|
dst: &Operand<'a>,
|
|
a: &Operand<'a>,
|
|
b: &Operand<'a>,
|
|
regs: &[Option<Decimal>; 16],
|
|
op: F,
|
|
) -> Option<Instruction<'a>>
|
|
where
|
|
F: Fn(Decimal, Decimal) -> Decimal,
|
|
{
|
|
let val_a = resolve_value(a, regs)?;
|
|
let val_b = resolve_value(b, regs)?;
|
|
Some(Instruction::Move(
|
|
dst.clone(),
|
|
Operand::Number(op(val_a, val_b)),
|
|
))
|
|
}
|
|
|
|
fn decimal_to_i64(d: Decimal) -> i64 {
|
|
// Convert decimal to i64, truncating if needed
|
|
if let Ok(int_val) = d.try_into() {
|
|
int_val
|
|
} else {
|
|
// For very large or very small values, use a default
|
|
if d.is_sign_negative() {
|
|
i64::MIN
|
|
} else {
|
|
i64::MAX
|
|
}
|
|
}
|
|
}
|
|
|
|
fn i64_to_decimal(i: i64) -> Decimal {
|
|
Decimal::from(i)
|
|
}
|
|
|
|
fn try_fold_bitwise<'a, F>(
|
|
dst: &Operand<'a>,
|
|
a: &Operand<'a>,
|
|
b: &Operand<'a>,
|
|
regs: &[Option<Decimal>; 16],
|
|
op: F,
|
|
) -> Option<Instruction<'a>>
|
|
where
|
|
F: Fn(i64, i64) -> i64,
|
|
{
|
|
let val_a = resolve_value(a, regs)?;
|
|
let val_b = resolve_value(b, regs)?;
|
|
let result = op(decimal_to_i64(val_a), decimal_to_i64(val_b));
|
|
Some(Instruction::Move(
|
|
dst.clone(),
|
|
Operand::Number(i64_to_decimal(result)),
|
|
))
|
|
}
|
|
|
|
fn try_resolve_branch<'a, F>(
|
|
a: &Operand<'a>,
|
|
b: &Operand<'a>,
|
|
label: &Operand<'a>,
|
|
regs: &[Option<Decimal>; 16],
|
|
check: F,
|
|
) -> Option<Instruction<'a>>
|
|
where
|
|
F: Fn(Decimal, Decimal) -> bool,
|
|
{
|
|
let val_a = resolve_value(a, regs)?;
|
|
let val_b = resolve_value(b, regs)?;
|
|
if check(val_a, val_b) {
|
|
Some(Instruction::Jump(label.clone()))
|
|
} else {
|
|
Some(Instruction::LabelDef("".into())) // NOP
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use il::InstructionNode;
|
|
|
|
#[test]
|
|
fn test_fold_add() {
|
|
let input = vec![InstructionNode::new(
|
|
Instruction::Add(
|
|
Operand::Register(1),
|
|
Operand::Number(5.into()),
|
|
Operand::Number(3.into()),
|
|
),
|
|
None,
|
|
)];
|
|
|
|
let (output, changed) = constant_propagation(input);
|
|
assert!(changed);
|
|
assert_eq!(output.len(), 1);
|
|
assert!(matches!(
|
|
output[0].instruction,
|
|
Instruction::Move(Operand::Register(1), Operand::Number(_))
|
|
));
|
|
}
|
|
}
|