import { GridInitialState } from '@mui/x-data-grid';
import { GridApi, GridColumns } from '@mui/x-data-grid-pro';
import { useCallback, useMemo } from 'react';

/**
 * Store and retrieve column state from local storage.
 * @param key the name of the state in the local storage, it will be used to retrieve the state
 * @param initialState the DataGrid initialState
 * @param columns an array representing the columns of the DataGridPro
 * @param apiRef the DataGrid apiRef used to export the current state
 * @returns An object containing a method to store the state, the updated initial state and the updated columns.
 */
const useColumnsState = (
	key: string,
	initialState: GridInitialState,
	columns: GridColumns,
	apiRef: React.MutableRefObject<GridApi>
) => {
	const storeColumnState = useCallback(() => {
		const state = apiRef.current?.exportState();
		if (state) {
			localStorage.setItem(key, JSON.stringify(state.columns));
		}
	}, [key, apiRef]);

	const { updatedInitialState, updatedColumns } = useMemo(() => {
		const stringifiedColumnsState = localStorage.getItem(key);
		const updatedInitialState = {
			...initialState,
			columns: {
				columnVisibilityModel: {},
				...initialState.columns,
			},
		};
		const updatedColumns: GridColumns = stringifiedColumnsState ? [] : columns;
		if (stringifiedColumnsState) {
			const columnsState = JSON.parse(stringifiedColumnsState);

			// Restore the width of the columns
			for (const column of columns) {
				const storedWidth = columnsState.dimensions[column.field]?.width;
				updatedColumns.push(storedWidth ? { ...column, width: storedWidth } : column);
			}

			// Restore the order of the columns
			updatedColumns.sort((column1, column2) => {
				let column1Index = columnsState.orderedFields.indexOf(column1.field);
				let column2Index = columnsState.orderedFields.indexOf(column2.field);

				if (column1Index === -1) {
					column1Index = columns.indexOf(column1);
				}

				if (column2Index === -1) {
					column2Index = columns.indexOf(column2);
				}

				return column1Index - column2Index;
			});

			// Restore the visibility of the columns
			if (columnsState.columnVisibilityModel) {
				updatedInitialState.columns.columnVisibilityModel = columnsState.columnVisibilityModel;
			}
		}
		return { updatedInitialState, updatedColumns };
	}, [key, initialState, columns]);

	return { storeColumnState, updatedInitialState, updatedColumns };
};

export { useColumnsState };
