import React, { ReactNode } from 'react';
import {
  Checkbox as MuiCheckbox,
  TableCell as MuiTableCell,
  TableHead as MuiTableHead,
  TableRow as MuiTableRow,
} from '@mui/material';
import { SystemStyleObject } from '@mui/system/styleFunctionSx/styleFunctionSx';
import { ConditionalRender } from '@egym/ui';
import { GroupedTableColumn, TableColumn, TableRenderSubComponentParams, TableState } from '../../TableProps';
import TableFilter from '../TableFilter';
import ThColumnCell from './components/ThColumnCell';
import ThGroupedColumnCell from './components/ThGroupedColumnCell';

type Props = {
  columns: (TableColumn | GroupedTableColumn)[];
  hasData: boolean;
  checkboxSelectionProp?: string;
  renderSubComponent?: (params: TableRenderSubComponentParams) => ReactNode;
  onSelectAllClick: (event: React.ChangeEvent<HTMLInputElement>) => void;
  totalRowsCount: number;
  isSelectAllChecked: boolean;
  isSelectAllIndeterminate: boolean;
  showCheckAll?: boolean;
  tableState: TableState;
  updateTableStateSorting: (fieldName: string) => void;
  updateTableStateFilters: (fieldNames: string[], value: any) => void;
  hasFilterableColumns: boolean;
  headersCellSx?: SystemStyleObject;
  isViewMode?: boolean;
  testIdPrefix: string;
};

const TableHead: React.FC<Props> = ({
  columns,
  hasData,
  checkboxSelectionProp,
  renderSubComponent,
  onSelectAllClick,
  totalRowsCount,
  isSelectAllChecked,
  isSelectAllIndeterminate,
  showCheckAll,
  tableState,
  updateTableStateSorting,
  updateTableStateFilters,
  hasFilterableColumns,
  headersCellSx,
  isViewMode,
  testIdPrefix,
}) => {
  const groupedColumns: {
    group?: GroupedTableColumn;
    first?: boolean;
    column: TableColumn;
  }[] = columns.flatMap(column => {
    const { children } = column as GroupedTableColumn;
    if (children) {
      return children.map((child, index) => ({ group: column, first: index === 0, column: child }));
    }
    return [{ column: column as TableColumn }];
  });

  return (
    <MuiTableHead>
      <MuiTableRow>
        {!!checkboxSelectionProp && !!totalRowsCount && (
          <MuiTableCell variant="head" component="th" padding="checkbox">
            <ConditionalRender condition={showCheckAll}>
              <MuiCheckbox
                color="primary"
                indeterminate={isSelectAllIndeterminate}
                checked={isSelectAllChecked}
                onChange={onSelectAllClick}
              />
            </ConditionalRender>
          </MuiTableCell>
        )}
        {groupedColumns
          .filter(({ column }) => !(column.hidden && column.hidden(isViewMode)))
          .map(({ group, first, column }) =>
            group ? (
              <ThGroupedColumnCell
                key={column.field}
                group={group}
                column={column}
                first={Boolean(first)}
                headersCellSx={{ textTransform: 'uppercase', ...headersCellSx }}
                tableState={tableState}
                updateTableStateSorting={updateTableStateSorting}
                testIdPrefix={testIdPrefix}
              />
            ) : (
              <ThColumnCell
                key={column.field}
                column={column}
                headersCellSx={{ textTransform: 'uppercase', ...headersCellSx }}
                tableState={tableState}
                updateTableStateSorting={updateTableStateSorting}
                testIdPrefix={testIdPrefix}
              />
            ),
          )}
        {renderSubComponent && <MuiTableCell />}
      </MuiTableRow>
      {hasFilterableColumns && hasData && (
        <MuiTableRow>
          {!!checkboxSelectionProp && !!totalRowsCount && (
            <MuiTableCell variant="head" component="th" padding="checkbox" />
          )}
          {groupedColumns
            .filter(({ column }) => !(column.hidden && column.hidden(isViewMode)))
            .map(({ column }) => (
              <MuiTableCell
                key={column.field}
                variant="head"
                component="th"
                align={column.numeric ? 'right' : 'left'}
                padding={column.disablePadding ? 'none' : 'normal'}
                width={column.width}
                sx={{ py: 1, ...column.headerCellSx, display: 'table-cell', bgcolor: 'grey.100' }}
                data-testid={`${testIdPrefix}table-filter-cell-${column.field}`}
              >
                <TableFilter
                  column={column}
                  updateTableStateFilters={updateTableStateFilters}
                  tableState={tableState}
                />
              </MuiTableCell>
            ))}
        </MuiTableRow>
      )}
    </MuiTableHead>
  );
};

export default TableHead;
