/****************************************************************************
 *  This file is part of PPMD project                                       *
 *  Written and distributed to public domain by Dmitry Shkarin 1997,        *
 *  1999-2000                                                               *
 *  Contents: model description and encoding/decoding routines              *
 ****************************************************************************/
#include <string.h>
#include <time.h>
#include "ppmd.h"
#include "suballoc.h"
#pragma hdrstop
#include "coder.hpp"

const int ALFA=16, INT_BITS=7, PERIOD_BITS=7, MAX_FREQ=124;
const int INTERVAL=1 << INT_BITS, BIN_SCALE=INTERVAL << PERIOD_BITS;

#pragma pack(1)
static struct PPM_CONTEXT {
    WORD NumStats,SummFreq;                     // sizeof(WORD) > sizeof(BYTE)
    struct STATE { BYTE Symbol, Freq; PPM_CONTEXT* Successor; } * Stats;
    PPM_CONTEXT* Lesser;
    inline PPM_CONTEXT(STATE* pStats, PPM_CONTEXT* LesserContext);
    inline BOOL    encodeBinSymbol(int symbol); // MaxOrder:
    inline BOOL      encodeSymbol1(int symbol); //  ABCD    context
    _BIG_INLINE BOOL  innerEncode1(int symbol); //   BCD    lesser
    _BIG_INLINE BOOL encodeSymbol2(int symbol); //   BCDE   successor
    inline BOOL              decodeBinSymbol(); // other orders:
    inline BOOL                decodeSymbol1(); //   BCD    context
    _BIG_INLINE BOOL  innerDecode1(int  count); //    CD    lesser
    _BIG_INLINE BOOL           decodeSymbol2(); //   BCDE   successor
    inline void              update1(STATE* p);
    inline void              update2(STATE* p);
    inline UINT          getEscFreq2(int Diff);
    void                             rescale();
    void* operator new(size_t ) { return AllocUnits(); }
    STATE* oneState() const { return (STATE*) (((BYTE*)this)+sizeof(WORD)); }
} * MinContext, * MaxContext;
static struct SEE2_CONTEXT {
    WORD Summ;
    BYTE Shift, Count;
    void init(int InitVal) { Summ=InitVal << (Shift=PERIOD_BITS-3); Count=16; }
    UINT getMean() {
        UINT RetVal=Summ >> Shift;          Summ -= RetVal;
        RetVal &= 0x03FF;                   return RetVal+(RetVal == 0);
    }
    void update() {
        if (Shift < PERIOD_BITS && MinContext->NumStats > 1 && --Count == 0) {
            Summ += Summ;                   Count=2 << ++Shift;
        }
    }
} SEE2Cont[44][8], * psee2c;
#pragma pack()

static PPM_CONTEXT::STATE* pCState;
static DWORD CharMask[256], SymbolCounter, NumConts;
static int NumMasked, InitEsc, MaxOrder;
static BYTE NS2Indx[256], NS2BSIndx[256], PrevSuccess;
static WORD BinSumm[128][16];

inline PPM_CONTEXT::PPM_CONTEXT(STATE* pStats,PPM_CONTEXT *LesserContext):
    NumStats(0), Lesser(LesserContext)
{
    pStats->Successor=this;                 NumConts++;
}
static clock_t StartModel()
{
    int i,k;
    InitSubAllocator();
    PPM_CONTEXT* Order0 = (PPM_CONTEXT*) AllocUnits();
    Order0->Lesser=NULL;
    Order0->NumStats=256;                   Order0->SummFreq=257;
    Order0->Stats = (PPM_CONTEXT::STATE*) AllocBlk(sizeof(PPM_CONTEXT::STATE)*256);
    for (PrevSuccess=i=0;i < 256;i++) {
        Order0->Stats[i].Symbol=i;          Order0->Stats[i].Freq=1;
        Order0->Stats[i].Successor=NULL;
    }
    PPM_CONTEXT::STATE* p=(MaxContext=Order0)->Stats;
    for (NumConts=i=1; ;i++) {
        MaxContext = new PPM_CONTEXT(p,MaxContext);
        if (i == MaxOrder)                  break;
        MaxContext->NumStats=1;             p=MaxContext->oneState();
        p->Symbol = 0;                      p->Freq = 1;
    }
    MinContext=MaxContext->Lesser;          psee2c=SEE2Cont[43];
    if ( pCState )                          return 0;
static const WORD InitBinEsc[16] = {
                0x3CDD,0x1F3F,0x59BF,0x48F3,0x5FFB,0x5545,0x63D1,0x5D9D,
                0x64A1,0x5ABC,0x6632,0x6051,0x68F6,0x549B,0x6BCA,0x3AB0, };
    for (i=0;i < 128;i++)
        for (k=0;k < 16;k++)
            BinSumm[i][k]=BIN_SCALE-InitBinEsc[k]/(i+2);
    for (i=0;i <   6;i++)                   NS2BSIndx[i]=2*i;
    for (   ;i <  50;i++)                   NS2BSIndx[i]=12;
    for (   ;i < 256;i++)                   NS2BSIndx[i]=14;
    for (i=0;i < 44;i++)
            for (k=0;k < 8;k++)             SEE2Cont[i][k].init(4*i+8);
    for (i=0;i < 4;i++)                     NS2Indx[i]=i;
    for ( ;i < 4+8;i++)                     NS2Indx[i]=4+((i-4) >> 1);
    for ( ;i < 4+8+32;i++)                  NS2Indx[i]=4+4+((i-4-8) >> 2);
    for ( ;i < 256;i++)                     NS2Indx[i]=4+4+8+((i-4-8-32) >> 3);
    memset(CharMask,0,sizeof(CharMask));    SymbolCounter=1;
    return clock();
}
inline void StopModel()
{
}
void PPM_CONTEXT::rescale()
{
    int OldNS=NumStats, i=NumStats-1, Adder, EscFreq;
    STATE* p1, * p;
    for (p=pCState;p != Stats;p--)          SWAP(p[0],p[-1]);
    Stats->Freq += 4;                       SummFreq += 4;
    EscFreq=SummFreq-p->Freq;               Adder=(this != MaxContext);
    SummFreq = (p->Freq=(p->Freq+Adder) >> 1);
    do {
        EscFreq -= (++p)->Freq;
        SummFreq += (p->Freq=(p->Freq+Adder) >> 1);
        if (p[0].Freq > p[-1].Freq) {
            STATE tmp=*(p1=p);
            do { p1[0]=p1[-1]; } while (--p1 != Stats && tmp.Freq > p1[-1].Freq);
            *p1=tmp;
        }
    } while ( --i );
    if (p->Freq == 0) {
        do { i++; } while ((--p)->Freq == 0);
        EscFreq += i;
        if ((NumStats -= i) == 1) {
            STATE tmp=*Stats;
            do { tmp.Freq-=(tmp.Freq >> 1); EscFreq>>=1; } while (EscFreq > 1);
            FreeBlk(Stats,OldNS);           *(pCState=oneState())=tmp;
            return;
        }
    }
    SummFreq += (EscFreq -= (EscFreq >> 1));
    Stats=(STATE*) ShrinkBlk(Stats,OldNS,NumStats,ALFA*EscFreq > SummFreq);
    pCState=Stats;
}
static void UpdateModel()
{
    PPM_CONTEXT::STATE* p, * ps[MAX_O], ** pps=ps;
    PPM_CONTEXT* pc;
    UINT cf,f0=pCState->Freq,sf=MinContext->SummFreq-MinContext->NumStats;
    UINT InitFreq,s0=sf-(cf=f0-1);
#if !defined(__I_LIKE_FAST_AND_DIRTY_PROGRAMS)
    if (f0 < MAX_FREQ/4 && (pc=MinContext->Lesser) != NULL) {
        if (pc->NumStats == 1)
                pc->oneState()->Freq += (pc->oneState()->Freq < 32);
        else {
            psee2c->update();
            if ((p=pc->Stats)->Symbol != pCState->Symbol) {
                do { p++; } while (p->Symbol != pCState->Symbol);
                if (p[0].Freq >= p[-1].Freq) {
                    SWAP(p[0],p[-1]);       p--;
                }
            }
            if (p->Freq < 7*MAX_FREQ/8) {
                pc->SummFreq += 2;          p->Freq += 2;
            }
        }
    }
#else
    psee2c->update();
#endif /* !defined(__I_LIKE_FAST_AND_DIRTY_PROGRAMS) */
    if ((pc=MaxContext)->NumStats == 0) {
        InitFreq=(MinContext->NumStats == 1)?(f0):
                (1+((cf <= s0)?(4*cf > s0):((cf+s0-1)/s0)));
        do {
            p = pc->oneState();             pc->NumStats = 1;
            p->Symbol = pCState->Symbol;    p->Freq = InitFreq;
            pc = pc->Lesser;                *pps++ = p;
        } while (pc->NumStats == 0);
    }
    for ( ;pc != MinContext;pc=pc->Lesser, *pps++ = p) {
        if (pc->NumStats == 1) {
            PPM_CONTEXT::STATE tmp=*(pc->oneState());
            pc->Stats=(PPM_CONTEXT::STATE*) AllocUnits();
            if ( !pc->Stats )               goto RESTART_MODEL;
            if (tmp.Freq < MAX_FREQ/4-1)    tmp.Freq += tmp.Freq;
            else                            tmp.Freq  = MAX_FREQ-4;
            pc->SummFreq=InitEsc+(pc->Stats[0]=tmp).Freq+(MinContext->NumStats > 3);
        } else {
            pc->Stats=(PPM_CONTEXT::STATE*) ExpandBlk(pc->Stats,
                pc->NumStats,ALFA*(pc->NumStats+1) > pc->SummFreq);
            if ( !pc->Stats )               goto RESTART_MODEL;
            pc->SummFreq += (2*pc->NumStats < MinContext->NumStats)+2*(
                (4*pc->NumStats <= MinContext->NumStats) &
                (pc->SummFreq <= 8*pc->NumStats));
        }
        cf=2*f0*(pc->SummFreq+6);           sf=s0+pc->SummFreq;
        if (cf < 6*sf) {
            InitFreq=1+(cf >= sf)+(cf >= 4*sf);
            pc->SummFreq += 3;
        } else {
            InitFreq=4+(cf >= 9*sf)+(cf >= 12*sf)+(cf >= 15*sf);
            pc->SummFreq += InitFreq;
        }
        p=pc->Stats+pc->NumStats++;
        p->Symbol = pCState->Symbol;        p->Freq = InitFreq;
    }
    if ( pCState->Successor )               MinContext=pc=pCState->Successor;
    else if ((pc = new PPM_CONTEXT(pCState,pc)) == NULL)
                    goto RESTART_MODEL;
    while (--pps != ps)
            if ((pc = new PPM_CONTEXT(*pps,pc)) == NULL)
                    goto RESTART_MODEL;
    MaxContext=(*pps)->Successor=pc;
#if !defined(__I_LIKE_FAST_AND_DIRTY_PROGRAMS)
    if (MinContext->NumStats == 1) {
        InitFreq=MinContext->oneState()->Freq;
        for (pc=MinContext;(pc=pc->Lesser)->NumStats == 1; )
                if (pc->oneState()->Freq > InitFreq)
                        InitFreq=pc->oneState()->Freq;
        if (pc->Stats->Symbol == MinContext->oneState()->Symbol) {
            s0=pc->SummFreq-pc->NumStats+1-(cf=pc->Stats->Freq);
            cf=1+((cf <= s0)?(2*cf > s0):((cf+s0-1)/s0));
            if (cf > InitFreq)              InitFreq=cf;
        }
        MinContext->oneState()->Freq=InitFreq;
    }
#endif /* !defined(__I_LIKE_FAST_AND_DIRTY_PROGRAMS) */
    return;
RESTART_MODEL:
    StartModel();
}
// Tabulated escapes for exponential symbol distribution
static const BYTE ExpEscape[16]={ 25,14, 9, 7, 5, 5, 4, 4, 4, 3, 3, 3, 2, 2, 2, 2 };
#define GET_MEAN(SUMM,SHIFT,ROUND) ((SUMM+(1 << (SHIFT-ROUND))) >> (SHIFT))
inline BOOL PPM_CONTEXT::encodeBinSymbol(int symbol)
{
    STATE* p=oneState();
    WORD& bs=BinSumm[p->Freq-1][PrevSuccess+NS2BSIndx[Lesser->NumStats-1]];
    if (p->Symbol == symbol) {
        SubRange.LowCount=0;                p->Freq += ((pCState=p)->Freq < 128);
        SubRange.HighCount=bs;              bs += INTERVAL-GET_MEAN(bs,PERIOD_BITS,2);
        PrevSuccess=1;                      return TRUE;
    } else {
        SubRange.LowCount=bs;               bs -= GET_MEAN(bs,PERIOD_BITS,2);
        SubRange.HighCount=BIN_SCALE;       InitEsc=ExpEscape[bs >> 10];
        NumMasked=1;                        CharMask[p->Symbol]=SymbolCounter;
        PrevSuccess=0;                      return FALSE;
    }
}
inline BOOL PPM_CONTEXT::decodeBinSymbol()
{
    int count=ariGetCurrentShiftCount(INT_BITS+PERIOD_BITS);
    STATE* p=oneState();
    WORD& bs=BinSumm[p->Freq-1][PrevSuccess+NS2BSIndx[Lesser->NumStats-1]];
    if (count < bs) {
        SubRange.LowCount=0;                p->Freq += ((pCState=p)->Freq < 128);
        SubRange.HighCount=bs;              bs += INTERVAL-GET_MEAN(bs,PERIOD_BITS,2);
        PrevSuccess=1;                      return TRUE;
    } else {
        SubRange.LowCount=bs;               bs -= GET_MEAN(bs,PERIOD_BITS,2);
        SubRange.HighCount=BIN_SCALE;       InitEsc=ExpEscape[bs >> 10];
        NumMasked=1;                        CharMask[p->Symbol]=SymbolCounter;
        PrevSuccess=0;                      return FALSE;
    }
}
inline BOOL PPM_CONTEXT::encodeSymbol1(int symbol)
{
    SubRange.scale=SummFreq;
    if (Stats->Symbol == symbol) {
        PrevSuccess=(2*Stats->Freq > SummFreq);
        SubRange.HighCount=(pCState = Stats)->Freq;
        SummFreq += 4;                      Stats->Freq += 4;
        if (Stats->Freq > MAX_FREQ)         rescale();
        SubRange.LowCount=0;                return TRUE;
    }
    PrevSuccess=0;
    return innerEncode1(symbol);
}
inline BOOL PPM_CONTEXT::decodeSymbol1()
{
    SubRange.scale=SummFreq;
    int count=ariGetCurrentCount();
    if (Stats->Freq > count) {
        PrevSuccess=(2*Stats->Freq > SummFreq);
        SubRange.HighCount=(pCState = Stats)->Freq;
        SummFreq += 4;                      Stats->Freq += 4;
        if (Stats->Freq > MAX_FREQ)         rescale();
        SubRange.LowCount=0;                return TRUE;
    }
    PrevSuccess=0;
    return innerDecode1(count);
}
inline void PPM_CONTEXT::update1(STATE* p)
{
    SummFreq += 4;                          (pCState=p)->Freq += 4;
    if (p[0].Freq > p[-1].Freq) {
        SWAP(p[0],p[-1]);                   pCState=--p;
        if (p->Freq > MAX_FREQ)             rescale();
    }
}
_BIG_INLINE BOOL PPM_CONTEXT::innerEncode1(int symbol)
{
    STATE* p=Stats;
    int LoCnt=p->Freq, i=NumStats-1;
    do {
        if ((++p)->Symbol == symbol)        goto SYMBOL_FOUND;
        LoCnt += p->Freq;
    } while ( --i );
    SubRange.LowCount=LoCnt;                CharMask[p->Symbol]=SymbolCounter;
    do { CharMask[(--p)->Symbol]=SymbolCounter; } while (p != Stats);
    SubRange.HighCount=SubRange.scale;
    NumMasked = NumStats;                   return FALSE;
SYMBOL_FOUND:
    SubRange.HighCount=(SubRange.LowCount=LoCnt)+p->Freq;
    update1(p);                             return TRUE;
}
_BIG_INLINE BOOL PPM_CONTEXT::innerDecode1(int count)
{
    STATE* p=Stats;
    int HiCnt=p->Freq, i=NumStats-1;
    do {
        if ((HiCnt+=(++p)->Freq) > count)   goto SYMBOL_FOUND;
    } while ( --i );
    SubRange.LowCount=HiCnt;                CharMask[p->Symbol]=SymbolCounter;
    do { CharMask[(--p)->Symbol]=SymbolCounter; } while (p != Stats);
    SubRange.HighCount=SubRange.scale;
    NumMasked = NumStats;                   return FALSE;
SYMBOL_FOUND:
    SubRange.LowCount=(SubRange.HighCount=HiCnt)-p->Freq;
    update1(p);                             return TRUE;
}
inline void PPM_CONTEXT::update2(STATE* p)
{
    SummFreq += 4;                          (pCState=p)->Freq += 4;
    if (p->Freq > MAX_FREQ)                 rescale();
}
inline UINT PPM_CONTEXT::getEscFreq2(int Diff)
{
    if (NumStats != 256) {
        int tmp=(Diff < Lesser->NumStats-NumStats);
        psee2c=SEE2Cont[NS2Indx[Diff-1]]+4*tmp+2*(SummFreq < 11*NumStats)+
                (NumMasked > Diff);
        return psee2c->getMean();
    } else {
        psee2c=SEE2Cont[43];                return 1;
    }
}
_BIG_INLINE BOOL PPM_CONTEXT::encodeSymbol2(int symbol)
{
    int HiCnt, i=NumStats-NumMasked;
    SubRange.scale=getEscFreq2(i);
    STATE* p=Stats-1;                       HiCnt=0;
    do {
        do { p++; } while (CharMask[p->Symbol] == SymbolCounter);
        HiCnt += p->Freq;
        if (p->Symbol == symbol)            goto SYMBOL_FOUND;
        CharMask[p->Symbol]=SymbolCounter;
    } while ( --i );
    SubRange.HighCount=(SubRange.scale += (SubRange.LowCount=HiCnt));
    psee2c->Summ += SubRange.scale;         NumMasked = NumStats;
    return FALSE;
SYMBOL_FOUND:
    SubRange.LowCount = (SubRange.HighCount=HiCnt)-p->Freq;
    if ( --i ) {
        STATE* p1=p;
        do {
            do { p1++; } while (CharMask[p1->Symbol] == SymbolCounter);
            HiCnt += p1->Freq;
        } while ( --i );
    }
    SubRange.scale += HiCnt;
    update2(p);                             return TRUE;
}
_BIG_INLINE BOOL PPM_CONTEXT::decodeSymbol2()
{
    int count, HiCnt, i=NumStats-NumMasked;
    count=getEscFreq2(i);
    STATE* ps[256], ** pps=ps, * p=Stats-1;
    do {
        do { p++; } while (CharMask[p->Symbol] == SymbolCounter);
        count += p->Freq;                   *pps++ = p;
    } while ( --i );
    *pps=NULL;                              SubRange.scale=count;
    count=ariGetCurrentCount();
    HiCnt=0;                                p=*(pps=ps);
    do {
        CharMask[p->Symbol]=SymbolCounter;
        if ((HiCnt += p->Freq) > count)     goto SYMBOL_FOUND;
    } while ((p=*++pps) != NULL);
    SubRange.LowCount=HiCnt;                SubRange.HighCount=SubRange.scale;
    psee2c->Summ += SubRange.scale;         NumMasked = NumStats;
    return FALSE;
SYMBOL_FOUND:
    SubRange.LowCount = (SubRange.HighCount=HiCnt)-p->Freq;
    update2(p);                             return TRUE;
}
void EncodeFile(FILE* EncodedFile,FILE* DecodedFile,int MaxOrder)
{
    ariInitEncoder(EncodedFile);
    ::MaxOrder=MaxOrder;                    pCState=NULL;
    clock_t StartClock=StartModel();
    BOOL SymbolEncoded;
    for (int ns=MinContext->NumStats; ;ns=MinContext->NumStats) {
        ARI_ENC_NORMALIZE(EncodedFile);
        int c = getc(DecodedFile);
        if (ns == 1) {
            SymbolEncoded=MinContext->encodeBinSymbol(c);
            ariShiftEncodeSymbol(INT_BITS+PERIOD_BITS);
        } else {
            SymbolEncoded=MinContext->encodeSymbol1(c);
            ariEncodeSymbol();
        }
        while ( !SymbolEncoded ) {
            ARI_ENC_NORMALIZE(EncodedFile);
            do {
                MinContext=MinContext->Lesser;
                if ( !MinContext )          goto STOP_ENCODING;
            } while (MinContext->NumStats == NumMasked);
            SymbolEncoded=MinContext->encodeSymbol2(c);
            ariEncodeSymbol();
        }
        if (MaxContext == MinContext)       MinContext=MaxContext=pCState->Successor;
        else {
            UpdateModel();
            if ((++SymbolCounter & 0xFFFF) == 0)
                    PrintInfo(DecodedFile,EncodedFile,NumConts,
                        ((clock()-StartClock) << 10)/int(CLK_TCK));
        }
    }
STOP_ENCODING:
    StopModel();
    ARI_FLUSH_ENCODER(EncodedFile);
    PrintInfo(DecodedFile,EncodedFile,NumConts,
            ((clock()-StartClock) << 10)/int(CLK_TCK));
}
void DecodeFile(FILE* DecodedFile,FILE* EncodedFile,int MaxOrder)
{
    ARI_INIT_DECODER(EncodedFile);
    ::MaxOrder=MaxOrder;                    pCState=NULL;
    clock_t StartClock=StartModel();
    BOOL SymbolDecoded;
    for (int ns=MinContext->NumStats; ;ns=MinContext->NumStats) {
        ARI_DEC_NORMALIZE(EncodedFile);
        SymbolDecoded = (ns == 1) ? MinContext->decodeBinSymbol():
                                    MinContext->decodeSymbol1();
        ariRemoveSubrange();
        while ( !SymbolDecoded ) {
            ARI_DEC_NORMALIZE(EncodedFile);
            do {
                MinContext=MinContext->Lesser;
                if ( !MinContext )          goto STOP_DECODING;
            } while (MinContext->NumStats == NumMasked);
            SymbolDecoded = MinContext->decodeSymbol2();
            ariRemoveSubrange();
        }
        putc(pCState->Symbol,DecodedFile);
        if (MaxContext == MinContext)       MinContext=MaxContext=pCState->Successor;
        else {
            UpdateModel();
            if ((++SymbolCounter & 0xFFFF) == 0)
                    PrintInfo(DecodedFile,EncodedFile,NumConts,
                        ((clock()-StartClock) << 10)/int(CLK_TCK));
        }
    }
STOP_DECODING:
    StopModel();
    PrintInfo(DecodedFile,EncodedFile,NumConts,
            ((clock()-StartClock) << 10)/int(CLK_TCK));
}
