/*--------------------------------------------------------------------
   Alged:  Algebra Editor    henckel@vnet.ibm.com

   Copyright (c) 1994 John Henckel
   Permission to use, copy, modify, distribute and sell this software
   and its documentation for any purpose is hereby granted without fee,
   provided that the above copyright notice appear in all copies.
*/
#include "alged.h"
/*--------------------------------------------------------------------
   associate
   rotate the association on add,sub,mul,div
   a+b-c => a-c+b => b+a-c => b-c+a
   note, we make use of the fact that ADD+1 = SUB and ADD is even.
*/
void associate(node *p) {
  node *a,*b;
  int opr;

  opr = p->kind;
  b = p->rt;

  if (opr==DIV && b->kind==MUL) {     /* special handle for bisected */
    p->rt = b->rt;
    b->rt = b->lf;
    b->lf = p->lf;
    p->lf = b;
    b->kind = DIV;
    b = p->rt;
  }
  if (opr==ADD || opr==MUL || opr==SUB || opr==DIV) {
    a = p;
    while ((a->lf->kind|1) == (opr|1)) {
      a->kind = a->lf->kind;
      a->rt = a->lf->rt;
      a = a->lf;
    }
    a->kind = opr;
    if (opr&1) {           /* subtr or divide */
      a->rt = b;
    }
    else {
      a->rt = a->lf;
      a->lf = b;
    }
  }
  else if (opr==EQU) {           /* commute equality */
    p->rt = p->lf;
    p->lf = b;
  }
}

/*--------------------------------------------------------------------
   commute
   commute on add,sub,mul,div
   e.g.   a/b ==>  b^(-1)/a^(-1)
*/
void commuteNOTUSED(node *p) {
  node *a,*b;

  a = p->lf;
  b = p->rt;
  if (p->kind==MUL || p->kind==ADD || p->kind==EQU) {
    p->lf=b; p->rt=a;
  }
  else if (p->kind==DIV || p->kind==SUB) {     /* non-commutative */
    if (a->kind==p->kind+1 &&
        b->kind==p->kind+1 &&
        a->rt->kind==NUM &&
        b->rt->kind==NUM) {              /* has coefficients */
      p->lf=b; p->rt=a;
      a->rt->value = -a->rt->value;
      b->rt->value = -b->rt->value;
    }
    else {                          /* need to make coefficients */
      p->lf = newoper(p->kind+1);
      p->rt = newoper(p->kind+1);
      p->lf->lf = b;
      p->lf->rt = newnum(-1);
      p->rt->lf = a;
      p->rt->rt = newnum(-1);
    }
  }
}

/*--------------------------------------------------------------------
      a^x * a^y ==> a^(x+y)
      x^a * y^a ==> (xy)^a
*/
int expjoin(node *p) {
  int i,r=0;
  node *a,*b;

  for (i=0; i<p->nump; ++i)
    r+=expjoin(p->parm[i]);

  a = p->lf;
  b = p->rt;
  if (p->kind==MUL || p->kind==DIV) {
    if (a->kind==EXP && b->kind==EXP &&
        equal(a->lf,b->lf)) {      /*  a^x * a^y = a^(x+y) */
      freetree(a->lf);
      a = a->rt;
      freenode(p->lf);
      b->kind = p->kind-2;
      p->kind = EXP;
      p->lf = b->lf;
      b->lf = a;
    }
    else if (b->kind==EXP &&
        equal(a,b->lf)) {               /*  a * a^y = a^(1+y) */
      freetree(a);
      a = newnum(1);
      b->kind = p->kind-2;
      p->kind = EXP;
      p->lf = b->lf;
      b->lf = a;
    }
    else if (a->kind==EXP &&
             equal(a->lf,b)) {          /*  a^x * a = a^(x+1) */
      freetree(b);
      b = newnum(1);
      a->kind = p->kind-2;
      p->kind = EXP;
      p->lf = a->lf;
      a->lf = a->rt;
      a->rt = b;
      p->rt = a;
    }
    else if (a->kind==EXP && b->kind==EXP &&
        equal(a->rt,b->rt)) {      /*  x^a * y^a = (xy)^a */
      freetree(b->rt);
      p->rt = a->rt;
      a->rt = b->lf;
      freenode(b);
      a->kind = p->kind;
      p->kind = EXP;
    }
    else return r;
    ++r;
  }
  return r;
}

/*--------------------------------------------------------------------
   remove sub and div

   x/y  =  x*y^-1
   x-y  =  x+y*-1
*/
int nosubdiv(node *p) {
  int i,r=0;
  node *a,*b;

  for (i=0; i<p->nump; ++i)
    r+=nosubdiv(p->parm[i]);

  a = p->lf;
  b = p->rt;
  if (p->kind==SUB || p->kind==DIV) {
    ++r;
    --p->kind;
    if (b->kind==NUM && p->kind==ADD)           /* a-2 = a+(-2) */
       b->value = -b->value;
    else if (b->kind==NUM && !whole(b->value))
       b->value = 1.0/b->value;
    else {
      a = newoper(p->kind+2);    /* MUL or EXP */
      a->lf = b;
      a->rt = newnum(-1);
      p->rt = a;
    }
  }
  return r;
}

/*--------------------------------------------------------------------
   bisect node

   This converts expressions to canonical form in which
   1. + * are left assoc, ^ is right assoc.
   2. scope of / is increased, e.g. (a/b)*c ==> ac/b,
   3. x+y*-2 is changed to x-y*2
   3' x+(-2) is changed to x-2
   4. x*y^-2 is changed to x/y^2
   4' x*(0.5) is changed to x/2
   5. scope of ^ is reduced
   The name "bisect" means that each MUL clag is broken into two pieces the
   numerator elements and the denominator ones.
*/
int bisect(node *p) {
  int i,r=0;
  node *a,*b;

  for (i=0; i<p->nump; ++i)
    r+=bisect(p->parm[i]);

  a = p->lf;
  b = p->rt;
  switch (p->kind) {
  case ADD:
  case MUL:
    /*--------------------------------------------------------------------
       Add or Multiply
    */
    if (b->kind==p->kind ||              /* a+(b+c) = a+b+c */
        b->kind==p->kind+1) {            /* a+(b-c) = a+b-c */
      swingb;
      i = b->kind;
      b->kind = p->kind;
      p->kind = i;
    }
    else if (p->kind==ADD &&             /* a+(-2) = a-2 */
             b->kind==NUM &&
             b->value < 0) {
      p->kind = SUB;
      b->value = -b->value;
    }
    else if (b->kind==p->kind+2 &&       /* a+b*-2 = a-b*2 */
        b->rt->kind==NUM &&
        b->rt->value < 0) {
      ++p->kind;
      b->rt->value = -b->rt->value;
    }
    /*--------------------------------------------------------------------
       Add only
    */
    else if (p->kind==ADD &&
        aop(b->kind)) {                   /* a+(b-c) = (a+b)-c */
      swingb;
      p->kind = b->kind;
      b->kind = ADD;
    }
    /*--------------------------------------------------------------------
       Multiply only
    */
    else if (p->kind==MUL &&
        a->kind==p->kind+1) {       /* a-b+c = a+c-b */
      p->rt = a->rt;
      a->rt = b;
      ++p->kind;
      --a->kind;
    }
    else if (p->kind==MUL &&
        a->kind==p->kind+2 &&       /* b*-2+a = a-b*2 */
        a->rt->kind==NUM &&
        a->rt->value < 0) {
      ++p->kind;
      p->lf = b;
      p->rt = a;
      a->rt->value = -a->rt->value;
    }
    else break;
    ++r; break;
  case SUB:
  case DIV:
    /*--------------------------------------------------------------------
       Subtract or divide
    */
    if (b->kind==p->kind+1 &&       /* a-b*-2 = a+b*2   NOT NECES. */
        b->rt->kind==NUM &&
        b->rt->value < 0) {
      --p->kind;
      b->rt->value = -b->rt->value;
    }
    /*--------------------------------------------------------------------
       Subtract only
    */
    else if (p->kind==SUB &&             /* a-(-2) = a+2 */
        b->kind==NUM &&
        b->value < 0) {
      p->kind = ADD;
      b->value = -b->value;
    }
    else if (p->kind==SUB &&
         aop(b->kind)) {                   /* a-(b+c) = (a-b)-c */
      swingb;
      p->kind = (ADD+SUB) - b->kind;
      b->kind = SUB;
    }
    /*--------------------------------------------------------------------
       Divide only
    */
    else if (p->kind==DIV &&
        b->kind==p->kind) {             /* a-(b-c) = a+c-b */
      p->lf = b;
      p->rt = b->lf;
      b->lf = a;
      --b->kind;
    }
    else if (p->kind==DIV &&
        a->kind==p->kind) {        /* a-b-c = a-(b+c) */
      swinga;
      --a->kind;
    }
    else break;
    ++r; break;
  case EXP:
    /*--------------------------------------------------------------------
       exponent
    */
    if (a->kind==EXP) {                  /* (x^y)^z = x^(y*z) */
      swinga;
      a->kind = MUL;
    }
    else if (b->kind==NUM &&             /* a^(-2) = 1/a^2 */
             b->value < 0) {
      p->kind = DIV;
      p->lf = newnum(1);
      p->rt = newoper(EXP);
      p->rt->lf = a;
      p->rt->rt = b;
      b->value = -b->value;
    }
    else break;
    ++r; break;
  }
  return r;
}

/*--------------------------------------------------------------------
   exponent expand - expand any integer exponents less than 100.
*/
int exexpand(node *p) {
  int i,r=0;
  node *a,*b;

  for (i=0; i<p->nump; ++i)
    r+=exexpand(p->parm[i]);

  a = p->lf;
  b = p->rt;
  if (p->kind==EXP &&
      b->kind==NUM &&           /* a^n = a^(n-1)*a */
      b->value > 1 &&
      b->value <= maxpow &&
      whole(b->value)) {
    if (--b->value == 1) {
      p->lf = deepcopy(a);
      freenode(b);
    }
    else {
      b = newoper(EXP);
      b->lf = deepcopy(a);
      b->rt = p->rt;
      p->lf = b;
    }
    p->rt = a;
    p->kind = MUL;
    ++r;
  }
  return r;
}

/*-----------------------------------------------------------------
   within - if p is a factor within q, then return the rest of q.
   if the result is not null, then it is a new copy.
*/
node *within(node *p, node *q) {
  node *a=q,*b=NULL;

  while (a->kind==MUL) {
    if (equal(p,a->rt)) {
      if (b) {
        b->lf = a->lf;             // skip over match
        q = deepcopy(q);           // make copy
        b->lf = a;                 // repair
      }
      else q = deepcopy(a->lf);
      return q;
    }
    else if (equal(p,a->lf)) {
      if (b) {
        b->lf = a->rt;             // skip over match
        q = deepcopy(q);           // make copy
        b->lf = a;                 // repair
      }
      else q = deepcopy(a->rt);
      return q;
    }
    b = a;
    a = b->lf;
  }
  return NULL;
}

/*--------------------------------------------------------------------
   ComDeno - find common denominators.
*/
int comdeno(node *p) {
  int i,r=0;
  double x;
  node *a,*b,*nu,*de;

  for (i=0; i<p->nump; ++i)
    r+=comdeno(p->parm[i]);

  a = p->lf;
  b = p->rt;
  if (aop(p->kind) && (a->kind==DIV || b->kind==DIV)) {
    if (a->kind==DIV && b->kind==DIV) {
      if (equal(a->rt,b->rt)) {     /* a/x + b/x = (a+b)/x */
        freetree(b->rt);
        de = a->rt;
        a = a->lf;
        b = b->lf;
        freenode(p->lf);
        freenode(p->rt);
      }
      else if (!!(nu=within(a->rt,b->rt))) {   /* a/x + b/xy = (ay+b)/xy */
        freetree(a->rt);
        a = cons(nu,MUL,a->lf);
        de = b->rt;
        b = b->lf;
        freenode(p->lf);
        freenode(p->rt);
      }
      else if (!!(nu=within(b->rt,a->rt))) {   /* a/xy + b/x = (a+by)/xy */
        freetree(b->rt);
        b = cons(nu,MUL,b->lf);
        de = a->rt;
        a = a->lf;
        freenode(p->lf);
        freenode(p->rt);
      }
      else {                        /* a/c + b/d = (ad+bc)/cd */
        a->kind = b->kind = MUL;
        de = b->rt;
        b->rt = a->rt;
        a->rt = de;
        de = newoper(MUL);
        de->lf = deepcopy(b->rt);
        de->rt = deepcopy(a->rt);
      }
    }
    else if (a->kind==DIV) {        /* a/b + c = (a + bc)/b */
      a->kind = MUL;
      de = a->lf;
      a->lf = b;
      b = a;
      a = de;
      de = deepcopy(b->rt);
    }
    else {                        /* a + b/c = (ac + b)/c */
      b->kind = MUL;
      de = b->lf;
      b->lf = a;
      a = b;
      b = de;
      de = deepcopy(a->rt);
    }
    nu = newoper(p->kind);
    nu->lf = a;
    nu->rt = b;
    p->kind = DIV;
    p->lf = nu;
    p->rt = de;
    ++r;
  }
  return r;
}

/*-----------------------------------------------------------------
   distribute2

   This applies the distributive laws
   1. division over addition.
*/
int distribute2(node *p) {
  node *a,*b,*c;
  int i,r=0;

  for (i=0; i<p->nump; ++i)
    r+=distribute2(p->parm[i]);

  i=0;
  c = p->parm[i];
  if (p->kind==DIV && aop(c->kind)) {
    b = deepcopy(p->parm[1-i]);
    a = newoper(p->kind);
    a->parm[i] = c->parm[1-i];
    a->parm[1-i] = b;
    c->parm[1-i] = p->parm[1-i];
    p->parm[1-i] = a;
    p->kind = c->kind;
    c->kind = a->kind;
    ++r;
  }
  return r;
}

/*    This is used to control the direction of distribute */
static int topdown = 0;
/*-----------------------------------------------------------------
   distribute

   This applies the distributive laws
   1. multiplication over addition.
   2. (ab)^2 = a^2*b^2
   3. x^(a+b) = x^a*x^b
*/
int distribute(node *p) {
  node *a,*b,*c;
  int i,r=0;

  if (!topdown)
    for (i=0; i<p->nump; ++i)
      r+=distribute(p->parm[i]);

  for (i=0; i<2; ++i) {
    c = p->parm[i];
    if (p->kind==MUL && aop(c->kind) ||
        i && p->kind==EXP && aop(c->kind) ||
        !i && p->kind==EXP && (c->kind==MUL || c->kind==DIV)) {
      b = deepcopy(p->parm[1-i]);
      a = newoper(p->kind);
      a->parm[i] = c->parm[1-i];
      a->parm[1-i] = b;
      c->parm[1-i] = p->parm[1-i];
      p->parm[1-i] = a;
      p->kind = c->kind;
      c->kind = a->kind;
      if (i && a->kind==EXP) p->kind+=2;    /* change ADD to MUL */
      ++r;
      break;   /* just to be safe */
    }
  }

  if (topdown && !r)
    for (i=0; i<p->nump; ++i)
      r+=distribute(p->parm[i]);

  return r;
}

/*-----------------------------------------------------------------
   distribute_c
   this applies the distributive law over multiplication which is
   not governed by another multiplication or equality.
*/
int distribute_c(node *p) {
  int r=0;

  topdown = 1;
  if (p->kind==EQU) {
    r+=distribute_c(p->lf);
    r+=distribute_c(p->rt);
  }
  else if (p->kind==MUL) {
    r+=distribute(p->rt);
    r+=distribute_c(p->lf);
  }
  else r+=distribute(p);
  topdown = 0;

  return r;
}

/*--------------------------------------------------------------------
   fixassoc
     + * are left assoc, ^ is right assoc.
*/
int fixassoc(node *p) {
  int i,r=0;
  node *a,*b;

  for (i=0; i<p->nump; ++i)
    r+=fixassoc(p->parm[i]);

  a = p->lf;
  b = p->rt;
  switch (p->kind) {
  case ADD:
  case MUL:
    if (b->kind==p->kind ||              /* a+(b+c) = a+b+c */
        b->kind==p->kind+1) {            /* a+(b-c) = a+b-c */
      swingb;
      i=b->kind; b->kind=p->kind; p->kind=i;
    }
    else break;
    ++r; break;
  case SUB:
  case DIV:
    if (b->kind==p->kind) {             /* a-(b-c) = a-b+c */
      swingb;
      --p->kind;
    }
    else if (b->kind==p->kind-1) {      /* a-(b+c) = a-b-c */
      swingb;
      ++b->kind;
    }
    else break;
    ++r; break;
  case EXP:
    if (a->kind==EXP) {                  /* (x^y)^z = x^(y*z) */
      swinga;
      a->kind = MUL;
    }
    else break;
    ++r; break;
  }
  return r;
}

/*-----------------------------------------------------------------
   get term, return p without coefficient
*/
node *get_term(node *p, int oper, double *r) {
  *r = 1.0;              /* default coeff */
  if (p->kind==oper && p->rt->kind==NUM) {
    *r = p->rt->value;
    return p->lf;
  }
  return p;
}
/*--------------------------------------------------------------------
   Combine all terms with common base within a clag.
*/
int combine(node *p) {
  int i,oper,r=0;
  node *a,*b;
  node *t1,*t2;
  double c1,c2;

  for (i=0; i<p->nump; ++i)
    r+=combine(p->parm[i]);

  a = p->lf;
  b = p->rt;
  oper = p->kind;
  if (oper!=ADD && oper!=MUL)          /* not a clag */
    return r;
  i = 1;
  if (a->kind!=oper) {
    a=p; i=0;
  }

  t1 = get_term(a->parm[i],oper+2,&c1);
  t2 = get_term(b,oper+2,&c2);
  if (equal(t1,t2)) {            /* same base, combine terms */
    if (a->parm[i]==t1)
      freetree(t1);
    else {                    /* free one of the expressions */
      freetree(b);
      b = a->parm[i];
      t2 = t1;
    }
    c2 += c1;
    if (b==t2) {             /* make room for coeff */
      b = newoper(oper+2);
      b->lf = t2;
      b->rt = newnum(0);
    }
    b->rt->value = c2;
    if (c2==1.0) {            /* remove unit coeff */
      freenode(b->rt);
      freenode(b);
      b = t2;
    }
    if (i) {                  /* cleanup... 2 stage */
      nodecpy(p,a);
      freenode(a);
      p->rt = b;
    }
    else {                    /* 1 stage */
      nodecpy(p,b);
      freenode(b);
    }
    ++r;
  }
  return r;
}

/*-----------------------------------------------------------------
   Simplify
   This function converts a rational expression to normal form.
   Some of the attributes of normal form: no negative exponents,
   terms are sorted over mul and add.
   The first movenums pushes all numbers to the left so that
   when the sortnode runs it pushes them to the right and combines
   them.
*/
void simplify(node *p) { simplify2(p,0); }

void simplify2(node *p, int slow)
{
  while (calcnode(p,1)) if (slow) slow=debug(p);
  while (fixassoc(p)) if (slow) slow=debug(p);
  while (nosubdiv(p)) if (slow) slow=debug(p);
  while (fixassoc(p)) if (slow) slow=debug(p);
  while (movenums(p,0,MUL)) if (slow) slow=debug(p);
  if (sortnode(p,MUL) && slow) slow=debug(p);
  while (movenums(p,0,ADD)) if (slow) slow=debug(p);
  if (sortnode(p,ADD) && slow) slow=debug(p);
  while (combine(p)) if (slow) slow=debug(p);
  while (calcnode(p,1)) if (slow) slow=debug(p);
  if (sortnode(p,MUL) && slow) slow=debug(p);
  if (sortnode(p,ADD) && slow) slow=debug(p);
  while (bisect(p)) if (slow) slow=debug(p);
  while (calcnode(p,1)) if (slow) slow=debug(p);
  while (movenums(p,0,MUL)) if (slow) slow=debug(p);
}

/*-----------------------------------------------------------------
   substitute -- prereq tgt MUST be an EQU.
*/
void substitution(node *p) {
  int i;
  if (equal(p,tgt->lf))
    movenode(p,deepcopy(tgt->rt));
  else for (i=0; i<p->nump; ++i)
    substitution(p->parm[i]);
}


/*-----------------------------------------------------------------
   insert key
   has special handling for equations...
*/
void insertkey(int opr) {
  node *tmp,*t1,*t2;

  if (!src || !tgt) return;
  if (tgt->kind==EQU) {
    t1 = tgt->lf; t2 = tgt->rt;
  } else t1 = t2 = tgt;

  if (src->kind==EQU) {
    tmp = src->lf;
    src->lf = newoper(opr);
    src->lf->lf = tmp;
    src->lf->rt = deepcopy(t1);
    tmp = src->rt;
    src->rt = newoper(opr);
    src->rt->lf = tmp;
    src->rt->rt = deepcopy(t2);
    src->rt->lf = tmp;
  }
  else if (opr==EXP) {
    tmp = newnode();
    nodecpy(tmp,src);
    src->kind = EXP;
    src->nump = 2;
    src->lf = cons(tmp,EXP,deepcopy(t1));
    src->rt = cons(newnum(1),DIV,deepcopy(t2));
    src = src->lf;
  }
  else {
    tmp = newnode();
    nodecpy(tmp,src);
    src->kind = opr ^ 1;
    src->nump = 2;
    src->lf = newoper(opr);
    src->rt = deepcopy(t2);
    src = src->lf;
    src->lf = tmp;
    src->rt = deepcopy(t1);
  }
}

/*-----------------------------------------------------------------
   invert left side of equality, i=0 or 1
*/
void cross_eq(node *p, int i) {
  node *t;
  if (p->kind != EQU) return;
  t = p->lf;
  if (t->kind==ADD ||
      t->kind==SUB ||
      t->kind==MUL ||
      t->kind==DIV) {
    p->lf = t->parm[i];
    t->parm[i] = p->rt;
    p->rt = t;
    if (i && !(t->kind&1)) {   /* swap lf and rt */
      t = t->rt;
      p->rt->rt = p->rt->lf;
      p->rt->lf = t;
      t = p->rt;
    }
    if (!i || !(t->kind&1)) {   /* toggle operator */
      t->kind ^= 1;
    }
  }
}