import React, { HTMLAttributes, ReactNode } from "react";
import { makeStyles, createStyles, Theme } from "@material-ui/core/styles";
import { Breakpoint } from "@material-ui/core/styles/createBreakpoints";
import MasonryController, { IColumn } from "./MasonryController";
import clsx from "clsx";
import useMediaQuery from "@material-ui/core/useMediaQuery";

const useStyles = makeStyles((theme: Theme) =>
    createStyles({
        root: { display: "flex", flex: "1" },
        column: { flex: "1", display: "flex", flexDirection: "column" },
    })
);

export interface IMasonryBreakPoint {
    breakpoint: Breakpoint;
    numberOfColums: number;
}

type IClasses = Partial<ReturnType<typeof useStyles>>;

interface IProps<T> extends HTMLAttributes<HTMLDivElement> {
    breakpoints: IMasonryBreakPoint[]; // all brakpoints are calculated using up
    items: T[];
    classes?: IClasses;
    renderItem: (item: T) => ReactNode;
}

export default function Masonry<T>(props: IProps<T>) {
    const classes = useStyles();
    const { breakpoints, items, renderItem, className = "", classes: overideClasses } = props;

    const listDisplayController = new MasonryController(items);

    return (
        <div className={clsx(classes.root, className, overideClasses?.root)}>
            {breakpoints.map((breakpoint: IMasonryBreakPoint) => (
                <DisplayGrid
                    breakPoint={breakpoint}
                    className={clsx(classes.column, overideClasses?.column)}
                    columns={listDisplayController.getColumns(breakpoint.numberOfColums)}
                    renderItem={renderItem}
                />
            ))}
        </div>
    );
}

interface IDisplayGridProps<T> extends HTMLAttributes<HTMLDivElement> {
    breakPoint: IMasonryBreakPoint;
    columns: IColumn<T>[];
    renderItem: (item: T) => ReactNode;
}

function DisplayGrid<T>(props: IDisplayGridProps<T>) {
    const { columns, renderItem, breakPoint, className = "" } = props;

    const matches = useMediaQuery((theme: Theme) => theme.breakpoints.only(breakPoint.breakpoint));

    return (
        <>
            {matches
                ? columns.map((column: T[], i: number) => (
                      <div className={className} key={i}>
                          {column.map((item: T) => renderItem(item))}
                      </div>
                  ))
                : null}
        </>
    );
}
