"""
Module for reading Lobster output files. For more information
on LOBSTER see www.cohp.de.
If you use this module, please cite:
J. George, G. Petretto, A. Naik, M. Esters, A. J. Jackson, R. Nelson, R. Dronskowski, G.-M. Rignanese, G. Hautier,
"Automated Bonding Analysis with Crystal Orbital Hamilton Populations",
ChemPlusChem 2022, e202200123,
DOI: 10.1002/cplu.202200123.
"""

from __future__ import annotations

import collections
import fnmatch
import os
import re
import warnings
from collections import defaultdict
from typing import TYPE_CHECKING, Any

import numpy as np
from monty.io import zopen

from pymatgen.core.structure import Structure
from pymatgen.electronic_structure.bandstructure import LobsterBandStructureSymmLine
from pymatgen.electronic_structure.core import Orbital, Spin
from pymatgen.electronic_structure.dos import Dos, LobsterCompleteDos
from pymatgen.io.vasp.inputs import Kpoints
from pymatgen.io.vasp.outputs import Vasprun, VolumetricData
from pymatgen.util.due import Doi, due

if TYPE_CHECKING:
    from pymatgen.core.structure import IStructure

__author__ = "Janine George, Marco Esters"
__copyright__ = "Copyright 2017, The Materials Project"
__version__ = "0.2"
__maintainer__ = "Janine George "
__email__ = "janinegeorge.ulfen@gmail.com"
__date__ = "Dec 13, 2017"

MODULE_DIR = os.path.dirname(os.path.abspath(__file__))

due.cite(
    Doi("10.1002/cplu.202200123"),
    description="Automated Bonding Analysis with Crystal Orbital Hamilton Populations",
)


class Cohpcar:
    """
    Class to read COHPCAR/COOPCAR/COBICAR files generated by LOBSTER.

    Attributes:
        cohp_data (dict[str, Dict[str, Any]]): A dictionary containing the COHP data of the form:
            {bond: {"COHP": {Spin.up: cohps, Spin.down:cohps},
                    "ICOHP": {Spin.up: icohps, Spin.down: icohps},
                    "length": bond length,
                    "sites": sites corresponding to the bond}
            Also contains an entry for the average, which does not have a "length" key.
        efermi (float): The Fermi energy in eV.
        energies (Sequence[float]): Sequence of energies in eV. Note that LOBSTER shifts the energies
            so that the Fermi energy is at zero.
        is_spin_polarized (bool): Boolean to indicate if the calculation is spin polarized.
        orb_cohp (dict[str, Dict[str, Dict[str, Any]]]): A dictionary containing the orbital-resolved COHPs of the form:
            orb_cohp[label] = {bond_data["orb_label"]: {
                "COHP": {Spin.up: cohps, Spin.down:cohps},
                "ICOHP": {Spin.up: icohps, Spin.down: icohps},
                "orbitals": orbitals,
                "length": bond lengths,
                "sites": sites corresponding to the bond},
            }
    """

    def __init__(self, are_coops: bool = False, are_cobis: bool = False, filename: str | None = None):
        """
        Args:
            are_coops: Determines if the file is a list of COHPs or COOPs.
              Default is False for COHPs.
            are_cobis: Determines if the file is a list of COHPs or COBIs.
              Default is False for COHPs.

            filename: Name of the COHPCAR file. If it is None, the default
              file name will be chosen, depending on the value of are_coops.
        """
        if are_coops and are_cobis:
            raise ValueError("You cannot have info about COOPs and COBIs in the same file.")
        self.are_coops = are_coops
        self.are_cobis = are_cobis
        if filename is None:
            if are_coops:
                filename = "COOPCAR.lobster"
            elif are_cobis:
                filename = "COBICAR.lobster"
            else:
                filename = "COHPCAR.lobster"

        with zopen(filename, mode="rt") as file:
            contents = file.read().split("\n")

        # The parameters line is the second line in a COHPCAR file. It
        # contains all parameters that are needed to map the file.
        parameters = contents[1].split()
        # Subtract 1 to skip the average
        num_bonds = int(parameters[0]) - 1
        self.efermi = float(parameters[-1])
        self.is_spin_polarized = int(parameters[1]) == 2
        spins = [Spin.up, Spin.down] if int(parameters[1]) == 2 else [Spin.up]

        # The COHP data start in row num_bonds + 3
        data = np.array([np.array(row.split(), dtype=float) for row in contents[num_bonds + 3 :]]).transpose()
        self.energies = data[0]
        cohp_data: dict[str, dict[str, Any]] = {
            "average": {
                "COHP": {spin: data[1 + 2 * s * (num_bonds + 1)] for s, spin in enumerate(spins)},
                "ICOHP": {spin: data[2 + 2 * s * (num_bonds + 1)] for s, spin in enumerate(spins)},
            }
        }

        orb_cohp: dict[str, Any] = {}
        # present for Lobster versions older than Lobster 2.2.0
        very_old = False
        # the labeling had to be changed: there are more than one COHP for each atom combination
        # this is done to make the labeling consistent with ICOHPLIST.lobster
        bond_num = 0
        for bond in range(num_bonds):
            bond_data = self._get_bond_data(contents[3 + bond])

            label = str(bond_num)

            orbs = bond_data["orbitals"]
            cohp = {spin: data[2 * (bond + s * (num_bonds + 1)) + 3] for s, spin in enumerate(spins)}

            icohp = {spin: data[2 * (bond + s * (num_bonds + 1)) + 4] for s, spin in enumerate(spins)}
            if orbs is None:
                bond_num = bond_num + 1
                label = str(bond_num)
                cohp_data[label] = {
                    "COHP": cohp,
                    "ICOHP": icohp,
                    "length": bond_data["length"],
                    "sites": bond_data["sites"],
                }

            elif label in orb_cohp:
                orb_cohp[label].update(
                    {
                        bond_data["orb_label"]: {
                            "COHP": cohp,
                            "ICOHP": icohp,
                            "orbitals": orbs,
                            "length": bond_data["length"],
                            "sites": bond_data["sites"],
                        }
                    }
                )
            else:
                # present for Lobster versions older than Lobster 2.2.0
                if bond_num == 0:
                    very_old = True
                if very_old:
                    bond_num += 1
                    label = str(bond_num)

                orb_cohp[label] = {
                    bond_data["orb_label"]: {
                        "COHP": cohp,
                        "ICOHP": icohp,
                        "orbitals": orbs,
                        "length": bond_data["length"],
                        "sites": bond_data["sites"],
                    }
                }

        # present for lobster older than 2.2.0
        if very_old:
            for bond_str in orb_cohp:
                cohp_data[bond_str] = {
                    "COHP": None,
                    "ICOHP": None,
                    "length": bond_data["length"],
                    "sites": bond_data["sites"],
                }

        self.orb_res_cohp = orb_cohp or None
        self.cohp_data = cohp_data

    @staticmethod
    def _get_bond_data(line: str) -> dict:
        """
        Subroutine to extract bond label, site indices, and length from
        a LOBSTER header line. The site indices are zero-based, so they
        can be easily used with a Structure object.

        Example header line: No.4:Fe1->Fe9(2.4524893531900283)
        Example header line for orbtial-resolved COHP:
            No.1:Fe1[3p_x]->Fe2[3d_x^2-y^2](2.456180552772262)

        Args:
            line: line in the COHPCAR header describing the bond.

        Returns:
            Dict with the bond label, the bond length, a tuple of the site
            indices, a tuple containing the orbitals (if orbital-resolved),
            and a label for the orbitals (if orbital-resolved).
        """
        line_new = line.rsplit("(", 1)
        length = float(line_new[-1][:-1])

        sites = line_new[0].replace("->", ":").split(":")[1:3]
        site_indices = tuple(int(re.split(r"\D+", site)[1]) - 1 for site in sites)

        if "[" in sites[0]:
            orbs = [re.findall(r"\[(.*)\]", site)[0] for site in sites]
            orb_label, orbitals = get_orb_from_str(orbs)

        else:
            orbitals = orb_label = None

        return {
            "length": length,
            "sites": site_indices,
            "orbitals": orbitals,
            "orb_label": orb_label,
        }


class Icohplist:
    """
    Class to read ICOHPLIST/ICOOPLIST files generated by LOBSTER.

    Attributes:
        are_coops (bool): Indicates whether the object is consisting of COOPs.
        is_spin_polarized (bool): Boolean to indicate if the calculation is spin polarized.
        Icohplist (dict[str, Dict[str, Union[float, int, Dict[Spin, float]]]]): Dict containing the
            listfile data of the form: {
                bond: "length": bond length,
                "number_of_bonds": number of bonds
                "icohp": {Spin.up: ICOHP(Ef) spin up, Spin.down: ...}
            }
        IcohpCollection (IcohpCollection): IcohpCollection Object.
    """

    def __init__(self, are_coops: bool = False, are_cobis: bool = False, filename: str | None = None):
        """
        Args:
            are_coops: Determines if the file is a list of ICOOPs.
              Defaults to False for ICOHPs.
            are_cobis: Determines if the file is a list of ICOBIs.
              Defaults to False for ICOHPs.
            filename: Name of the ICOHPLIST file. If it is None, the default
              file name will be chosen, depending on the value of are_coops.
        """
        if are_coops and are_cobis:
            raise ValueError("You cannot have info about COOPs and COBIs in the same file.")
        self.are_coops = are_coops
        self.are_cobis = are_cobis
        if filename is None:
            if are_coops:
                filename = "ICOOPLIST.lobster"
            elif are_cobis:
                filename = "ICOBILIST.lobster"
            else:
                filename = "ICOHPLIST.lobster"

        # LOBSTER list files have an extra trailing blank line
        # and we don't need the header.
        with zopen(filename, mode="rt") as file:
            data = file.read().split("\n")[1:-1]
        if len(data) == 0:
            raise OSError("ICOHPLIST file contains no data.")

        # Which Lobster version?
        if len(data[0].split()) == 8:
            version = "3.1.1"
        elif len(data[0].split()) == 6:
            version = "2.2.1"
            warnings.warn("Please consider using the new Lobster version. See www.cohp.de.")
        else:
            raise ValueError

        # If the calculation is spin polarized, the line in the middle
        # of the file will be another header line.
        # TODO: adapt this for orbital-wise stuff
        self.is_spin_polarized = "distance" in data[len(data) // 2]

        # check if orbital-wise ICOHPLIST
        # include case when there is only one ICOHP!!!
        self.orbitalwise = len(data) > 2 and "_" in data[1].split()[1]

        if self.orbitalwise:
            data_without_orbitals = []
            data_orbitals = []
            for line in data:
                if "_" not in line.split()[1]:
                    data_without_orbitals.append(line)
                else:
                    data_orbitals.append(line)

        else:
            data_without_orbitals = data

        if "distance" in data_without_orbitals[len(data_without_orbitals) // 2]:
            # TODO: adapt this for orbital-wise stuff
            num_bonds = len(data_without_orbitals) // 2
            if num_bonds == 0:
                raise OSError("ICOHPLIST file contains no data.")
        else:
            num_bonds = len(data_without_orbitals)

        list_labels = []
        list_atom1 = []
        list_atom2 = []
        list_length = []
        list_translation = []
        list_num = []
        list_icohp = []

        for bond in range(num_bonds):
            line = data_without_orbitals[bond].split()
            icohp = {}
            if version == "2.2.1":
                label = f"{line[0]}"
                atom1 = str(line[1])
                atom2 = str(line[2])
                length = float(line[3])
                icohp[Spin.up] = float(line[4])
                num = int(line[5])
                translation = [0, 0, 0]
                if self.is_spin_polarized:
                    icohp[Spin.down] = float(data_without_orbitals[bond + num_bonds + 1].split()[4])

            elif version == "3.1.1":
                label = f"{line[0]}"
                atom1 = str(line[1])
                atom2 = str(line[2])
                length = float(line[3])
                translation = [int(line[4]), int(line[5]), int(line[6])]
                icohp[Spin.up] = float(line[7])
                num = 1

                if self.is_spin_polarized:
                    icohp[Spin.down] = float(data_without_orbitals[bond + num_bonds + 1].split()[7])

            list_labels.append(label)
            list_atom1.append(atom1)
            list_atom2.append(atom2)
            list_length.append(length)
            list_translation.append(translation)
            list_num.append(num)
            list_icohp.append(icohp)

        list_orb_icohp: list[dict] | None = None
        if self.orbitalwise:
            list_orb_icohp = []
            num_orbs = len(data_orbitals) // 2 if self.is_spin_polarized else len(data_orbitals)

            for i_data_orb in range(num_orbs):
                data_orb = data_orbitals[i_data_orb]
                icohp = {}
                line = data_orb.split()
                label = f"{line[0]}"
                orbs = re.findall(r"_(.*?)(?=\s)", data_orb)
                orb_label, orbitals = get_orb_from_str(orbs)
                icohp[Spin.up] = float(line[7])

                if self.is_spin_polarized:
                    icohp[Spin.down] = float(data_orbitals[num_orbs + i_data_orb].split()[7])

                if len(list_orb_icohp) < int(label):
                    list_orb_icohp.append({orb_label: {"icohp": icohp, "orbitals": orbitals}})
                else:
                    list_orb_icohp[int(label) - 1][orb_label] = {"icohp": icohp, "orbitals": orbitals}

        # to avoid circular dependencies
        from pymatgen.electronic_structure.cohp import IcohpCollection

        self._icohpcollection = IcohpCollection(
            are_coops=are_coops,
            are_cobis=are_cobis,
            list_labels=list_labels,
            list_atom1=list_atom1,
            list_atom2=list_atom2,
            list_length=list_length,
            list_translation=list_translation,
            list_num=list_num,
            list_icohp=list_icohp,
            is_spin_polarized=self.is_spin_polarized,
            list_orb_icohp=list_orb_icohp,
        )

    @property
    def icohplist(self) -> dict[Any, dict[str, Any]]:
        """Returns: icohplist compatible with older version of this class."""
        icohplist_new = {}
        for key, value in self._icohpcollection._icohplist.items():
            icohplist_new[key] = {
                "length": value._length,
                "number_of_bonds": value._num,
                "icohp": value._icohp,
                "translation": value._translation,
                "orbitals": value._orbitals,
            }
        return icohplist_new

    @property
    def icohpcollection(self):
        """Returns: IcohpCollection object."""
        return self._icohpcollection


class NciCobiList:
    """
    Class to read NcICOBILIST (multi-center ICOBI) files generated by LOBSTER.

    Attributes:
        is_spin_polarized (bool): Boolean to indicate if the calculation is spin polarized.
        NciCobiList (dict): Dict containing the listfile data of the form:
        {bond: "number_of_atoms": number of atoms involved in the multi-center interaction,
                "ncicobi": {Spin.up: Nc-ICOBI(Ef) spin up, Spin.down: ...}},
                "interaction_type": type of the multi-center interaction
    """

    def __init__(self, filename: str | None = "NcICOBILIST.lobster"):  # LOBSTER < 4.1.0: no COBI/ICOBI/NcICOBI
        """
        Args:
            filename: Name of the NcICOBILIST file.
        """

        # LOBSTER list files have an extra trailing blank line
        # and we don't need the header.
        with zopen(filename, mode="rt") as file:
            data = file.read().split("\n")[1:-1]
        if len(data) == 0:
            raise OSError("NcICOBILIST file contains no data.")

        # If the calculation is spin-polarized, the line in the middle
        # of the file will be another header line.
        self.is_spin_polarized = "spin" in data[len(data) // 2]  # TODO: adapt this for orbitalwise case

        # check if orbitalwise NcICOBILIST
        # include case when there is only one NcICOBI
        self.orbital_wise = False  # set as default
        for entry in data:  # NcICOBIs orbitalwise and non-orbitalwise can be mixed
            if len(data) > 2 and "s]" in str(entry.split()[3:]):
                self.orbital_wise = True
                warnings.warn(
                    "This is an orbitalwise NcICOBILIST.lobster file. "
                    "Currently, the orbitalwise information is not read!"
                )
                break  # condition has only to be met once

        if self.orbital_wise:
            data_without_orbitals = []
            for line in data:
                if "_" not in str(line.split()[3:]) and "s]" not in str(line.split()[3:]):
                    data_without_orbitals.append(line)
        else:
            data_without_orbitals = data

        if "spin" in data_without_orbitals[len(data_without_orbitals) // 2]:
            # TODO: adapt this for orbitalwise case
            n_bonds = len(data_without_orbitals) // 2
            if n_bonds == 0:
                raise OSError("NcICOBILIST file contains no data.")
        else:
            n_bonds = len(data_without_orbitals)

        self.list_labels = []
        self.list_n_atoms = []
        self.list_ncicobi = []
        self.list_interaction_type = []
        self.list_num = []

        for bond in range(n_bonds):
            line = data_without_orbitals[bond].split()
            ncicobi = {}

            label = f"{line[0]}"
            n_atoms = str(line[1])
            ncicobi[Spin.up] = float(line[2])
            interaction_type = str(line[3:]).replace("'", "").replace(" ", "")
            num = 1

            if self.is_spin_polarized:
                ncicobi[Spin.down] = float(data_without_orbitals[bond + n_bonds + 1].split()[2])

            self.list_labels.append(label)
            self.list_n_atoms.append(n_atoms)
            self.list_ncicobi.append(ncicobi)
            self.list_interaction_type.append(interaction_type)
            self.list_num.append(num)

        # TODO: add functions to get orbital resolved NcICOBIs

    @property
    def ncicobi_list(self) -> dict[Any, dict[str, Any]]:
        """
        Returns: ncicobilist.
        """
        ncicobi_list = {}
        for idx in range(len(self.list_labels)):
            ncicobi_list[str(idx + 1)] = {
                "number_of_atoms": int(self.list_n_atoms[idx]),
                "ncicobi": self.list_ncicobi[idx],
                "interaction_type": self.list_interaction_type[idx],
            }

        return ncicobi_list


class Doscar:
    """
    Class to deal with Lobster's projected DOS and local projected DOS.
    The beforehand quantum-chemical calculation was performed with VASP.

    Attributes:
        completedos (LobsterCompleteDos): LobsterCompleteDos Object.
        pdos (list): List of Dict including numpy arrays with pdos. Access as
            pdos[atomindex]['orbitalstring']['Spin.up/Spin.down'].
        tdos (Dos): Dos Object of the total density of states.
        energies (numpy.ndarray): Numpy array of the energies at which the DOS was calculated
            (in eV, relative to Efermi).
        tdensities (dict): tdensities[Spin.up]: numpy array of the total density of states for
            the Spin.up contribution at each of the energies. tdensities[Spin.down]: numpy array
            of the total density of states for the Spin.down contribution at each of the energies.
            If is_spin_polarized=False, tdensities[Spin.up]: numpy array of the total density of states.
        itdensities (dict): itdensities[Spin.up]: numpy array of the total density of states for
            the Spin.up contribution at each of the energies. itdensities[Spin.down]: numpy array
            of the total density of states for the Spin.down contribution at each of the energies.
            If is_spin_polarized=False, itdensities[Spin.up]: numpy array of the total density of states.
        is_spin_polarized (bool): Boolean. Tells if the system is spin polarized.
    """

    def __init__(
        self,
        doscar: str = "DOSCAR.lobster",
        structure_file: str | None = "POSCAR",
        structure: IStructure | Structure | None = None,
    ):
        """
        Args:
            doscar: DOSCAR filename, typically "DOSCAR.lobster"
            structure_file: for vasp, this is typically "POSCAR"
            structure: instead of a structure file, the structure can be given
                directly. structure_file will be preferred.
        """
        self._doscar = doscar

        self._final_structure = Structure.from_file(structure_file) if structure_file is not None else structure

        self._parse_doscar()

    def _parse_doscar(self):
        doscar = self._doscar

        tdensities = {}
        itdensities = {}
        with zopen(doscar, mode="rt") as file:
            natoms = int(file.readline().split()[0])
            efermi = float([file.readline() for nn in range(4)][3].split()[17])
            dos = []
            orbitals = []
            for _atom in range(natoms + 1):
                line = file.readline()
                ndos = int(line.split()[2])
                orbitals.append(line.split(";")[-1].split())
                line = file.readline().split()
                cdos = np.zeros((ndos, len(line)))
                cdos[0] = np.array(line)
                for nd in range(1, ndos):
                    line = file.readline().split()
                    cdos[nd] = np.array(line)
                dos.append(cdos)
        doshere = np.array(dos[0])
        if len(doshere[0, :]) == 5:
            self._is_spin_polarized = True
        elif len(doshere[0, :]) == 3:
            self._is_spin_polarized = False
        else:
            raise ValueError("There is something wrong with the DOSCAR. Can't extract spin polarization.")
        energies = doshere[:, 0]
        if not self._is_spin_polarized:
            tdensities[Spin.up] = doshere[:, 1]
            itdensities[Spin.up] = doshere[:, 2]
            pdoss = []
            spin = Spin.up
            for atom in range(natoms):
                pdos = defaultdict(dict)
                data = dos[atom + 1]
                _, ncol = data.shape
                orbnumber = 0
                for j in range(1, ncol):
                    orb = orbitals[atom + 1][orbnumber]
                    pdos[orb][spin] = data[:, j]
                    orbnumber = orbnumber + 1
                pdoss.append(pdos)
        else:
            tdensities[Spin.up] = doshere[:, 1]
            tdensities[Spin.down] = doshere[:, 2]
            itdensities[Spin.up] = doshere[:, 3]
            itdensities[Spin.down] = doshere[:, 4]
            pdoss = []
            for atom in range(natoms):
                pdos = defaultdict(dict)
                data = dos[atom + 1]
                _, ncol = data.shape
                orbnumber = 0
                for j in range(1, ncol):
                    spin = Spin.down if j % 2 == 0 else Spin.up
                    orb = orbitals[atom + 1][orbnumber]
                    pdos[orb][spin] = data[:, j]
                    if j % 2 == 0:
                        orbnumber = orbnumber + 1
                pdoss.append(pdos)

        self._efermi = efermi
        self._pdos = pdoss
        self._tdos = Dos(efermi, energies, tdensities)
        self._energies = energies
        self._tdensities = tdensities
        self._itdensities = itdensities
        final_struct = self._final_structure

        pdossneu = {final_struct[i]: pdos for i, pdos in enumerate(self._pdos)}

        self._completedos = LobsterCompleteDos(final_struct, self._tdos, pdossneu)

    @property
    def completedos(self) -> LobsterCompleteDos:
        """LobsterCompleteDos"""
        return self._completedos

    @property
    def pdos(self) -> list:
        """Projected DOS"""
        return self._pdos

    @property
    def tdos(self) -> Dos:
        """Total DOS"""
        return self._tdos

    @property
    def energies(self) -> np.ndarray:
        """Energies"""
        return self._energies

    @property
    def tdensities(self) -> np.ndarray:
        """total densities as a np.ndarray"""
        return self._tdensities

    @property
    def itdensities(self) -> np.ndarray:
        """integrated total densities as a np.ndarray"""
        return self._itdensities

    @property
    def is_spin_polarized(self) -> bool:
        """Whether run is spin polarized."""
        return self._is_spin_polarized


class Charge:
    """
    Class to read CHARGE files generated by LOBSTER.

    Attributes:
        atomlist (list[str]): List of atoms in CHARGE.lobster.
        types (list[str]): List of types of atoms in CHARGE.lobster.
        Mulliken (list[float]): List of Mulliken charges of atoms in CHARGE.lobster.
        Loewdin (list[float]): List of Loewdin charges of atoms in CHARGE.Loewdin.
        num_atoms (int): Number of atoms in CHARGE.lobster.
    """

    def __init__(self, filename: str = "CHARGE.lobster"):
        """
        Args:
            filename: filename for the CHARGE file, typically "CHARGE.lobster".
        """
        with zopen(filename, mode="rt") as file:
            data = file.read().split("\n")[3:-3]
        if len(data) == 0:
            raise OSError("CHARGES file contains no data.")

        self.num_atoms = len(data)
        self.atomlist: list[str] = []
        self.types: list[str] = []
        self.Mulliken: list[float] = []
        self.Loewdin: list[float] = []
        for atom in range(self.num_atoms):
            line = data[atom].split()
            self.atomlist.append(line[1] + line[0])
            self.types.append(line[1])
            self.Mulliken.append(float(line[2]))
            self.Loewdin.append(float(line[3]))

    def get_structure_with_charges(self, structure_filename):
        """
        Get a Structure with Mulliken and Loewdin charges as site properties

        Args:
            structure_filename: filename of POSCAR

        Returns:
            Structure Object with Mulliken and Loewdin charges as site properties.
        """
        struct = Structure.from_file(structure_filename)
        Mulliken = self.Mulliken
        Loewdin = self.Loewdin
        site_properties = {"Mulliken Charges": Mulliken, "Loewdin Charges": Loewdin}
        return struct.copy(site_properties=site_properties)


class Lobsterout:
    """
    Class to read in the lobsterout and evaluate the spilling, save the basis, save warnings, save infos.

    Attributes:
        basis_functions (list[str]): List of basis functions that were used in lobster run as strings.
        basis_type (list[str]): List of basis type that were used in lobster run as strings.
        charge_spilling (list[float]): List of charge spilling (first entry: result for spin 1,
            second entry: result for spin 2 or not present).
        dft_program (str): String representing the DFT program used for the calculation of the wave function.
        elements (list[str]): List of strings of elements that were present in lobster calculation.
        has_charge (bool): Whether CHARGE.lobster is present.
        has_cohpcar (bool): Whether COHPCAR.lobster and ICOHPLIST.lobster are present.
        has_madelung (bool): Whether SitePotentials.lobster and MadelungEnergies.lobster are present.
        has_coopcar (bool): Whether COOPCAR.lobster and ICOOPLIST.lobster are present.
        has_cobicar (bool): Whether COBICAR.lobster and ICOBILIST.lobster are present.
        has_doscar (bool): Whether DOSCAR.lobster is present.
        has_doscar_lso (bool): Whether DOSCAR.LSO.lobster is present.
        has_projection (bool): Whether projectionData.lobster is present.
        has_bandoverlaps (bool): Whether bandOverlaps.lobster is present.
        has_density_of_energies (bool): Whether DensityOfEnergy.lobster is present.
        has_fatbands (bool): Whether fatband calculation was performed.
        has_grosspopulation (bool): Whether GROSSPOP.lobster is present.
        info_lines (str): String with additional infos on the run.
        info_orthonormalization (str): String with infos on orthonormalization.
        is_restart_from_projection (bool): Boolean that indicates that calculation was restarted
            from existing projection file.
        lobster_version (str): String that indicates Lobster version.
        number_of_spins (int): Integer indicating the number of spins.
        number_of_threads (int): Integer that indicates how many threads were used.
        timing (dict[str, float]): Dictionary with infos on timing.
        total_spilling (list[float]): List of values indicating the total spilling for spin
            channel 1 (and spin channel 2).
        warning_lines (str): String with all warnings.
    """

    # TODO: add tests for skipping COBI and madelung
    # TODO: add tests for including COBI and madelung
    def __init__(self, filename="lobsterout"):
        """
        Args:
            filename: filename of lobsterout.
        """
        # read in file
        with zopen(filename, mode="rt") as file:
            data = file.read().split("\n")  # [3:-3]
        if len(data) == 0:
            raise OSError("lobsterout does not contain any data")

        # check if Lobster starts from a projection
        self.is_restart_from_projection = "loading projection from projectionData.lobster..." in data

        self.lobster_version = self._get_lobster_version(data=data)

        self.number_of_threads = int(self._get_threads(data=data))
        self.dft_program = self._get_dft_program(data=data)

        self.number_of_spins = self._get_number_of_spins(data=data)
        chargespilling, totalspilling = self._get_spillings(data=data, number_of_spins=self.number_of_spins)
        self.charge_spilling = chargespilling
        self.total_spilling = totalspilling

        elements, basistype, basisfunctions = self._get_elements_basistype_basisfunctions(data=data)
        self.elements = elements
        self.basis_type = basistype
        self.basis_functions = basisfunctions

        wall_time, user_time, sys_time = self._get_timing(data=data)
        timing = {}
        timing["wall_time"] = wall_time
        timing["user_time"] = user_time
        timing["sys_time"] = sys_time
        self.timing = timing

        warninglines = self._get_all_warning_lines(data=data)
        self.warning_lines = warninglines

        orthowarning = self._get_warning_orthonormalization(data=data)
        self.info_orthonormalization = orthowarning

        infos = self._get_all_info_lines(data=data)
        self.info_lines = infos

        self.has_doscar = "writing DOSCAR.lobster..." in data and "SKIPPING writing DOSCAR.lobster..." not in data
        self.has_doscar_lso = (
            "writing DOSCAR.LSO.lobster..." in data and "SKIPPING writing DOSCAR.LSO.lobster..." not in data
        )
        self.has_cohpcar = (
            "writing COOPCAR.lobster and ICOOPLIST.lobster..." in data
            and "SKIPPING writing COOPCAR.lobster and ICOOPLIST.lobster..." not in data
        )
        self.has_coopcar = (
            "writing COHPCAR.lobster and ICOHPLIST.lobster..." in data
            and "SKIPPING writing COHPCAR.lobster and ICOHPLIST.lobster..." not in data
        )
        self.has_cobicar = (
            "writing COBICAR.lobster and ICOBILIST.lobster..." in data
            and "SKIPPING writing COBICAR.lobster and ICOBILIST.lobster..." not in data
        )

        self.has_charge = "SKIPPING writing CHARGE.lobster..." not in data
        self.has_projection = "saving projection to projectionData.lobster..." in data
        self.has_bandoverlaps = "WARNING: I dumped the band overlap matrices to the file bandOverlaps.lobster." in data
        self.has_fatbands = self._has_fatband(data=data)
        self.has_grosspopulation = "writing CHARGE.lobster and GROSSPOP.lobster..." in data
        self.has_density_of_energies = "writing DensityOfEnergy.lobster..." in data
        self.has_madelung = (
            "writing SitePotentials.lobster and MadelungEnergies.lobster..." in data
            and "skipping writing SitePotentials.lobster and MadelungEnergies.lobster..." not in data
        )

    def get_doc(self):
        """Returns: LobsterDict with all the information stored in lobsterout."""
        LobsterDict = {}
        # check if Lobster starts from a projection
        LobsterDict["restart_from_projection"] = self.is_restart_from_projection
        LobsterDict["lobster_version"] = self.lobster_version
        LobsterDict["threads"] = self.number_of_threads
        LobsterDict["dft_program"] = self.dft_program

        LobsterDict["charge_spilling"] = self.charge_spilling
        LobsterDict["total_spilling"] = self.total_spilling

        LobsterDict["elements"] = self.elements
        LobsterDict["basis_type"] = self.basis_type
        LobsterDict["basis_functions"] = self.basis_functions

        LobsterDict["timing"] = self.timing

        LobsterDict["warning_lines"] = self.warning_lines

        LobsterDict["info_orthonormalization"] = self.info_orthonormalization

        LobsterDict["info_lines"] = self.info_lines

        LobsterDict["has_doscar"] = self.has_doscar
        LobsterDict["has_doscar_lso"] = self.has_doscar_lso
        LobsterDict["has_cohpcar"] = self.has_cohpcar
        LobsterDict["has_coopcar"] = self.has_coopcar
        LobsterDict["has_cobicar"] = self.has_cobicar
        LobsterDict["has_charge"] = self.has_charge
        LobsterDict["has_madelung"] = self.has_madelung
        LobsterDict["has_projection"] = self.has_projection
        LobsterDict["has_bandoverlaps"] = self.has_bandoverlaps
        LobsterDict["has_fatbands"] = self.has_fatbands
        LobsterDict["has_grosspopulation"] = self.has_grosspopulation
        LobsterDict["has_density_of_energies"] = self.has_density_of_energies

        return LobsterDict

    @staticmethod
    def _get_lobster_version(data):
        for row in data:
            splitrow = row.split()
            if len(splitrow) > 1 and splitrow[0] == "LOBSTER":
                return splitrow[1]
        raise RuntimeError("Version not found.")

    @staticmethod
    def _has_fatband(data):
        for row in data:
            splitrow = row.split()
            if len(splitrow) > 1 and splitrow[1] == "FatBand":
                return True
        return False

    @staticmethod
    def _get_dft_program(data):
        for row in data:
            splitrow = row.split()
            if len(splitrow) > 4 and splitrow[3] == "program...":
                return splitrow[4]
        return None

    @staticmethod
    def _get_number_of_spins(data):
        if "spillings for spin channel 2" in data:
            return 2
        return 1

    @staticmethod
    def _get_threads(data):
        for row in data:
            splitrow = row.split()
            if len(splitrow) > 11 and ((splitrow[11]) == "threads" or (splitrow[11] == "thread")):
                return splitrow[10]
        raise ValueError("Threads not found.")

    @staticmethod
    def _get_spillings(data, number_of_spins):
        charge_spilling = []
        total_spilling = []
        for row in data:
            splitrow = row.split()
            if len(splitrow) > 2 and splitrow[2] == "spilling:":
                if splitrow[1] == "charge":
                    charge_spilling.append(np.float_(splitrow[3].replace("%", "")) / 100.0)
                if splitrow[1] == "total":
                    total_spilling.append(np.float_(splitrow[3].replace("%", "")) / 100.0)

            if len(charge_spilling) == number_of_spins and len(total_spilling) == number_of_spins:
                break

        return charge_spilling, total_spilling

    @staticmethod
    def _get_elements_basistype_basisfunctions(data):
        begin = False
        end = False
        elements = []
        basistype = []
        basisfunctions = []
        for row in data:
            if begin and not end:
                splitrow = row.split()
                if splitrow[0] not in [
                    "INFO:",
                    "WARNING:",
                    "setting",
                    "calculating",
                    "post-processing",
                    "saving",
                    "spillings",
                    "writing",
                ]:
                    elements.append(splitrow[0])
                    basistype.append(splitrow[1].replace("(", "").replace(")", ""))
                    # last sign is a ''
                    basisfunctions.append(splitrow[2:])
                else:
                    end = True
            if "setting up local basis functions..." in row:
                begin = True
        return elements, basistype, basisfunctions

    @staticmethod
    def _get_timing(data):
        # will give back wall, user and sys time
        begin = False
        # end=False
        # time=[]

        for row in data:
            splitrow = row.split()
            if "finished" in splitrow:
                begin = True
            if begin:
                if "wall" in splitrow:
                    wall_time = splitrow[2:10]
                if "user" in splitrow:
                    user_time = splitrow[0:8]
                if "sys" in splitrow:
                    sys_time = splitrow[0:8]

        wall_time_dict = {"h": wall_time[0], "min": wall_time[2], "s": wall_time[4], "ms": wall_time[6]}
        user_time_dict = {"h": user_time[0], "min": user_time[2], "s": user_time[4], "ms": user_time[6]}
        sys_time_dict = {"h": sys_time[0], "min": sys_time[2], "s": sys_time[4], "ms": sys_time[6]}

        return wall_time_dict, user_time_dict, sys_time_dict

    @staticmethod
    def _get_warning_orthonormalization(data):
        orthowarning = []
        for row in data:
            splitrow = row.split()
            if "orthonormalized" in splitrow:
                orthowarning.append(" ".join(splitrow[1:]))
        return orthowarning

    @staticmethod
    def _get_all_warning_lines(data):
        ws = []
        for row in data:
            splitrow = row.split()
            if len(splitrow) > 0 and splitrow[0] == "WARNING:":
                ws.append(" ".join(splitrow[1:]))
        return ws

    @staticmethod
    def _get_all_info_lines(data):
        infos = []
        for row in data:
            splitrow = row.split()
            if len(splitrow) > 0 and splitrow[0] == "INFO:":
                infos.append(" ".join(splitrow[1:]))
        return infos


class Fatband:
    """
    Reads in FATBAND_x_y.lobster files.

    Attributes:
        efermi (float): Fermi energy read in from vasprun.xml.
        eigenvals (dict[Spin, np.ndarray]): Eigenvalues as a dictionary of numpy arrays of shape (nbands, nkpoints).
            The first index of the array refers to the band and the second to the index of the kpoint.
            The kpoints are ordered according to the order of the kpoints_array attribute.
            If the band structure is not spin polarized, we only store one data set under Spin.up.
        is_spin_polarized (bool): Boolean that tells you whether this was a spin-polarized calculation.
        kpoints_array (list[np.ndarray]): List of kpoints as numpy arrays, in frac_coords of the given
            lattice by default.
        label_dict (dict[str, Union[str, np.ndarray]]): Dictionary that links a kpoint (in frac coords or Cartesian
            coordinates depending on the coords attribute) to a label.
        lattice (Lattice): Lattice object of reciprocal lattice as read in from vasprun.xml.
        nbands (int): Number of bands used in the calculation.
        p_eigenvals (dict[Spin, np.ndarray]): Dictionary of orbital projections as {spin: array of dict}.
            The indices of the array are [band_index, kpoint_index].
            The dict is then built the following way: {"string of element": "string of orbital as read in
            from FATBAND file"}. If the band structure is not spin polarized, we only store one data set under Spin.up.
        structure (Structure): Structure read in from vasprun.xml.
    """

    def __init__(self, filenames=".", vasprun="vasprun.xml", Kpointsfile="KPOINTS"):
        """
        Args:
            filenames (list or string): can be a list of file names or a path to a folder from which all
                "FATBAND_*" files will be read
            vasprun: corresponding vasprun file
            Kpointsfile: KPOINTS file for bandstructure calculation, typically "KPOINTS".
        """
        warnings.warn("Make sure all relevant FATBAND files were generated and read in!")
        warnings.warn("Use Lobster 3.2.0 or newer for fatband calculations!")

        vasp_run = Vasprun(
            filename=vasprun,
            ionic_step_skip=None,
            ionic_step_offset=0,
            parse_dos=True,
            parse_eigen=False,
            parse_projected_eigen=False,
            parse_potcar_file=False,
            occu_tol=1e-8,
            exception_on_bad_xml=True,
        )
        self.structure = vasp_run.final_structure
        self.lattice = self.structure.lattice.reciprocal_lattice
        self.efermi = vasp_run.efermi
        kpoints_object = Kpoints.from_file(Kpointsfile)

        atomtype = []
        atomnames = []
        orbital_names = []

        if not isinstance(filenames, list) or filenames is None:
            filenames_new = []
            if filenames is None:
                filenames = "."
            for file in os.listdir(filenames):
                if fnmatch.fnmatch(file, "FATBAND_*.lobster"):
                    filenames_new.append(os.path.join(filenames, file))
            filenames = filenames_new
        if len(filenames) == 0:
            raise ValueError("No FATBAND files in folder or given")
        for filename in filenames:
            with zopen(filename, mode="rt") as file:
                contents = file.read().split("\n")

            atomnames.append(os.path.split(filename)[1].split("_")[1].capitalize())
            parameters = contents[0].split()
            atomtype.append(re.split(r"[0-9]+", parameters[3])[0].capitalize())
            orbital_names.append(parameters[4])

        # get atomtype orbital dict
        atom_orbital_dict = {}
        for iatom, atom in enumerate(atomnames):
            if atom not in atom_orbital_dict:
                atom_orbital_dict[atom] = []
            atom_orbital_dict[atom].append(orbital_names[iatom])
        # test if there are the same orbitals twice or if two different formats were used or if all necessary orbitals
        # are there
        for items in atom_orbital_dict.values():
            if len(set(items)) != len(items):
                raise ValueError("The are two FATBAND files for the same atom and orbital. The program will stop.")
            split = []
            for item in items:
                split.append(item.split("_")[0])
            for number in collections.Counter(split).values():
                if number not in (1, 3, 5, 7):
                    raise ValueError(
                        "Make sure all relevant orbitals were generated and that no duplicates (2p and 2p_x) are "
                        "present"
                    )

        kpoints_array = []
        for ifilename, filename in enumerate(filenames):
            with zopen(filename, mode="rt") as file:
                contents = file.read().split("\n")

            if ifilename == 0:
                self.nbands = int(parameters[6])
                self.number_kpts = kpoints_object.num_kpts - int(contents[1].split()[2]) + 1

            if len(contents[1:]) == self.nbands + 2:
                self.is_spinpolarized = False
            elif len(contents[1:]) == self.nbands * 2 + 2:
                self.is_spinpolarized = True
            else:
                linenumbers = []
                for iline, line in enumerate(contents[1 : self.nbands * 2 + 4]):
                    if line.split()[0] == "#":
                        linenumbers.append(iline)

                if ifilename == 0:
                    self.is_spinpolarized = len(linenumbers) == 2

            if ifilename == 0:
                eigenvals = {}
                eigenvals[Spin.up] = [
                    [collections.defaultdict(float) for i in range(self.number_kpts)] for j in range(self.nbands)
                ]
                if self.is_spinpolarized:
                    eigenvals[Spin.down] = [
                        [collections.defaultdict(float) for i in range(self.number_kpts)] for j in range(self.nbands)
                    ]

                p_eigenvals = {}
                p_eigenvals[Spin.up] = [
                    [
                        {
                            str(e): {str(orb): collections.defaultdict(float) for orb in atom_orbital_dict[e]}
                            for e in atomnames
                        }
                        for i in range(self.number_kpts)
                    ]
                    for j in range(self.nbands)
                ]

                if self.is_spinpolarized:
                    p_eigenvals[Spin.down] = [
                        [
                            {
                                str(e): {str(orb): collections.defaultdict(float) for orb in atom_orbital_dict[e]}
                                for e in atomnames
                            }
                            for i in range(self.number_kpts)
                        ]
                        for j in range(self.nbands)
                    ]

            ikpoint = -1
            for line in contents[1:-1]:
                if line.split()[0] == "#":
                    KPOINT = np.array(
                        [
                            float(line.split()[4]),
                            float(line.split()[5]),
                            float(line.split()[6]),
                        ]
                    )
                    if ifilename == 0:
                        kpoints_array.append(KPOINT)

                    linenumber = 0
                    iband = 0
                    ikpoint += 1
                if linenumber == self.nbands:
                    iband = 0
                if line.split()[0] != "#":
                    if linenumber < self.nbands:
                        if ifilename == 0:
                            eigenvals[Spin.up][iband][ikpoint] = float(line.split()[1]) + self.efermi

                        p_eigenvals[Spin.up][iband][ikpoint][atomnames[ifilename]][orbital_names[ifilename]] = float(
                            line.split()[2]
                        )
                    if linenumber >= self.nbands and self.is_spinpolarized:
                        if ifilename == 0:
                            eigenvals[Spin.down][iband][ikpoint] = float(line.split()[1]) + self.efermi
                        p_eigenvals[Spin.down][iband][ikpoint][atomnames[ifilename]][orbital_names[ifilename]] = float(
                            line.split()[2]
                        )

                    linenumber += 1
                    iband += 1

        self.kpoints_array = kpoints_array
        self.eigenvals = eigenvals
        self.p_eigenvals = p_eigenvals

        label_dict = {}
        for ilabel, label in enumerate(kpoints_object.labels[-self.number_kpts :], start=0):
            if label is not None:
                label_dict[label] = kpoints_array[ilabel]

        self.label_dict = label_dict

    def get_bandstructure(self):
        """Returns a LobsterBandStructureSymmLine object which can be plotted with a normal BSPlotter."""
        return LobsterBandStructureSymmLine(
            kpoints=self.kpoints_array,
            eigenvals=self.eigenvals,
            lattice=self.lattice,
            efermi=self.efermi,
            labels_dict=self.label_dict,
            structure=self.structure,
            projections=self.p_eigenvals,
        )


class Bandoverlaps:
    """
    Class to read in bandOverlaps.lobster files. These files are not created during every Lobster run.
    Attributes:
        bandoverlapsdict (dict[Spin, Dict[str, Dict[str, Union[float, np.ndarray]]]]): A dictionary
            containing the band overlap data of the form: {spin: {"kpoint as string": {"maxDeviation":
            float that describes the max deviation, "matrix": 2D array of the size number of bands
            times number of bands including the overlap matrices with}}}.
        maxDeviation (list[float]): A list of floats describing the maximal deviation for each problematic kpoint.
    """

    def __init__(self, filename: str = "bandOverlaps.lobster"):
        """
        Args:
            filename: filename of the "bandOverlaps.lobster" file.
        """
        with zopen(filename, mode="rt") as file:
            contents = file.read().split("\n")

        spin_numbers = [0, 1] if contents[0].split()[-1] == "0" else [1, 2]

        self._read(contents, spin_numbers)

    def _read(self, contents: list, spin_numbers: list):
        """
        Will read in all contents of the file

        Args:
            contents: list of strings
            spin_numbers: list of spin numbers depending on `Lobster` version.
        """
        self.bandoverlapsdict: dict[Any, dict] = {}  # Any is spin number 1 or -1
        self.max_deviation = []
        # This has to be done like this because there can be different numbers of problematic k-points per spin
        for line in contents:
            if f"Overlap Matrix (abs) of the orthonormalized projected bands for spin {spin_numbers[0]}" in line:
                spin = Spin.up
            elif f"Overlap Matrix (abs) of the orthonormalized projected bands for spin {spin_numbers[1]}" in line:
                spin = Spin.down
            elif "k-point" in line:
                kpoint = line.split(" ")
                kpoint_array = []
                for kpointel in kpoint:
                    if kpointel not in ["at", "k-point", ""]:
                        kpoint_array.append(str(kpointel))

            elif "maxDeviation" in line:
                if spin not in self.bandoverlapsdict:
                    self.bandoverlapsdict[spin] = {}
                if " ".join(kpoint_array) not in self.bandoverlapsdict[spin]:
                    self.bandoverlapsdict[spin][" ".join(kpoint_array)] = {}
                maxdev = line.split(" ")[2]
                self.bandoverlapsdict[spin][" ".join(kpoint_array)]["maxDeviation"] = float(maxdev)
                self.max_deviation.append(float(maxdev))
                self.bandoverlapsdict[spin][" ".join(kpoint_array)]["matrix"] = []

            else:
                overlaps = []
                for el in line.split(" "):
                    if el not in [""]:
                        overlaps.append(float(el))
                self.bandoverlapsdict[spin][" ".join(kpoint_array)]["matrix"].append(overlaps)

    def has_good_quality_maxDeviation(self, limit_maxDeviation: float = 0.1) -> bool:
        """
        Will check if the maxDeviation from the ideal bandoverlap is smaller or equal to limit_maxDeviation

        Args:
            limit_maxDeviation: limit of the maxDeviation

        Returns:
            Boolean that will give you information about the quality of the projection.
        """
        return all(deviation <= limit_maxDeviation for deviation in self.max_deviation)

    def has_good_quality_check_occupied_bands(
        self,
        number_occ_bands_spin_up: int,
        number_occ_bands_spin_down: int | None = None,
        spin_polarized: bool = False,
        limit_deviation: float = 0.1,
    ) -> bool:
        """
        Will check if the deviation from the ideal bandoverlap of all occupied bands
        is smaller or equal to limit_deviation.

        Args:
            number_occ_bands_spin_up (int): number of occupied bands of spin up
            number_occ_bands_spin_down (int): number of occupied bands of spin down
            spin_polarized (bool): If True, then it was a spin polarized calculation
            limit_deviation (float): limit of the maxDeviation

        Returns:
            Boolean that will give you information about the quality of the projection
        """
        for matrix in self.bandoverlapsdict[Spin.up].values():
            for iband1, band1 in enumerate(matrix["matrix"]):
                for iband2, band2 in enumerate(band1):
                    if iband1 < number_occ_bands_spin_up and iband2 < number_occ_bands_spin_up:
                        if iband1 == iband2:
                            if abs(band2 - 1.0) > limit_deviation:
                                return False
                        elif band2 > limit_deviation:
                            return False

        if spin_polarized:
            for matrix in self.bandoverlapsdict[Spin.down].values():
                for iband1, band1 in enumerate(matrix["matrix"]):
                    for iband2, band2 in enumerate(band1):
                        if number_occ_bands_spin_down is not None:
                            if iband1 < number_occ_bands_spin_down and iband2 < number_occ_bands_spin_down:
                                if iband1 == iband2:
                                    if abs(band2 - 1.0) > limit_deviation:
                                        return False
                                elif band2 > limit_deviation:
                                    return False
                        else:
                            ValueError("number_occ_bands_spin_down has to be specified")
        return True


class Grosspop:
    """
    Class to read in GROSSPOP.lobster files.

    Attributes:
        list_dict_grosspop (list[dict[str, str| dict[str, str]]]): List of dictionaries
            including all information about the grosspopulations. Each dictionary contains the following keys:
            - 'element': The element symbol of the atom.
            - 'Mulliken GP': A dictionary of Mulliken gross populations, where the keys are the orbital labels and the
                values are the corresponding gross populations as strings.
            - 'Loewdin GP': A dictionary of Loewdin gross populations, where the keys are the orbital labels and the
                values are the corresponding gross populations as strings.
            The 0th entry of the list refers to the first atom in GROSSPOP.lobster and so on.
    """

    def __init__(self, filename: str = "GROSSPOP.lobster"):
        """
        Args:
            filename: filename of the "GROSSPOP.lobster" file.
        """
        # opens file
        with zopen(filename, mode="rt") as file:
            contents = file.read().split("\n")

        self.list_dict_grosspop = []
        # transfers content of file to list of dict
        for line in contents[3:]:
            cleanline = [i for i in line.split(" ") if i != ""]
            if len(cleanline) == 5:
                small_dict = {}
                small_dict["element"] = cleanline[1]
                small_dict["Mulliken GP"] = {}
                small_dict["Loewdin GP"] = {}
                small_dict["Mulliken GP"][cleanline[2]] = float(cleanline[3])
                small_dict["Loewdin GP"][cleanline[2]] = float(cleanline[4])
            elif len(cleanline) > 0:
                small_dict["Mulliken GP"][cleanline[0]] = float(cleanline[1])
                small_dict["Loewdin GP"][cleanline[0]] = float(cleanline[2])
                if "total" in cleanline[0]:
                    self.list_dict_grosspop.append(small_dict)

    def get_structure_with_total_grosspop(self, structure_filename: str) -> Structure:
        """
        Get a Structure with Mulliken and Loewdin total grosspopulations as site properties

        Args:
            structure_filename (str): filename of POSCAR

        Returns:
            Structure Object with Mulliken and Loewdin total grosspopulations as site properties.
        """
        struct = Structure.from_file(structure_filename)
        site_properties: dict[str, Any] = {}
        mullikengp = []
        loewdingp = []
        for grosspop in self.list_dict_grosspop:
            mullikengp.append(grosspop["Mulliken GP"]["total"])
            loewdingp.append(grosspop["Loewdin GP"]["total"])

        site_properties = {
            "Total Mulliken GP": mullikengp,
            "Total Loewdin GP": loewdingp,
        }
        return struct.copy(site_properties=site_properties)


class Wavefunction:
    """
    Class to read in wave function files from Lobster and transfer them into an object of the type VolumetricData.

    Attributes:
        grid (tuple[int, int, int]): Grid for the wave function [Nx+1,Ny+1,Nz+1].
        points (list[Tuple[float, float, float]]): List of points.
        real (list[float]): List of real part of wave function.
        imaginary (list[float]): List of imaginary part of wave function.
        distance (list[float]): List of distance to first point in wave function file.
    """

    def __init__(self, filename, structure):
        """
        Args:
            filename: filename of wavecar file from Lobster
            structure: Structure object (e.g., created by Structure.from_file("")).
        """
        self.filename = filename
        self.structure = structure
        self.grid, self.points, self.real, self.imaginary, self.distance = Wavefunction._parse_file(filename)

    @staticmethod
    def _parse_file(filename):
        with zopen(filename, mode="rt") as file:
            contents = file.read().split("\n")
        points = []
        distance = []
        real = []
        imaginary = []
        splitline = contents[0].split()
        grid = [int(splitline[7]), int(splitline[8]), int(splitline[9])]
        for line in contents[1:]:
            splitline = line.split()
            if len(splitline) >= 6:
                points.append([float(splitline[0]), float(splitline[1]), float(splitline[2])])
                distance.append(float(splitline[3]))
                real.append(float(splitline[4]))
                imaginary.append(float(splitline[5]))

        if not len(real) == grid[0] * grid[1] * grid[2]:
            raise ValueError("Something went wrong while reading the file")
        if not len(imaginary) == grid[0] * grid[1] * grid[2]:
            raise ValueError("Something went wrong while reading the file")
        return grid, points, real, imaginary, distance

    def set_volumetric_data(self, grid, structure):
        """
        Will create the VolumetricData Objects.

        Args:
            grid: grid on which wavefunction was calculated, e.g. [1,2,2]
            structure: Structure object
        """
        Nx = grid[0] - 1
        Ny = grid[1] - 1
        Nz = grid[2] - 1
        a = structure.lattice.matrix[0]
        b = structure.lattice.matrix[1]
        c = structure.lattice.matrix[2]
        new_x = []
        new_y = []
        new_z = []
        new_real = []
        new_imaginary = []
        new_density = []

        runner = 0
        for x in range(Nx + 1):
            for y in range(Ny + 1):
                for z in range(Nz + 1):
                    x_here = x / float(Nx) * a[0] + y / float(Ny) * b[0] + z / float(Nz) * c[0]
                    y_here = x / float(Nx) * a[1] + y / float(Ny) * b[1] + z / float(Nz) * c[1]
                    z_here = x / float(Nx) * a[2] + y / float(Ny) * b[2] + z / float(Nz) * c[2]

                    if x != Nx and y != Ny and z != Nz:
                        if (
                            not np.isclose(self.points[runner][0], x_here, 1e-3)
                            and not np.isclose(self.points[runner][1], y_here, 1e-3)
                            and not np.isclose(self.points[runner][2], z_here, 1e-3)
                        ):
                            raise ValueError(
                                "The provided wavefunction from Lobster does not contain all relevant"
                                " points. "
                                "Please use a line similar to: printLCAORealSpaceWavefunction kpoint 1 "
                                "coordinates 0.0 0.0 0.0 coordinates 1.0 1.0 1.0 box bandlist 1 "
                            )

                        new_x.append(x_here)
                        new_y.append(y_here)
                        new_z.append(z_here)

                        new_real.append(self.real[runner])
                        new_imaginary.append(self.imaginary[runner])
                        new_density.append(self.real[runner] ** 2 + self.imaginary[runner] ** 2)

                    runner += 1

        self.final_real = np.reshape(new_real, [Nx, Ny, Nz])
        self.final_imaginary = np.reshape(new_imaginary, [Nx, Ny, Nz])
        self.final_density = np.reshape(new_density, [Nx, Ny, Nz])

        self.volumetricdata_real = VolumetricData(structure, {"total": self.final_real})
        self.volumetricdata_imaginary = VolumetricData(structure, {"total": self.final_imaginary})
        self.volumetricdata_density = VolumetricData(structure, {"total": self.final_density})

    def get_volumetricdata_real(self):
        """
        Will return a VolumetricData object including the real part of the wave function.

        Returns:
            VolumetricData
        """
        if not hasattr(self, "volumetricdata_real"):
            self.set_volumetric_data(self.grid, self.structure)
        return self.volumetricdata_real

    def get_volumetricdata_imaginary(self):
        """
        Will return a VolumetricData object including the imaginary part of the wave function.

        Returns:
            VolumetricData
        """
        if not hasattr(self, "volumetricdata_imaginary"):
            self.set_volumetric_data(self.grid, self.structure)
        return self.volumetricdata_imaginary

    def get_volumetricdata_density(self):
        """
        Will return a VolumetricData object including the imaginary part of the wave function.

        Returns:
            VolumetricData
        """
        if not hasattr(self, "volumetricdata_density"):
            self.set_volumetric_data(self.grid, self.structure)
        return self.volumetricdata_density

    def write_file(self, filename="WAVECAR.vasp", part="real"):
        """
        Will save the wavefunction in a file format that can be read by VESTA
        This will only work if the wavefunction from lobster was constructed with:
        "printLCAORealSpaceWavefunction kpoint 1 coordinates 0.0 0.0 0.0 coordinates 1.0 1.0 1.0 box bandlist 1 2 3 4
        5 6 "
        or similar (the whole unit cell has to be covered!).

        Args:
            filename: Filename for the output, e.g., WAVECAR.vasp
            part: which part of the wavefunction will be saved ("real" or "imaginary")
        """
        if not (
            hasattr(self, "volumetricdata_real")
            and hasattr(self, "volumetricdata_imaginary")
            and hasattr(self, "volumetricdata_density")
        ):
            self.set_volumetric_data(self.grid, self.structure)
        if part == "real":
            self.volumetricdata_real.write_file(filename)
        elif part == "imaginary":
            self.volumetricdata_imaginary.write_file(filename)
        elif part == "density":
            self.volumetricdata_density.write_file(filename)
        else:
            raise ValueError('part can be only "real" or "imaginary" or "density"')


# madleung and sitepotential classes
class MadelungEnergies:
    """
    Class to read MadelungEnergies.lobster files generated by LOBSTER.

    Attributes:
        madelungenergies_Mulliken (float): Float that gives the Madelung energy based on the Mulliken approach.
        madelungenergies_Loewdin (float): Float that gives the Madelung energy based on the Loewdin approach.
        ewald_splitting (float): Ewald splitting parameter to compute SitePotentials.
    """

    def __init__(self, filename: str = "MadelungEnergies.lobster"):
        """
        Args:
            filename: filename of the "MadelungEnergies.lobster" file.
        """
        with zopen(filename, mode="rt") as file:
            data = file.read().split("\n")[5]
        if len(data) == 0:
            raise OSError("MadelungEnergies file contains no data.")
        line = data.split()
        self.ewald_splitting = float(line[0])
        self.madelungenergies_Mulliken = float(line[1])
        self.madelungenergies_Loewdin = float(line[2])


class SitePotential:
    """
    Class to read SitePotentials.lobster files generated by LOBSTER.

    Attributes:
        atomlist (list[str]): List of atoms in SitePotentials.lobster.
        types (list[str]): List of types of atoms in SitePotentials.lobster.
        num_atoms (int): Number of atoms in SitePotentials.lobster.
        sitepotentials_Mulliken (list[float]): List of Mulliken potentials of sites in SitePotentials.lobster.
        sitepotentials_Loewdin (list[float]): List of Loewdin potentials of sites in SitePotentials.lobster.
        madelung_Mulliken (float): Float that gives the Madelung energy based on the Mulliken approach.
        madelung_Loewdin (float): Float that gives the Madelung energy based on the Loewdin approach.
        ewald_splitting (float): Ewald Splitting parameter to compute SitePotentials.
    """

    def __init__(self, filename: str = "SitePotentials.lobster"):
        """
        Args:
            filename: filename for the SitePotentials file, typically "SitePotentials.lobster".
        """
        # site_potentials
        with zopen(filename, mode="rt") as file:
            data = file.read().split("\n")
        if len(data) == 0:
            raise OSError("SitePotentials file contains no data.")

        self.ewald_splitting = float(data[0].split()[9])

        data = data[5:-1]
        self.num_atoms = len(data) - 2
        self.atomlist: list[str] = []
        self.types: list[str] = []
        self.sitepotentials_Mulliken: list[float] = []
        self.sitepotentials_Loewdin: list[float] = []
        for atom in range(self.num_atoms):
            line = data[atom].split()
            self.atomlist.append(line[1] + str(line[0]))
            self.types.append(line[1])
            self.sitepotentials_Mulliken.append(float(line[2]))
            self.sitepotentials_Loewdin.append(float(line[3]))

        self.madelungenergies_Mulliken = float(data[self.num_atoms + 1].split()[3])
        self.madelungenergies_Loewdin = float(data[self.num_atoms + 1].split()[4])

    def get_structure_with_site_potentials(self, structure_filename):
        """
        Get a Structure with Mulliken and Loewdin charges as site properties

        Args:
            structure_filename: filename of POSCAR

        Returns:
            Structure Object with Mulliken and Loewdin charges as site properties.
        """
        struct = Structure.from_file(structure_filename)
        Mulliken = self.sitepotentials_Mulliken
        Loewdin = self.sitepotentials_Loewdin
        site_properties = {
            "Mulliken Site Potentials (eV)": Mulliken,
            "Loewdin Site Potentials (eV)": Loewdin,
        }
        return struct.copy(site_properties=site_properties)


def get_orb_from_str(orbs):
    """

    Args:
        orbs: list of two str, e.g. ["2p_x", "3s"].

    Returns:
        list of tw Orbital objects
    """
    # TODO: also useful for plotting of DOS
    orb_labs = [
        "s",
        "p_y",
        "p_z",
        "p_x",
        "d_xy",
        "d_yz",
        "d_z^2",
        "d_xz",
        "d_x^2-y^2",
        "f_y(3x^2-y^2)",
        "f_xyz",
        "f_yz^2",
        "f_z^3",
        "f_xz^2",
        "f_z(x^2-y^2)",
        "f_x(x^2-3y^2)",
    ]
    orbitals = [(int(orb[0]), Orbital(orb_labs.index(orb[1:]))) for orb in orbs]
    orb_label = f"{orbitals[0][0]}{orbitals[0][1].name}-{orbitals[1][0]}{orbitals[1][1].name}"  # type: ignore
    return orb_label, orbitals


class LobsterMatrices:
    """
    Class to read Matrices file generated by LOBSTER (e.g. hamiltonMatrices.lobster).

    Attributes:
        for filename == "hamiltonMatrices.lobster"
        onsite_energies (list[np.arrays]): List real part of onsite energies from the matrices each k-point.
        average_onsite_energies (dict): dict with average onsite elements energies for all k-points with keys as
                                        basis used in the LOBSTER computation (uses only real part of matrix).
        hamilton_matrices (dict[np.arrays]) : dict with the complex hamilton matrix
                                        at each k-point with k-point and spin as keys

        for filename == "coefficientMatrices.lobster"

        onsite_coefficients (list[np.arrays]): List real part of onsite coefficients from the matrices each k-point.
        average_onsite_coefficient (dict): dict with average onsite elements coefficients for all k-points with keys as
                                        basis used in the LOBSTER computation (uses only real part of matrix).
        coefficient_matrices (dict[np.arrays]) : dict with the coefficients matrix
                                        at each k-point with k-point and spin as keys

        for filename == "transferMatrices.lobster"

        onsite_transfer (list[np.arrays]): List real part of onsite transfer coefficients from the matrices at each
                                        k-point.
        average_onsite_transfer (dict): dict with average onsite elements transfer coefficients for all k-points with
                                        keys as basis used in the LOBSTER computation (uses only real part of matrix).
        transfer_matrices (dict[np.arrays]) : dict with the coefficients matrix at
                                        each k-point with k-point and spin as keys

        for filename == "overlapMatrices.lobster"

        onsite_overlaps (list[np.arrays]): List real part of onsite overlaps from the matrices each k-point.
        average_onsite_overlaps (dict): dict with average onsite elements overlaps for all k-points with keys as
                                        basis used in the LOBSTER computation (uses only real part of matrix).
        overlap_matrices (dict[np.arrays]) : dict with the overlap matrix at
                                        each k-point with k-point as keys
    """

    def __init__(self, e_fermi=None, filename: str = "hamiltonMatrices.lobster"):
        """
        Args:
            filename: filename for the hamiltonMatrices file, typically "hamiltonMatrices.lobster".
            e_fermi: fermi level in eV for the structure only
            relevant if input file contains hamilton matrices data
        """

        self._filename = filename
        # hamiltonMatrices
        with zopen(self._filename, mode="rt") as file:
            file_data = file.readlines()
        if len(file_data) == 0:
            raise OSError("Please check provided input file, it seems to be empty")

        pattern_coeff_hamil_trans = r"(\d+)\s+kpoint\s+(\d+)"  # regex pattern to extract spin and k-point number
        pattern_overlap = r"kpoint\s+(\d+)"  # regex pattern to extract k-point number

        if "hamilton" in self._filename:
            if e_fermi is None:
                raise ValueError("Please provide the fermi energy in eV ")
            self.onsite_energies, self.average_onsite_energies, self.hamilton_matrices = self._parse_matrix(
                file_data=file_data, pattern=pattern_coeff_hamil_trans, e_fermi=e_fermi
            )

        elif "coefficient" in self._filename:
            self.onsite_coefficients, self.average_onsite_coefficient, self.coefficient_matrices = self._parse_matrix(
                file_data=file_data, pattern=pattern_coeff_hamil_trans, e_fermi=0
            )

        elif "transfer" in self._filename:
            self.onsite_transfer, self.average_onsite_transfer, self.transfer_matrices = self._parse_matrix(
                file_data=file_data, pattern=pattern_coeff_hamil_trans, e_fermi=0
            )

        elif "overlap" in self._filename:
            self.onsite_overlaps, self.average_onsite_overlaps, self.overlap_matrices = self._parse_matrix(
                file_data=file_data, pattern=pattern_overlap, e_fermi=0
            )

    @staticmethod
    def _parse_matrix(file_data, pattern, e_fermi):
        complex_matrices = {}
        matrix_diagonal_values = []
        start_inxs_real = []
        end_inxs_real = []
        start_inxs_imag = []
        end_inxs_imag = []
        # get indices of real and imaginary part of matrix for each k point
        for i, line in enumerate(file_data):
            line = line.strip()
            if "Real parts" in line:
                start_inxs_real.append(i + 1)
                if i == 1:  # ignore the first occurrence as files start with real matrices
                    pass
                else:
                    end_inxs_imag.append(i - 1)
                matches = re.search(pattern, file_data[i - 1])
                if matches and len(matches.groups()) == 2:
                    k_point = matches.group(2)
                    complex_matrices[k_point] = {}
            if "Imag parts" in line:
                end_inxs_real.append(i - 1)
                start_inxs_imag.append(i + 1)
            # explicitly add the last line as files end with imaginary matrix
            if i == len(file_data) - 1:
                end_inxs_imag.append(len(file_data))

        # extract matrix data and store diagonal elements
        for start_inx_real, end_inx_real, start_inx_imag, end_inx_imag in zip(
            start_inxs_real, end_inxs_real, start_inxs_imag, end_inxs_imag
        ):
            # matrix with text headers
            matrix_real = file_data[start_inx_real:end_inx_real]
            matrix_imag = file_data[start_inx_imag:end_inx_imag]

            # extract only numerical data and convert to numpy arrays
            matrix_array_real = np.array([line.split()[1:] for line in matrix_real[1:]], dtype=float)
            matrix_array_imag = np.array([line.split()[1:] for line in matrix_imag[1:]], dtype=float)

            # combine real and imaginary parts to create a complex matrix
            comp_matrix = matrix_array_real + 1j * matrix_array_imag

            matches = re.search(pattern, file_data[start_inx_real - 2])
            if matches and len(matches.groups()) == 2:
                spin = Spin.up if matches.group(1) == "1" else Spin.down
                k_point = matches.group(2)
                complex_matrices[k_point].update({spin: comp_matrix})
            elif matches and len(matches.groups()) == 1:
                k_point = matches.group(1)
                complex_matrices.update({k_point: comp_matrix})
            matrix_diagonal_values.append(comp_matrix.real.diagonal() - e_fermi)

        # extract elements basis functions as list
        elements_basis_functions = [
            line.split()[:1][0] for line in matrix_real if line.split()[:1][0] != "basisfunction"
        ]

        # get average row-wise
        average_matrix_diagonal_values = np.array(matrix_diagonal_values, dtype=float).mean(axis=0)

        # get a dict with basis functions as keys and average values as values
        average_average_matrix_diag_dict = dict(zip(elements_basis_functions, average_matrix_diagonal_values))

        return matrix_diagonal_values, average_average_matrix_diag_dict, complex_matrices
