Setting up a multi table synthesiser

When dealing with a multi table schema the first part that needs to be updated is the Data Schema. In this example we will demonstrate how to set up a multi table schema on the movielens dataset, adding in tables step by step.

Setup

In order to run this tutorial, you must download the MovieLens-100K dataset.

You can use the following snippet to set up MovieLens-100K locally by downloading and formatting the dataset. The resultant .csv files will be located in ./movielens/.

Copy the Python code below into a file titled download.py and then run python3 download.py.

import json
from pathlib import Path
from typing import Dict

import pandas as pd

BASE_URL = "https://files.grouplens.org/datasets/movielens/ml-100k/"
DATA_FOLDER = Path(__file__).absolute().parent / "movielens"


def get_ratings():
    ratings_headers = ["user_id", "movie_id", "rating", "timestamp"]
    df_ratings = pd.read_csv(BASE_URL + "u.data", sep="\t", names=ratings_headers)

    # normalise ratings and covert timestamp to datetime
    df_ratings["timestamp"] = pd.to_datetime(df_ratings["timestamp"], unit="s")
    df_ratings["rating"] = (df_ratings["rating"] - 1) / 4

    return df_ratings


def get_movies():
    movie_headers = [
        "movie_id",
        "movie_title",
        "release_date",
        "video_release_date",
        "IMDb_URL",
        "unknown",
        "Action",
        "Adventure",
        "Animation",
        "Childrens",
        "Comedy",
        "Crime",
        "Documentary",
        "Drama",
        "Fantasy",
        "Film_Noir",
        "Horror",
        "Musical",
        "Mystery",
        "Romance",
        "Sci_Fi",
        "Thriller",
        "War",
        "Western",
    ]
    df_movie = pd.read_csv(
        BASE_URL + "u.item",
        encoding="latin1",
        sep="|",
        parse_dates=["release_date"],
        names=movie_headers,
    )
    df_movie.drop(columns=["video_release_date", "unknown"], inplace=True)
    df_movie.dropna(axis=0, inplace=True)
    return df_movie


def get_users():
    user_headers = ["user_id", "age", "gender", "occupation", "zip_code"]
    df_user = pd.read_csv(
        BASE_URL + "u.user",
        encoding="latin1",
        sep="|",
        names=user_headers,
    )
    return df_user


def create_directory():
    if not DATA_FOLDER.exists():
        DATA_FOLDER.mkdir()


def main():
    create_directory()

    get_ratings().to_csv(DATA_FOLDER / "ratings.csv", index=False)
    get_movies().to_csv(DATA_FOLDER / "movie.csv", index=False)
    get_users().to_csv(DATA_FOLDER / "user.csv", index=False)


if __name__ == "__main__":
    main()

The first table - movie

movie_table = TabularTable(
    name="movie",
    dtypes=[
        IdType(
            col="movie_id",
            settings=NumericalIdSettings(length=5),
            primary_key=True,
        ),
        DatetimeType(col="release_date", format="%Y-%m-%d"),
        CategoryType(col="Action"),
        CategoryType(col="Adventure"),
        CategoryType(col="Animation"),
        CategoryType(col="Childrens"),
        CategoryType(col="Comedy"),
        CategoryType(col="Crime"),
        CategoryType(col="Documentary"),
        CategoryType(col="Drama"),
        CategoryType(col="Fantasy"),
        CategoryType(col="Film_Noir"),
        CategoryType(col="Horror"),
        CategoryType(col="Musical"),
        CategoryType(col="Mystery"),
        CategoryType(col="Romance"),
        CategoryType(col="Sci_Fi"),
        CategoryType(col="Thriller"),
        CategoryType(col="War"),
        CategoryType(col="Western"),
    ],
)

Key Notes

  • In this case we have set up our movie table as a TabularTable , this is the standard table you will use where each row is an independent item.
  • We have set it up with a basic primary key column on movie_id which will have purely generated and non statistical numerical IDs. All available ID settings can be found here.
  • All data types are categorical, as these are one hot encoded columns except release_date column, which is a Datetime and we have had to specify the strftime format for parsing dates.

The second table - user

user_table = TabularTable(
    name="user",
    dtypes=[
        IdType(
            col="user_id",
            settings=NumericalIdSettings(length=5),
            primary_key=True,
        ),
        IntType(col="age"),
        CategoryType(col="gender"),
        CategoryType(col="occupation"),
        PostcodeType(
            col="zip_code",
            entity_id=1,
        ),
    ],
)

Key Notes

  • Again we have defined it using a TabularTable
  • A simple numerical ID has been used for the user_id column.
  • In this table we have a zip_code column for which we use the LocationType. This allows us to create statistically generated zip/postcodes based on geographic location.

Linking together with the final table - ratings

This is where we create links between the tables using the ForeignKeyType

ratings_table = TabularTable(
    name="ratings_table",
    dtypes=[
        ForeignKeyType(col="movie_id", primary_key=True, ref=("movie", "movie_id")),
        ForeignKeyType(col="user_id", primary_key=True, ref=("user", "user_id")),
        FloatType(col="rating"),
        DatetimeType(col="timestamp", format="%Y-%m-%d %H:%M:%S"),
    ],
)

Key Notes

  • This table contains a composite key - demonstrated by the primary_key=True flag being set on both movie_id and user_id.
  • Both parts of the composite key are also foreign keys referencing the previous tables.
  • There is a one to many relationship between users and ratings.
  • There is also a one to many relationship between movies and ratings.

Let's ‘fix’ the movie table.

Say you don't care about protecting privacy in the movie table. It contains no customer information, just publicly available movie data, so we decide we still want to use the information in the training process, but we don't wish to create synthetic movies.

We still wish to protect the privacy of our customers. Since both the user and ratings table contains information about users or their interactions. Therefore we will leave them as is.

We will make use of a ReferenceTable to fix the movie table.

movie_table = ReferenceTable(
    name="movie",
    dtypes=[
        RealType(
            col="movie_id",
            primary_key=True,
        ),
        DatetimeType(col="release_date", format="%Y-%m-%d"),
        CategoryType(col="Action"),
        CategoryType(col="Adventure"),
        CategoryType(col="Animation"),
        CategoryType(col="Childrens"),
        CategoryType(col="Comedy"),
        CategoryType(col="Crime"),
        CategoryType(col="Documentary"),
        CategoryType(col="Drama"),
        CategoryType(col="Fantasy"),
        CategoryType(col="Film_Noir"),
        CategoryType(col="Horror"),
        CategoryType(col="Musical"),
        CategoryType(col="Mystery"),
        CategoryType(col="Romance"),
        CategoryType(col="Sci_Fi"),
        CategoryType(col="Thriller"),
        CategoryType(col="War"),
        CategoryType(col="Western"),
    ],
)

Key Notes

  • We have changed TabularTable to ReferenceTable to ‘fix’ the movie table.
  • We have changed movie_id column from IdType to RealType. This means that real IDs from the training data will be re-used. This type can only be used inside a ReferenceTable.

The full data schema

data_schema = DataSchema(
    tables=[
        user_table,
        movie_table,
        ratings_table,
    ]
)

Now all we have left to do is point to our new data locations.

Data input

In this example we will stick to .csv files for input.

USER = "user"
MOVIE = "movie"
RATINGS = "ratings"

data_path = "movielens/"

data_input = [
    DataLocationInput(name=USER, location=data_path + f"{USER}.csv"),
    DataLocationInput(name=MOVIE, location=data_path + f"{MOVIE}.csv"),
    DataLocationInput(name=RATINGS, location=data_path + f"{RATINGS}.csv"),
]

Key Notes

  • We've defined 3 helper variables USER, MOVIE and RATINGS to help make sure we reference our tables correctly.

Finally, putting this all together in a single script...

Full script

from os.path import exists

from hazy_client2 import SynthDocker
from hazy_configurator import (
    CategoryType,
    DataLocationInput,
    DataLocationOutput,
    DataSchema,
    DatetimeType,
    FloatType,
    ForeignKeyType,
    GenerationConfig,
    GeoLocales,
    IdType,
    IntType,
    LocationEntity,
    NumericalIdSettings,
    PostcodeType,
    RealType,
    ReferenceTable,
    TabularTable,
    TrainingConfig,
)

# Replace me with the Hazy supplied docker image!
DOCKER_IMAGE = "docker_image:tag"
synth = SynthDocker(image=DOCKER_IMAGE)

USER = "user"
MOVIE = "movie"
RATINGS = "ratings"

data_path = "movielens/"

training_config = TrainingConfig(
    model_output="movielens.hmf",
    data_schema=DataSchema(
        tables=[
            ReferenceTable(
                name=MOVIE,
                dtypes=[
                    RealType(
                        col="movie_id",
                        primary_key=True,
                    ),
                    DatetimeType(col="release_date", format="%Y-%m-%d"),
                    CategoryType(col="Action"),
                    CategoryType(col="Adventure"),
                    CategoryType(col="Animation"),
                    CategoryType(col="Childrens"),
                    CategoryType(col="Comedy"),
                    CategoryType(col="Crime"),
                    CategoryType(col="Documentary"),
                    CategoryType(col="Drama"),
                    CategoryType(col="Fantasy"),
                    CategoryType(col="Film_Noir"),
                    CategoryType(col="Horror"),
                    CategoryType(col="Musical"),
                    CategoryType(col="Mystery"),
                    CategoryType(col="Romance"),
                    CategoryType(col="Sci_Fi"),
                    CategoryType(col="Thriller"),
                    CategoryType(col="War"),
                    CategoryType(col="Western"),
                ],
            ),
            TabularTable(
                name=USER,
                dtypes=[
                    IdType(
                        col="user_id",
                        settings=NumericalIdSettings(length=5),
                        primary_key=True,
                    ),
                    IntType(col="age"),
                    CategoryType(col="gender"),
                    CategoryType(col="occupation"),
                    PostcodeType(
                        col="zip_code",
                        entity_id=1,
                    ),
                ],
            ),
            TabularTable(
                name=RATINGS,
                dtypes=[
                    ForeignKeyType(
                        col="movie_id", primary_key=True, ref=(MOVIE, "movie_id")
                    ),
                    ForeignKeyType(
                        col="user_id", primary_key=True, ref=(USER, "user_id")
                    ),
                    FloatType(col="rating"),
                    DatetimeType(col="timestamp", format="%Y-%m-%d %H:%M:%S"),
                ],
            ),
        ],
        entities=[
            LocationEntity(entity_id=1, num_clusters=1000, locales=[GeoLocales.en_US])
        ],
    ),
    data_input=[
        DataLocationInput(name=USER, location=data_path + f"{USER}.csv"),
        DataLocationInput(name=MOVIE, location=data_path + f"{MOVIE}.csv"),
        DataLocationInput(name=RATINGS, location=data_path + f"{RATINGS}.csv"),
    ],
)

generation_config = GenerationConfig(
    model="movielens.hmf",
    data_output=[
        DataLocationOutput(name=USER, location=f"output/{USER}.csv"),
        DataLocationOutput(name=MOVIE, location=f"output/{MOVIE}.csv"),
        DataLocationOutput(name=RATINGS, location=f"output/{RATINGS}.csv"),
    ],
)

synth.train(cfg=training_config)
assert exists("movielens.hmf"), "Synthesiser should generate .hmf model file!"

synth.generate(cfg=generation_config)
assert exists("output/user.csv"), "Synthesiser should generate synthetic data!"
assert exists("output/movie.csv"), "Synthesiser should generate synthetic data!"
assert exists("output/ratings.csv"), "Synthesiser should generate synthetic data!"