/*
 * This source was translated from a C++ implementation found in
 * the MaART (Music and Audio Retrieval Tools) software package.
 * MaART can be found on sourceforge at http://maart.sourceforge.net
 * 
 * Translated by Andrew Matheny 
 */

package graph;

import java.util.*;
import java.io.*;
import java.text.DecimalFormat;


public class FastMap
{
	public ArrayList<ArrayList<Double>> _origObjects;
	private double[][] _mappedObjects;
	private int[][] _pivotArray;
	private int _k;
	private int _numInstances;
	private ArrayList<String> _idLookup;

	private int _distanceFn = 0;
	public static final int EUCLIDEAN = 0;
	public static final int SUM = 1;
	public static final int BASIC = 2;	

	public FastMap(ArrayList<ArrayList<Double>> objects, int k)
	{
		_k = k;
		_numInstances = objects.size();
		_origObjects = objects;
		_idLookup = null;
		_pivotArray = new int[_k][2];
		_mappedObjects = new double[_numInstances][k];
	}
	
	public FastMap(ArrayList<ArrayList<Double>> objects, int k, ArrayList<String> lookup)
	{
		_k = k;
		_numInstances = objects.size();
		_origObjects = objects;
		_idLookup = lookup;
		_pivotArray = new int[_k][2];
		_mappedObjects = new double[_numInstances][k];
	}

	public FastMap(ArrayList<ArrayList<Double>> objects, int k, int distanceFunc, ArrayList<String> lookup)
	{
		_k = k;
		_distanceFn = distanceFunc;
		_numInstances = objects.size();
		_origObjects = objects;
		_idLookup = lookup;
		_mappedObjects = new double[_numInstances][_k];
		_pivotArray = new int[_k][2];
	}
	
	public double[] getRow(String id)
	{
		if(_idLookup == null)
			return null;
		else
		{
			Integer x = _idLookup.indexOf(id);
			return _mappedObjects[x];
		}
	}
	
	public void setK(int k)
	{
		_k = k;
		_mappedObjects = new double[_numInstances][_k];
		_pivotArray = new int[_k][2];
	}
	
	public int getK()
	{
		return _k;
	}

	private DecimalFormat[] floatFormatters = new DecimalFormat[]{
			new DecimalFormat ("#,##0"),
			new DecimalFormat ("#,##0.0"),
			new DecimalFormat ("#,##0.00"),
			new DecimalFormat ("#,##0.000"),
			new DecimalFormat ("#,##0.0000"),
			new DecimalFormat ("#,##0.00000")
	};
	
	private String formatFloatNonSci(double value, int numDecimals) {

		if (numDecimals < floatFormatters.length) {
			return floatFormatters[numDecimals].format(value);
		}
		else
		{
			StringBuilder format = new StringBuilder ("#,##0.0");

			for (int i=2; i<=numDecimals; i++) {
				format.append('#');
			}

			return new DecimalFormat(format.toString()).format(value);
		}
	}
	
	public ArrayList<Double> MapNewInstance(ArrayList<Double> x)
	{
		//if(x.length != this._origObjects.get(0).size())
		//	return null;
		
		ArrayList<Double> xMapped = new ArrayList<Double>();
		for(int i = 0; i < _k; i++) { xMapped.add(0.0); }
		
		MapNewInstanceRecursive(_k, 0, x, xMapped);
		
		return xMapped;
	}
	
	public ArrayList<String> getNClosest(int n, ArrayList<Double> x)
	{
		Map<Double, ArrayList<Integer>> dtable = new HashMap<Double, ArrayList<Integer>>();
		ArrayList<Double> distances = new ArrayList<Double>();
		
		for(int i = 0; i < _numInstances; i++)
		{
			ArrayList<Double> iList = new ArrayList<Double>();
			for(int k = 0; k < _k; k++) { iList.add(k, _mappedObjects[i][k]); } 
			Double d = fmDistNew(iList, x, 0);
			
			if(dtable.containsKey(d))
			{
				ArrayList<Integer> tmp = dtable.get(d);
				tmp.add(i);
				dtable.put(d, tmp);
			}
			else
			{
				ArrayList<Integer> tmp = new ArrayList<Integer>();
				tmp.add(i);
				dtable.put(d, tmp);
				distances.add(d);
			}
		}
		
		Collections.sort(distances);
		
		ArrayList<String> ids = new ArrayList<String>();
		int i = 0;
		while(ids.size() < n)
		{
			ArrayList<Integer> list = dtable.get(distances.get(i++));
			for(Integer idx:list)
			{
				ids.add(_idLookup.get(idx.intValue())  + " "+ distances.get(i));
				 
				if(ids.size() == n)
					break;
			}
		}
		
		return ids;
	}
	
	public void MapNewInstanceRecursive(int k, int column, ArrayList<Double> x, ArrayList<Double> xMapped)
	{
		if(k == 0)
			return;
		
		int a = _pivotArray[k-1][0];
		int b = _pivotArray[k-1][1];

		if ( fmDist(_origObjects, a, b, column) == 0.0)
		{
			// set X[ i, col#] =0 for every i and return
			// since all inter-object distances are 0

			for (int n=0; n<k; n++)
				xMapped.set(column+n, 0.0);

			return;
		}
		
		double dab = fmDist(_origObjects, a, b, column);

		double dai = fmDistNew(_origObjects.get(a), x, column);
		double dbi = fmDistNew(_origObjects.get(b), x, column);

		xMapped.set(column,((dai*dai) + (dab*dab) - (dbi*dbi)) / (dab * 2.0));

		MapNewInstanceRecursive(k-1, column+1, x, xMapped);
		
	}


	public void printData(String fileName)
	{
		BufferedWriter out;
		
		try
		{
			out = new BufferedWriter(new FileWriter(fileName));
			
			char ch = 'a';
			System.out.println("k = " + _k + " numInstances= " + _numInstances);
			for(int i = 0; i<_k; i++)
			{
				out.write("@attribute " + ch++ + " real \n");
			}
			out.write("@data \n");
	
			for(int i = 0; i < _numInstances; i++)
			{
				if(_idLookup != null) { out.write(_idLookup.get(i) + ", "); }
		
				for(int j = 0; j < _k; j++)
				{
					out.write(formatFloatNonSci(_mappedObjects[i][j], 6));
					if (j +1 < _k) out.write(", ");
				}
				
				out.write("\n");
			}
			
			out.close();
		}
		catch(IOException ex)
		{
			System.out.println("Error writing fastMap to file " + ex.getMessage());
		}
	}

	public void doMap()
	{
		doMap(_k, _origObjects, 0);
	}

	private void doMap(int k, ArrayList<ArrayList<Double>> objects, int column)
	{
		if(k == 0)
			return;

		int[] a = new int[1];
		int[] b = new int[1];
		chooseDistantObjects(objects, a, b, column);

		_pivotArray[k-1][0] = a[0];
		_pivotArray[k-1][1] = b[0];

		if ( fmDist(objects, a[0], b[0], column) == 0.0)
		{
			// set X[ i, col#] =0 for every i and return
			// since all inter-object distances are 0

			for (int i=0; i<objects.size(); i++)
				for (int n=0; n<k; n++)
					_mappedObjects[i][column+n] = 0.0;

			return;
		}
		double dab = fmDist(objects, a[0], b[0], column);

		for (int i=0; i<objects.size(); i++)
		{
			double dai = fmDist(objects, a[0], i, column);
			double dbi = fmDist(objects, b[0], i, column);

			_mappedObjects[i][column] = ((dai*dai) + (dab*dab) - (dbi*dbi)) / (dab * 2.0);
		}

		// 6) Consider the projections of the objects on a hyper-plane
		//    perpendicular to the line (Oa, Ob); the distance function D()
		//    between two projections is given by Eq. 4
		doMap(k-1, objects, column+1);
	}

	private void chooseDistantObjects(ArrayList<ArrayList<Double>> objects, int[] a, int[] b, int column)
	{
		/** The number of iterations to find the most distant objects */
		int num_iterations = 5;

		//
		// Choose arbitrarily an object, and let it be the second pivot object Ob
		//
		b[0] = 0; // Start with the first object, to avoid randomness, apart from anything else

		double last_distance = 0.0;  // A note of the distance this iteration has to beat
		for (int iteration = 0; iteration < num_iterations; iteration++)
		{
			//
			// let Oa = (the object that is farthest apart from Ob) (according to the distance function dist())
			//
			a[0] = b[0];
			double max_distance = 0.0;
			for (int n=0; n<objects.size(); n++)
			{
				double distance = fmDist(objects, b[0], n, column);
				if (distance > max_distance)
				{
					a[0] = n;
					max_distance = distance;
				}
			}


			//
			// let Ob = (the object that is farthest apart from 0a)
			//
			b[0] = a[0];
			max_distance = 0.0;
			for (int n=0; n<objects.size(); n++)
			{
				double distance = fmDist(objects, a[0], n, column);
				if (distance > max_distance)
				{
					b[0] = n;
					max_distance = distance;
				}
			}


			//
			// Ensure each iteration is increasing the distance, stop if it isn't.
			// If this happens, it is probably due to the same two objects being
			// selected each time round the loop.
			//
			if (max_distance > last_distance)
				last_distance = max_distance;
			else
			{
				break;
			}
		}
	}

	private double fmDist(ArrayList<ArrayList<Double>> objects, int a, int b, int column)
	{
		if (column == 0)
		{
			switch(_distanceFn)
			{
			case EUCLIDEAN:
				return euclideanDistance(objects.get(a), objects.get(b));
			case BASIC:
				return basicDistance(objects.get(a), objects.get(b));
			case SUM:
				return sumOfSquaresDistance(objects.get(a), objects.get(b));
			default:
				return euclideanDistance(objects.get(a), objects.get(b));
			}
		}
		else
		{
			// p must be > 0, so xa and xb must be in the p-1 column
			double d=fmDist(objects, a, b, column-1);
			double xa=_mappedObjects[a][column-1];
			double xb=_mappedObjects[b][column-1];
			return Math.sqrt( (d*d) - ((xa-xb)*(xa-xb)) );
		}
	}
	
	private double fmDistNew(ArrayList<Double> a, ArrayList<Double> b, int column)
	{
		if (column == 0)
		{
			switch(_distanceFn)
			{
			case EUCLIDEAN:
				return euclideanDistance(a, b);
			case BASIC:
				return basicDistance(a, b);
			case SUM:
				return sumOfSquaresDistance(a, b);
			default:
				return euclideanDistance(a, b);
			}
		}
		else
		{
			// p must be > 0, so xa and xb must be in the p-1 column
			double d = fmDistNew(a, b, column-1);
			double xa= a.get(column-1);
			double xb= b.get(column-1);
			return Math.sqrt( (d*d) - ((xa-xb)*(xa-xb)) );
		}
	}

	private double euclideanDistance(ArrayList<Double> a, ArrayList<Double> b)
	{
		double diff = 0;

		for (int n=0; n<a.size(); n++)
			diff += (a.get(n) - b.get(n)) * (a.get(n) - b.get(n));

		return Math.sqrt(diff);
	}

	private double sumOfSquaresDistance(ArrayList<Double> a, ArrayList<Double> b)
	{
		double diff = 0;

		for (int n=0; n<a.size(); n++)
			diff += (a.get(n) - b.get(n)) * (a.get(n) - b.get(n));

		return diff;
	}

	private double basicDistance(ArrayList<Double> a, ArrayList<Double> b)
	{
		double diff = 0;

		for (int n=0; n<a.size(); n++)
			diff += Math.abs(a.get(n) - b.get(n));

		return diff / a.size();
	}

}
