/* (c) Copyright 1990, 1991 Carnegie Mellon University 
 *
 * Cprim.c: Primitives for sml2c.
 * Written by David Tarditi
 *
 *  
 *  Assumptions in this file.
 *
 *  - We are using two's complement arithmetic, where the minimum
 *    integer is -2^31 and the maximum integer is 2^31-1 
 *  - floating point numbers are represented as doubles
 *  - A floating point number takes at most 8 bytes of storage
 *  - values of type C_val_t are large enough to hold pointers
 *  - arithmetic is done on integers of type C_val_t
 */

#include <math.h>
#include <errno.h>
#include <setjmp.h>
#include "tags.h"
#include "request.h"
#include "cause.h"
#include "ml_state.h"
#include "prim.h"

/* some header files are missing this declarations */

extern int errno;
extern ML_val_t overflow_e0[], sqrt_e0[], ln_e0[];
extern int inML,handlerPending;
extern jmp_buf top_level;

#define INT_MAX 0x7fffffff
#define INT_MIN -0x80000000

/* maximum float value for floor = 2^30, minimum float value for floor
   is 2^30 */

#define FLOOR_MAX 1073741824.0
#define FLOOR_MIN -1073741824.0
typedef int C_val_t;

#define NUMREGS 39

#define LIMIT_PTR_REG 2
#define SIGNAL_LIMIT_PTR_REG 3
#define STORE_PTR_REG 4
#define DATA_PTR_REG 5
#define EXN_PTR_REG 6
#define PC_REG 7
#define STANDARD_CLOSURE_REG 8
#define STANDARD_ARG_REG 9
#define STANDARD_CONT_REG 10

#define LIMIT_PTR        (Csp[LIMIT_PTR_REG])
#define SIGNAL_LIMIT_PTR (Csp[SIGNAL_LIMIT_PTR_REG])
#define STORE_PTR        (Csp[STORE_PTR_REG])
#define DATA_PTR         (Csp[DATA_PTR_REG])
#define EXN_PTR          (Csp[EXN_PTR_REG])
#define PC               (Csp[PC_REG])
#define STANDARD_CLOSURE (Csp[STANDARD_CLOSURE_REG])
#define STANDARD_ARG     (Csp[STANDARD_ARG_REG])
#define STANDARD_CONT    (Csp[STANDARD_CONT_REG])

/* register descriptor for functions using the standard calling
   convention */

#define STDGCMASK 7
#define CLOSURE(name,func_name) int name[2] = { MAKE_DESC(1,tag_record), (int) func_name};
#define UNTAG(v) ((v) >> 1)

#define RAISE(x) \
{ MLState->ml_allocptr = (int) DATA_PTR; \
  MLState->ml_storeptr = (int) STORE_PTR; \
  MLState->ml_roots[CONT_INDX] = (ML_val_t) EXN_PTR; \
  MLState->ml_roots[ARG_INDX] = (ML_val_t) (x); \
  MLState->ml_roots[PC_INDX] = (ML_val_t) (*(C_val_t*)EXN_PTR); \
  request = REQ_RUN; \
  longjmp(top_level,1); }

#define CONT { return(*(C_val_t *) STANDARD_CONT); }

C_val_t Csp[NUMREGS];
unsigned int Cmask;

int sig_return_v_function()
{ request = REQ_SIG_RETURN;
  quicksave();
}

int sigh_resume()
{ request = REQ_SIG_RESUME;
  quicksave();
}

int handle_c_function()
{ request = REQ_EXN;
  quicksave();
}

int return_c_function()
{ request = REQ_RETURN;
  quicksave();
}

int callc_v_function()
{ l0: if (DATA_PTR<=SIGNAL_LIMIT_PTR)
         { request = REQ_CALLC; quicksave(); }
  invoke_gc(STDGCMASK,callc_v_function);
  goto l0;
}

int quicksave()
{ register MLState_t *msp = MLState;
  register C_val_t *csp=Csp;
  inML = 0;
  msp->ml_allocptr = (int) csp[DATA_PTR_REG];
  msp->ml_storeptr = (int) csp[STORE_PTR_REG];
  msp->ml_roots[EXN_INDX]  = (ML_val_t) csp[EXN_PTR_REG];
  msp->ml_roots[CONT_INDX] = (ML_val_t) csp[STANDARD_CONT_REG];
  msp->ml_roots[ARG_INDX]  = (ML_val_t) csp[STANDARD_ARG_REG];
  longjmp(top_level,1);
}

static void moveregs()
{ register C_val_t *csp = Csp;
  register C_val_t *s = csp+NUMREGS;
  register MLState_t *msp = MLState;
  register ML_val_t *roots = msp->ml_roots;
  msp->ml_allocptr = (int) csp[DATA_PTR_REG];
  msp->ml_limitptr = (int) csp[LIMIT_PTR_REG];
  msp->ml_storeptr = (int) csp[STORE_PTR_REG];
  for (csp += 6; csp < s; *roots++ = (ML_val_t) *csp++);
}

static void fetchregs()
{ register C_val_t *csp = Csp;
  register C_val_t *s = csp+NUMREGS;
  register MLState_t *msp = MLState;
  register ML_val_t *roots = msp->ml_roots;
  register C_val_t limit = (C_val_t) msp->ml_limitptr;
  csp[DATA_PTR_REG] = (C_val_t) msp->ml_allocptr;
  csp[LIMIT_PTR_REG] = limit;
  csp[SIGNAL_LIMIT_PTR_REG] = limit;
  csp[STORE_PTR_REG] = (C_val_t) msp->ml_storeptr;
  for (csp += 6; csp < s; *csp++ = (C_val_t) *roots++);
}

void saveregs()
{
 inML = 0;
 moveregs();
 longjmp(top_level,1);

 /* should never reach here */

 die("saveregs: should never reach this point!\n");
}

void restoreregs()
{ extern int NumPendingSigs, maskSignals,inSigHandler,handlerPending;
  register C_val_t (*next)();
#ifdef CDEBUG
 register C_val_t (*prev)(),(*tmp)();
#endif

  fetchregs(); 
  next = (C_val_t (*)()) PC;

 if (NumPendingSigs && !maskSignals && !inSigHandler) {
       handlerPending = 1;
       SIGNAL_LIMIT_PTR = (C_val_t) 0;
  }

 inML = 1;
loop:
#ifdef CDEBUG
      tmp = (C_val_t  (*)()) ((*next)());
      prev = next;
      next = tmp;
      goto loop;
#else
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   next = (C_val_t (*)()) ((*next)());
   goto loop;
#endif
}

C_val_t invoke_gc(mask,func)
unsigned int mask;
{ inML = 0;
  if (handlerPending) {
    sig_setup();
    PC = func;
    Cmask = mask;
    saveregs();
  }
  moveregs();
  callgc0(CAUSE_GC,mask);
  fetchregs();
  inML = 1;
}

C_val_t inlined_gc(mask)
unsigned int mask;
{ inML = 0;
  moveregs();
  callgc0(CAUSE_GC,mask);
  fetchregs();
  inML = 1;
}

C_val_t array_v_function()
{ register C_val_t *dataptr, *finish, val;
  register C_val_t *arg;
  register int l;

l0:
  dataptr = (C_val_t *) DATA_PTR;
  arg = (C_val_t *) STANDARD_ARG;
  l = UNTAG(*arg);
  if (dataptr+l < (C_val_t *) SIGNAL_LIMIT_PTR)
     { *dataptr++ = (l << width_tags) | tag_array;
        STANDARD_ARG = (C_val_t) dataptr;
        finish = dataptr+l;
	for (val = *(arg+1); dataptr<finish; *dataptr++ = val);
        DATA_PTR = (C_val_t) finish;
        CONT;
     }
  invoke_gc(STDGCMASK,array_v_function);
  goto l0;
}

C_val_t create_s_v_function()
{ register C_val_t *dataptr;
  register int l = UNTAG(STANDARD_ARG);
  register C_val_t newtag = (l << width_tags) | tag_string;

  /* # of longwords needed */

  l = (l+3) >> 2;  
l0:
  dataptr = (C_val_t *) DATA_PTR;
  if (dataptr+l < (C_val_t *) SIGNAL_LIMIT_PTR)
      { *dataptr++ = newtag;
	STANDARD_ARG = (C_val_t) dataptr;
        DATA_PTR = (C_val_t) (dataptr+l);
        CONT
      }
  invoke_gc(STDGCMASK,create_s_v_function);
  goto l0;
}

C_val_t create_b_v_function()
{ register C_val_t *dataptr;
  register int l = UNTAG(STANDARD_ARG);
  register C_val_t newtag = (l << width_tags) | tag_bytearray;

  /* # of longwords needed */

  l = (l+3) >> 2;  
l0:
  dataptr = (C_val_t *) DATA_PTR;
  if (dataptr+l < (C_val_t *) SIGNAL_LIMIT_PTR)
      { *dataptr++ = newtag;
	STANDARD_ARG = (C_val_t) dataptr;
        DATA_PTR = (C_val_t) (dataptr+l);
        CONT
      }
  invoke_gc(STDGCMASK,create_b_v_function);
  goto l0;
}
 
int logb_v_function()
{ RAISE(overflow_e0+1); }

int scalb_v_function()
{ RAISE(overflow_e0+1); }

int floor_v_function()
{ register double d = floor(*(double *) STANDARD_ARG);
  if (d< FLOOR_MIN || d>FLOOR_MAX) {
                RAISE(overflow_e0+1);
    }
  STANDARD_ARG = (C_val_t) ((C_val_t) d * 2 + 1) ;
  CONT
}

#define MATH_FUNC(f,name) \
int name() \
{ register C_val_t *dataptr; \
l0: dataptr = (C_val_t *) DATA_PTR; \
    if (dataptr < (C_val_t *) SIGNAL_LIMIT_PTR) { \
      *dataptr++ = MAKE_DESC(sizeof(double),tag_string); \
      *(double *) dataptr = f (*(double *)STANDARD_ARG); \
      STANDARD_ARG = (C_val_t) dataptr; \
      DATA_PTR = (C_val_t) (dataptr+(sizeof(double)/sizeof(C_val_t))); \
      CONT \
  } \
  invoke_gc(STDGCMASK,name); \
  goto l0; \
} \

#define MATH_FUNC_WITH_ERR(f,name,err) \
int name() \
{   register C_val_t *dataptr; \
l0: dataptr = (C_val_t *) DATA_PTR; \
    if (dataptr < (C_val_t *) SIGNAL_LIMIT_PTR) { \
    *dataptr++ = MAKE_DESC(sizeof(double),tag_string); \
    *(double *) dataptr = f (*(double *)STANDARD_ARG); \
    STANDARD_ARG = (C_val_t) dataptr; \
    DATA_PTR = (C_val_t) (dataptr+sizeof(double)/sizeof(C_val_t)); \
    if ((errno == EDOM) || (errno == ERANGE)) {errno = -1; RAISE(err); } \
    CONT } \
   invoke_gc(STDGCMASK,name); \
   goto l0; \
} \

MATH_FUNC(sin, sin_v_function)
MATH_FUNC(cos,  cos_v_function)
MATH_FUNC(atan, arctan_v_function)

MATH_FUNC_WITH_ERR(exp,  exp_v_function, overflow_e0+1)
MATH_FUNC_WITH_ERR(log, ln_v_function, ln_e0+1)
MATH_FUNC_WITH_ERR(sqrt, sqrt_v_function, sqrt_e0+1)

CLOSURE(sigh_return_c,sig_return_v_function)
CLOSURE(handle_c,handle_c_function)
CLOSURE(return_c,return_c_function)
CLOSURE(callc_v,callc_v_function)
CLOSURE(array_v,array_v_function)
CLOSURE(create_b_v,create_b_v_function)
CLOSURE(create_s_v,create_s_v_function)
CLOSURE(arctan_v,arctan_v_function)
CLOSURE(cos_v,cos_v_function)
CLOSURE(exp_v,exp_v_function)
CLOSURE(floor_v,floor_v_function)
CLOSURE(ln_v,ln_v_function)
CLOSURE(sin_v,sin_v_function)
CLOSURE(sqrt_v,sqrt_v_function)
CLOSURE(logb_v,logb_v_function)
CLOSURE(scalb_v,scalb_v_function)

/* multiplication with overflow checking.
   We break the operands into 16-bit parts, multiply them, and put them
   back together.
*/

/* overflow check for unsigned integer addition, where a and b are the
   msb of the operands:

        a b r | ov
        ----------
        0 0 0   0
        0 0 1   0
        0 1 0   1
        0 1 1   0
        1 0 0   1
        1 0 1   0
        1 1 0   1
        1 1 1   1

    Overflow = and(a,b)|((eor(a,b)&(~r)))
*/

#define NO_OVERFLOW(a,b,r) (((a&b)|((a^b)&(~r)))>=0)
#define WORD_SIZE 16
#define LONG_WORD_SIZE 32
#define POW2(x) (1<<x)

/* mult: multiply two two's complement numbers, raise exception
   if an overflow occurs. */

int mult(b,d)
register unsigned int b,d;
{ register unsigned int a,c;
  register int sign = b^d;

/* break b and d into hi/lo words 

      -------------   ---------
      |  a  |  b  |  |c   |  d|
      -------------  ---------
*/

  if ((int)b<0) {b = -(int)b; }
  if ((int)d<0) {d = -(int)d; }
  a = b >> WORD_SIZE;
  b = b & (POW2(WORD_SIZE)-1);
  c = d >> WORD_SIZE;
  d = d & (POW2(WORD_SIZE)-1);
  if (a&c) goto overflow;
  a = a*d;
  c = c*b;
  b = b*d;
  if (a<(POW2(LONG_WORD_SIZE-WORD_SIZE)) &&
      c<(POW2(LONG_WORD_SIZE-WORD_SIZE)))
    { d = a+c;
      if (d<(POW2(LONG_WORD_SIZE-WORD_SIZE)))
          { d <<= WORD_SIZE;
            a=d+b;
           if NO_OVERFLOW(d,b,a)
	       if (sign<0)
		  if (a<=POW2(LONG_WORD_SIZE-1))
                      return (-a);
		  else goto overflow;
	       else if (a<(POW2(LONG_WORD_SIZE-1))) return (a);
	  }
    }
 overflow:
#ifdef DEBUG
      printf("overflow occurred\n");
#endif
  inML = 0;
  RAISE (overflow_e0+1)
}
