#include "mex.h"
#include "matrix.h"
#include <stack>
#include <algorithm>

void mexFunction( int nlhs, mxArray *plhs[],
				 int nrhs, const mxArray *prhs[] )
{
	double *adjmat = mxGetPr( prhs[0] );
	const mwSize *dims = mxGetDimensions( prhs[0] );
	// mexPrintf( "dims[0]: %d, dims[1]: %d\n", dims[0], dims[1] );
	int numRegion = dims[0];
	if( numRegion != dims[1] )
	{
		mexErrMsgTxt( "Error, adjmat should be square.\n" );
	}

	plhs[0] = mxCreateDoubleMatrix( numRegion, 1, mxREAL );
	double *labels = mxGetPr( plhs[0] );
	for( int ix = 0; ix < numRegion; ++ix )
		labels[ix] = 0;

	double *ind = mxGetPr( prhs[1] );
	const mwSize *dm = mxGetDimensions( prhs[1] );
	// mexPrintf( "dm[0]: %d, dm[1]: %d\n", dm[0], dm[1] );
	int curLabel = 1;
	for( int ix = 0; ix < dm[0]; ++ix )
	{
		int index = ind[ix];
		int x = static_cast<int>( index / numRegion );
		int y = index % numRegion;
		// mexPrintf( "index: %d, x: %d, y: %d, adjmat[%d]: %.1f\n", index, x, y, index, adjmat[index] );
		if( labels[x] != 0 && labels[y] != 0 )
		{
			if( labels[x] != labels[y] )
			{
				mexPrintf( "x: %d, y: %d, labels[x]: %.1f, labels[y]: %.1f\n", x, y, labels[x], labels[y] );
				mexPrintf( "curLabel: %d\n", curLabel );
				mexErrMsgTxt( "Error in merging." );
			}
			else
				continue;
		} 

		std::stack<int> toCheck;
		toCheck.push( y );
		while( !toCheck.empty() )
		{
			int tempY = toCheck.top();
			toCheck.pop();
			if( labels[tempY] != 0 )
				continue;
			labels[tempY] = curLabel;
			for( int jx = 0; jx < numRegion; ++jx )
			{
				if( adjmat[jx*numRegion+tempY] != 0 )
					toCheck.push( jx );
			}
			// mexPrintf( "\t\\\\\\\\\\ size of toCheck: %d\n", toCheck.size() );
			//std::stack<int>::iterator it = toCheck.begin

		}
		curLabel++;
		//mexPrintf( "\t*** ix = %d\n", ix );
		//for( int kx = 0; kx < numRegion; ++kx )
		//	mexPrintf( "%.1f ", labels[kx] );
		//mexPrintf( "\n" );
	}
}