Source code for molecule_benchmarks.dataset

import csv
import pickle
import random
from pathlib import Path
from typing import Optional, TypeVar

from rdkit import Chem

from molecule_benchmarks.utils import download_with_cache, mapper


class SmilesDataset:
    @classmethod
    def load_qm9_dataset(
        cls,
        max_train_samples: Optional[int] = None,
        cache_dir: str | Path | None = None,
    ):
        """Load the QM9 dataset."""
        cache_path = None
        if cache_dir is not None:
            cache_path = Path(cache_dir) / f"qm9_num_samples={max_train_samples}.pkl"
            if cache_path.exists():
                with open(cache_path, "rb") as f:
                    return pickle.load(f)
        ds_url = (
            "https://huggingface.co/datasets/n0w0f/qm9-csv/resolve/main/qm9_dataset.csv"
        )
        content = download_with_cache(ds_url, cache_dir=cache_dir)
        smiles = []

        reader = csv.DictReader(content.splitlines())

        for row in reader:
            smiles.append(row["smiles"])  # Assuming the column name is "smiles"
        # smiles = canonicalize_smiles_list(smiles)
        random.seed(42)  # For reproducibility
        random.shuffle(smiles)
        fraction = 1.0
        num_train_samples = int(len(smiles) * 0.8)
        num_val_samples = len(smiles) - num_train_samples
        if max_train_samples is not None:
            max_train_samples = min(max_train_samples, int(len(smiles) * 0.8))
            fraction = max_train_samples / num_train_samples
            num_train_samples = int(max_train_samples)
            num_val_samples = int(num_val_samples * fraction)

        train_smiles = smiles[:num_train_samples]
        validation_smiles = smiles[
            num_train_samples : num_train_samples + num_val_samples
        ]
        ds = cls(train_smiles=train_smiles, validation_smiles=validation_smiles, name=f"qm9_num_train_samples={max_train_samples}")
        if cache_path is not None:
            cache_path.parent.mkdir(parents=True, exist_ok=True)
            with open(cache_path, "wb") as f:
                pickle.dump(ds, f)
        return ds

    @classmethod
    def load_guacamol_dataset(
        cls,
        max_train_samples: Optional[int] = None,
        cache_dir: str | Path | None = None,
    ):
        """Load the Guacamole dataset."""
        cache_path = None
        if cache_dir is not None:
            cache_path = Path(cache_dir) / f"guacamol_num_samples={max_train_samples}.pkl"
            if cache_path.exists():
                with open(cache_path, "rb") as f:
                    return pickle.load(f)
        train_ds_url = "https://ndownloader.figshare.com/files/13612760"
        validation_ds_url = "https://ndownloader.figshare.com/files/13612766"
        # download the dataset into memory
        train_content = download_with_cache(train_ds_url, cache_dir=cache_dir)
        train_smiles = train_content.splitlines()
        random.seed(42)  # For reproducibility
        validation_content = download_with_cache(validation_ds_url, cache_dir=cache_dir)
        validation_smiles = validation_content.splitlines()
        if max_train_samples is not None:
            max_train_samples = min(max_train_samples, len(train_smiles))
            fraction = max_train_samples / len(train_smiles)
            train_smiles = random.sample(train_smiles, max_train_samples)
            validation_smiles = random.sample(
                validation_smiles, int(len(validation_smiles) * fraction)
            )
        ds= cls(train_smiles=train_smiles, validation_smiles=validation_smiles, name=f"guacamol_num_train_samples={max_train_samples}")
         # cache the dataset
        if cache_path is not None:
            cache_path.parent.mkdir(parents=True, exist_ok=True)
            with open(cache_path, "wb") as f:
                pickle.dump(ds, f)
        return ds

    @classmethod
    def load_moses_dataset(
        cls,
        max_train_samples: Optional[int] = None,
        cache_dir: str | Path | None = None,
    ):
        cache_path = None
        """Load the Moses dataset."""
        if cache_dir is not None:
            cache_path = Path(cache_dir) / f"moses_num_samples={max_train_samples}.pkl"
            if cache_path.exists():
                with open(cache_path, "rb") as f:
                    return pickle.load(f)

        def download_smiles(split: str) -> list[str]:
            """Download SMILES from a given URL split."""
            url = f"https://media.githubusercontent.com/media/molecularsets/moses/master/data/{split}.csv"
            content = download_with_cache(url, cache_dir=cache_dir)
            csv_file = content.splitlines()
            reader = csv.DictReader(csv_file)
            smiles = [row["SMILES"] for row in reader]
            return smiles

        train_smiles = download_smiles("train")
        validation_smiles = download_smiles("test")
        random.seed(42)  # For reproducibility
        if max_train_samples is not None:
            max_train_samples = min(max_train_samples, len(train_smiles))
            fraction = max_train_samples / len(train_smiles)
            train_smiles = random.sample(train_smiles, max_train_samples)
            validation_smiles = random.sample(
                validation_smiles, int(len(validation_smiles) * fraction)
            )
        ds = cls(train_smiles=train_smiles, validation_smiles=validation_smiles, name=f"moses_num_train_samples={max_train_samples}")
        if cache_path is not None:
            cache_path.parent.mkdir(parents=True, exist_ok=True)
            with open(cache_path, "wb") as f:
                pickle.dump(ds, f)
        return ds

    @classmethod
    def load_dummy_dataset(cls):
        """Load a dummy dataset for testing purposes."""
        train_smiles = ["C1=CC=CC=C1", "C1=CC=CN=C1", "C1=CC=CO=C1"]
        validation_smiles = ["C1=CC=CC=C1", "C1=CC=CN=C1"]
        return cls(train_smiles=train_smiles, validation_smiles=validation_smiles)

    def __init__(
        self,
        train_smiles: list[str],
        validation_smiles: list[str],
        canonicalize: bool = True,
        name: Optional[str] = None,
    ) -> None:
        self.train_smiles: list[str] = (
            canonicalize_smiles_list(train_smiles) if canonicalize else train_smiles
        )  # type: ignore
        self.validation_smiles: list[str] = (
            canonicalize_smiles_list(validation_smiles)
            if canonicalize
            else validation_smiles
        )  # type: ignore
        self.name = name

    def get_train_smiles(self) -> list[str]:
        """Get the training SMILES."""
        return self.train_smiles

    def get_validation_smiles(self) -> list[str]:
        """Get the validation SMILES."""
        return self.validation_smiles

    def get_train_molecules(self) -> list[Chem.Mol | None]:
        """Get the training molecules."""
        return [Chem.MolFromSmiles(s) for s in self.train_smiles]

    def get_validation_molecules(self) -> list[Chem.Mol | None]:
        """Get the validation molecules."""
        return [Chem.MolFromSmiles(s) for s in self.validation_smiles]

    def __repr__(self) -> str:
        """String representation of the dataset."""
        return (
            f"SmilesDataset(train_size={len(self.train_smiles)}, "
            f"validation_size={len(self.validation_smiles)})"
        )


def _canonicalize_single_smiles(smiles: Optional[str]) -> Optional[str]:
    """Helper function to canonicalize a single SMILES string for multiprocessing."""
    if smiles is not None:
        try:
            return Chem.CanonSmiles(smiles)
        except Exception:
            return smiles
    return None


[docs] def canonicalize_smiles_list( smiles: list[str | None] | list[str], n_jobs: Optional[int] = None, ): """Canonicalize a list of SMILES strings using multiprocessing with progress tracking. Args: smiles (list[str | None] | list[str]): List of SMILES strings to canonicalize. n_jobs (Optional[int]): Number of parallel jobs to run. If None, uses all available cores. Returns: list[str | None]: List of canonicalized SMILES strings, with None for invalid SMILES. """ return mapper(n_jobs=n_jobs, job_name="Canonicalizing SMILES")( _canonicalize_single_smiles, smiles )