// PPMLanguageModel.h
//
/////////////////////////////////////////////////////////////////////////////
//
// Copyright (c) 1999-2002 David Ward
//
/////////////////////////////////////////////////////////////////////////////

#include <math.h>
#include "PPMLanguageModel.h"

using namespace Dasher;
using namespace std;

// static TCHAR debug[256];
typedef unsigned long ulong;

////////////////////////////////////////////////////////////////////////
/// PPMnode definitions 
////////////////////////////////////////////////////////////////////////

CPPMLanguageModel::CPPMnode *CPPMLanguageModel::CPPMnode::find_symbol(int sym)
// see if symbol is a child of node
{
	//  printf("finding symbol %d at node %d\n",sym,node->id);
	CPPMnode *found=child;
	while (found) {
		if (found->symbol==sym)
			return found;
		found=found->next;
	}
	return 0;
}


CPPMLanguageModel::CPPMnode * CPPMLanguageModel::CPPMnode::add_symbol_to_node(int sym,int *update)
{
	CPPMnode *born,*search;
	search=find_symbol(sym);
	if (!search) {
		born = new CPPMnode(sym);
		born->next=child;
		child=born;
		//   node->count=1;
		return born;		
	} else {
		if (*update) {   // perform update exclusions
			search->count++;
			*update=0;
		}
		return search;
	}
	
}


/////////////////////////////////////////////////////////////////////
// CPPMLanguageModel defs
/////////////////////////////////////////////////////////////////////

CPPMLanguageModel::CPPMLanguageModel(CAlphabet *_alphabet,int _normalization)
	: CLanguageModel(_alphabet,_normalization)
{
	root=new CPPMnode(-1);
	m_rootcontext=new CPPMContext(root,0);
}


CPPMLanguageModel::~CPPMLanguageModel()
{
	delete root;
}


bool CPPMLanguageModel::GetProbs(CContext *context,vector<unsigned int> &probs,double addprob)
	// get the probability distribution at the context
{
	// seems like we have to have this hack for VC++
	CPPMContext *ppmcontext=static_cast<CPPMContext *> (context);
	
	
	int modelchars=GetNumberModelChars();
	int norm=CLanguageModel::normalization();
	probs.resize(modelchars);
	CPPMnode *temp,*s; 
	int loop,total;
	int sym; 
	ulong spent=0; 
	ulong size_of_slice;
	bool *exclusions=new bool [modelchars];
	ulong uniform=modelchars;
	ulong tospend=norm-uniform;
	temp=ppmcontext->head;
	for (loop=0; loop <modelchars; loop++) {   /* set up the exclusions array */
		probs[loop]=0;
		exclusions[loop]=0;
	}
	while (temp!=0) {
		//	Usprintf(debug,TEXT("tospend %u\n"),tospend);
		//	DebugOutput(TEXT("round\n"));
		total=0;
		s=temp->child;
		while (s) {
			sym=s->symbol; 
			if (!exclusions[s->symbol])
				total=total+s->count;
			s=s->next;
		}
		if (total) {
			//	Usprintf(debug,TEXT"escape %u\n"),tospend*
			size_of_slice=tospend;
			s=temp->child;
			while (s) {
				if (!exclusions[s->symbol]) {
					exclusions[s->symbol]=1;
					ulong p=size_of_slice*(2*s->count-1)/2/ulong(total);
					probs[s->symbol]+=p;
					tospend-=p;		
				}
				//				Usprintf(debug,TEXT("sym %u counts %d p %u tospend %u \n"),sym,s->count,p,tospend);	 
				//				DebugOutput(debug);
				s=s->next;
			}
		}
		temp = temp->vine;
	}
	//	Usprintf(debug,TEXT("Norm %u tospend %u\n"),Norm,tospend);
	//	DebugOutput(debug);
	
	size_of_slice=tospend;
	int symbolsleft=0;
	for (sym=1;sym<modelchars;sym++)
		if (!probs[sym])
			symbolsleft++;
		for (sym=1;sym<modelchars;sym++) 
			if (!probs[sym]) {
				ulong p=size_of_slice/symbolsleft;
				probs[sym]+=p;
				tospend-=p;
			}
			
			// distribute what's left evenly	
			tospend+=uniform;
			for (sym=1;sym<modelchars;sym++) {
				ulong p=tospend/(modelchars-sym);
				probs[sym]+=p;
				tospend-=p;
			}
			//	Usprintf(debug,TEXT("finaltospend %u\n"),tospend);
			//	DebugOutput(debug);
			
			// free(exclusions); // !!!
			// !!! NB by IAM: p577 Stroustrup 3rd Edition: "Allocating an object using new and deleting it using free() is asking for trouble"
			delete[] exclusions;
			return true;
}


void CPPMLanguageModel::AddSymbol(CPPMLanguageModel::CPPMContext &context,int symbol)
	// add symbol to the context
	// creates new nodes, updates counts
	// and leaves 'context' at the new context
{
	// sanity check
	if (symbol==0 || symbol>=GetNumberModelChars())
		return;
	
	CPPMnode *vineptr,*temp;
	int updatecnt=1;
	
	temp=context.head->vine;
	context.head=context.head->add_symbol_to_node(symbol,&updatecnt);
	vineptr=context.head;
	context.order++;
	
	while (temp!=0) {
		vineptr->vine=temp->add_symbol_to_node(symbol,&updatecnt);    
		vineptr=vineptr->vine;
		temp=temp->vine;
	}
	vineptr->vine=root;
	if (context.order>MAX_ORDER){
		context.head=context.head->vine;
		context.order--;
	}
}


// update context with symbol 'Symbol'
void CPPMLanguageModel::EnterSymbol(CContext* Context, modelchar Symbol)
{
	CPPMLanguageModel::CPPMContext& context = * static_cast<CPPMContext *> (Context);
	
	CPPMnode *find;
	CPPMnode *temp=context.head;
	
	while (context.head) {
		find =context.head->find_symbol(Symbol);
		if (find) {
			context.order++;
			context.head=find;
			//	Usprintf(debug,TEXT("found context %x order %d\n"),head,order);
			//	DebugOutput(debug);
			return;
		}
		context.order--;
		context.head=context.head->vine;
	}
	
	if (context.head==0) {
		context.head=root;
		context.order=0;
	}
	
}


void CPPMLanguageModel::LearnSymbol(CContext* Context, modelchar Symbol)
{
	CPPMLanguageModel::CPPMContext& context = * static_cast<CPPMContext *> (Context);
	AddSymbol(context, Symbol);
}


void CPPMLanguageModel::dumpSymbol(int symbol)
{
	if ((symbol <= 32) || (symbol >= 127))
		printf( "<%d>", symbol );
	else
		printf( "%c", symbol );
}


void CPPMLanguageModel::dumpString( char *str, int pos, int len )
	// Dump the string STR starting at position POS
{
	char cc;
	int p;
	for (p = pos; p<pos+len; p++) {
		cc = str [p];
		if ((cc <= 31) || (cc >= 127))
			printf( "<%d>", cc );
		else
			printf( "%c", cc );
	}
}


void CPPMLanguageModel::dumpTrie( CPPMLanguageModel::CPPMnode *t, int d )
	// diagnostic display of the PPM trie from node t and deeper
{
//TODO
/*
	dchar debug[256];
	int sym;
	CPPMnode *s;
	Usprintf( debug,TEXT("%5d %7x "), d, t );
	//TODO: Uncomment this when headers sort out
	//DebugOutput(debug);
	if (t < 0) // pointer to input
		printf( "                     <" );
	else {
		Usprintf(debug,TEXT( " %3d %5d %7x  %7x  %7x    <"), t->symbol,t->count, t->vine, t->child, t->next );
		//TODO: Uncomment this when headers sort out
		//DebugOutput(debug);
	}
	
	dumpString( dumpTrieStr, 0, d );
	Usprintf( debug,TEXT(">\n") );
	//TODO: Uncomment this when headers sort out
	//DebugOutput(debug);
	if (t != 0) {
		s = t->child;
		while (s != 0) {
			sym =s->symbol;
			
			dumpTrieStr [d] = sym;
			dumpTrie( s, d+1 );
			s = s->next;
		}
	}
*/
}


void CPPMLanguageModel::dump()
	// diagnostic display of the whole PPM trie
{
// TODO:
/*
	dchar debug[256];
	Usprintf(debug,TEXT(  "Dump of Trie : \n" ));
	//TODO: Uncomment this when headers sort out
	//DebugOutput(debug);
	Usprintf(debug,TEXT(   "---------------\n" ));
	//TODO: Uncomment this when headers sort out
	//DebugOutput(debug);
	Usprintf( debug,TEXT(  "depth node     symbol count  vine   child      next   context\n") );
	//TODO: Uncomment this when headers sort out
	//DebugOutput(debug);
	dumpTrie( root, 0 );
	Usprintf( debug,TEXT(  "---------------\n" ));
	//TODO: Uncomment this when headers sort out
	//DebugOutput(debug);
	Usprintf(debug,TEXT( "\n" ));
	//TODO: Uncomment this when headers sort out
	//DebugOutput(debug);
*/
}