(*------------------------------------------------------------------
   file  : preempt_co_thread.sml
   date  : August 22, 1990
   info  : goes with SML/NJ version 0.59
   author: Eric Cooper and Greg Morrisett
   desc  :

         This file contains the functor for the Standard ML Thread 
         signature plus primitive preemptive scheduling hooks.  This
	 implementation uses continuations as threads and the UNIX
         SIGALRM to trigger preemption.

         The corresponding signature is in the file:
         preempt_co_thread.sig.sml

         Many of the ideas used in this implementation were taken
         from John Reppy's preemptive Events implementation.

         Special thanks to Andrzej Filinski for help with the
         per-thread state implementation.
 ------------------------------------------------------------------*)
import "queue.sig";
import "preempt_co_thread.sig"; 

functor Preempt_Co_Thread (Queue : QUEUE) : PREEMPT_CO_THREAD =
    struct
        (********************)
        (* Thread Structure *)
        (********************)
	structure Thread = 
	    struct
                (************************************************)
                (* per-thread state                             *)
		(************************************************)
		type env = unit ref
		datatype 'a var = VAR of (env * 'a) list ref
		exception Undefined

		fun new_env () = ref ()

		val current_env = ref (new_env ())

		fun var a = VAR (ref [(!current_env, a)])

		fun find _ [] = raise Undefined
		  | find env ((e, a) :: rest) =
		    if e = env then a else find env rest

		fun get (VAR v) = find (!current_env) (!v)

		fun replace env [] a = [(env, a)]
		  | replace env ((pair as (e, _)) :: rest) a =
		    if e = env then (e, a) :: rest
		    else pair :: replace env rest a

		fun set (VAR v) a = (v := replace (!current_env) (!v) a)

		(************************************************)
		(* miscellaneous                                *)
		(************************************************)
		fun bracket pre post obj body =
		    let val _ = pre obj
			val result = body () handle exn =>
			    (post obj; raise exn)
		    in
			post obj;
			result
		    end

		(************************************************)
	        (* threads                                      *)
		(************************************************)
		datatype thread = THREAD of unit cont * env
		    
		fun thread k = THREAD (k, !current_env)
		    
		fun block thread q = Queue.enq q thread

		(************************************************)
		(* scheduling                                   *)
		(************************************************)
		val run_queue : (thread Queue.t) ref = ref (Queue.create ())
		    
		fun reschedule thread = block thread (!run_queue) 
		    
		exception Deadlock

		fun next () =
		    Queue.deq (!run_queue)
		    handle Queue.Deq => raise Deadlock

		(*---------------------------------------------------------*)
		(* When atomicLevel goes above 0, no context switching     *)
		(* (logically) occurs.  If it goes below 0, the exception  *)
		(* Atomic is raised.                                       *)
		(*---------------------------------------------------------*)
		val atomicLevel = ref 0 
		exception Atomic

		(*---------------------------------------------------------*)
		(* Set when a SIGALRM occurs during an atomic region.      *)
		(*---------------------------------------------------------*)
		val signalPending = ref false 

		(*---------------------------------------------------------*)
		(* a call to setTimer sets SIGALRM to go off every msec    *)
		(* milliseconds.  Note that if msec=0 then the timer is    *)
		(* turned off.                                             *)
		(*---------------------------------------------------------*)
		local
		    open System.Timer
		    val setitimer = System.Unsafe.CInterface.setitimer
		    val timerInterval = ref 0 
		in
		    fun setTimer msec = 
			(setitimer(0,TIME{sec=0,usec=1000*msec}, 
				   TIME{sec=0,usec=1000*msec});  
			 timerInterval := msec)                 
		end

		(*---------------------------------------------------------*)
		(* Turns preemptive scheduling on and off by installing a  *)
		(* handler for SIGALRM and then calling setTimer.          *)
		(*---------------------------------------------------------*)
		local 
		    open System.Signals
			
		    val oldAlarmHandler = inqHandler SIGALRM 
			
		    fun alarmHandler(_,k) = 
			if (!atomicLevel = 0) then
			    (reschedule (thread k);
			     let val THREAD(c,env) = next() in
				 current_env := env; c
			     end)
			else
			    (signalPending := true; k)
		in
		    fun setPreempt' (SOME msec) =
			(setHandler(SIGALRM,SOME(alarmHandler));
			 setTimer(msec);                  
			 ())               
		      | setPreempt' NONE =                         
			(setTimer(0);    
			 setHandler(SIGALRM,oldAlarmHandler))
		end


		fun enterAtomic () = 
		    atomicLevel := !atomicLevel + 1
		
		(*---------------------------------------------------------*)
		(* Used to exit an atomic region during a context switch   *)
		(* of some sort.  Since we're doing a context switch       *)
		(* already, no check is made to see if there's a signal    *)
		(* pending.                                                *)
		(*---------------------------------------------------------*)
		fun resetAtomic () =
		    let val atomicLev = (!atomicLevel)
		    in
			if (atomicLev = 0) then
			    raise Atomic
			else
			    atomicLevel := 0
		    end

		(*---------------------------------------------------------*)
		(* Gives the next thread in the running queue the cpu.     *)
		(*---------------------------------------------------------*)   
		fun switch () = 
		    let val _ = enterAtomic()
			val THREAD(kont,env) = next() 
		    in 
			current_env := env;
			resetAtomic(); 
			throw kont ()   
		    end

		(*---------------------------------------------------------*)
		(* The current thread gives up the cpu and hands it off to *)
		(* the next thread in the running queue.                   *)
		(*---------------------------------------------------------*)
		fun yield () =
		    (enterAtomic();
		     callcc (fn oldKont =>
			     (reschedule (thread oldKont);
			      let 
				  val THREAD(newKont,newEnv) = next()
			      in
				  current_env := newEnv;
				  resetAtomic();
				  throw newKont (); ()
			      end)))

		(*---------------------------------------------------------*)
		(* Used to leave an atomic region.  Checks to see if a     *)
		(* signal is pending.  If so, and we're not in a nested    *)
		(* atomic region, we yield to another thread.              *)
		(*---------------------------------------------------------*)
		fun leaveAtomic () =
		    let val atomicLev = (!atomicLevel)
			val signalPend = (!signalPending)
		    in
			signalPending := false;
			if ((atomicLev = 1) andalso signalPend) then
			    yield()
			else
			    (if (atomicLev = 0) then
				 raise Atomic
			     else
				 atomicLevel := atomicLev - 1)
		    end

		fun atomically f = bracket enterAtomic leaveAtomic () f
		    
		fun fork child =
		    (callcc (fn parent =>
			     (enterAtomic();
			      reschedule (thread parent);
			      current_env := new_env();
			      resetAtomic();
			      child () handle exn =>
				  (print "Unhandled exception ";
				   print (System.exn_name exn);
				   print " raised in thread.\n");
				  switch ())))

		fun exit () = switch ()

		fun atomicBlock x q = atomically (fn () => Queue.enq q x)

		fun reset () = 
		    (enterAtomic();
		     run_queue := Queue.create();
		     signalPending := false;
		     setPreempt' NONE;
		     resetAtomic())

		(************************************************)
		(* mutex locks                                  *)
		(************************************************)
		datatype mutex =
		    MUTEX of bool ref * thread Queue.t

		fun mutex () =
		    MUTEX (ref false, Queue.create ())

		(*---------------------------------------------------------*)
		(* Your basic test and set operation.                      *)
		(*---------------------------------------------------------*)
		fun try_acquire (MUTEX (held, _)) =
		    atomically (fn () =>                 
				if not (!held) then      
				    (held := true; true) 
				else                     
				    false)               

		fun acquire (mutex as MUTEX (held, q)) =
		    let fun loop () =
			if (try_acquire mutex) then
			    ()
			else
			    (callcc (fn k => (atomicBlock (thread k) q;
					      switch()));
			     loop())
		    in
			loop()
		    end

		fun nonatomic_release (MUTEX (held, q)) =
		    (held := false;
		     reschedule (Queue.deq q)
		     handle Queue.Deq => ())

		fun release mutex =
		    atomically (fn () => nonatomic_release mutex)

		fun with_mutex mutex body =
		    bracket acquire release mutex body

		(************************************************)
		(* conditions                                   *)
		(************************************************)
		datatype condition =
		    CONDITION of (thread * mutex) Queue.t

		fun condition () =
		    CONDITION (Queue.create ())

		(*---------------------------------------------------------*)
		(* Wake up a thread on the condition's queue.  Have it try *)
		(* to regain the mutex that it gave up on the wait.        *)
		(*---------------------------------------------------------*)
		fun awaken condition_queue =
		    let fun awaken' () = 
			let val (thread,mutex as MUTEX (_, mutex_queue)) =
			    Queue.deq condition_queue
			in
			    if try_acquire mutex then
				reschedule thread
			    else
				block thread mutex_queue
			end
		    in
			atomically awaken'
		    end

		fun repeat f =
		    (f (); repeat f)

		fun signal (CONDITION q) =
		    awaken q handle Deq => ()

		fun broadcast (CONDITION q) =
		    (repeat (fn () => awaken q)) handle Deq => ()

		(*---------------------------------------------------------*)
		(* Give up the mutex and enqueue it and the current thread *)
		(* on the condition queue.  Then switch to another thread. *)
		(*---------------------------------------------------------*)
		fun wait mutex (CONDITION q) =
		    (enterAtomic();
		     nonatomic_release mutex;
		     callcc (fn k =>
			     (Queue.enq q ((thread k),mutex);
			      switch())))
		    
		fun await mutex cond test =
		    if test () then
			()
		    else
			(wait mutex cond; await mutex cond test)
	    end	

        (*********************)	
	(* Control Structure *)
        (*********************)
	structure Control =
	    struct
		val setPreempt = Thread.setPreempt'
	    end (* structure Control *)

    end (* structure Preempt_Co_Thread *)
