Pattern Matching in Scala

In my article on parser combinators, I developed a parser that parses simple integer sums and products into an internal syntax tree. In this article, I’ll make a simple evaluator that takes such a syntax tree as parsed and computes the value of the expression.

This is the sort of thing that in Java you might do using the Visitor Pattern, but Scala (like some other functional languages) has a powerful mechanism called “pattern matching” that makes this somewhat complex pattern unnecessary in most cases, and allows us to create quite concise code.

For reference, the syntax created in the last article was given by:

1
2
3
4
5
6
7
8
9
10
package com.donroby.parsing
 
trait ExpressionSyntax {
 
  sealed abstract class Expression
 
  case class IntegerLiteral(i: Int) extends Expression
  case class Sum(e1: Expression, e2: Expression) extends Expression
  case class Product(e1: Expression, e2: Expression) extends Expression
}

Starting as usual with a test, I write:

1
2
3
4
5
6
7
8
9
10
11
12
package com.donroby.parsing
 
import com.donroby.parsing.ExpressionSyntaxEvaluator._
import org.scalatest.FlatSpec
 
class ExpressionSyntaxEvaluatorSpec  extends FlatSpec {
 
  "The expression syntax evaluator "  should "evaluate a literal" in {
    assert(evaluate(IntegerLiteral(42)) == 42)
  }
 
}

And in order to get it to compile, I create a scala object that will do the evaluation but leave the `evaluate` method unimplemented so the test definitely fails on first run:

1
2
3
4
5
package com.donroby.parsing
 
object ExpressionSyntaxEvaluator extends ExpressionSyntax {
  def evaluate(e: Expression): Int = ???
}

Now as I’ve already decided I’m going to use pattern matching to do this, I’ll start implementing a simple “case-expression” with only the already specified literal handling and a default case:

1
2
3
4
5
6
7
8
package com.donroby.parsing
 
object ExpressionSyntaxEvaluator extends ExpressionSyntax {
  def evaluate(e: Expression): Int = e match {
    case IntegerLiteral(i) => i
    case _ => 0
  }
}

This makes the test pass of course. It works because Scala supplies this syntax for matching a pattern based on case classes, which I used when defining the syntax. (It also works for some other classes supplied with “extractor objects”.)

What this case expression does depends of course on what Expression you pass it. If it is an IntegerLiteral, it unwraps (“extracts”) the contained value and returns it. For any other expression it returns 0. We’ll of course be adding cases for our other syntax classes.

Another test, for evaluating a sum:

1
2
3
  it should "evaluate a sum of two literals" in {
    assert(evaluate(Sum(IntegerLiteral(2), IntegerLiteral(3))) == 5)
  }

This test can be made to pass by adding a more complex case:

1
2
3
4
5
6
7
object ExpressionSyntaxEvaluator extends ExpressionSyntax {
  def evaluate(e: Expression): Int = e match {
    case IntegerLiteral(i) => i
    case Sum(IntegerLiteral(i), IntegerLiteral(j)) => i + j
    case _ => 0
  }
}

This is a bit silly, but it does reveal how complex the cases in such an expression are allowed to be. It will change soon.

I’ll make another test that forces us to handle complex sums instead of just sums of two integers:

1
2
3
4
5
6
7
8
9
10
11
12
13
  it should "evaluate a sum of two sum expressions" in {
    assert(
      evaluate(
        Sum(
          Sum(
            IntegerLiteral(2),
            IntegerLiteral(2)),
          Sum(
            IntegerLiteral(3),
            IntegerLiteral(1))
          )
      ) == 8)
  }

This of course fails, as we only handle sums of literals.

We can add a case for sums of arbitrary expressions:

1
    case Sum(e1, e2) => evaluate(e1) + evaluate(e2)

But as a literal is itself an expression, and evaluation of a literal is already defined to do the right thing, we can also remove the messy pattern match in the last case, leaving us with:

1
2
3
4
5
6
7
object ExpressionSyntaxEvaluator extends ExpressionSyntax {
  def evaluate(e: Expression): Int = e match {
    case IntegerLiteral(i) => i
    case Sum(e1, e2) => evaluate(e1) + evaluate(e2)
    case _ => 0
  }
}

Of course, we can do similarly for products, resulting in a complete test:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
package com.donroby.parsing
 
import com.donroby.parsing.ExpressionSyntaxEvaluator._
import org.scalatest.FlatSpec
 
class ExpressionSyntaxEvaluatorSpec  extends FlatSpec {
 
  "The expression syntax evaluator "  should "evaluate a literal" in {
    assert(evaluate(IntegerLiteral(42)) == 42)
  }
 
  it should "evaluate a sum of two literals" in {
    assert(evaluate(Sum(IntegerLiteral(2), IntegerLiteral(3))) == 5)
  }
 
  it should "evaluate a sum of two sum expressions" in {
    assert(
      evaluate(
        Sum(
          Sum(
            IntegerLiteral(2),
            IntegerLiteral(2)),
          Sum(
            IntegerLiteral(3),
            IntegerLiteral(1))
          )
      ) == 8)
  }
 
  it should "evaluate a product of two literals" in {
    assert(evaluate(Product(IntegerLiteral(2), IntegerLiteral(3))) == 6)
  }
 
  it should "evaluate a product of two expressions" in {
    assert(
      evaluate(
        Product(
          Product(
            IntegerLiteral(2),
            IntegerLiteral(2)),
          Sum(
            IntegerLiteral(3),
            IntegerLiteral(1))
        )
      ) == 16)
  }
 
}

And quite simple evaluating code:

1
2
3
4
5
6
7
8
9
10
package com.donroby.parsing
 
object ExpressionSyntaxEvaluator extends ExpressionSyntax {
  def evaluate(e: Expression): Int = e match {
    case IntegerLiteral(i) => i
    case Sum(e1, e2) => evaluate(e1) + evaluate(e2)
    case Product(e1, e2) => evaluate(e1) * evaluate(e2)
    case _ => 0
  }
}

The default case will now never be reached, as all possibilities are covered by the first three, but I elect to keep it as a safety, to ease adding more expression types later.

Leave a Reply