Source code for zoo.orca.learn.tf2.estimator

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import itertools
import logging
import pickle

import numpy as np
import ray

from zoo.orca.data.ray_xshards import RayXShards
from zoo.orca.learn.dl_cluster import RayDLCluster
from zoo.orca.learn.tf2.tf_runner import TFRunner
from zoo.orca.learn.ray_estimator import Estimator as OrcaRayEstimator
from zoo.orca.learn.utils import maybe_dataframe_to_xshards, dataframe_to_xshards, \
    convert_predict_xshards_to_dataframe, update_predict_xshards, \
    process_xshards_of_pandas_dataframe
from zoo.orca.data.utils import process_spark_xshards
from zoo.ray import RayContext

logger = logging.getLogger(__name__)


[docs]class Estimator(object):
[docs] @staticmethod def from_keras(*, model_creator, config=None, verbose=False, workers_per_node=1, compile_args_creator=None, backend="tf2", cpu_binding=True, ): """ Create an Estimator for tensorflow 2. :param model_creator: (dict -> Model) This function takes in the `config` dict and returns a compiled TF model. :param config: (dict) configuration passed to 'model_creator', 'data_creator'. Also contains `fit_config`, which is passed into `model.fit(data, **fit_config)` and `evaluate_config` which is passed into `model.evaluate`. :param verbose: (bool) Prints output of one model if true. :param workers_per_node: (Int) worker number on each node. default: 1. :param compile_args_creator: (dict -> dict of loss, optimizer and metrics) Only used when the backend="horovod". This function takes in the `config` dict and returns a dictionary like {"optimizer": tf.keras.optimizers.SGD(lr), "loss": "mean_squared_error", "metrics": ["mean_squared_error"]} :param backend: (string) You can choose "horovod" or "tf2" as backend. Default: `tf2`. :param cpu_binding: (bool) Whether to binds threads to specific CPUs. Default: True """ return TensorFlow2Estimator(model_creator=model_creator, config=config, verbose=verbose, workers_per_node=workers_per_node, backend=backend, compile_args_creator=compile_args_creator, cpu_binding=cpu_binding)
[docs]def make_data_creator(refs): def data_creator(config, batch_size): return refs return data_creator
[docs]def data_length(data): x = data["x"] if isinstance(x, np.ndarray): return x.shape[0] else: return x[0].shape[0]
[docs]class TensorFlow2Estimator(OrcaRayEstimator): def __init__(self, model_creator, compile_args_creator=None, config=None, verbose=False, backend="tf2", workers_per_node=1, cpu_binding=True): self.model_creator = model_creator self.compile_args_creator = compile_args_creator self.config = {} if config is None else config self.verbose = verbose ray_ctx = RayContext.get() if "batch_size" in self.config: raise Exception("Please do not specify batch_size in config. Input batch_size in the" " fit/evaluate function of the estimator instead.") if "inter_op_parallelism" not in self.config: self.config["inter_op_parallelism"] = 1 if "intra_op_parallelism" not in self.config: self.config["intra_op_parallelism"] = ray_ctx.ray_node_cpu_cores // workers_per_node if backend == "horovod": assert compile_args_creator is not None, "compile_args_creator should not be None," \ " when backend is set to horovod" params = { "model_creator": model_creator, "compile_args_creator": compile_args_creator, "config": self.config, "verbose": self.verbose, } if backend == "tf2": cores_per_node = ray_ctx.ray_node_cpu_cores // workers_per_node num_nodes = ray_ctx.num_ray_nodes * workers_per_node self.cluster = RayDLCluster( num_workers=num_nodes, worker_cores=cores_per_node, worker_cls=TFRunner, worker_param=params, cpu_binding=cpu_binding ) self.remote_workers = self.cluster.get_workers() ips = ray.get( [worker.get_node_ip.remote() for worker in self.remote_workers]) ports = ray.get( [worker.find_free_port.remote() for worker in self.remote_workers]) urls = ["{ip}:{port}".format(ip=ips[i], port=ports[i]) for i in range(len(self.remote_workers))] ray.get([worker.setup.remote() for worker in self.remote_workers]) # Get setup tasks in order to throw errors on failure ray.get([ worker.setup_distributed.remote(urls, i, len(self.remote_workers)) for i, worker in enumerate(self.remote_workers)]) elif backend == "horovod": # it is necessary to call self.run first to set horovod environment from zoo.orca.learn.horovod.horovod_ray_runner import HorovodRayRunner horovod_runner = HorovodRayRunner(ray_ctx, worker_cls=TFRunner, worker_param=params, workers_per_node=workers_per_node) horovod_runner.run(lambda: print("worker initialized")) self.remote_workers = horovod_runner.remote_workers ray.get([worker.setup.remote() for worker in self.remote_workers]) ray.get([ worker.setup_horovod.remote() for i, worker in enumerate(self.remote_workers)]) else: raise Exception("Only \"tf2\" and \"horovod\" are legal " "values of backend, but got {}".format(backend)) self.num_workers = len(self.remote_workers)
[docs] def fit(self, data, epochs=1, batch_size=32, verbose=1, callbacks=None, validation_data=None, class_weight=None, steps_per_epoch=None, validation_steps=None, validation_freq=1, data_config=None, feature_cols=None, label_cols=None): """ Train this tensorflow model with train data. :param data: train data. It can be XShards, Spark DataFrame or creator function which returns Iter or DataLoader. If data is XShards, each partition can be a Pandas DataFrame or a dictionary of {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of numpy arrays. :param epochs: Number of epochs to train the model. Default: 1. :param batch_size: Batch size used for training. Default: 32. :param verbose: Prints output of one model if true. :param callbacks: List of Keras compatible callbacks to apply during training. :param validation_data: validation data. Validation data type should be the same as train data. :param class_weight: Optional dictionary mapping class indices (integers) to a weight (float) value, used for weighting the loss function. This can be useful to tell the model to "pay more attention" to samples from an under-represented class. :param steps_per_epoch: Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. If `steps_pre_epoch` is `None`, the epoch will run until the input dataset is exhausted. When passing an infinitely repeating dataset, you must specify the `step_per_epoch` argument. :param validation_steps: Total number of steps (batches of samples) to draw before stopping when performing validation at the end of every epoch. Default: None. :param validation_freq: Only relevant if validation data is provided. Integer of `collections_abc.Container` instance (e.g. list, tuple, etc.). If an integer, specifies how many training epochs to run before a new validation run is performed, e.g. `validation_freq=2` runs validation every 2 epochs. If a Container, specifies the epochs on which to run validation, e.g. `validation_freq=[1, 2, 10]` runs validation at the end of the 1st, 2nd, and 10th epochs. :param data_config: An optional dictionary that can be passed to data creator function. :param feature_cols: Feature column name(s) of data. Only used when data is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None. :param label_cols: Label column name(s) of data. Only used when data is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None. :return: """ params = dict( epochs=epochs, batch_size=batch_size, verbose=verbose, callbacks=callbacks, class_weight=class_weight, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, validation_freq=validation_freq, data_config=data_config ) from zoo.orca.data import SparkXShards data, validation_data = maybe_dataframe_to_xshards(data, validation_data, feature_cols, label_cols, mode="fit", num_workers=self.num_workers) if isinstance(data, SparkXShards): if data._get_class_name() == 'pandas.core.frame.DataFrame': data, validation_data = process_xshards_of_pandas_dataframe(data, feature_cols, label_cols, validation_data, "fit") ray_xshards = process_spark_xshards(data, self.num_workers) if validation_data is None: def transform_func(worker, partition_refs): params["data_creator"] = make_data_creator(partition_refs) return worker.step.remote(**params) worker_stats = ray_xshards.reduce_partitions_for_actors(self.remote_workers, transform_func) else: val_ray_xshards = process_spark_xshards(validation_data, self.num_workers) def zip_func(worker, this_partition_refs, that_partition_refs): params["data_creator"] = make_data_creator(this_partition_refs) params["validation_data_creator"] = \ make_data_creator(that_partition_refs) return worker.step.remote(**params) worker_stats = ray_xshards.zip_reduce_shards_with_actors(val_ray_xshards, self.remote_workers, zip_func) else: params["data_creator"] = data params["validation_data_creator"] = validation_data params_list = [params] * self.num_workers worker_stats = ray.get([self.remote_workers[i].step.remote(**params_list[i]) for i in range(self.num_workers)]) worker_stats = list(itertools.chain.from_iterable(worker_stats)) stats = worker_stats[0].copy() return stats
[docs] def evaluate(self, data, batch_size=32, num_steps=None, verbose=1, sample_weight=None, callbacks=None, data_config=None, feature_cols=None, label_cols=None): """ Evaluates the model on the validation data set. :param data: evaluate data. It can be XShards, Spark DataFrame or creator function which returns Iter or DataLoader. If data is XShards, each partition can be a Pandas DataFrame or a dictionary of {'x': feature, 'y': label}, where feature(label) is a numpy array or a tuple of numpy arrays. :param batch_size: Batch size used for evaluation. Default: 32. :param num_steps: Total number of steps (batches of samples) before declaring the evaluation round finished. Ignored with the default value of `None`. :param verbose: Prints output of one model if true. :param sample_weight: Optional Numpy array of weights for the training samples, used for weighting the loss function. You can either pass a flat (1D) Numpy array with the same length as the input samples (1:1 mapping between weights and samples), or in the case of temporal data, you can pass a 2D array with shape (samples, sequence_length), to apply a different weight to every timestep of every sample. :param callbacks: List of Keras compatible callbacks to apply during evaluation. :param data_config: An optional dictionary that can be passed to data creator function. :param feature_cols: Feature column name(s) of data. Only used when data is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None. :param label_cols: Label column name(s) of data. Only used when data is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None. :return: validation result """ logger.info("Starting validation step.") params = dict( batch_size=batch_size, verbose=verbose, sample_weight=sample_weight, steps=num_steps, callbacks=callbacks, data_config=data_config, ) from zoo.orca.data import SparkXShards data, _ = maybe_dataframe_to_xshards(data, validation_data=None, feature_cols=feature_cols, label_cols=label_cols, mode="evaluate", num_workers=self.num_workers) if isinstance(data, SparkXShards): if data._get_class_name() == 'pandas.core.frame.DataFrame': data = process_xshards_of_pandas_dataframe(data, feature_cols, label_cols) data = data if data.num_partitions() != self.num_workers: data = data.repartition(self.num_workers) ray_xshards = RayXShards.from_spark_xshards(data) def transform_func(worker, partition_refs): params["data_creator"] = make_data_creator(partition_refs) return worker.validate.remote(**params) worker_stats = ray_xshards.reduce_partitions_for_actors(self.remote_workers, transform_func) else: # data_creator functions; should return Iter or DataLoader params["data_creator"] = data params_list = [params] * self.num_workers worker_stats = ray.get([w.validate.remote(**params_list[i]) for i, w in enumerate(self.remote_workers)]) worker_stats = list(itertools.chain.from_iterable(worker_stats)) stats = worker_stats[0].copy() return stats
def _predict_spark_xshards(self, xshards, params): ray_xshards = RayXShards.from_spark_xshards(xshards) def transform_func(worker, shards_ref): params["data_creator"] = make_data_creator(shards_ref) return worker.predict.remote(**params) pred_shards = ray_xshards.transform_shards_with_actors(self.remote_workers, transform_func) spark_xshards = pred_shards.to_spark_xshards() return spark_xshards
[docs] def predict(self, data, batch_size=None, verbose=1, steps=None, callbacks=None, data_config=None, feature_cols=None): """ Predict the input data :param data: predict input data. It can be XShards or Spark DataFrame. If data is XShards, each partition can be a Pandas DataFrame or a dictionary of {'x': feature}, where feature is a numpy array or a tuple of numpy arrays. :param batch_size: Batch size used for inference. Default: None. :param verbose: Prints output of one model if true. :param steps: Total number of steps (batches of samples) before declaring the prediction round finished. Ignored with the default value of None. :param callbacks: List of Keras compatible callbacks to apply during prediction. :param data_config: An optional dictionary that can be passed to data creator function. :param feature_cols: Feature column name(s) of data. Only used when data is a Spark DataFrame or an XShards of Pandas DataFrame. Default: None. :return: """ logger.info("Starting predict step.") params = dict( verbose=verbose, batch_size=batch_size, steps=steps, callbacks=callbacks, data_config=data_config, ) from zoo.orca.data import SparkXShards from pyspark.sql import DataFrame if isinstance(data, DataFrame): xshards, _ = dataframe_to_xshards(data, validation_data=None, feature_cols=feature_cols, label_cols=None, mode="predict") pred_shards = self._predict_spark_xshards(xshards, params) result = convert_predict_xshards_to_dataframe(data, pred_shards) elif isinstance(data, SparkXShards): if data._get_class_name() == 'pandas.core.frame.DataFrame': data = process_xshards_of_pandas_dataframe(data, feature_cols) pred_shards = self._predict_spark_xshards(data, params) result = update_predict_xshards(data, pred_shards) else: raise ValueError("Only xshards or Spark DataFrame is supported for predict") return result
[docs] def get_model(self): """ Returns the learned model. :return: the learned model. """ state_refs = [w.get_state.remote() for w in self.remote_workers] state = ray.get(state_refs[0]) return self._get_model_from_state(state)
[docs] def save(self, checkpoint): """ Saves the model at the provided checkpoint. :param checkpoint: (str) Path to the target checkpoint file. """ # Some model might need to aggregate variables during checkpointing # which requires both the chief and workers to participate in the # allreduce communication protocol. # So we need to call get_state on every remote workers, otherwise # it might get stuck state_refs = [w.get_state.remote() for w in self.remote_workers] state = ray.get(state_refs[0]) with open(checkpoint, "wb") as f: pickle.dump(state, f) return checkpoint
[docs] def load(self, checkpoint, **kwargs): """ Loads the model from the provided checkpoint. :param checkpoint: (str) Path to target checkpoint file. """ with open(checkpoint, "rb") as f: state = pickle.load(f) state_id = ray.put(state) ray.get([worker.set_state.remote(state_id) for worker in self.remote_workers])
[docs] def shutdown(self): """ Shuts down workers and releases resources. """ for worker in self.remote_workers: worker.shutdown.remote() worker.__ray_terminate__.remote()
def _get_model_from_state(self, state): """Creates model and load weights from state""" # keep the same behavior as `set_state` in `load` do model = self.model_creator(self.config) model.set_weights(state["weights"]) return model