fold nested literal binary expressions

This commit is contained in:
2025-12-05 23:25:23 -07:00
parent 9993bff574
commit a60e9d7dce
10 changed files with 176 additions and 60 deletions

View File

@@ -252,12 +252,13 @@ name = "compiler"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"crc32fast", "helpers",
"indoc", "indoc",
"lsp-types", "lsp-types",
"parser", "parser",
"pretty_assertions", "pretty_assertions",
"quick-error", "quick-error",
"rust_decimal",
"tokenizer", "tokenizer",
] ]
@@ -373,6 +374,9 @@ checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]] [[package]]
name = "helpers" name = "helpers"
version = "0.1.0" version = "0.1.0"
dependencies = [
"crc32fast",
]
[[package]] [[package]]
name = "indexmap" name = "indexmap"

View File

@@ -7,8 +7,9 @@ edition = "2024"
quick-error = { workspace = true } quick-error = { workspace = true }
parser = { path = "../parser" } parser = { path = "../parser" }
tokenizer = { path = "../tokenizer" } tokenizer = { path = "../tokenizer" }
helpers = { path = "../helpers" }
lsp-types = { workspace = true } lsp-types = { workspace = true }
crc32fast = { workspace = true } rust_decimal = { workspace = true }
[dev-dependencies] [dev-dependencies]
anyhow = { version = "1.0" } anyhow = { version = "1.0" }

View File

@@ -17,8 +17,7 @@ fn simple_binary_expression() -> anyhow::Result<()> {
" "
j main j main
main: main:
add r1 1 2 move r8 3 #i
move r8 r1 #i
" "
} }
); );
@@ -72,7 +71,7 @@ fn nested_binary_expressions() -> anyhow::Result<()> {
} }
#[test] #[test]
fn stress_test_negation_with_stack_spillover() -> anyhow::Result<()> { fn stress_test_constant_folding() -> anyhow::Result<()> {
let compiled = compile! { let compiled = compile! {
debug debug
" "
@@ -86,12 +85,7 @@ fn stress_test_negation_with_stack_spillover() -> anyhow::Result<()> {
" "
j main j main
main: main:
add r1 -1 -2 move r8 -123 #negationHell
add r2 -5 -6
mul r3 -4 r2
add r4 -3 r3
mul r5 r1 r4
move r8 r5 #negationHell
" "
} }
); );

View File

@@ -112,9 +112,8 @@ fn test_math_with_logic() -> anyhow::Result<()> {
" "
j main j main
main: main:
add r1 1 2 sgt r1 3 1
sgt r2 r1 1 move r8 r1 #logic
move r8 r2 #logic
" "
} }
); );

View File

@@ -88,7 +88,7 @@ fn test_set_on_device_batched() -> anyhow::Result<()> {
let compiled = compile! { let compiled = compile! {
debug debug
r#" r#"
let doorHash = hash("Door"); const doorHash = hash("Door");
setOnDeviceBatched(doorHash, "Lock", true); setOnDeviceBatched(doorHash, "Lock", true);
"# "#
}; };
@@ -99,9 +99,7 @@ fn test_set_on_device_batched() -> anyhow::Result<()> {
r#" r#"
j main j main
main: main:
move r15 HASH("Door") #hash_ret sb 718797587 Lock 1
move r8 r15 #doorHash
sb r8 Lock 1
"# "#
} }
); );
@@ -133,27 +131,3 @@ fn test_load_from_device() -> anyhow::Result<()> {
Ok(()) Ok(())
} }
#[test]
fn test_hash() -> anyhow::Result<()> {
let compiled = compile! {
debug
r#"
let nameHash = hash("testValue");
"#
};
assert_eq!(
compiled,
indoc! {
r#"
j main
main:
move r15 HASH("testValue") #hash_ret
move r8 r15 #nameHash
"#
}
);
Ok(())
}

View File

@@ -1,6 +1,6 @@
#![allow(clippy::result_large_err)] #![allow(clippy::result_large_err)]
use crate::variable_manager::{self, LocationRequest, VariableLocation, VariableScope}; use crate::variable_manager::{self, LocationRequest, VariableLocation, VariableScope};
use crc32fast::hash as crc32_hash; use helpers::prelude::*;
use parser::{ use parser::{
Parser as ASTParser, Parser as ASTParser,
sys_call::{SysCall, System}, sys_call::{SysCall, System},
@@ -559,15 +559,24 @@ impl<'a, W: std::io::Write> Compiler<'a, W> {
let result = self.expression_binary(bin_expr, scope)?; let result = self.expression_binary(bin_expr, scope)?;
let var_loc = scope.add_variable(&name_str, LocationRequest::Persist)?; let var_loc = scope.add_variable(&name_str, LocationRequest::Persist)?;
// Move result from temp to new persistent variable if let CompilationResult {
let result_reg = self.resolve_register(&result.location)?; location: VariableLocation::Constant(Literal::Number(num)),
self.emit_variable_assignment(&name_str, &var_loc, result_reg)?; ..
} = result
{
self.emit_variable_assignment(&name_str, &var_loc, num)?;
(var_loc, None)
} else {
// Move result from temp to new persistent variable
let result_reg = self.resolve_register(&result.location)?;
self.emit_variable_assignment(&name_str, &var_loc, result_reg)?;
// Free the temp result // Free the temp result
if let Some(name) = result.temp_name { if let Some(name) = result.temp_name {
scope.free_temp(name)?; scope.free_temp(name)?;
}
(var_loc, None)
} }
(var_loc, None)
} }
Expression::Logical(log_expr) => { Expression::Logical(log_expr) => {
let result = self.expression_logical(log_expr, scope)?; let result = self.expression_logical(log_expr, scope)?;
@@ -686,14 +695,7 @@ impl<'a, W: std::io::Write> Compiler<'a, W> {
LiteralOr::Or(Spanned { LiteralOr::Or(Spanned {
node: SysCall::System(System::Hash(Literal::String(str_to_hash))), node: SysCall::System(System::Hash(Literal::String(str_to_hash))),
.. ..
}) => { }) => Literal::Number(Number::Integer(crc_hash_signed(&str_to_hash))),
let hash = crc32_hash(str_to_hash.as_bytes());
// in stationeers, crc32 is a SIGNED int.
let hash_value_i32 = i32::from_le_bytes(hash.to_le_bytes());
Literal::Number(Number::Integer(hash_value_i32 as i128))
}
LiteralOr::Or(Spanned { span, .. }) => { LiteralOr::Or(Spanned { span, .. }) => {
return Err(Error::Unknown( return Err(Error::Unknown(
"hash only supports string literals in this context.".into(), "hash only supports string literals in this context.".into(),
@@ -1232,6 +1234,58 @@ impl<'a, W: std::io::Write> Compiler<'a, W> {
expr: Spanned<BinaryExpression>, expr: Spanned<BinaryExpression>,
scope: &mut VariableScope<'v>, scope: &mut VariableScope<'v>,
) -> Result<CompilationResult, Error> { ) -> Result<CompilationResult, Error> {
fn fold_binary_expression(expr: &BinaryExpression) -> Option<Number> {
let (lhs, rhs) = match &expr {
BinaryExpression::Add(l, r)
| BinaryExpression::Subtract(l, r)
| BinaryExpression::Multiply(l, r)
| BinaryExpression::Divide(l, r)
| BinaryExpression::Exponent(l, r)
| BinaryExpression::Modulo(l, r) => (fold_expression(l)?, fold_expression(r)?),
};
match expr {
BinaryExpression::Add(..) => Some(lhs + rhs),
BinaryExpression::Subtract(..) => Some(lhs - rhs),
BinaryExpression::Multiply(..) => Some(lhs * rhs),
BinaryExpression::Divide(..) => Some(lhs / rhs), // Watch out for div by zero panics!
BinaryExpression::Modulo(..) => Some(lhs % rhs),
_ => None, // Handle Exponent separately or implement pow
}
}
fn fold_expression(expr: &Expression) -> Option<Number> {
match expr {
// 1. Base Case: It's already a number
Expression::Literal(lit) => match lit.node {
Literal::Number(n) => Some(n),
_ => None,
},
// 2. Handle Parentheses: Just recurse deeper
Expression::Priority(inner) => fold_expression(&inner.node),
// 3. Handle Negation: Recurse, then negate
Expression::Negation(inner) => {
let val = fold_expression(&inner.node)?;
Some(-val) // Requires impl Neg for Number
}
// 4. Handle Binary Ops: Recurse BOTH sides, then combine
Expression::Binary(bin) => fold_binary_expression(&bin.node),
// 5. Anything else (Variables, Function Calls) cannot be compile-time folded
_ => None,
}
}
if let Some(const_lit) = fold_binary_expression(&expr.node) {
return Ok(CompilationResult {
location: VariableLocation::Constant(Literal::Number(const_lit)),
temp_name: None,
});
};
let (op_str, left_expr, right_expr) = match expr.node { let (op_str, left_expr, right_expr) = match expr.node {
BinaryExpression::Add(l, r) => ("add", l, r), BinaryExpression::Add(l, r) => ("add", l, r),
BinaryExpression::Multiply(l, r) => ("mul", l, r), BinaryExpression::Multiply(l, r) => ("mul", l, r),
@@ -1553,8 +1607,9 @@ impl<'a, W: std::io::Write> Compiler<'a, W> {
)); ));
}; };
let loc = VariableLocation::Persistant(VariableScope::RETURN_REGISTER); let loc = VariableLocation::Constant(Literal::Number(Number::Integer(
self.emit_variable_assignment("hash_ret", &loc, format!(r#"HASH("{}")"#, str_lit))?; crc_hash_signed(&str_lit),
)));
Ok(Some(CompilationResult { Ok(Some(CompilationResult {
location: loc, location: loc,

View File

@@ -4,3 +4,4 @@ version = "0.1.0"
edition = "2024" edition = "2024"
[dependencies] [dependencies]
crc32fast = { workspace = true }

View File

@@ -0,0 +1,11 @@
use crc32fast::hash as crc32_hash;
/// This function takes an input which is meant to be hashed via the CRC32 algorithm, but it then
/// converts the generated UNSIGNED number into it's SIGNED counterpart.
pub fn crc_hash_signed(input: &str) -> i128 {
let hash = crc32_hash(input.as_bytes());
// in stationeers, crc32 is a SIGNED int.
let hash_value_i32 = i32::from_le_bytes(hash.to_le_bytes());
hash_value_i32 as i128
}

View File

@@ -1,3 +1,4 @@
mod helper_funcs;
mod macros; mod macros;
mod syscall; mod syscall;
@@ -11,5 +12,6 @@ pub trait Documentation {
} }
pub mod prelude { pub mod prelude {
pub use super::helper_funcs::*;
pub use super::{Documentation, documented, with_syscalls}; pub use super::{Documentation, documented, with_syscalls};
} }

View File

@@ -173,6 +173,81 @@ pub enum Number {
Decimal(Decimal), Decimal(Decimal),
} }
impl From<Number> for Decimal {
fn from(value: Number) -> Self {
match value {
Number::Decimal(d) => d,
Number::Integer(i) => Decimal::from(i),
}
}
}
impl std::ops::Neg for Number {
type Output = Number;
fn neg(self) -> Self::Output {
match self {
Self::Integer(i) => Self::Integer(-i),
Self::Decimal(d) => Self::Decimal(-d),
}
}
}
impl std::ops::Add for Number {
type Output = Number;
fn add(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Self::Integer(l), Self::Integer(r)) => Number::Integer(l + r),
(Self::Decimal(l), Self::Decimal(r)) => Number::Decimal(l + r),
(Self::Integer(l), Self::Decimal(r)) => Number::Decimal(Decimal::from(l) + r),
(Self::Decimal(l), Self::Integer(r)) => Number::Decimal(l + Decimal::from(r)),
}
}
}
impl std::ops::Sub for Number {
type Output = Number;
fn sub(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Self::Integer(l), Self::Integer(r)) => Self::Integer(l - r),
(Self::Decimal(l), Self::Integer(r)) => Self::Decimal(l - Decimal::from(r)),
(Self::Integer(l), Self::Decimal(r)) => Self::Decimal(Decimal::from(l) - r),
(Self::Decimal(l), Self::Decimal(r)) => Self::Decimal(l - r),
}
}
}
impl std::ops::Mul for Number {
type Output = Number;
fn mul(self, rhs: Self) -> Self::Output {
match (self, rhs) {
(Number::Integer(l), Number::Integer(r)) => Number::Integer(l * r),
(Number::Integer(l), Number::Decimal(r)) => Number::Decimal(Decimal::from(l) * r),
(Number::Decimal(l), Number::Integer(r)) => Number::Decimal(l * Decimal::from(r)),
(Number::Decimal(l), Number::Decimal(r)) => Number::Decimal(l * r),
}
}
}
impl std::ops::Div for Number {
type Output = Number;
fn div(self, rhs: Self) -> Self::Output {
Number::Decimal(Decimal::from(self) / Decimal::from(rhs))
}
}
impl std::ops::Rem for Number {
type Output = Number;
fn rem(self, rhs: Self) -> Self::Output {
Number::Decimal(Decimal::from(self) % Decimal::from(rhs))
}
}
impl std::fmt::Display for Number { impl std::fmt::Display for Number {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self { match self {