From dd433e174674300c5f89d6ef1820110028cba933 Mon Sep 17 00:00:00 2001 From: Devin Bidwell Date: Tue, 25 Nov 2025 00:26:17 -0700 Subject: [PATCH] Basic support for logical expressions --- libs/compiler/src/test/logic_expression.rs | 119 ++++++++++++ libs/compiler/src/test/mod.rs | 1 + libs/compiler/src/v1.rs | 109 ++++++++++- libs/parser/src/lib.rs | 203 ++++++++++++++++----- libs/parser/src/tree_node.rs | 2 + 5 files changed, 392 insertions(+), 42 deletions(-) create mode 100644 libs/compiler/src/test/logic_expression.rs diff --git a/libs/compiler/src/test/logic_expression.rs b/libs/compiler/src/test/logic_expression.rs new file mode 100644 index 0000000..e4a2ead --- /dev/null +++ b/libs/compiler/src/test/logic_expression.rs @@ -0,0 +1,119 @@ +use crate::compile; +use indoc::indoc; +use pretty_assertions::assert_eq; + +#[test] +fn test_comparison_expressions() -> anyhow::Result<()> { + let compiled = compile! { + debug + " + let isGreater = 10 > 5; + let isLess = 5 < 10; + let isEqual = 5 == 5; + let isNotEqual = 5 != 10; + let isGreaterOrEqual = 10 >= 10; + let isLessOrEqual = 5 <= 5; + " + }; + + assert_eq!( + compiled, + indoc! { + " + j main + main: + sgt r1 10 5 + move r8 r1 #isGreater + slt r2 5 10 + move r9 r2 #isLess + seq r3 5 5 + move r10 r3 #isEqual + sne r4 5 10 + move r11 r4 #isNotEqual + sge r5 10 10 + move r12 r5 #isGreaterOrEqual + sle r6 5 5 + move r13 r6 #isLessOrEqual + " + } + ); + + Ok(()) +} + +#[test] +fn test_logical_and_or_not() -> anyhow::Result<()> { + let compiled = compile! { + debug + " + let logic1 = 1 && 1; + let logic2 = 1 || 0; + let logic3 = !1; + " + }; + + assert_eq!( + compiled, + indoc! { + " + j main + main: + and r1 1 1 + move r8 r1 #logic1 + or r2 1 0 + move r9 r2 #logic2 + seq r3 1 0 + move r10 r3 #logic3 + " + } + ); + + Ok(()) +} + +#[test] +fn test_complex_logic() -> anyhow::Result<()> { + let compiled = compile! { + debug + " + let logic = (10 > 5) && (5 < 10); + " + }; + + assert_eq!( + compiled, + indoc! { + " + j main + main: + sgt r1 10 5 + slt r2 5 10 + and r3 r1 r2 + move r8 r3 #logic + " + } + ); + + Ok(()) +} + +#[test] +fn test_math_with_logic() -> anyhow::Result<()> { + let compiled = compile! { + debug + " + let logic = (1 + 2) > 1; + " + }; + + assert_eq!( + compiled, + indoc! { + " + + " + } + ); + + Ok(()) +} diff --git a/libs/compiler/src/test/mod.rs b/libs/compiler/src/test/mod.rs index 09fa1aa..05e91a0 100644 --- a/libs/compiler/src/test/mod.rs +++ b/libs/compiler/src/test/mod.rs @@ -44,3 +44,4 @@ mod binary_expression; mod declaration_function_invocation; mod declaration_literal; mod function_declaration; +mod logic_expression; diff --git a/libs/compiler/src/v1.rs b/libs/compiler/src/v1.rs index 8e0fd28..b7fc5d1 100644 --- a/libs/compiler/src/v1.rs +++ b/libs/compiler/src/v1.rs @@ -3,7 +3,7 @@ use parser::{ Parser as ASTParser, tree_node::{ BinaryExpression, BlockExpression, DeviceDeclarationExpression, Expression, - FunctionExpression, InvocationExpression, Literal, + FunctionExpression, InvocationExpression, Literal, LogicalExpression, }, }; use quick_error::quick_error; @@ -165,6 +165,10 @@ impl<'a, W: std::io::Write> Compiler<'a, W> { let result = self.expression_binary(bin_expr, scope)?; Ok(Some(result)) } + Expression::Logical(log_expr) => { + let result = self.expression_logical(log_expr, scope)?; + Ok(Some(result)) + } Expression::Literal(Literal::Number(num)) => { let temp_name = self.next_temp_name(); let loc = scope.add_variable(&temp_name, LocationRequest::Temp)?; @@ -280,6 +284,20 @@ impl<'a, W: std::io::Write> Compiler<'a, W> { } var_loc } + Expression::Logical(log_expr) => { + let result = self.expression_logical(log_expr, scope)?; + let var_loc = scope.add_variable(&var_name, LocationRequest::Persist)?; + + // Move result from temp to new persistent variable + let result_reg = self.resolve_register(&result.location)?; + self.emit_variable_assignment(&var_name, &var_loc, result_reg)?; + + // Free the temp result + if let Some(name) = result.temp_name { + scope.free_temp(name)?; + } + var_loc + } Expression::Variable(name) => { let src_loc = scope.get_location_of(&name)?; let var_loc = scope.add_variable(&var_name, LocationRequest::Persist)?; @@ -374,6 +392,15 @@ impl<'a, W: std::io::Write> Compiler<'a, W> { stack.free_temp(name)?; } } + Expression::Logical(log_expr) => { + // Compile the logical expression to a temp register + let result = self.expression_logical(log_expr, stack)?; + let reg_str = self.resolve_register(&result.location)?; + self.write_output(format!("push {reg_str}"))?; + if let Some(name) = result.temp_name { + stack.free_temp(name)?; + } + } _ => { return Err(Error::Unknown(format!( "Attempted to call `{}` with an unsupported argument type", @@ -524,6 +551,73 @@ impl<'a, W: std::io::Write> Compiler<'a, W> { }) } + fn expression_logical<'v>( + &mut self, + expr: LogicalExpression, + scope: &mut VariableScope<'v>, + ) -> Result { + match expr { + LogicalExpression::Not(inner) => { + let (inner_str, cleanup) = self.compile_operand(*inner, scope)?; + + let result_name = self.next_temp_name(); + let result_loc = scope.add_variable(&result_name, LocationRequest::Temp)?; + let result_reg = self.resolve_register(&result_loc)?; + + // seq rX rY 0 => if rY == 0 set rX = 1 else rX = 0 + self.write_output(format!("seq {result_reg} {inner_str} 0"))?; + + if let Some(name) = cleanup { + scope.free_temp(name)?; + } + + Ok(CompilationResult { + location: result_loc, + temp_name: Some(result_name), + }) + } + _ => { + let (op_str, left_expr, right_expr) = match expr { + LogicalExpression::And(l, r) => ("and", l, r), + LogicalExpression::Or(l, r) => ("or", l, r), + LogicalExpression::Equal(l, r) => ("seq", l, r), + LogicalExpression::NotEqual(l, r) => ("sne", l, r), + LogicalExpression::GreaterThan(l, r) => ("sgt", l, r), + LogicalExpression::GreaterThanOrEqual(l, r) => ("sge", l, r), + LogicalExpression::LessThan(l, r) => ("slt", l, r), + LogicalExpression::LessThanOrEqual(l, r) => ("sle", l, r), + LogicalExpression::Not(_) => unreachable!(), + }; + + // Compile LHS + let (lhs_str, lhs_cleanup) = self.compile_operand(*left_expr, scope)?; + // Compile RHS + let (rhs_str, rhs_cleanup) = self.compile_operand(*right_expr, scope)?; + + // Allocate result register + let result_name = self.next_temp_name(); + let result_loc = scope.add_variable(&result_name, LocationRequest::Temp)?; + let result_reg = self.resolve_register(&result_loc)?; + + // Emit instruction: op result lhs rhs + self.write_output(format!("{op_str} {result_reg} {lhs_str} {rhs_str}"))?; + + // Clean up operand temps + if let Some(name) = lhs_cleanup { + scope.free_temp(name)?; + } + if let Some(name) = rhs_cleanup { + scope.free_temp(name)?; + } + + Ok(CompilationResult { + location: result_loc, + temp_name: Some(result_name), + }) + } + } + } + fn expression_block<'v>( &mut self, mut expr: BlockExpression, @@ -623,6 +717,18 @@ impl<'a, W: std::io::Write> Compiler<'a, W> { scope.free_temp(name)?; } } + Expression::Logical(log_expr) => { + let result = self.expression_logical(log_expr, scope)?; + let result_reg = self.resolve_register(&result.location)?; + self.write_output(format!( + "move r{} {}", + VariableScope::RETURN_REGISTER, + result_reg + ))?; + if let Some(name) = result.temp_name { + scope.free_temp(name)?; + } + } _ => { return Err(Error::Unknown(format!( "Unsupported `return` statement: {:?}", @@ -748,3 +854,4 @@ impl<'a, W: std::io::Write> Compiler<'a, W> { Ok(()) } } + diff --git a/libs/parser/src/lib.rs b/libs/parser/src/lib.rs index 9e2a635..e504361 100644 --- a/libs/parser/src/lib.rs +++ b/libs/parser/src/lib.rs @@ -164,15 +164,22 @@ impl Parser { return Ok(None); }; - // check if the next or current token is an operator - if self_matches_peek!(self, TokenType::Symbol(s) if s.is_operator()) { - return Ok(Some(Expression::Binary(self.binary(lhs)?))); + // check if the next or current token is an operator, comparison, or logical symbol + if self_matches_peek!( + self, + TokenType::Symbol(s) if s.is_operator() || s.is_comparison() || s.is_logical() + ) { + return Ok(Some(self.infix(lhs)?)); } - // This is an edge case. We need to move back one token if the current token is an operator - // so the binary expression can pick up the operator - else if self_matches_current!(self, TokenType::Symbol(s) if s.is_operator()) { + // This is an edge case. We need to move back one token if the current token is an + // operator, comparison, or logical symbol so the binary expression can pick up + // the operator + else if self_matches_current!( + self, + TokenType::Symbol(s) if s.is_operator() || s.is_comparison() || s.is_logical() + ) { self.tokenizer.seek(SeekFrom::Current(-1))?; - return Ok(Some(Expression::Binary(self.binary(lhs)?))); + return Ok(Some(self.infix(lhs)?)); } Ok(Some(lhs)) @@ -254,6 +261,13 @@ impl Parser { Expression::Negation(boxed!(inner_expr)) } + // match logical NOT `!` + TokenType::Symbol(Symbol::LogicalNot) => { + self.assign_next()?; // consume the `!` symbol + let inner_expr = self.unary()?.ok_or(Error::UnexpectedEOF)?; + Expression::Logical(LogicalExpression::Not(boxed!(inner_expr))) + } + _ => { return Err(Error::UnexpectedToken(current_token.clone())); } @@ -262,7 +276,7 @@ impl Parser { Ok(Some(expr)) } - fn get_binary_child_node(&mut self) -> Result { + fn get_infix_child_node(&mut self) -> Result { let current_token = token_from_option!(self.current_token); match current_token.token_type { @@ -286,9 +300,15 @@ impl Parser { TokenType::Symbol(Symbol::Minus) => { self.assign_next()?; // recurse to handle double negation or simple negation of atoms - let inner = self.get_binary_child_node()?; + let inner = self.get_infix_child_node()?; Ok(Expression::Negation(boxed!(inner))) } + // Handle Logical Not + TokenType::Symbol(Symbol::LogicalNot) => { + self.assign_next()?; + let inner = self.get_infix_child_node()?; + Ok(Expression::Logical(LogicalExpression::Not(boxed!(inner)))) + } _ => Err(Error::UnexpectedToken(current_token.clone())), } } @@ -345,8 +365,8 @@ impl Parser { }) } - /// Handles mathmatical expressions in the explicit order of PEMDAS - fn binary(&mut self, previous: Expression) -> Result { + /// Handles mathmatical and logical expressions in the explicit order of operations + fn infix(&mut self, previous: Expression) -> Result { // We cannot use recursion here, as we need to handle the precedence of the operators // We need to use a loop to parse the binary expressions. @@ -354,15 +374,18 @@ impl Parser { // first, make sure the previous expression supports binary expressions match previous { - Expression::Binary(_) // 1 + 2 + 3 - | Expression::Invocation(_) // add() + 3 - | Expression::Priority(_) // (1 + 2) + 3 - | Expression::Literal(Literal::Number(_)) // 1 + 2 (no addition of strings) - | Expression::Variable(_) // x + 2 - | Expression::Negation(_) // -1 + 2 - => {} + Expression::Binary(_) + | Expression::Logical(_) + | Expression::Invocation(_) + | Expression::Priority(_) + | Expression::Literal(Literal::Number(_)) + | Expression::Variable(_) + | Expression::Negation(_) => {} _ => { - return Err(Error::InvalidSyntax(current_token.clone(), String::from("Invalid expression for binary operation"))) + return Err(Error::InvalidSyntax( + current_token.clone(), + String::from("Invalid expression for binary/logical operation"), + )); } } @@ -372,12 +395,15 @@ impl Parser { let mut operators = Vec::::new(); // +, + // build the expressions and operators vectors - while token_matches!(current_token, TokenType::Symbol(s) if s.is_operator()) { - // We are guaranteed to have an operator symbol here as we checked in the while loop + while token_matches!( + current_token, + TokenType::Symbol(s) if s.is_operator() || s.is_comparison() || s.is_logical() + ) { + // We are guaranteed to have an operator/comparison/logical symbol here as we checked in the while loop let operator = extract_token_data!(current_token, TokenType::Symbol(s), s); operators.push(operator); self.assign_next()?; - expressions.push(self.get_binary_child_node()?); + expressions.push(self.get_infix_child_node()?); current_token = token_from_option!(self.get_next()?).clone(); } @@ -394,7 +420,7 @@ impl Parser { // This means that we need to keep track of the current iteration to ensure we are // removing the correct expressions from the vector - // Loop through operators, and build the binary expressions for exponential operators only + // --- PRECEDENCE LEVEL 1: Exponent (**) --- for (i, operator) in operators.iter().enumerate().rev() { if operator == &Symbol::Exp { let right = expressions.remove(i + 1); @@ -405,12 +431,10 @@ impl Parser { ); } } - - // remove all the exponential operators from the operators vector operators.retain(|symbol| symbol != &Symbol::Exp); - let mut current_iteration = 0; - // Loop through operators, and build the binary expressions for multiplication and division operators + // --- PRECEDENCE LEVEL 2: Multiplicative (*, /, %) --- + let mut current_iteration = 0; for (i, operator) in operators.iter().enumerate() { if matches!(operator, Symbol::Slash | Symbol::Asterisk | Symbol::Percent) { let index = i - current_iteration; @@ -430,21 +454,18 @@ impl Parser { index, Expression::Binary(BinaryExpression::Modulo(boxed!(left), boxed!(right))), ), - // safety: we have already checked for the operator _ => unreachable!(), } current_iteration += 1; } } - - // remove all the multiplication and division operators from the operators vector operators .retain(|symbol| !matches!(symbol, Symbol::Asterisk | Symbol::Percent | Symbol::Slash)); - current_iteration = 0; - // Loop through operators, and build the binary expressions for addition and subtraction operators + // --- PRECEDENCE LEVEL 3: Additive (+, -) --- + current_iteration = 0; for (i, operator) in operators.iter().enumerate() { - if operator == &Symbol::Plus || operator == &Symbol::Minus { + if matches!(operator, Symbol::Plus | Symbol::Minus) { let index = i - current_iteration; let left = expressions.remove(index); let right = expressions.remove(index); @@ -458,16 +479,120 @@ impl Parser { index, Expression::Binary(BinaryExpression::Subtract(boxed!(left), boxed!(right))), ), - // safety: we have already checked for the operator _ => unreachable!(), } current_iteration += 1; } } - - // remove all the addition and subtraction operators from the operators vector operators.retain(|symbol| !matches!(symbol, Symbol::Plus | Symbol::Minus)); + // --- PRECEDENCE LEVEL 4: Comparison (<, >, <=, >=) --- + current_iteration = 0; + for (i, operator) in operators.iter().enumerate() { + if operator.is_comparison() && !matches!(operator, Symbol::Equal | Symbol::NotEqual) { + let index = i - current_iteration; + let left = expressions.remove(index); + let right = expressions.remove(index); + + match operator { + Symbol::LessThan => expressions.insert( + index, + Expression::Logical(LogicalExpression::LessThan( + boxed!(left), + boxed!(right), + )), + ), + Symbol::GreaterThan => expressions.insert( + index, + Expression::Logical(LogicalExpression::GreaterThan( + boxed!(left), + boxed!(right), + )), + ), + Symbol::LessThanOrEqual => expressions.insert( + index, + Expression::Logical(LogicalExpression::LessThanOrEqual( + boxed!(left), + boxed!(right), + )), + ), + Symbol::GreaterThanOrEqual => expressions.insert( + index, + Expression::Logical(LogicalExpression::GreaterThanOrEqual( + boxed!(left), + boxed!(right), + )), + ), + _ => unreachable!(), + } + current_iteration += 1; + } + } + operators.retain(|symbol| { + !symbol.is_comparison() || matches!(symbol, Symbol::Equal | Symbol::NotEqual) + }); + + // --- PRECEDENCE LEVEL 5: Equality (==, !=) --- + current_iteration = 0; + for (i, operator) in operators.iter().enumerate() { + if matches!(operator, Symbol::Equal | Symbol::NotEqual) { + let index = i - current_iteration; + let left = expressions.remove(index); + let right = expressions.remove(index); + + match operator { + Symbol::Equal => expressions.insert( + index, + Expression::Logical(LogicalExpression::Equal(boxed!(left), boxed!(right))), + ), + Symbol::NotEqual => expressions.insert( + index, + Expression::Logical(LogicalExpression::NotEqual( + boxed!(left), + boxed!(right), + )), + ), + _ => unreachable!(), + } + current_iteration += 1; + } + } + operators.retain(|symbol| !matches!(symbol, Symbol::Equal | Symbol::NotEqual)); + + // --- PRECEDENCE LEVEL 6: Logical AND (&&) --- + current_iteration = 0; + for (i, operator) in operators.iter().enumerate() { + if matches!(operator, Symbol::LogicalAnd) { + let index = i - current_iteration; + let left = expressions.remove(index); + let right = expressions.remove(index); + + expressions.insert( + index, + Expression::Logical(LogicalExpression::And(boxed!(left), boxed!(right))), + ); + current_iteration += 1; + } + } + operators.retain(|symbol| !matches!(symbol, Symbol::LogicalAnd)); + + // --- PRECEDENCE LEVEL 7: Logical OR (||) --- + current_iteration = 0; + for (i, operator) in operators.iter().enumerate() { + if matches!(operator, Symbol::LogicalOr) { + let index = i - current_iteration; + let left = expressions.remove(index); + let right = expressions.remove(index); + + expressions.insert( + index, + Expression::Logical(LogicalExpression::Or(boxed!(left), boxed!(right))), + ); + current_iteration += 1; + } + } + operators.retain(|symbol| !matches!(symbol, Symbol::LogicalOr)); + // Ensure there is only one expression left in the expressions vector, and no operators left if expressions.len() != 1 || !operators.is_empty() { return Err(Error::InvalidSyntax( @@ -484,11 +609,7 @@ impl Parser { self.tokenizer.seek(SeekFrom::Current(-1))?; } - // Ensure the last expression is a binary expression - match expressions.pop().unwrap() { - Expression::Binary(binary) => Ok(binary), - _ => unreachable!(), - } + Ok(expressions.pop().unwrap()) } fn priority(&mut self) -> Result, Error> { diff --git a/libs/parser/src/tree_node.rs b/libs/parser/src/tree_node.rs index 456a9c4..dbb4b94 100644 --- a/libs/parser/src/tree_node.rs +++ b/libs/parser/src/tree_node.rs @@ -5,6 +5,7 @@ use tokenizer::token::Number; pub enum Literal { Number(Number), String(String), + Boolean(bool), } impl std::fmt::Display for Literal { @@ -12,6 +13,7 @@ impl std::fmt::Display for Literal { match self { Literal::Number(n) => write!(f, "{}", n), Literal::String(s) => write!(f, "\"{}\"", s), + Literal::Boolean(b) => write!(f, "{}", if *b { 1 } else { 0 }), } } }