#include <comp.hpp>
// #define ROBERT

namespace ngcomp
{
  using namespace ngcomp;

  LinearForm ::
  LinearForm (const FESpace & afespace, 
	      const string & aname,
	      const Flags & flags)
    : NGS_Object(afespace.GetMeshAccess(), aname), fespace(afespace)
  {
    independent = 0;
    printelvec = 0;
  }

  LinearForm :: ~LinearForm ()
  { 
    for (int i = 0; i < parts.Size(); i++)
      delete parts[i];
  }


  void LinearForm :: PrintReport (ostream & ost)
  {
    ost << "on space " << GetFESpace().GetName() << endl
	<< "integrators: " << endl;
    for (int i = 0; i < parts.Size(); i++)
      ost << "  " << parts[i]->Name() << endl;
  }

  void LinearForm :: MemoryUsage (ARRAY<MemoryUsageStruct*> & mu) const
  {
    if (&GetVector())  
      {
	int olds = mu.Size();
	GetVector().MemoryUsage (mu);
	for (int i = olds; i < mu.Size(); i++)
	  mu[i]->AddName (string(" lf ")+GetName());
      }
  }


  template <class SCAL>
  void S_LinearForm<SCAL> :: Assemble (LocalHeap & lh)
  {
    if (independent)
      {
	AssembleIndependent(lh);
	return;
      }
      
    try
      {
	ma.PushStatus ("Assemble Vector");
	int i, j;

	AllocateVector();

	ARRAY<int> dnums;
	ElementTransformation eltrans;
	
	int ne = ma.GetNE();
	int nse = ma.GetNSE();

	bool hasbound = 0;
	bool hasinner = 0;
	
	for (j = 0; j < parts.Size(); j++)
	  {
	    if (parts[j] -> BoundaryForm())
	      hasbound = 1;
	    else
	      hasinner = 1;
	  }
	
        if (hasinner)
	  {
	    for (i = 0; i < ne; i++)
	      {
		if (i % 10 == 0)
		  cout << "\rassemble element " << i << "/" << ne << flush;
		ma.SetThreadPercentage ( 100.0*(i+1) / (ne+nse) );

		lh.CleanUp();

		const FiniteElement & fel = fespace.GetFE (i, lh);
		ma.GetElementTransformation (i, eltrans, lh);
		      
		fespace.GetDofNrs (i, dnums);
		
		for (j = 0; j < parts.Size(); j++)
		  {
		    if (parts[j] -> BoundaryForm()) continue;
		    if (!parts[j] -> DefinedOn (ma.GetElIndex (i))) continue;		    

		    FlatVector<TSCAL> elvec;
		  
		    parts[j] -> AssembleElementVector (fel, eltrans, elvec, lh);
                 
		    if (printelvec)
		      {
			testout->precision(8);

			(*testout) << "elnum= " << i << endl;
			(*testout) << "integrator " << parts[j]->Name() << endl;
			(*testout) << "dnums = " << endl << dnums << endl;
			(*testout) << "element-index = " << eltrans.GetElementIndex() << endl;
			(*testout) << "elvec = " << endl << elvec << endl;
		      }

		    fespace.TransformVec (i, false, elvec, TRANSFORM_RHS);
		    AddElementVector (dnums, elvec);
		  }
	      }
	    cout << "\rassemble element " << ne << "/" << ne << endl;
	  }


	if (hasbound)
	  {
	    for (i = 0; i < nse; i++)
	      {
		if (i % 100 == 0)
		  cout << "\rassemble surface element " << i << "/" << nse << flush;
		ma.SetThreadPercentage ( 100.0*(ne+i+1) / (ne+nse) );

		lh.CleanUp();
	      
		const FiniteElement & fel = fespace.GetSFE (i, lh);
	      
		ma.GetSurfaceElementTransformation (i, eltrans, lh);
		fespace.GetSDofNrs (i, dnums);
	      
		for (j = 0; j < parts.Size(); j++)
		  {
		    if (!parts[j] -> BoundaryForm()) continue;
		    if (!parts[j] -> DefinedOn (ma.GetSElIndex (i))) continue;		    
		  
		    FlatVector<TSCAL> elvec;
		    parts[j] -> AssembleElementVector (fel, eltrans, elvec, lh);

		    if (printelvec)
		      {
			testout->precision(8);

			(*testout) << "surface-elnum= " << i << endl;
			(*testout) << "integrator " << parts[j]->Name() << endl;
			(*testout) << "dnums = " << endl << dnums << endl;
			(*testout) << "element-index = " << eltrans.GetElementIndex() << endl;
			(*testout) << "elvec = " << endl << elvec << endl;
		      }



		    fespace.TransformVec (i, true, elvec, TRANSFORM_RHS);
		    AddElementVector (dnums, elvec);
		  }
	      }
	    cout << "\rassemble surface element " << nse << "/" << nse << endl;	  
	  }

	ma.PopStatus ();
	// (*testout) << "Linearform, vec = " << endl << GetVector() << endl;
      }
    catch (Exception & e)
      {
	stringstream ost;
	ost << "in Assemble LinearForm" << endl;
	e.Append (ost.str());
	throw;
      }
    catch (exception & e)
      {
	throw (Exception (string(e.what()) +
			  string("\n in Assemble LinearForm\n")));
      }
  }








  template <class SCAL>
  void S_LinearForm<SCAL> :: AssembleIndependent (LocalHeap & lh)
  {
    try
      {
	int i, j, k;
	
	AllocateVector();

	// int ne = ma.GetNE();
	int nse = ma.GetNSE();
	
	
	ARRAY<int> dnums;
	ElementTransformation seltrans, geltrans;

	for (i = 0; i < nse; i++)
	  {
	    lh.CleanUp();
	    
	    const FiniteElement & sfel = fespace.GetSFE (i, lh);
	    ma.GetSurfaceElementTransformation (i, seltrans, lh);
	      	
	    // (*testout) << "el = " << i << ", ind = " << ma.GetSElIndex(i) << endl;
	    if (!parts[0]->DefinedOn (ma.GetSElIndex(i))) continue;
	    // (*testout) << "integrate surf el " << endl;
	    
	    const IntegrationRule & ir =
	      GetIntegrationRules().SelectIntegrationRule (sfel.ElementType(), 5);
	    
	    for (j = 0; j < ir.GetNIP(); j++)
	      {
		const IntegrationPoint & ip = ir.GetIP(j);
		SpecificIntegrationPoint<2,3> sip(ip, seltrans, lh);
		
		// (*testout) << "point = " << sip.GetPoint() << endl;
		
		IntegrationPoint gip;
		int elnr;
		elnr = ma.FindElementOfPoint (FlatVector<>(sip.GetPoint()), gip, 1);
		
		// (*testout) << "elnr = " << elnr << endl;
		if (elnr == -1) continue;
		
		const FiniteElement & gfel = fespace.GetFE (elnr, lh);
		ma.GetElementTransformation (elnr, geltrans, lh);
		SpecificIntegrationPoint<3,3> gsip(gip, geltrans, lh);
		
		// (*testout) << " =?= p = " << gsip.GetPoint() << endl;

		fespace.GetDofNrs (elnr, dnums);
		
		for (k = 0; k < parts.Size(); k++)
		  {
		    FlatVector<TSCAL> elvec;
		    parts[k] -> AssembleElementVectorIndependent
		      (gfel, seltrans, ip, geltrans, gip, elvec, lh);
		
		    // (*testout) << "elvec, 1 = " << elvec << endl;

		    elvec *= fabs (sip.GetJacobiDet()) * ip.Weight();
		    fespace.TransformVec (elnr, 0, elvec, TRANSFORM_RHS);

		    // (*testout) << "final vec = " << elvec << endl;
		    // (*testout) << "dnums = " << dnums << endl;
		    AddElementVector (dnums, elvec);
		  }
	      }
	  }
	// (*testout) << "Assembled vector = " << endl << GetVector() << endl;
      }
    
    catch (Exception & e)
      {
	stringstream ost;
	ost << "in Assemble LinearForm Independent" << endl;
	e.Append (ost.str());
	throw;
      }
    catch (exception & e)
      {
	throw (Exception (string(e.what()) +
			  string("\n in Assemble LinearForm Independent\n")));
      }
  }














  template class S_LinearForm<double>;
  template class S_LinearForm<Complex>;




  template <typename TV>
  T_LinearForm<TV> ::
  T_LinearForm (const FESpace & afespace, const string & aname, const Flags & flags)
    : S_LinearForm<TSCAL> (afespace, aname, flags), vec(0) 
  { 
    ; 
  }
  
  template <typename TV>
  T_LinearForm<TV> :: ~T_LinearForm ()
  {
    delete vec;
  }


  template <typename TV>
  void T_LinearForm<TV> :: AllocateVector ()
  {
    delete vec;
    vec = new ngla::VVector<TV> (this->fespace.GetNDof());
    (*vec) = TSCAL(0);
  }


  template <typename TV>
  void T_LinearForm<TV> ::
  AddElementVector (const ARRAY<int> & dnums,
		    const FlatVector<TSCAL> & elvec) 
  {
    FlatVector<TV> fv = vec->FV();
    for (int k = 0; k < dnums.Size(); k++)
      if (dnums[k] != -1)
	for (int j = 0; j < HEIGHT; j++)
	  fv(dnums[k])(j) += elvec(k*HEIGHT+j);
  }
  
  template <> void T_LinearForm<double>::
  AddElementVector (const ARRAY<int> & dnums,
		    const FlatVector<double> & elvec) 
  {
    FlatVector<double> fv = vec->FV();
    for (int k = 0; k < dnums.Size(); k++)
      if (dnums[k] != -1)
	fv(dnums[k]) += elvec(k);
  }
  
  template <> void T_LinearForm<Complex>::
  AddElementVector (const ARRAY<int> & dnums,
		    const FlatVector<Complex> & elvec) 
  {
    FlatVector<Complex> fv = vec->FV();
    for (int k = 0; k < dnums.Size(); k++)
      if (dnums[k] != -1)
	fv(dnums[k]) += elvec(k);
  }
  




  template <typename TV>
  void T_LinearForm<TV> ::
  SetElementVector (const ARRAY<int> & dnums,
		    const FlatVector<TSCAL> & elvec) 
  {
    FlatVector<TV> fv = vec->FV();
    for (int k = 0; k < dnums.Size(); k++)
      if (dnums[k] != -1)
	for (int j = 0; j < HEIGHT; j++)
	  fv(dnums[k])(j) = elvec(k*HEIGHT+j);
  }
  
  template <> void T_LinearForm<double>::
  SetElementVector (const ARRAY<int> & dnums,
		    const FlatVector<double> & elvec) 
  {
    FlatVector<double> fv = vec->FV();
    for (int k = 0; k < dnums.Size(); k++)
      if (dnums[k] != -1)
	fv(dnums[k]) = elvec(k);
  }
  
  template <> void T_LinearForm<Complex>::
  SetElementVector (const ARRAY<int> & dnums,
		    const FlatVector<Complex> & elvec) 
  {
    FlatVector<Complex> fv = vec->FV();
    for (int k = 0; k < dnums.Size(); k++)
      if (dnums[k] != -1)
	fv(dnums[k]) = elvec(k);
  }
  




  template <typename TV>
  void T_LinearForm<TV> ::
  GetElementVector (const ARRAY<int> & dnums,
		    FlatVector<TSCAL> & elvec) const
  {
    FlatVector<TV> fv = vec->FV();
    for (int k = 0; k < dnums.Size(); k++)
      if (dnums[k] != -1)
	for (int j = 0; j < HEIGHT; j++)
	  elvec(k*HEIGHT+j) = fv(dnums[k])(j);
  }
  
  template <> void T_LinearForm<double>::
  GetElementVector (const ARRAY<int> & dnums,
		    FlatVector<double> & elvec) const
  {
    FlatVector<double> fv = vec->FV();
    for (int k = 0; k < dnums.Size(); k++)
      if (dnums[k] != -1)
	elvec(k)= fv(dnums[k]);
  }
  
  template <> void T_LinearForm<Complex>::
  GetElementVector (const ARRAY<int> & dnums,
		    FlatVector<Complex> & elvec) const
  {
    FlatVector<Complex> fv = vec->FV();
    for (int k = 0; k < dnums.Size(); k++)
      if (dnums[k] != -1)
	elvec(k)= fv(dnums[k]);
  }
  







  LinearForm * CreateLinearForm (const FESpace * space,
				 const string & name, const Flags & flags)
  {
    /*
    LinearForm * lf;
    CreateVecObject2 (lf, T_LinearForm, 
		      space->GetDimension(), space->IsComplex(),   
		      *space, name);

    lf->SetIndependent (flags.GetDefineFlag ("independent"));
    return lf;
    */

    LinearForm * lf = 
      CreateVecObject  <T_LinearForm, LinearForm, const FESpace, const string, const Flags>
      (space->GetDimension(), space->IsComplex(), *space, name, flags);
    lf->SetIndependent (flags.GetDefineFlag ("independent"));
    return lf;
  }

}
