#include "differ.h"

#define fnsym  '#'
#define opsym  '$'
#define xsym   'x'
#define numsym ''
#define consym '@'
#define lbrac  '('
#define rbrac  ')'
#define fnend 1

logical parsefailed;

/*
  heres the list of supported functions
*/

struct sfuncs
{ char *name;
  functions command;
} funcs[] =
{  {"exp",EXP},
   {"sinh",SINH},
   {"sin",SIN},
   {"ln",LN},
   {"tanh",TANH},
   {"tan",TAN},
   {"cosech",COSECH},
   {"cosec",COSEC},
   {"coth",COTH},
   {"cot",COT},
   {"cosh",COSH},
   {"cos",COS},
   {"sech",SECH},
   {"sec",SEC},
   {"arcsinh",ARCSINH},
   {"arcsin",ARCSIN},
   {"arccosh",ARCCOSH},
   {"arccos",ARCCOS},
   {"arctanh",ARCTANH},
   {"arctan",ARCTAN},
   {NULL,UNDEF_FUNC}
};

/*
   Heres the list of supported operators. Note that although unitary minus
   does not appear in the list it is supported separately in the routines.
*/

struct sops
{ char *name;
  functions command;
} ops[] =
{  {"*",TIMES},
   {"/",DIVIDE},
   {"+",PLUS},
   {"-",MINUS},
   {"^",POWER},
   {NULL,UNDEF_OP}
};

/*
  This function takes the input, looks for functions and replaces them
  with markers prior to the parse.
*/

logical reducefnstosymbols(char *buff)
{ int newbufptr=0,
	oldbufptr=0,
	fncounter;
  char tempbuff[MAXINPLINELEN*2];
  logical found;
  while(buff[oldbufptr])
  { fncounter=0;
    found=FALSE;
    if(buff[oldbufptr]==fnsym) return FALSE;
    while(funcs[fncounter].name!=NULL)
    { if(!strnicmp(funcs[fncounter].name,&buff[oldbufptr],strlen(funcs[fncounter].name)))
	{ found=TRUE;
	  break;
	}
	fncounter++;
    }
    if(found)
    { tempbuff[newbufptr++]=fnsym;
	tempbuff[newbufptr++]=funcs[fncounter].command;
	oldbufptr+=strlen(funcs[fncounter].name);
    }
    else
	tempbuff[newbufptr++]=buff[oldbufptr++];
  }
  tempbuff[newbufptr]=0;
  strcpy(buff,tempbuff);
  return TRUE;
}

/*
   This function identifies operators and marks them. It also identifies
   unitary minus.
*/

logical reduceopstosymbols(char *buff)
{ int newbufptr=0,
	oldbufptr=0,
	opcounter;
  char lastchar=0;
  char tempbuff[MAXINPLINELEN*2];
  logical found;
  while(buff[oldbufptr])
  { opcounter=0;
    found=FALSE;
    if(buff[oldbufptr]==opsym) return FALSE;
    while(ops[opcounter].name!=NULL)
    { if(!strnicmp(ops[opcounter].name,&buff[oldbufptr],strlen(ops[opcounter].name)))
	{ found=TRUE;
	  break;
	}
	opcounter++;
    }
    if(found)
    { tempbuff[newbufptr++]=opsym;
	if(((!lastchar)||(lastchar==lbrac))&&(ops[opcounter].command==MINUS))
	  tempbuff[newbufptr++]=UMINUS;
	else
	  tempbuff[newbufptr++]=ops[opcounter].command;
	oldbufptr+=strlen(ops[opcounter].name);
    }
    else
    { if(buff[oldbufptr]!=' ')lastchar=buff[oldbufptr];
	tempbuff[newbufptr++]=buff[oldbufptr++];
    }
  }
  tempbuff[newbufptr]=0;
  strcpy(buff,tempbuff);
  return TRUE;
}

/*
   This function marks any constants.
*/

logical findconsts(char *buff)
{ int newbufptr=0,
	oldbufptr=0;
  char nextch;
  char tempbuff[MAXINPLINELEN*2];
  while(buff[oldbufptr])
  { nextch=buff[oldbufptr];
    if((nextch==consym)||(nextch==numsym))return FALSE;
    if((isdigit(nextch))||(nextch=='.'))
    { tempbuff[newbufptr++]=numsym;
	while((isdigit(nextch))||(nextch=='.'))
	{ tempbuff[newbufptr++]=buff[oldbufptr++];
	  nextch=buff[oldbufptr];
	}
    }
    else if((isalpha(nextch))&&(nextch!=xsym))
    { tempbuff[newbufptr++]=consym;
	tempbuff[newbufptr++]=buff[oldbufptr++];
    }
    else
	tempbuff[newbufptr++]=buff[oldbufptr++];
  }
  tempbuff[newbufptr]=0;
  strcpy(buff,tempbuff);
  return TRUE;
}

/*
   This is a useful routine. It allocates memory for an item in an expression
   and initialises it.
*/

struct expression *newtree(void)
{ struct expression *allocated;
  if((allocated=(struct expression *)malloc(sizeof(struct expression)))==NULL)
  { printf("Unable to allocate enough memory\n\r");
    exit(1);
  }
  allocated->lval=NULL;
  allocated->rval=NULL;
  return allocated;
}

/*
   Here is our expression parser. It is recursive and returns a pointer
   to a tree containing the expression.
*/

struct expression *parseexpr(char **buff,char endchar)
{ struct expression *tree=NULL,
			  *temptree,
			  *rectree,
			  *restree;
  operators optype;
  while(**buff)
  { switch(**buff)
    { case ' ':  (*buff)++;
		     break;
	case lbrac:(*buff)++;
		     if(tree!=NULL)
		     { parsefailed=TRUE;
			 printf("Misplaced '('\n\r");
			 return tree;
		     }
		     tree=parseexpr(buff,rbrac);
		     temptree=newtree();
		     temptree->rval=tree;
		     temptree->exptype=OPERATOR;
		     temptree->expval.op=BRACKET;
		     tree=temptree;
		     break;
	case rbrac:if(endchar==fnend) return tree;
		     (*buff)++;
		     if(tree==NULL)
		     { parsefailed=TRUE;
			 printf("No expression in brackets\n\r");
			 return tree;
		     }
		     if(endchar!=rbrac)
		     { parsefailed=TRUE;
			 printf("Misplaced ')'\n\r");
			 return tree;
		     }
		     return tree;
	case xsym: (*buff)++;
		     if(tree!=NULL)
		     { parsefailed=TRUE;
			 printf("Improper expression\n\r");
			 return tree;
		     }
		     tree=newtree();
		     tree->exptype=X;
		     break;
	case opsym:if(endchar==fnend)return tree;
		     (*buff)++;
		     optype=(**buff);
		     (*buff)++;
		     if((tree==NULL)&&(optype!=UMINUS))
		     { parsefailed=TRUE;
			 printf("Invalid syntax - missing L-Value\n\r");
			 return tree;
		     }
		     temptree=parseexpr(buff,endchar);
		     if(temptree==NULL)
		     { parsefailed=TRUE;
			 printf("Operator with no R-Value\n\r");
			 return tree;
		     }
		     if((temptree->exptype==OPERATOR)&&(optype>=temptree->expval.op))
		     { rectree=temptree;
			 while((rectree->lval->exptype==OPERATOR)&&(optype>=rectree->lval->expval.op))
			   rectree=rectree->lval;
			 restree=newtree();
			 restree->exptype=OPERATOR;
			 restree->expval.op=optype;
			 restree->lval=tree;
			 restree->rval=rectree->lval;
			 rectree->lval=restree;
			 tree=temptree;
		     }
		     else
		     { restree=newtree();
			 restree->exptype=OPERATOR;
			 restree->expval.op=optype;
			 restree->lval=tree;
			 restree->rval=temptree;
			 tree=restree;
		     }
		     if(endchar) return tree;
		     break;
	case fnsym:(*buff)++;
		     optype=(**buff);
		     (*buff)++;
		     if(tree!=NULL)
		     { parsefailed=TRUE;
			 printf("Invalid syntax - missing operator\n\r");
			 return tree;
		     }
		     temptree=parseexpr(buff,fnend);
		     if(temptree==NULL)
		     { parsefailed=TRUE;
			 printf("Function with no R-Value\n\r");
			 return tree;
		     }
		     tree=newtree();
		     tree->exptype=FUNCTION;
		     tree->expval.fn=optype;
		     tree->rval=temptree;
		     break;
	case numsym:(*buff)++;
		     if(tree!=NULL)
		     { parsefailed=TRUE;
			 printf("Improper expression\n\r");
			 return tree;
		     }
		     tree=newtree();
		     tree->exptype=VALUE;
		     tree->cval=(float)strtod(*buff,buff);
		     break;
	case consym:(*buff)++;
		     optype=(**buff);
		     (*buff)++;
		     if(tree!=NULL)
		     { parsefailed=TRUE;
			 printf("Improper expression\n\r");
			 return tree;
		     }
		     tree=newtree();
		     tree->exptype=CONSTANT;
		     tree->expval.co=optype;
		     break;
	default:   parsefailed=TRUE;
		     printf("Unknown symbol : %c",**buff);
		     return tree;
    }
  }
  if(endchar==rbrac)
  { parsefailed=TRUE;
    printf("Expression incomplete\n\r");
    return tree;
  }
  return tree;
}

/*
   This function removes any bracket pointers which are in the expression.
   Strange results happen if you don't put the bracket notes in, but we
   don't need them afterwards.
*/

struct expression *removebrackets(struct expression *expr)
{ struct expression *exprptr;
  if(expr==NULL) return NULL;
  if(expr->lval!=NULL)expr->lval=removebrackets(expr->lval);
  if(expr->rval!=NULL)expr->rval=removebrackets(expr->rval);
  if(expr->exptype==OPERATOR)
    if(expr->expval.op==BRACKET)
    { exprptr=expr->rval;
	free(expr);
	return exprptr;
    }
  return expr;
}

/*
   This just calls all the other functions and checks the results.
   Error indications could be better..
*/

struct expression *parser(char *buff)
{ struct expression *expr;
  if(!reducefnstosymbols(buff))
  { printf("Syntax error\n\r");
    return NULL;
  }
  if(!reduceopstosymbols(buff))
  { printf("Syntax error\n\r");
    return NULL;
  }
  if(!findconsts(buff))
  { printf("Syntax error\n\r");
    return NULL;
  }
  parsefailed=FALSE;
  expr=parseexpr(&buff,0);
  if(parsefailed)
  { freetree(expr);
    return NULL;
  }
  if(buff[0])
  { printf("Syntax error\n\r");
    return NULL;
  }
  expr=removebrackets(expr);
  return expr;
}

/*
   This routine prints the expression out. I have tried to remove the
   printing of as many extraneous brackets as possible, this was after
   looking at expressions like ((a)*((x)^(2))) for too long. Its quite
   good now though.
*/

void printout(struct expression *expr)
{ int tval=0;
  operators op1,op2;
  if(expr==NULL) return;
  switch(expr->exptype)
  { case OPERATOR:   if(expr->expval.op!=UMINUS)
			     if(expr->lval->exptype==OPERATOR)
				 if((op1=expr->lval->expval.op)<=(op2=expr->expval.op))
				 { if(!((op1==op2)&&((op1==PLUS)||(op2==TIMES))))
				   { printf("(");
				     printout(expr->lval);
				     printf(")");
				   }
				   else printout(expr->lval);
				 }
				 else printout(expr->lval);
			     else printout(expr->lval);
			   switch(expr->expval.op)
			   { case PLUS:    printf("+");
						 break;
			     case MINUS:   printf("-");
						 break;
			     case DIVIDE:  printf("/");
						 break;
			     case TIMES:   printf("*");
						 break;
			     case POWER:   printf("^");
						 break;
			     case UMINUS:  printf("-");
						 if(expr->rval->exptype==OPERATOR)
						   if((expr->rval->expval.op==PLUS)||
							(expr->rval->expval.op==MINUS))
						   { printf("(");
						     printout(expr->rval);
						     printf(")");
						     return;
						   }
						 printout(expr->rval);
						 return;
			     case UNDEF_OP:
			     default:      printf("\n\rInternal error:Undefined operation\n\r");
						 return;
			   }
			   if(expr->rval->exptype==OPERATOR)
			     if((op1=expr->rval->expval.op)<=(op2=expr->expval.op))
			     { if(!((op1==op2)&&((op1==PLUS)||(op2==TIMES))))
				 { printf("(");
				   printout(expr->rval);
				   printf(")");
				 }
				 else printout(expr->rval);
			     }
			     else printout(expr->rval);
			   else printout(expr->rval);
			   return;
    case VALUE:      printf("%g",expr->cval);
			   return;
    case X:          printf("x");
			   return;
    case CONSTANT:   printf("%c",expr->expval.co);
			   return;
    case FUNCTION:   while((expr->expval.fn!=funcs[tval].command)&&
				  (funcs[tval].command!=UNDEF_FUNC))tval++;
			   if (funcs[tval].name==NULL)
			   {  printf("\n\rInternal error:Undefined function\n\r");
				return;
			   }
			   printf("%s(",funcs[tval].name);
			   printout(expr->rval);
			   printf(")");
			   return;
    case UNDEF_ETYPE:
    default:         printf("\n\rInternal error:Undefined type\n\r");
			   break;
  }
}