Introduction To Scala — Part4 — Types and Case-classes

/**
 * Exercise 4.1 Suppose you are given the following definition of a list type.
 *
 *    type 'a mylist = Nil | Cons of 'a * 'a mylist
 *
 * 1. Write a function map : ('a -> 'b) -> 'a mylist -> 'b mylist, where
 *
 *    map f [x0 ; x1 ; ...; xn ] = [f x0 ; f x1 ; ...; f xn ].
 *
 * 2. Write a function append : 'a mylist -> 'a mylist -> 'a mylist, where
 *
 *    append [x1 ;...; xn ] [xn+1 ; ...; xn+m ] = [x1 ; ...; xn+m ].
 */
sealed abstract class MyList[T]
 case class Cons[T](h: T, v: MyList[T]) extends MyList[T]
 case class Nil[T]() extends MyList[T]
 
// 1. This is a non-tail-recursive version of map
def map[U, T](l: MyList[U])(f: U => T): MyList[T] = l match {
  case Nil()      => Nil()
  case Cons(h, t) => Cons(f(h), map(t){f})
}
 
// 2. This is a not-tail-recursive append
def append[T](l1: MyList[T], l2: MyList[T]): MyList[T] = l1 match {
  case Nil() => l2
  case Cons(h, t) => Cons(h, append(t, l2))
}
 
 
// Test map
val lst1 = Cons(12, Cons(100, Cons(17, Cons(8, Nil()))))
val lst2 = Cons(9, Cons(4, Cons(2, Nil())))
val mapped1 = map(lst1){x => 1.0 * x * x}
val mapped2 = map(lst2){x => math.sqrt(x)}
println("List1: " + lst1)
println("Not tail-recursive mapping: " + mapped1)
println("\nList2: " + lst2)
println("Not tail-recursive mapping: " + mapped2)
println("\nConcatenation of Lis1 and Lis2: " + append(lst1, lst2))

 

 

/**
 * Exercise 4.2 A type of unary (base-1) natural numbers can be defined as
 * follows,
 *
 *    type unary_number = Z | S of unary_number
 *
 * where Z represents the number zero, and if i is a unary number, then S i
 * is i + 1. For example, the number 5 would be represented as
 * S (S (S (S (S Z)))).
 *
 * 1. Write a function to add two unary numbers. What is the complexity of your
 * function?
 *
 * 2. Write a function to multiply two unary numbers.
 */
sealed abstract class Unary
  case class Z() extends Unary
  case class S(counter: Unary) extends Unary
 
// 1. Write a function to add two unary numbers. What is the complexity of your
// function?
// The complexity of an expression add(m, n) is O(n)
def add(m: Unary, n: Unary): Unary = n match {
  case S(n) => add(S(m), n)
  case Z()  => m
}
 
// 2. Write a function to multiply two unary numbers.
def multiply(m: Unary, n: Unary): Unary = {
  def iter(i: Unary, acc: Unary): Unary = i match {
    case Z()  => acc
    case S(n) => iter(n, add(acc, m))
  }
  iter(n, Z())
}
 
// test of functions
val u5 = S(S (S (S (S(Z())))))
val u4 = S(S (S (S(Z()))))
val zero = Z()
println("5 + 4 = 9, in unary is: ")
println(u5 + " + " + u4 + " = " + add(u5, u4))
println("\n5 + 0 = 5, in unary is: ")
println(u5 + " + " + zero + " = " + add(u5, zero))
println("\n5 * 0 = 0, in unary is: ")
println(u5 + " * " + zero + " = " + multiply(u5, zero))
println("\n5 * 4 = 20, in unary is: ")
println(u5 + " * " + u4 + " = " + multiply(u5, u4))

 

 

/**
 * Exercise 4.3 Suppose we have the following definition for a type of small
 * numbers.
 *
 *      type small = Four | Three | Two | One
 *
 * The builtin comparison (< ) orders the numbers in reverse order.
 * 
 * scala> Four < Three
 * res6: Boolean = false
 *
 * 1. Write a function lt_small : small -> small -> bool that orders the
 * numbers in the normal way.
 *
 * 2. Suppose the type small defines n small integers. How does the size of
 * your code depend on n?
 */
sealed abstract class SmallNumbers
  case class One() extends SmallNumbers
  case class Two() extends SmallNumbers
  case class Three() extends SmallNumbers
  case class Four() extends SmallNumbers
 
 
def lt_small(n1: SmallNumbers, n2: SmallNumbers): Boolean = (n1, n2) match {
  case (One(), Two() | Three() | Four()) => true
  case (Two(), Three() | Four()) => true
  case (Three(), Four()) => true
  case (One(), One()) => false
  case (Two(), One() | Two()) => false
  case (Three(), One() | Two() | Three()) => false
  case (Four(), One() | Two() | Three() | Four()) => false
}
 
 
// 2. Suppose the type small defines n small integers. How does the size of
// your code depend on n?
// R. Using lt_small, the code grows in $O(n^2)$. To reduce the code size,
// transform the number in integers and uses Scala's native operators for
// comparison.
def smallNumbersToInt(number: SmallNumbers): Int = number match {
  case One()   => 1
  case Two()   => 2
  case Three() => 3
  case Four()  => 4
}
def lt_small_smarty(n1: SmallNumbers, n2: SmallNumbers): Boolean =
  smallNumbersToInt(n1) < smallNumbersToInt(n2)
 
// test cases
println("Naive lt_small: ")
println("One < Two? " + lt_small(One(), Two()))
println("Four < Three? " + lt_small(Four(), Three()))
println("Two < Two? " + lt_small(Two(), Two()))
 
println("\nNot so naive lt_small")
println("One < Two? " + lt_small_smarty(One(), Two()))
println("Four < Three? " + lt_small_smarty(Four(), Three()))
println("Two < Two? " + lt_small_smarty(Two(), Two()))

 

 

/**
 * Exercise 4.4 We can define a data type for simple arithmetic expressions as
 * follows.
 *
 *    type unop = Neg
 *    type binop = Add | Sub | Mul | Div
 *    type exp =
 *      Constant of int
 *      | Unary of unop * exp
 *      | Binary of exp * binop * exp
 *
 * Write a function eval : exp -> int to evaluate an expression, performing
 * the calculation to produce an integer result.
 */
sealed abstract class UnOp
  case class Neg() extends UnOp
 
sealed abstract class BinOp
  case class Add() extends BinOp
  case class Sub() extends BinOp
  case class Mul() extends BinOp
  case class Div() extends BinOp
 
sealed abstract class Exp
  case class Constant(value: Int) extends Exp
  case class Unary(unop: UnOp, expr: Exp) extends Exp
  case class Binary(expr1: Exp, binop: BinOp, expr2: Exp) extends Exp
 
def eval(e: Exp): Int = e match {
  case Constant(i)        => i
  case Unary(Neg(), e)    => -1 * eval(e)
  case Binary(e1, op, e2) => {
    val (el, er) = (eval(e1), eval(e2))
    val ret = op match {
      case Add() => el + er
      case Sub() => el - er
      case Mul() => el * er
      case Div() => el / er
    }
    ret
  }
}
 
 
// Test
val b1 = Binary(Constant(12), Add(), Constant(15))
val u1 = Unary(Neg(), Constant(10))
val express = Binary(b1, Sub(), u1)
println("Evaluation of (12 + 15) - (-10) = 37: ")
println(express + " = " + eval(express))

 

 

/**
 * Exercise 4.5. A way to implement a dictionary is with tree, where each
 * node in the tree has a label and a value.
 *
 * 1. Implement a polymorphic dictionary, (’key, ’value) dictionary, as a tree
 * with the three dictionary operations.
 *
 *    empty: (’key, ’value) dict
 *    add: (’key, ’value) dict -> ’key -> ’value -> (’key, ’value) dict
 *    find : (’key, ’value) dict -> ’key -> ’value
 *
 */
sealed abstract class Dict[T,U]
  case class Node[T,U](no:Tuple2[T,U], l: Dict[T,U], r: Dict[T,U])
    extends Dict[T,U]
  case class Leaf[T,U]() extends Dict[T,U]
 
def empty = Leaf()
 
def add[T < % Ordered[T],U](dict: Dict[T, U], key: T, value: U): Dict[T,U] =
  dict match {
    case Leaf() =>
      Node((key, value), Leaf(), Leaf())
    case Node((k, v), left, right) if key < k =>
      Node((k, v), add(left, key, value), right)
    case Node((k, v), left, right) if key > k =>
      Node((k, v), left, add(right, key, value))
    case Node((k, v), left, right) if key == k =>
      Node((key, value), left, right)
    case _ => throw new Exception("Invalid Argument in add dict.")
  }
 
def find[T < % Ordered[T],U](dict: Dict[T, U], key: T): U =
  dict match {
    case Leaf() => throw new Exception("Not Found.")
    case Node((k, v), left, right) if key < k => find(left, key)
    case Node((k, v), left, right) if key > k => find(right, key)
    case Node((k, v), left, right) if key == k => v
    case _ => throw new Exception("Invalid Argument in add dict.")
  }

 

 

/** 
 * Exercise 4.6 Consider the function insert for unbalanced, ordered, binary
 * trees. One potential implementation problem is when uses the builtin
 * comparison (< ). Rewrite the definition so the it is parameterized by a
 * comparison function that, given two elements, returns on of three values
 *
 *    type comparison = LessThan | Equal | GreaterThan.
 *
 * The expression insert compare x tree inserts an element x into the tree
 * tree. The type is
 * 
 *    insert : (’a -> ’a -> comparison) -> ’a -> ’a tree -> ’a tree.
 */
 
sealed abstract class Tree[T]
  case class Leaf[T]() extends Tree[T]
  case class Node[T](v: T, left: Tree[T], right: Tree[T]) extends Tree[T]
 
sealed abstract class Comparison
  case class LessThan() extends Comparison
  case class Equal() extends Comparison
  case class GreaterThan() extends Comparison
 
def insert[T](value: T, tree: Tree[T]) (comp: Comparison): Tree[T] =
  tree match {
    case Leaf() => Node(value, Leaf(), Leaf())
    case Node(v, left, right) =>
      comp match {
        case Equal() => Node(v, left, right)
        case LessThan() => Node(v, insert(value, left){comp}, right)
        case GreaterThan() => Node(v, left, insert(value, right){comp})
      }
  }
 
def listToTree[T](l: List[T]): Tree[T] = l match {
  case Nil    => Leaf()
  case h :: t => insert(h, listToTree(t)){LessThan()}
}
 
 
// Tests
val list = List(7, 5, 9, 11, 3)
val tree = listToTree(list)
 
println("Convert list: " + list + " into tree: " + tree)

 

 

/**
 * Exercise 4.7 A heap of integers is a data structure supporting the
 * following operations.
 *
 *   makeheap: int -> heap: create a heap containing a single element,
 *
 *   insert: heap -> int -> heap: add an element to a heap; duplicates are
 *            alowed,
 * 
 *   findmin: heap -> int: return the smallest element of the heap.
 *
 *   deletemin: heap -> heap: return a new heap that is the same as the
 *               original, without the smallest element.
 *
 *   meld: heap -> heap -> heap: join two heaps into a new heap containing the
 *         elements of both.
 *
 * A heap can be represented as a binary tree, where for any node a, if b is a
 * child node of a, then label(a) < = label(b). The order of children does not
 * matter. A pairing heap is a particular implementation where the operations
 * are performed as follows.
 *
 *    makeheap i: produce a single-node tree with i at the root.
 *
 *    insert h i = meld h (makeheap i).
 *
 *    findmin h: return the root label.
 *
 *    deletemin h: remove the root, and meld the subtrees.
 *
 *    meld h1 h2: compare the roots, and make the heap with the larger element
 *    a subtree of the other.
 *
 * 1. Define a type heap and implement the five operations.
 * 
 * 2. A heap sort is performed by inserting the elements to be sorted into a
 *    heap, then the values are extracted from smallest to largest. Write a
 *    function heap_sort : int list -> int list that performs a heap sort,
 *    where the result is sorted from largest to smallest.
 */
 
// 1. Define a type heap and implement the five operations.
type T = Int
 
sealed abstract class Heap
  case class Node(value: T, l: List[Node]) extends Heap
 
def makeheap(i: T) = Node(i, Nil)
 
def meld(h1: Node, h2: Node): Node = (h1, h2) match {
  case (Node(hh1: T, t1: List[Node]), Node(hh2: T, t2: List[Node])) => 
    if (hh1 > hh2) Node(hh1, h2 :: t1)
    else Node(hh2, h1 :: t2)
}
 
def insert(h: Node, i: T): Node =
  meld(h, makeheap(i))
 
def findmin(h: Node): T = h match {
  case Node(i: T, _) => i
}
 
def deletemin(h: Node): Node = h match {
  case Node(_, h :: t) => {
    def iter(h: Node, acc: List[Node]): Node = acc match {
      case x :: t => iter(meld(h, x), t)
      case Nil    => h
    }
    iter(h, t)
  }
  case Node(_, List(h)) => h
  case Node(_, Nil) => throw new Exception("Invalid Argument in deletemin.")
}
 
// 2. heapsort
 
def heap_sort(l: List[T]): List[T] = l match {
  case Nil => Nil
  case x :: t => {
    def insertList(h: Node, l: List[T]): Node = l match {
      case Nil => h
      case x :: t => insertList(insert(h, x), t)
    }
    def iter(sorted: List[T], h: Node): List[T] = h match {
      case Node(x, Nil) => x :: sorted
      case _ => iter(findmin(h) :: sorted, deletemin(h))
    }
    iter(Nil, insertList(makeheap(x), t))
  }
}
 
// tests
var h = makeheap(12)
h = insert(h, 1)
h = insert(h, 17)
h = insert(h, 200)
val l = List(5, 7, 8, 10, 11, 0, 101, 112, 99)
println("heap: " + h)
println(l + " sorted " + heap_sort(l))