#include "SM_CommonFXPCH.h"
#include "SM_FVF.h"
#include "fxmesh.h"
#include "SM_Vector3d.h"

typedef Vector3D vec3_t;
inline void Cross(vec3_t& o, const vec3_t& a, const vec3_t& b) { o = Vector3D::Cross(a,b); }
inline void Normalize(vec3_t& v) { v.Normalize(); }
inline float Dot(const vec3_t& a, const vec3_t& b) { return Vector3D::Dot(a, b);}
inline float Length(const vec3_t& v) { return v.Length(); }
inline float SqrLength(const vec3_t& v) { return v.SquaredLength(); }

const int c_nMinVerts = 32;
const int c_nMinIndices = 32;

//-----------------------------------------------------------------------------
// Name: Init
// Desc: 
//-----------------------------------------------------------------------------
bool FXMesh::Init(int vertSize, DWORD fvf, bool bstatic)
{
	Destroy();
	m_static = bstatic;
	m_vertSize = vertSize;
	m_fvf = fvf;
	
	m_numverts = 0;
	m_numindices = 0;
	
	m_nIndices = c_nMinIndices;
	m_pIndices = new WORD[c_nMinIndices];
	m_nVerts = c_nMinVerts;
	m_pVerts = new byte[c_nMinVerts*m_vertSize];
	return true;
}

void FXMesh::ResetIndices()
{
	m_numindices = 0;
}

//-----------------------------------------------------------------------------
// Name: ReserveVerts
// Desc: 
//-----------------------------------------------------------------------------
void FXMesh::ReserveVerts(int n)
{
	byte* t = new byte[m_vertSize*n];
	if (m_numverts)
		memcpy(t, m_pVerts, min(m_numverts,n)*m_vertSize);
	delete[] m_pVerts;
	m_pVerts = t;
	m_nVerts = n;
}

//-----------------------------------------------------------------------------
// Name: ResizeVerts
// Desc: 
//-----------------------------------------------------------------------------
void FXMesh::ResizeVerts(int n)
{
	ReserveVerts(n);
	m_numverts = n;
}

//-----------------------------------------------------------------------------
// Name: Destroy
// Desc: 
//-----------------------------------------------------------------------------
void FXMesh::Destroy()
{
  if (m_vb) { m_vb->Release(); m_vb = 0; }
  if (m_ib) { m_ib->Release(); m_ib = 0; }

  delete[] m_pIndices;  m_pIndices = 0;
	delete[] m_pVerts; m_pVerts = 0;
}


//-----------------------------------------------------------------------------
// Name: AddVert
// Desc: add a new vertex to the mesh
//-----------------------------------------------------------------------------
int FXMesh::AddVert(void* pv)
{
	if (m_vb)
	{
		ASSERT(m_pVerts);
		ASSERT(m_numverts < m_nVerts);
		memcpy(&m_pVerts[m_numverts*m_vertSize], pv, m_vertSize);
		m_numverts++;
		return m_numverts-1;
	}
	else
	{
		if (m_numverts+1 >= m_nVerts)
		{
			byte* pt = new byte[m_vertSize*m_nVerts*2];
			memcpy(pt, m_pVerts, m_vertSize*m_nVerts);
			delete[] m_pVerts;
			m_pVerts = pt;
			m_nVerts *= 2;
		}

		memcpy(&m_pVerts[m_numverts*m_vertSize], pv, m_vertSize);
		m_numverts++;
		return m_numverts-1;
	}
}


//-----------------------------------------------------------------------------
// Name: AddIndex
// Desc: add a new index to the mesh
//-----------------------------------------------------------------------------
void FXMesh::AddIndex(WORD ind)
{
	if (m_numindices+1 >= m_nIndices)
	{
		WORD* pt = new WORD[m_nIndices*2];
		memcpy(pt, m_pIndices, sizeof(WORD)*m_nIndices);
		delete[] m_pIndices;
		m_pIndices= pt;
		m_nIndices *= 2;
	}

	m_pIndices[m_numindices++] = ind;
}

//-----------------------------------------------------------------------------
// Name: Render
// Desc: renders this mesh
//-----------------------------------------------------------------------------
void FXMesh::Render()
{
	HRESULT hr;

    int numtri;
	switch (m_pt)
	{
	case D3DPT_TRIANGLELIST:
		numtri = m_numindices/3;
		break;
	case D3DPT_TRIANGLESTRIP:
		numtri = m_numindices-2;
		break;
	case D3DPT_LINESTRIP:
		numtri = m_numindices-1;
		break;
	case D3DPT_POINTLIST:
		numtri = m_numverts;
		break;
	case D3DPT_LINELIST:
		numtri = m_numverts/2;
		break;
	default:
		ASSERT(!"unknown prim type");
		return;
	}
	
	if (m_vb)
	{
		hr = SM_D3d::Device()->SetVertexShader(m_fvf);
		hr = SM_D3d::Device()->SetStreamSource(0, m_vb, m_vertSize);
		if (m_pt != D3DPT_POINTLIST)
		{
			hr = SM_D3d::Device()->SetIndices(m_ib, 0);
			hr = SM_D3d::Device()->DrawIndexedPrimitive(m_pt, 0, m_numverts, 0, numtri); 
		}
		else
		{
			hr = SM_D3d::Device()->DrawPrimitive(m_pt, 0, numtri);
		}
	}
	else
	{
		hr = SM_D3d::Device()->SetVertexShader(m_fvf);
		if (m_pt != D3DPT_POINTLIST)
		{
#if 1
			// check all the indices to make sure they're valid before
			// sending them to nvidia's crappy drivers
			for (int i=0; i<m_numindices; i++)
			{
				ASSERT(m_pIndices[i] < m_numverts);
			}
#endif


			hr = SM_D3d::Device()->DrawIndexedPrimitiveUP(m_pt,
												 0,
												 m_numverts,
												 numtri,
												 m_pIndices,
												 D3DFMT_INDEX16,
												 m_pVerts,
												 m_vertSize);	
		}
		else
		{
			hr = SM_D3d::Device()->DrawPrimitiveUP(m_pt, numtri, m_pVerts, m_vertSize);
		}
	}
}


//-----------------------------------------------------------------------------
// Name: LockVerts
// Desc: 
//-----------------------------------------------------------------------------
void FXMesh::LockVerts(bool resetNumVerts)
{
	if (m_vb)
	{
		HRESULT hr = m_vb->Lock(0, m_numverts*m_vertSize, &m_pVerts, D3DLOCK_DISCARD);
		ASSERT(SUCCEEDED(hr));
		m_nVerts = m_numverts;
		if (resetNumVerts)
			m_numverts = 0;
	}
	else
	{
		if (resetNumVerts)
			m_numverts = 0;
	}
}


//-----------------------------------------------------------------------------
// Name: UnlockVerts
// Desc: 
//-----------------------------------------------------------------------------
void FXMesh::UnlockVerts()
{
	if (m_vb)
	{
		m_vb->Unlock();
		m_pVerts=0;
	}
}

//-----------------------------------------------------------------------------
// Name: Optimize
// Desc: 
//-----------------------------------------------------------------------------
void FXMesh::Optimize(bool keepIndices)
{
	HRESULT hr;
	ASSERT(!m_vb && !m_ib);

	// make index buffer
	if (m_numindices)
	{
		hr = SM_D3d::Device()->CreateIndexBuffer(m_numindices*sizeof(WORD),
			D3DUSAGE_WRITEONLY,
			D3DFMT_INDEX16,
			D3DPOOL_DEFAULT,
			&m_ib);
		ASSERT((hr == D3D_OK) && (m_ib));	

		// copy in indices
		WORD* pi;
		hr = m_ib->Lock(0, sizeof(WORD)*m_numindices, (byte**)&pi, 0);
		ASSERT(SUCCEEDED(hr));
		memcpy(pi, m_pIndices, m_numindices*sizeof(WORD));
		m_ib->Unlock();

		if (!keepIndices) { delete[] m_pIndices; m_pIndices=0; }
	}

	// make vertex buffer
	hr = SM_D3d::Device()->CreateVertexBuffer( m_numverts*m_vertSize,
		m_static ? D3DUSAGE_WRITEONLY : D3DUSAGE_WRITEONLY | D3DUSAGE_DYNAMIC,
		m_fvf,
		D3DPOOL_DEFAULT,
		&m_vb);
	ASSERT((hr == D3D_OK) && (m_vb));	

	// copy in verts
	byte* pb;
	hr = m_vb->Lock(0, m_numverts, &pb, 0);
	ASSERT(SUCCEEDED(hr));
	memcpy(pb, m_pVerts, m_numverts*m_vertSize);
	m_vb->Unlock();

	delete[] m_pVerts;   m_pVerts=0;
}


//-----------------------------------------------------------------------------
// Name: TriStripIndices
// Desc: 
//-----------------------------------------------------------------------------
void FXMesh::TriStripIndices(int w, int h)
{
	PrimType() = D3DPT_TRIANGLESTRIP;
#if 1
	for (int v=0; v<h-1; v++)
	{
		for (int u=0; u<w; u++)
		{
			AddIndex(u+v*w);
			AddIndex(u+(v+1)*w);
		}
		if (v<h-2)
		{
			AddIndex(u-1+(v+1)*w);
			AddIndex(0+(v+1)*w);
		}
	}
#else
	for (int v=0; v<h-1; v++)
	{
		if (v%2==0)
		{
			for (int u=0; u<w; u++)
			{
				AddIndex(u+v*w);
				AddIndex(u+(v+1)*w);
			}
			AddIndex(w-1+(v+1)*w);
		}
		else
		{
			for (int u=w-1; u>=0; u--)
			{
				AddIndex(u+v*w);
				AddIndex(u+(v+1)*w);
			}
			AddIndex(0+(v+1)*w);
		}
	}
#endif
}

//-----------------------------------------------------------------------------
// Name: MakeAddNormals
// Desc: 
//-----------------------------------------------------------------------------
template<typename T>
void MakeAddNormals(T& v0, T& v1, T& v2)
{
	vec3_t u = v2.p-v0.p;
	vec3_t v = v1.p-v0.p;
	vec3_t n;
	
	Cross(n, u, v);
	//Normalize(&n, &n);
	v0.n += n;
	v1.n += n;
	v2.n += n;
}

//-----------------------------------------------------------------------------
// Name: GenNormals
// Desc: 
//-----------------------------------------------------------------------------
template<typename T>
void GenNormals(FXMesh& mesh, int w, int h, bool fixSeam)
{
	for (int y=0; y<h-1; y++)
	{
		for (int x=0; x<w-1; x++)
		{
			{
			T& v0 = *(T*)mesh.GetVert( x+y*w );
			T& v1 = *(T*)mesh.GetVert( x+1+(y+1)*w );
			T& v2 = *(T*)mesh.GetVert( x+(y+1)*w );
			MakeAddNormals(v0, v1, v2);
			}

			{
			T& v0 = *(T*)mesh.GetVert( x+y*w );
			T& v1 = *(T*)mesh.GetVert( x+1+y*w );
			T& v2 = *(T*)mesh.GetVert( x+1+(y+1)*w );
			MakeAddNormals(v0, v1, v2);
			}
		}
	}

	if (fixSeam)
	{
		for (int v=0; v<h; v++)
		{
			T& a = *(T*)mesh.GetVert( v*w );
			T& b = *(T*)mesh.GetVert( v*w+w-1 );

			const vec3_t n1 = a.n;
			const vec3_t n2 = b.n;
			a.n += n2;
			b.n += n1;
		}
#if 1
		for (int u=0; u<w; u++)
		{
			T& a = *(T*)mesh.GetVert( u );
			T& b = *(T*)mesh.GetVert( u+(h-1)*w );

			const vec3_t n1 = a.n;
			const vec3_t n2 = b.n;
			a.n += n2;
			b.n += n1;
		}
#endif
	}

  for (y=0; y<h; y++)
	{
		for (int x=0; x<w; x++)
		{
      T& v0 = *(T*)mesh.GetVert( x+y*w );
      v0.n.Normalize();      
    }
  }
}


//-----------------------------------------------------------------------------
// Name: GenNormals
// Desc: 
//-----------------------------------------------------------------------------
void FXMesh::GenNormals(int w, int h, bool fixSeam)
{
	ASSERT(m_pVerts);

	if (m_fvf == ntxvertex_fvf)
	{
		::GenNormals<ntxvertex>(*this, w, h, fixSeam);
	}
	else if (m_fvf == cntxvertex_fvf)
	{
		::GenNormals<cntxvertex>(*this, w, h, fixSeam);
	}
	else
	{
		ASSERT(!"Unsupported fvf for gennormals!");
	}
}
