//-----------------------------------------------------------------------------
// FXSpikeTube
//-----------------------------------------------------------------------------
#include "SM_CommonFXPCH.h"
#include "SM_Engine3DPCH.h"
#include "SM_DemoEffect.h"
#include "SM_FVF.h"
#include "SM_Shader.h"
#include "fxmesh.h"
#include "FXParamParser.h"

//#define TRI_STRIP
//#define SHOW_NORMALS
//#define DBL_BUFFER

using namespace ShaderManager;

//-----------------------------------------------------------------------------
// Types and Macros
//-----------------------------------------------------------------------------
typedef Vector3D vec3_t;


//-----------------------------------------------------------------------------
// Constants
//-----------------------------------------------------------------------------
const float c_pi = 3.1415926535897932384626433f;
const float c_2pi = 2.0f*3.1415926535897932384626433f;

namespace
{
inline void Cross(vec3_t& o, const vec3_t& a, const vec3_t& b) { o = Vector3D::Cross(a,b); }
inline void Normalize(vec3_t& v) { v.Normalize(); }

void LookAt(const vec3_t& pos, const vec3_t& at, const vec3_t& up)
{
	Matrix4X4 mat;
	mat.LookAt(pos, at, up);
	SM_D3d::Device()->SetTransform(D3DTS_VIEW, (D3DMATRIX*)&mat);
}

void SetWireFrame(bool wire)
{
	SM_D3d::Device()->SetRenderState(D3DRS_FILLMODE, wire?D3DFILL_WIREFRAME:D3DFILL_SOLID);
}


SM_DemoEffect::Helper LoadHelp[] =
{
	{"shader=metaball radius=10 level=1 numCells=8 numRings=7 spikeSize=10 camVel=24 camHeight=10", 
		"shader: what to shade tube with\r\n"
		"radius: base radius of tube\r\n"
		"level: higher levels = many more triangles\r\n"
		"numCells: how many spikes in one segment\r\n"
		"numRings: resolution of spike\r\n"
		"spikeSize: maximum length of spike\r\n"
		"camVel: speed of camera\r\n"
		"camHeight: height of camera above x axis\r\n"
	},

	{"shader=metaball radius=10 level=1 numCells=8 numRings=7 spikeSize=10 camVel=24 camHeight=10", "Default"},
	{"shader=green radius=2 spikeSize=20 numRings=10", "Weirdness"},
	{"shader=metaball radius=20 level=2 numCells=8 numRings=7 spikeSize=4 camVel=14 camHeight=5", "Scales"},
	{"shader=metaball radius=10 level=1 numCells=8 numRings=8 spikeSize=-10 camVel=40 camHeight=-4", "Honeycomb"},
};

SM_DemoEffect::Helper CommandsHelp[] =
{
	{"",""}
/*
	{"SETSPOKES n delay","Set number of ridges (spokes) on tunnel wall, transition takes 'delay' seconds"},
	{"SETINNER radius delay", "Set inner radius of tunnel"},
	{"SETOUTER radius delay", "Set outer radius of tunnel"},
	{"TIMEMULT multiplier delay", "Make time move faster (>1) or slower (<1)"}
*/
};

}

//-----------------------------------------------------------------------------
// FXSpikeTube
//-----------------------------------------------------------------------------
class FXSpikeTube : public SM_DemoEffect
{
private:
	FXMesh	m_mesh;
	FXMesh	m_rmesh;
	int		m_shader;
	int		m_numPerSide;
	int		m_n;
	int		m_numCells;
	int		m_numRings;
	int		m_numSegments;
	int		m_res;
	float	m_segRadius;
	float	m_spikeSize;

	float m_camX;		// min x coord of cam window
	float m_camVel;		// speed of cam in units/sec
	float m_camHeight;

#ifdef SHOW_NORMALS
	std::vector<cvertex>	m_normals;
#endif

public:              
	FXSpikeTube(char const* pcName) 
		: SM_DemoEffect(pcName)
		, m_numCells(8)
		, m_numRings(7)
		, m_segRadius(10)
		, m_spikeSize(10)
		, m_numSegments(11)
		, m_n(1)
		, m_camX(0)
		, m_camVel(24)
		, m_camHeight(10)
	{
		m_numPerSide = 2*m_n+1;
		m_res = 4*m_numPerSide-4;
	}

	virtual ~FXSpikeTube()
	{
	}

	int LoadArgumentsHelp     (Helper*& pHelpers)
	{
		pHelpers = LoadHelp;
		return (sizeof(LoadHelp)/sizeof(Helper));
	}

	int CommandArgumentsHelp  (Helper*& pHelpers)
	{
		pHelpers = CommandsHelp;
		return (sizeof(CommandsHelp)/sizeof(Helper));
	}

	int Init(const char* pcCommand)
	{
		int   iReturn=0;

		FXParamParser parse;
		parse.Parse(pcCommand);

		const char* szShader = parse.GetStr("shader", "green");

		m_n = parse.GetInt("level", 1);
		m_numCells = parse.GetInt("numCells", 8);
		m_numRings = parse.GetInt("numRings", 7);
		m_spikeSize = parse.GetFloat("spikeSize", 10);
		m_camVel = parse.GetFloat("camVel", 24);
		m_segRadius = parse.GetFloat("radius", 10);
		m_camHeight = parse.GetFloat("camHeight", 10);

		m_numPerSide = 2*m_n+1;
		m_res = 4*m_numPerSide-4;


		m_mesh.Init(sizeof(nvertex), nvertex_fvf, false);
		m_mesh.PrimType() = D3DPT_TRIANGLELIST;
		UpdateMesh(0);
		BuildIndices2();

#ifdef DBL_BUFFER
		{
			m_rmesh.Init(sizeof(nvertex), nvertex_fvf, false);
			m_rmesh.PrimType() = D3DPT_TRIANGLELIST;
			for (int i=0; i<m_mesh.NumIndices(); i++)
				m_rmesh.AddIndex(m_mesh.GetIndex(i));
			for (int j=0; j<m_mesh.NumVerts(); j++)
				m_rmesh.AddVert(m_mesh.GetVert(j));
	
			m_rmesh.Optimize(false);
		}
#endif


		//
		// load shader
		//
		m_shader = LoadShader(szShader);
		if (m_shader == -1)
		{
			::MessageBox(NULL, "Could not load shader", szShader, MB_OK);
			iReturn = -1;
			goto FAILED;
		}

		
	FAILED:

		return iReturn;        
	}

	int Shutdown()
	{
		m_mesh.Destroy();
		return 0;
	}

	int Start(float fTime)
	{
		return 0;
	}

	int Stop()
	{
		return 0;
	}

	int Reset()
	{
		return 0;
	}

	int Run(float time)
	{
		UpdateMesh(time);

#ifdef DBL_BUFFER
		m_rmesh.LockVerts(false);
		memcpy(m_rmesh.GetVert(0), m_mesh.GetVert(0), sizeof(nvertex)*m_mesh.NumVerts());
		m_rmesh.UnlockVerts();
#endif
		
		RenderContext RC;   
		RC.Set(Vector3D(m_camX, 0.0f, -2.5f),
			   Quaternion(1.0f, 0.0f, 0.0f, 0.0f),
			   75,
			   0.75f,
			   1.0f,
			   200.0f);

	    RC.SetViewport(0, 0, 640.0f, 480.0f);  
	  
	    RC.SyncRasterizer();
	    RC.UpdateFrustum();
		ShaderManager::SetRenderContext(&RC);

		// does rc have to be updated?  most likely, but best
		// way to do that?
		const float r = 40.0f;
//		LookAt(vec3_t(r*cos(time), 2, r*sin(time)), vec3_t(0,0,0), vec3_t(0,1,0));
		const float t = time * 1.1f;

		const float x = 0;
		const vec3_t up(0,-sin(t),cos(t));
		LookAt(vec3_t(x,r*cos(t),r*sin(t)), vec3_t(x,0,0)+up*m_camHeight, up);

		//SetWireFrame(true);

		Shader* ps = GetShader(m_shader);
		if (ps)
		{
			for (int i=0; i<ps->m_uPasses; i++)
			{
				ps->SetShaderState(i);
#ifdef DBL_BUFFER
				m_rmesh.Render();
#else
				m_mesh.Render();
#endif
			}
		}

#ifdef SHOW_NORMALS
		SM_D3d::Device()->SetTexture(0, 0);
		SM_D3d::Device()->SetTextureStageState(0, D3DTSS_TEXTURETRANSFORMFLAGS, D3DTTFF_DISABLE);
		SM_D3d::Device()->SetVertexShader(cvertex_fvf);
		SM_D3d::Device()->DrawPrimitiveUP(D3DPT_LINELIST, m_normals.size()/2,
			&m_normals[0], sizeof(cvertex));
#endif

		return 1;
	}

	int Command(float fTime, const char* pcCommand)
	{
		return 0;
	}  

private:

	void UpdateMesh(float time)
	{
		m_mesh.LockVerts();

		m_camX = time * m_camVel;

		// build skeleton
		BuildSkeleton(time);

#if 1
		// build normals from mesh
		for (int i=0; i<m_mesh.NumIndices(); )
		{
			WORD i0 = m_mesh.GetIndex(i++);
			WORD i1 = m_mesh.GetIndex(i++);
			WORD i2 = m_mesh.GetIndex(i++);
			MakeAddNormal(i0,i1,i2);
		}

		// normalize each vert's normal
		for (int n=0; n< m_mesh.NumVerts(); n++)
		{
			ntxvertex& v = *(ntxvertex*)m_mesh.GetVert(n);
			Normalize(v.n);
		}
#endif

#ifdef SHOW_NORMALS
		m_normals.clear();
		for (int i=0; i<m_mesh.NumVerts(); i++)
		{
			ntxvertex& v = *(ntxvertex*)m_mesh.GetVert(i);
			m_normals.push_back(cvertex(v.p, 0xff0000));
			m_normals.push_back(cvertex(v.p + v.n * 2.0f, 0xffffffff));
		}
#endif

		m_mesh.UnlockVerts();
	}

	inline float GetSpikeScale(float x, float time) 
	{ 
		//float t = (cos(time*.12f+cos(time*.21212f))+cos(time*1.182123f))*.25f+.5f;
		//float t = 1;//.9f + .1f*(sin(time*1.1212f)+sin(time*2.1313f));
		//if (time < 3) t *= (time/3);
		//return t*t; 
		
		//return (.2f+.8f*(2.0f+( 1.2*sin(x*.165f+time*c_pi) + .8*sin(x*.165f+time*c_pi*.5f) ))*.25f);
		//return .75f*(.2f+.8f*(2.0f+( 1.2*sin(x*.165f+time*c_pi) + .8*sin(x*.165f+time*c_pi*.5f) ))*.25f) + .25f*(.5f*(sin(x*.23434f)+1));
		float t = GetRadius(x, time) / m_segRadius;
		return t*t;
	}
	inline int GetOfs(int seg) { return (seg%2)*(m_numPerSide/2); } 
	inline float GetRadius(float x, float time) 
	{ 
		//return .8f*((.3f+.7f*(sin(x*.165f+time*c_pi)*.5f+.5f)) + .4f* (sin((x+time)*.23432f+time*.12312f)+1.0f)+.2f*(sin(x*.0224f+time*.012312f)+1)  )*m_segRadius; 
		//return m_segRadius * (.1f+.9f*(3.0f+(sin(x*.165f+time*c_pi) + sin(time*.6234234f) + sin(x*0.085f+time*c_pi*2.0f)))/6.0f);
		//return m_segRadius*(.2f+.6f*((.3f+.7f*(sin(x*.165f+time*c_pi)*.5f+.5f)) + .15f* (sin((x+time)*.23432f+time*.12312f)+1.0f)+.2f*(sin(x*.0224f+time*.012312f)+1)  )); 
		//return m_segRadius * (.2f+.8f*(2.0f+( 1.2*sin(x*.165f+time*c_pi) + .8*sin(x*.165f+time*c_pi*.5f) ))*.25f);
		//return m_segRadius * GetSpikeScale(x, time);
		return m_segRadius*(.75f*(.2f+.8f*(2.0f+( 1.4*sin(x*.165f+time*c_pi) + .8*sin(x*.265f+time*c_pi*.4215f) ))*.25f) + .25f*(.6f*(sin(x*.13434f)+1)));
	}
//	inline float GetVertOfs(float x, float time) { return sin(x*.125) * 6; }


	void BuildIndices2()
	{
		const int segRingSize = m_numCells*m_numPerSide-m_numCells;
		int numPerCell = m_numPerSide-2 + (m_numRings-1)*m_res; // column + rings
//		int numPerRing = segRingSize + m_numCells*numPerCell;
		int numPerSeg = m_numCells*numPerCell + segRingSize;

		ASSERT(m_numSegments*numPerSeg+segRingSize == m_mesh.NumVerts());

		std::vector<WORD> outerRing(m_res);
		int k=0;
		
		for (int seg=0; seg<m_numSegments; seg++)
		{
			const int ofs = GetOfs(seg);
			for (int c=0; c<m_numCells; c++)
			{
				int o = (ofs + c*(m_numPerSide-1))%segRingSize;

				// build outer ring
				k=0;

				// top
				int i;
				for (i=0; i<m_numPerSide; i++)
					outerRing[k++] = seg * numPerSeg + (o+i)%segRingSize;
				
				// right side
				int tro = seg * numPerSeg + segRingSize + ((c+1)%m_numCells)*numPerCell;
				for (i=0; i<m_numPerSide-2; i++)
					outerRing[k++] = tro+i;

				// bottom
				int tbo = (ofs + c*(m_numPerSide-1))%segRingSize;
				for (i=0; i<m_numPerSide; i++)
					outerRing[k+m_numPerSide-i-1] = (seg+1) * numPerSeg + (tbo+i)%segRingSize;
				k+=m_numPerSide;

				// left side
				int tlo = seg * numPerSeg + segRingSize + c*numPerCell;
				for (i=0; i<m_numPerSide-2; i++)
					outerRing[k+(m_numPerSide-2)-i-1] = tlo+i;
				k+=m_numPerSide-2;

				ASSERT(k <= m_res);

#if 0
				for (int t=0; t<m_res; t++)
				{
					char tc[128];
					sprintf(tc, "%d, ", outerRing[t]);
					OutputDebugString(tc);
				}
#endif

				int v0 = seg * numPerSeg + segRingSize + c*numPerCell + m_numPerSide-2;

				// stitch rings to skeloton
				for (int x=0; x<m_res; x++)
				{
					int i0 = outerRing[x];
					int i1 = outerRing[(x+1)%m_res];
					int i2 = v0 + x;
					int i3 = v0 + (x+1)%m_res;

					ASSERT(i0 < m_mesh.NumVerts());
					ASSERT(i1 < m_mesh.NumVerts());
					ASSERT(i2 < m_mesh.NumVerts());
					ASSERT(i3 < m_mesh.NumVerts());

					m_mesh.AddIndex(i0);
					m_mesh.AddIndex(i3);
					m_mesh.AddIndex(i1);

					m_mesh.AddIndex(i0);
					m_mesh.AddIndex(i2);
					m_mesh.AddIndex(i3);
				}

				for (int y=0; y<m_numRings-2; y++)
				{
					for (int x=0; x<m_res; x++)
					{
						const int i0 = v0 + y*m_res + x;
						const int i1 = v0 + y*m_res + (x+1)%m_res;
						const int i2 = v0 + (y+1)*m_res + x;
						const int i3 = v0 + (y+1)*m_res + (x+1)%m_res;

						ASSERT(i0 < m_mesh.NumVerts());
						ASSERT(i1 < m_mesh.NumVerts());
						ASSERT(i2 < m_mesh.NumVerts());
						ASSERT(i3 < m_mesh.NumVerts());

						m_mesh.AddIndex(i0);
						m_mesh.AddIndex(i3);
						m_mesh.AddIndex(i1);

						m_mesh.AddIndex(i0);
						m_mesh.AddIndex(i2);
						m_mesh.AddIndex(i3);
					}
				}
			}
		}
	}

	void BuildSkeleton(float time)
	{
		const float theta = c_2pi / m_numCells;
		const float h = 8;//sqrtf(2*m_segRadius*m_segRadius - 2*m_segRadius*m_segRadius*cos(theta));
		const float th = h * m_numSegments;

		//float startX = h*(int(m_camX/h)-int(m_camX/h)%2) - (m_numSegments/2)*h;
		float startX = fmodf(-m_camX, h*2)-(m_numSegments/2)*h;

		// add all rings
		for (int seg=0; seg<m_numSegments+1; seg++)
		{
			//const float fseg = float(seg)/float(m_numSegments-1);
			float x = startX + seg * h;
			AddSegRing(x, h, time);

			const int ofs = GetOfs(seg);
			if (seg != m_numSegments)
				AddColumnsForRing(x, h, ofs, time);
		}
	}

	void AddCell2(float a0, float a1, float midX, float time)
	{
		const float aMid = .5f*(a0+a1);
		float segRadius = 10;

		// find basis for this cell
		const vec3_t N(0, cos(aMid), sin(aMid));
		const vec3_t B(1,0,0);
		const vec3_t T(0, -sin(aMid), cos(aMid));
		//const vec3_t origin(segRadius*N);

		const vec3_t p0 = vec3_t(0, m_segRadius*cos(a0), m_segRadius*sin(a0));
		const vec3_t p1 = vec3_t(0, m_segRadius*cos(a1), m_segRadius*sin(a1));

		const float tLength = (p1-p0).Length();
		const vec3_t origin = .5f*(p1+p0);

		// precalc sins and coss for inner loop
		static std::vector<float> sinus;//(m_res);
		static std::vector<float> cosinus;//(m_res);
		if (!sinus.size())
		{
			sinus.resize(m_res);
			cosinus.resize(m_res);
			for (int i=0; i<m_res; i++)
			{
				float a = -float(i)/float(m_res)*c_2pi + 135.0f*c_pi/180;
				cosinus[i] = cos(a);
				sinus[i] = -sin(a);
			}
		}

		// for each ring
		for (int r=1; r<m_numRings; r++)
		{
			float tr = float(r)/float(m_numRings-1);
			float ringRadius = (1.0f-tr) * tLength/2;
			if (r == 0)
				ringRadius = tLength / sqrtf(2.0f);

			const float dispScale = GetSpikeScale(midX, time);//midX/63.232);
			const float normDisp = tr*tr*dispScale;

			const vec3_t spikeFactor = N*normDisp*m_spikeSize + vec3_t(midX + normDisp*normDisp*normDisp*m_spikeSize,0,0);

			// for each point on ring
			vec3_t p;
			for (int n=0; n<m_res; n++)
			{
				//float a = -float(n)/float(m_res)*c_2pi + 135.0f*c_pi/180;
				//float px = ringRadius*cos(a);
				//float py = -ringRadius*sin(a);

				float px = ringRadius*cosinus[n];
				float py = ringRadius*sinus[n];

				// transform by basis
#if 1
				p = origin + px*T;
				p = p/p.Length()*GetRadius(py*B.x+midX, time);
				p += py*B;
				p += spikeFactor;
				//vec3_t p = vec3_t(m_segRadius*cos(ainc), midY, m_segRadius*sin(ainc)) + py*B;// + N*normDisp*m_spikeSize + vec3_t(0,normDisp*normDisp*normDisp*m_spikeSize,0);
#else
				vec3_t p = origin + vec3_t(midX,0,0) + px*T + py*B + N*normDisp*m_spikeSize + vec3_t(0,normDisp*normDisp*normDisp*m_spikeSize,0);
#endif
				//p.y += GetVertOfs(p.x,time);
				m_mesh.AddVert(&ntxvertex(p, vec3_t(0,0,0), 0,0));
			}
		}	
	}

	void AddColumnsForRing(float x, float w, int ofs, float time)
	{
		for (int c=0; c<m_numCells; c++)
		{
			const int segRingSize = m_numCells*m_numPerSide-m_numCells;
			int n = (c*(m_numPerSide-1)+ofs);//%segRingSize;
			int n2 = ((c+1)*(m_numPerSide-1)+ofs);//%segRingSize;

			const float a = float(n)/float(segRingSize)*c_2pi;
			const float a2 = float(n2)/float(segRingSize)*c_2pi;

			vec3_t p(0, cos(a), sin(a));

			// add column 
			for (int i=1; i<m_numPerSide-1; i++)
			{
				float t = float(i)/float(m_numPerSide-1);
				const float segRadius = GetRadius(x+t*w, time);
				const float y = 0;//GetVertOfs(x+t*w,time);
				m_mesh.AddVert(&nvertex(vec3_t(x+t*w,y,0)+segRadius*p, vec3_t(0,0,0)));
			}

			// add cell
			AddCell2(a, a2, x + .5f*w, time);
		}
	}

	void AddSegRing(float x, float w, float time)
	{
		const int segRingSize = m_numCells*m_numPerSide-m_numCells;
		const float segRadius = GetRadius(x, time);
		for (int i=0; i<segRingSize; i++)
		{
			const float t = float(i)/float(segRingSize);
			const float a = t*c_2pi;

			m_mesh.AddVert(&nvertex(vec3_t(x, segRadius*cos(a)/*+GetVertOfs(x,time)*/, segRadius*sin(a)), vec3_t(0,0,0)));
		}
	}

	void MakeAddNormal(int i0, int i1, int i2)
	{
		ntxvertex& v0 = *(ntxvertex*)m_mesh.GetVert(i0);
		ntxvertex& v1 = *(ntxvertex*)m_mesh.GetVert(i1);
		ntxvertex& v2 = *(ntxvertex*)m_mesh.GetVert(i2);

		vec3_t u = v2.p-v0.p;
		vec3_t v = v1.p-v0.p;
		vec3_t n;
		
		Cross(n, u, v);
		//Normalize(n);

		v0.n += n;
		v1.n += n;
		v2.n += n;
	}
};

DEFINE_EFFECT(FXSpikeTube)
FXSpikeTube Efecto00("SPIKETUBE_00");
FXSpikeTube Efecto01("SPIKETUBE_01");

