import type { DataGridProps, GridColDef } from "@mui/x-data-grid";
import { DataGrid } from "@mui/x-data-grid";
import type { FC } from "react";
import { useState } from "react";
import type { QueryKey } from "react-query";
import { useQuery } from "react-query";
import { get } from "../lib/amplify";
import type { CountResponse } from "../shared/api_schema";
import { LoadingScreen } from "./LoadingScreen";

const PAGE_SIZE = 10;

export type PaginatedCustomModel = {
  key: QueryKey; // used for the query keys for count and data queries
  countPath: string; // used to get a count of items
  dataPath: (skip: number, take: number) => string; // used to get the items themselves
  responseKey: string; // "appointments" -- used to access the server response at the right key
};

export const PaginatedCustomDataGrid: FC<
  Partial<Omit<DataGridProps, "pagination">> & {
    model: PaginatedCustomModel;
    columns: GridColDef[];
  }
> = ({ model, columns, ...rest }) => {
  const [skip, setSkip] = useState(0);

  function queryKey(suffix: unknown) {
    if (Array.isArray(model.key)) {
      return [...model.key, suffix];
    } else {
      return [model.key, suffix];
    }
  }

  const countQuery = useQuery<CountResponse>(queryKey("count"), async () =>
    get(model.countPath)
  );

  const dataQuery = useQuery(
    queryKey(skip),
    async () => get(model.dataPath(skip, PAGE_SIZE)),
    {
      keepPreviousData: true,
      enabled: (countQuery.data?.count ?? 0) > 0,
    }
  );

  function forceUnsortable(columns: GridColDef[]): GridColDef[] {
    return columns.map((c) => ({ ...c, sortable: false }));
  }

  // Not allowed to return `undefined` for `rowCount`
  if (countQuery.isFetching) {
    return <LoadingScreen />;
  }

  return (
    <DataGrid
      loading={dataQuery.isFetching}
      rows={dataQuery.data?.[model.responseKey] ?? []}
      columns={forceUnsortable(columns)}
      disableRowSelectionOnClick
      autoHeight
      paginationMode="server"
      paginationModel={{ page: skip / PAGE_SIZE, pageSize: PAGE_SIZE }}
      onPaginationModelChange={({ page, pageSize }) => setSkip(page * pageSize)}
      pageSizeOptions={[PAGE_SIZE]}
      rowCount={countQuery.data!.count}
      {...rest}
    />
  );
};
