/****************************************************************************
 *  This file is part of PPMD project                                       *
 *  Written and distributed to public domain by Dmitry Shkarin 97, 99       *
 *  Contents: model description and encoding/decoding routines              *
 ****************************************************************************/
#include <string.h>
#include <time.h>
#include "model.h"
#include "suballoc.h"
#pragma hdrstop
#include "coder.hpp"

const int ALFA=16, INT_BITS=7, PERIOD_BITS=7;
const int INTERVAL=1 << INT_BITS, BIN_SCALE=INTERVAL << PERIOD_BITS;
#pragma pack(1)
struct PPM_CONTEXT {
    WORD NumStats,SummFreq;                 // sizeof(WORD) > sizeof(BYTE)
    struct STATE { BYTE Symbol, Freq; PPM_CONTEXT* Successor; } * Stats;
    PPM_CONTEXT* Lesser;
    inline PPM_CONTEXT(STATE* pStats, PPM_CONTEXT* LesserContext);
    inline BOOL  encodeBinSymbol(int symbol);// MaxOrder:
    inline BOOL    encodeSymbol1(int symbol);//  ABCD    context
    BOOL _FASTCALL  innerEncode1(int symbol);//   BCD    lesser
    BOOL _FASTCALL encodeSymbol2(int symbol);//   BCDE   successor
    inline BOOL            decodeBinSymbol();// other orders:
    inline BOOL              decodeSymbol1();//   BCD    context
    BOOL _FASTCALL  innerDecode1(int  count);//    CD    lesser
    BOOL _FASTCALL           decodeSymbol2();//   BCDE   successor
    void _FASTCALL rescale(STATE* p);
    void* operator new(size_t ) { return AllocUnits(); }
    STATE* oneState() const { return (STATE*) (((BYTE*)this)+sizeof(WORD)); }
#define UPDATE1(p) {                                                        \
    SummFreq += 2;                          p->Freq += 2;                   \
    if (p[0].Freq <= p[-1].Freq)            pCState=p;                      \
    else {                                                                  \
        SWAP(p[0],p[-1]);                                                   \
        if ((--p)->Freq <= MAX_FREQ)        pCState=p;                      \
        else                                rescale(p);                     \
    }                                                                       \
}
#define UPDATE2(p) {                                                        \
    SummFreq += 2;                          p->Freq += 2;                   \
    if (p->Freq <= MAX_FREQ)                pCState=p;                      \
    else                                    rescale(p);                     \
    if (NumMasked == 1)     CharMask[UpperContext->oneState()->Symbol]=0;   \
    else if (NumMasked < 32) {                                              \
        i=NumMasked-1;      CharMask[(p=UpperContext->Stats)->Symbol]=0;    \
        do CharMask[(++p)->Symbol]=0; while (--i != 0);                     \
    } else                  memset(CharMask,0,sizeof(CharMask));            \
}
};
#pragma pack()

static PPM_CONTEXT * MinContext, * MaxContext, * UpperContext;
static PPM_CONTEXT::STATE* pCState;
static BYTE CharMask[256];
static int NumMasked, BSVal, NumConts, MaxOrder, NS2Indx[256];
static WORD BinSumm[128][8], Esc2Summ[44][4];

inline PPM_CONTEXT::PPM_CONTEXT(STATE* pStats,PPM_CONTEXT *LesserContext):
    Lesser(LesserContext), NumStats(0)
{
    pStats->Successor=this;                 NumConts++;
}
clock_t StartModel()
{
    int i,k;
    InitSubAllocator();
    PPM_CONTEXT* Order0 = (PPM_CONTEXT*) AllocUnits();
    Order0->Lesser=NULL;
    Order0->NumStats=256;                   Order0->SummFreq=257;
    Order0->Stats = (PPM_CONTEXT::STATE*) AllocBlk(sizeof(PPM_CONTEXT::STATE)*256);
    for (i=0;i < 256;i++) {
        Order0->Stats[i].Symbol=i;          Order0->Stats[i].Freq=1;
        Order0->Stats[i].Successor=NULL;
    }
    PPM_CONTEXT::STATE* p=(MaxContext=Order0)->Stats;
    for (NumConts=i=1; ;i++) {
        MaxContext = new PPM_CONTEXT(p,MaxContext);
        if (i == MaxOrder)                  break;
        MaxContext->NumStats=1;             p=MaxContext->oneState();
        p->Symbol = 0;                      p->Freq = 1;
    }
    MinContext=MaxContext->Lesser;
    if ( pCState )                          return 0;
static const WORD InitBinEsc[8] = { 0x5592,0x2738,0x3F6D,0x4A3F,0x51DD,0x5424,0x579D,0x58C1 };
    for (i=0;i < 128;i++)
        for (k=0;k < 8;k++)
            BinSumm[i][k]=BIN_SCALE-InitBinEsc[k]/(i+2);
    for (i=0;i < 44;i++)
        Esc2Summ[i][0]=Esc2Summ[i][1]=Esc2Summ[i][2]=Esc2Summ[i][3]=(2*i+4) << PERIOD_BITS;
    for (i=0;i < 4;i++)                     NS2Indx[i]=i;
    for ( ;i < 4+8;i++)                     NS2Indx[i]=4+((i-4) >> 1);
    for ( ;i < 4+8+32;i++)                  NS2Indx[i]=4+4+((i-4-8) >> 2);
    for ( ;i < 256;i++)                     NS2Indx[i]=4+4+8+((i-4-8-32) >> 3);
    memset(CharMask,0,sizeof(CharMask));
    return clock();
}
inline void StopModel()
{
}
void _FASTCALL PPM_CONTEXT::rescale(STATE* p)
{
    int OldNS=NumStats, i=NumStats-1, Adder, EscFreq;
    STATE* p1;
    for ( ;p != Stats;p--)                  SWAP(p[0],p[-1]);
    EscFreq=SummFreq-p->Freq;               Adder=(this != MaxContext);
    SummFreq = (p->Freq=(p->Freq+Adder) >> 1);
    do {
        EscFreq -= (++p)->Freq;
        SummFreq += (p->Freq=(p->Freq+Adder) >> 1);
        if (p[0].Freq > p[-1].Freq) {
            STATE tmp=*(p1=p);
            do p1[0]=p1[-1]; while (--p1 != Stats && tmp.Freq > p1[-1].Freq);
            *p1=tmp;
        }
    } while ( --i );
    if (p->Freq == 0) {
        do { i++; } while ((--p)->Freq == 0);
        EscFreq += i;
        if ((NumStats -= i) == 1) {
            STATE tmp=*Stats;
            do { tmp.Freq-=(tmp.Freq >> 1); EscFreq>>=1; } while (EscFreq > 1);
            FreeBlk(Stats,OldNS);           *(pCState=oneState())=tmp;
            return;
        }
    }
    SummFreq += (EscFreq -= (EscFreq >> 1));
    Stats=(STATE*) ShrinkBlk(Stats,OldNS,NumStats,ALFA*EscFreq > SummFreq);
    pCState=Stats;
}
static void UpdateModel()   // IPs are slightly inflated
{
    PPM_CONTEXT::STATE* p, * ps[MAX_O], ** pps;
    PPM_CONTEXT* pc;
    UINT sf1,cf,InitFreq,bf=pCState->Freq-1,sf=MinContext->SummFreq-MinContext->NumStats;
    if ( !MaxContext->NumStats ) {
        if (MinContext->NumStats == 1) {
            pc=MinContext;                  pps=ps;
            do *pps++ = pc->oneState(); while ((pc=pc->Lesser)->NumStats == 1);
            if (pc->Stats->Symbol == pCState->Symbol) {
                cf=pc->SummFreq-pc->NumStats+1-(bf=pc->Stats->Freq);
                InitFreq=2+((2*bf < 3*cf)?(4*bf > cf):((bf+(cf >> 1))/cf));
#if (MAX_FREQ > 128)
                if (InitFreq > 128)         InitFreq=128;
#endif /* MAX_FREQ > 128 */
            } else                          InitFreq=2;
            do {
                ((p=*(--pps))->Freq > InitFreq)?
                        (InitFreq=p->Freq):(p->Freq=InitFreq);
            } while (*pps != ps[0]);
        } else {
            cf=sf-bf;
            InitFreq=1+((2*bf < 3*cf)?(4*bf > cf):((bf+(cf >> 1))/cf));
#if (MAX_FREQ > 128)
            if (InitFreq > 128)             InitFreq=128;
#endif /* MAX_FREQ > 128 */
        }
    }
    for (pc=MaxContext,pps=ps;pc != MinContext;pc=pc->Lesser,*pps++ = p) {
        if (pc->NumStats < 1)               p=pc->oneState();
        else if (pc->NumStats == 1) {
            PPM_CONTEXT::STATE tmp=*(pc->oneState());
            pc->Stats=(PPM_CONTEXT::STATE*) AllocUnits();
            if ( !pc->Stats )               goto RESTART_MODEL;
            pc->SummFreq=BSVal+((tmp.Freq < 16)?(tmp.Freq+
                    (MinContext->NumStats > 3)):(tmp.Freq=16));
            pc->Stats[0]=tmp;
            InitFreq=1;                     p=pc->Stats+1;
            if ((cf=bf*(pc->SummFreq+3)) >= (sf1=sf+pc->SummFreq))
                    pc->SummFreq += (InitFreq=2+(cf > 2*sf1))-1;
        } else {
            pc->Stats=(PPM_CONTEXT::STATE*) ExpandBlk(pc->Stats,
                pc->NumStats,ALFA*(pc->NumStats+1) > pc->SummFreq);
            if ( !pc->Stats )               goto RESTART_MODEL;
            pc->SummFreq += 1+(5*pc->NumStats < 2*MinContext->NumStats)+
                    ((5*pc->NumStats < MinContext->NumStats) &
                    (pc->SummFreq < 3*pc->NumStats+2));
            InitFreq=1;                     p=pc->Stats+pc->NumStats;
            if ((cf=bf*(pc->SummFreq+2)) >= (sf1=sf+pc->SummFreq))
                    pc->SummFreq += (InitFreq=2+(cf > 2*sf1))-1;
        }
        pc->NumStats++;
        p->Symbol = pCState->Symbol;        p->Freq = InitFreq;
    }
    if ( pCState->Successor )               MinContext=pc=pCState->Successor;
    else if ((pc = new PPM_CONTEXT(pCState,pc)) == NULL)
                    goto RESTART_MODEL;
    while (--pps != ps)
            if ((pc = new PPM_CONTEXT(*pps,pc)) == NULL)
                    goto RESTART_MODEL;
    MaxContext=(*pps)->Successor=pc;
    return;
RESTART_MODEL:
    StartModel();
}
#define GET_MEAN(SUMM,SHIFT,ROUND) ((SUMM+(1 << (SHIFT-ROUND))) >> (SHIFT))
inline BOOL PPM_CONTEXT::encodeBinSymbol(int symbol)
{
    STATE* p=oneState();
    WORD& bs=BinSumm[p->Freq-1][(Lesser->NumStats < 8)?(Lesser->NumStats):(0)];
    if (p->Symbol == symbol) {
        p->Freq += ((pCState=p)->Freq < 128);
        SubRange.HighCount=bs;              bs += INTERVAL-GET_MEAN(bs,PERIOD_BITS,2);
        SubRange.LowCount=0;                return TRUE;
    } else {
        SubRange.LowCount=bs;               bs -= GET_MEAN(bs,PERIOD_BITS,2);
        BSVal=2+((bs > BIN_SCALE/2)?(0):((BIN_SCALE/2)/bs));
        CharMask[p->Symbol]=NumMasked=1;
        SubRange.HighCount=BIN_SCALE;       return FALSE;
    }
}
inline BOOL PPM_CONTEXT::decodeBinSymbol()
{
    int count=ariGetCurrentShiftCount(INT_BITS+PERIOD_BITS);
    STATE* p=oneState();
    WORD& bs=BinSumm[p->Freq-1][(Lesser->NumStats < 8)?(Lesser->NumStats):(0)];
    if (count < bs) {
        p->Freq += ((pCState=p)->Freq < 128);
        SubRange.HighCount=bs;              bs += INTERVAL-GET_MEAN(bs,PERIOD_BITS,2);
        SubRange.LowCount=0;                return TRUE;
    } else {
        SubRange.LowCount=bs;               bs -= GET_MEAN(bs,PERIOD_BITS,2);
        BSVal=2+((bs > BIN_SCALE/2)?(0):((BIN_SCALE/2)/bs));
        CharMask[p->Symbol]=NumMasked=1;
        SubRange.HighCount=BIN_SCALE;       return FALSE;
    }
}
inline BOOL PPM_CONTEXT::encodeSymbol1(int symbol)
{
    SubRange.scale=SummFreq;
    if (Stats->Symbol == symbol) {
        SummFreq += 2;
        SubRange.HighCount=(pCState = Stats)->Freq;
        if ((Stats->Freq += 2) > MAX_FREQ)  rescale(Stats);
        SubRange.LowCount=0;                return TRUE;
    }
    return innerEncode1(symbol);
}
inline BOOL PPM_CONTEXT::decodeSymbol1()
{
    SubRange.scale=SummFreq;
    int count=ariGetCurrentCount();
    if (Stats->Freq > count) {
        SummFreq += 2;
        SubRange.HighCount=(pCState = Stats)->Freq;
        if ((Stats->Freq += 2) > MAX_FREQ)  rescale(Stats);
        SubRange.LowCount=0;                return TRUE;
    }
    return innerDecode1(count);
}
BOOL _FASTCALL PPM_CONTEXT::innerEncode1(int symbol)
{
    STATE* p=Stats;
    int LoCnt=p->Freq, i=NumStats-1;
    do {
        if ((++p)->Symbol == symbol)        goto SYMBOL_FOUND;
        LoCnt += p->Freq;
    } while ( --i );
    CharMask[p->Symbol]=1;
    do { CharMask[(--p)->Symbol]=1; } while (p != Stats);
    SubRange.LowCount=LoCnt;                SubRange.HighCount=SubRange.scale;
    NumMasked = NumStats;                   return FALSE;
SYMBOL_FOUND:
    SubRange.HighCount=(SubRange.LowCount=LoCnt)+p->Freq;
    UPDATE1(p);                             return TRUE;
}
BOOL _FASTCALL PPM_CONTEXT::innerDecode1(int count)
{
    STATE* p=Stats;
    int HiCnt=p->Freq, i=NumStats-1;
    do {
        if ((HiCnt+=(++p)->Freq) > count)   goto SYMBOL_FOUND;
    } while ( --i );
    CharMask[p->Symbol]=1;
    do { CharMask[(--p)->Symbol]=1; } while (p != Stats);
    SubRange.LowCount=HiCnt;                SubRange.HighCount=SubRange.scale;
    NumMasked = NumStats;                   return FALSE;
SYMBOL_FOUND:
    SubRange.LowCount=(SubRange.HighCount=HiCnt)-p->Freq;
    UPDATE1(p);                             return TRUE;
}
BOOL _FASTCALL PPM_CONTEXT::encodeSymbol2(int symbol)
{
    int f, sym, LoCnt, i=NumStats-NumMasked;
    WORD* pes;
    if (NumStats != 256) {
        pes=Esc2Summ[NS2Indx[i-1]]+2*(SummFreq < 4*NumStats)+
                (Lesser->NumStats-NumStats > 2*i);
        WORD EscFreq=GET_MEAN(*pes,PERIOD_BITS,1);
        *pes -= EscFreq;                    SubRange.scale=EscFreq+(EscFreq == 0);
    } else {
        pes=Esc2Summ[0];                    SubRange.scale=1;
    }
    STATE* p=Stats-1;                       LoCnt=0;
    do {
        do { f=(++p)->Freq; } while ( CharMask[sym=p->Symbol] );
        if (sym == symbol)                  goto SYMBOL_FOUND;
        LoCnt += f;
    } while ( --i );
    SubRange.HighCount=(SubRange.scale += (SubRange.LowCount=LoCnt));
    *pes += SubRange.scale;                 NumMasked = NumStats;
    for (CharMask[p->Symbol]=1;p != Stats;CharMask[(--p)->Symbol]=1)
            ;
    return FALSE;
SYMBOL_FOUND:
    SubRange.LowCount=LoCnt;                SubRange.HighCount=(LoCnt += f);
    if ( --i ) {
        STATE* p1=p;
        do {
            do { f=(++p1)->Freq; } while ( CharMask[p1->Symbol] );
            LoCnt += f;
        } while ( --i );
    }
    SubRange.scale += LoCnt;
    UPDATE2(p);
    return TRUE;
}
BOOL _FASTCALL PPM_CONTEXT::decodeSymbol2()
{
    int f, count, HiCnt, i=NumStats-NumMasked;
    WORD* pes;
    if (NumStats != 256) {
        pes=Esc2Summ[NS2Indx[i-1]]+2*(SummFreq < 4*NumStats)+
                (Lesser->NumStats-NumStats > 2*i);
        WORD EscFreq=GET_MEAN(*pes,PERIOD_BITS,1);
        count=EscFreq+(EscFreq == 0);       *pes -= EscFreq;
    } else {
        count=1;                            pes=Esc2Summ[0];
    }
    STATE* ps[256], ** pps=ps, * p=Stats-1;
    do {
        do { f=(++p)->Freq; } while ( CharMask[p->Symbol] );
        count += f;                         *pps++ = p;
    } while ( --i );
    SubRange.scale=count;                   *pps=NULL;
    count=ariGetCurrentCount();             HiCnt=0;
    p=*(pps=ps);
    do {
        if ((HiCnt += p->Freq) > count)     goto SYMBOL_FOUND;
    } while ((p=*++pps) != NULL);
    SubRange.LowCount=HiCnt;                SubRange.HighCount=SubRange.scale;
    *pes += SubRange.scale;                 NumMasked = NumStats;
    p=*(pps=ps);
    do CharMask[p->Symbol]=1; while ((p=*++pps) != NULL);
    return FALSE;
SYMBOL_FOUND:
    SubRange.LowCount = (SubRange.HighCount=HiCnt)-p->Freq;
    UPDATE2(p);
    return TRUE;
}
void EncodeFile(FILE* EncodedFile,FILE* DecodedFile,int MaxOrder)
{
    ariInitEncoder(EncodedFile);
    ::MaxOrder=MaxOrder;                    pCState=NULL;
    clock_t StartClock=StartModel();
    BOOL SymbolEncoded;
    for (int ns=MinContext->NumStats,i=0x10000; ;ns=MinContext->NumStats) {
        ARI_ENC_NORMALIZE(EncodedFile);
        int c = getc(DecodedFile);
        if (ns == 1) {
            SymbolEncoded=MinContext->encodeBinSymbol(c);
            ariShiftEncodeSymbol(INT_BITS+PERIOD_BITS);
        } else {
            SymbolEncoded=MinContext->encodeSymbol1(c);
            ariEncodeSymbol();
        }
        while ( !SymbolEncoded ) {
            ARI_ENC_NORMALIZE(EncodedFile);
            UpperContext=MinContext;
            do {
                MinContext=MinContext->Lesser;
                if ( !MinContext )          goto STOP_ENCODING;
            } while (MinContext->NumStats == NumMasked);
            SymbolEncoded=MinContext->encodeSymbol2(c);
            ariEncodeSymbol();
        }
        if (MaxContext == MinContext)       MinContext=MaxContext=pCState->Successor;
        else                                UpdateModel();
        if ( !--i ) {
            i=0x20000;
            PrintInfo(DecodedFile,EncodedFile,NumConts,
                ((clock()-StartClock) << 10)/int(CLK_TCK));
        }
    }
STOP_ENCODING:
    StopModel();
    ARI_FLUSH_ENCODER(EncodedFile);
    PrintInfo(DecodedFile,EncodedFile,NumConts,
            ((clock()-StartClock) << 10)/int(CLK_TCK));
}
void DecodeFile(FILE* DecodedFile,FILE* EncodedFile,int MaxOrder)
{
    ariInitDecoder(EncodedFile);
    ::MaxOrder=MaxOrder;                    pCState=NULL;
    clock_t StartClock=StartModel();
    BOOL SymbolDecoded;
    for (int ns=MinContext->NumStats,i=0x10000; ;ns=MinContext->NumStats) {
        ARI_DEC_NORMALIZE(EncodedFile);
        SymbolDecoded = (ns == 1) ? MinContext->decodeBinSymbol():
                                    MinContext->decodeSymbol1();
        ariRemoveSubrange();
        while ( !SymbolDecoded ) {
            ARI_DEC_NORMALIZE(EncodedFile);
            UpperContext=MinContext;
            do {
                MinContext=MinContext->Lesser;
                if ( !MinContext )          goto STOP_DECODING;
            } while (MinContext->NumStats == NumMasked);
            SymbolDecoded = MinContext->decodeSymbol2();
            ariRemoveSubrange();
        }
        putc(pCState->Symbol,DecodedFile);
        if (MaxContext == MinContext)       MinContext=MaxContext=pCState->Successor;
        else                                UpdateModel();
        if ( !--i ) {
            i=0x20000;
            PrintInfo(DecodedFile,EncodedFile,NumConts,
                ((clock()-StartClock) << 10)/int(CLK_TCK));
        }
    }
STOP_DECODING:
    StopModel();
    PrintInfo(DecodedFile,EncodedFile,NumConts,
            ((clock()-StartClock) << 10)/int(CLK_TCK));
}
