#include "UnitTestSharpHLSL.fx"
#include "MatrixDecomposition.fxh"

/// AbsScale
void AbsScale_Identity()
{
	float2x2 Identity = {
		1, 0,
		0, 1,
	};
	
	float2 scale = SVD_AbsScaleOnly(Identity);
	
	CheckEqual(float2(1,1), scale);
}

void AbsScale_Simple()
{
	float2x2 Mat = {
		1, 2,
		-2, -1,
	};
	
	float2 scale = SVD_AbsScaleOnly(Mat);
	
	CheckEqual(float2(3,1), scale);
}

/// Eigenvector

void Eigenvector_Known()
{
	float2x2 Mat = {
		-10, 8,
		10, -1,
	};
	
	float2x2 Expected = {
		-0.8, 0.6,
		 0.6, 0.8,
	};
	
	float2x2 U = SVD_EigenvectorGivenEigenvalue(Mat, float2(7 * sqrt(5), 2 * sqrt(5)));
	
	CheckEqual(Expected, U);
}

void NormalizedNAN_1()
{
	float x = 0;
	float y = 0;
	
	float2 vec = float2(1, x/y);
	vec = normalize(vec);
	
	// NaNs disappear magically
	CheckEqual(float2(1,0), vec);
}

void NormalizedNAN_2()
{
	float x = 0;
	float y = 0;
	
	float2 vec = float2(x/y, 1);
	vec = normalize(vec);
	
	// NaNs disappear magically
	CheckEqual(float2(0,1), vec);
}

void SetNaNInfTo0()
{
	float x = 0;
	float y = 0;
	
	float2 vec = float2(x/y, 1/x);
	
	vec = step(vec, vec) * vec;
	
	CheckEqual(float2(0,0), vec);
}

void NormalNotSetTo0()
{
	float x = 3;
	float y = 1;
	
	float2 vec = float2(x/y, 1/y);
	
	vec = step(vec, vec) * vec;
	
	CheckEqual(float2(3,1), vec);
}

void Eigenvector_Identity()
{
	float2x2 Mat = {
		1, 0,
		0, 1,
	};
	
	float2x2 Expected = {
		0, 1,
		1, 0,
	};
	
	float2x2 U = SVD_EigenvectorGivenEigenvalue(Mat, float2(1,1));
	
	CheckEqual(Expected, U);
}

void CheckSVD(float2x2 Mat)
{
	float2x2 U;
	float2 scale;
	float2x2 V;
	
	SVD_FullDecompose(Mat, U, scale, V);
	
	float2x2 Scale = {
		scale.x, 0,
		0, scale.y,
	};
	
	float2x2 rebuiltMat = mul(U, mul(Scale, V));
	
	CheckEqual(Mat, rebuiltMat);
}

void SVDFull_Basic()
{
	float2x2 Mat = {
		-10, 8,
		10, -1,
	};
	
	CheckSVD(Mat);
}

void SVDFull_Basic2()
{
	float2x2 Mat = {
		1, 2,
		2, 1,
	};
	
	CheckSVD(Mat);
}

void SVDFull_Basic3()
{
	float2x2 Mat = {
		10, 100,
		7, 11,
	};
	
	CheckSVD(Mat);
}

void SVDFull_Basic4()
{
	float2x2 Mat = {
		-81, 55,
		83, 36,
	};
	
	CheckSVD(Mat);
}

void SVDFull_Basic5()
{
	float2x2 Mat = {
		71, -15,
		18, -37,
	};
	
	CheckSVD(Mat);
}

void SVDFull_Basic6()
{
	float2x2 Mat = {
		-33, -62,
		-50, -55,
	};
	
	CheckSVD(Mat);
}

void SVDFull_Basic7()
{
	float2x2 Mat = {
		24, -68,
		-67, -16,
	};
	
	CheckSVD(Mat);
}

void SVDFull_Diagonal()
{
	float2x2 Mat = {
		10, 0,
		0, 10,
	};
	
	CheckSVD(Mat);
}	

void SVDFull_Identity()
{
	float2x2 Mat = {
		1, 0,
		0, 1,
	};
	
	CheckSVD(Mat);
}