from .converter import PandasConverter
from .rpcs import GrpcApi
from .utils import CiphermodeException, normalize_address, parse_sql_permissions
from .auth import AuthHandler
from .proto import common_pb2
from .onnx import convert_onnx_model_to_typed_value
[docs]
class CiphermodeApi:
    def __init__(self, address, auth_handler, cert=None, tls_domain=None, private_key=None, certificate_chain=None, *args, **kwargs):
        """
        Initialize the CiphermodeApi instance.
        Args:
            address (str): The address of the server.
            auth_handler (AuthenticationHandler): An instance of an authentication handler.
            cert (str, optional): Path to a TLS certificate file.
            tls_domain (str, optional): The domain protected by the TLS certificate.
            private_key (str, optional): The client's private key.
            certificate_chain (str, optional): The client's certificate chain.
            *args: Arguments for the PandasConverter.
            **kwargs: Kwargs for the PandasConverter.
        """
        self.stub = GrpcApi(
            address, auth_handler, cert, tls_domain, private_key, certificate_chain)
        self.converter = PandasConverter(*args, **kwargs)
[docs]
    def list_users(self):
        """
        List all users.
        Returns:
            DataFrame: A pandas DataFrame containing the list of users.
        """
        ids = self.list_users_ids()
        res = self.stub.populate_users(ids)
        return self.converter.list_users(res) 
[docs]
    def list_users_ids(self):
        """
        List the IDs of all users.
        Returns:
            list[str]: A list of user IDs.
        """
        return self.stub.list_users() 
[docs]
    def add_user_role(self, user_id, role):
        """
        Add a role to a user.
        Args:
            user_id (str): The ID of the user.
            role (str): The role to be added to the user.
        Returns:
            DataFrame: A pandas DataFrame containing the updated user information.
        """
        self.stub.add_user_role(user_id=user_id, role=role)
        users = self.stub.populate_users([user_id])
        return self.converter.list_users(users) 
[docs]
    def remove_user_role(self, user_id, role):
        """
        Remove a role from a user.
        Args:
            user_id (str): The ID of the user.
            role (str): The role to be removed from the user.
        Returns:
            DataFrame: A pandas DataFrame containing the updated user information.
        """
        self.stub.remove_user_role(user_id=user_id, role=role)
        users = self.stub.populate_users([user_id])
        return self.converter.list_users(users) 
[docs]
    def list_groups(self):
        """
        List all groups.
        Returns:
            DataFrame: A pandas DataFrame containing the list of groups.
        """
        ids = self.list_groups_ids()
        res = self.stub.populate_groups(ids)
        return self.converter.list_groups(res) 
[docs]
    def list_groups_ids(self):
        """
        List the IDs of all groups.
        Returns:
            list[str]: A list of group IDs.
        """
        return self.stub.list_groups() 
[docs]
    def run_gc(self):
        """
        Run garbage collection.
        Returns:
            int: The number of collected values.
        """
        return self.stub.run_gc() 
[docs]
    def node_connections(self):
        """
        Get node connections.
        Returns:
            DataFrame: A pandas DataFrame containing the node connections.
        """
        res = self.stub.node_connections()
        return self.converter.node_connections(res) 
[docs]
    def local_node_connections(self):
        """
        Get local node connections.
        Returns:
            list: Local node connections.
        """
        res = self.stub.local_node_connections()
        # TODO: reuse things from `self.converter.node_connections`
        return res.stats.connections 
[docs]
    def build_info(self):
        """
        Get the build information.
        Returns:
            Object: An object containing the build information.
        """
        return self.stub.build_info() 
[docs]
    def upload_dataset(self, name='', description='', type='columnwise', endpoint='', data=None, column_permissions='everything', sql_permissions='', include_report=True, publish=False, async_init=False, allow_secure_test=False):
        """
        Upload a dataset.
        Args:
            name (str, optional): The name of the dataset.
            description (str, optional): A description of the dataset.
            type (str, optional): The type of the dataset. Default is 'columnwise', available options are {'typed_value', 'columnwise', 'rowwise', 'model'}.
            endpoint (str, optional): In case of non-local datasets (cloud storage, remote SQL server), the address of the dataset object.
            data (list, optional): In case of local datasets, the data to upload (CSV files for columnwise/rowwise types, binary data of an ONNX model, or TypedValue JSON otherwise).
            column_permissions (str, optional): The column permissions of the dataset. Default is 'everything'. Avaliable options are {'everything', 'everything_local', None}.
            sql_permissions (str, optional): The SQL permissions of the dataset.
            include_report (bool, optional): Whether to include a report in the upload.
            publish (bool, optional): Whether to make dataset visible for all organizations.
            async_init (bool, optional): Whether to download the dataset from the `endpoint` asynchronously.
            allow_secure_test (bool, optional): Whether to allow the dataset to be used in SecureTest computations.
        Returns:
            A pandas Series containing the uploaded dataset.
        Raises:
            CiphermodeException: If both endpoint and data are specified, or if permissions are given for a non-columnwise dataset.
        """
        if endpoint and data:
            raise CiphermodeException(
                'Cannot specify both endpoint and data for dataset upload')
        permissions = common_pb2.PermissionConfig()
        if (column_permissions is not None) and (column_permissions != 'everything'):
            if type != 'columnwise':
                raise CiphermodeException(
                    'Only columnwise datasets can have column permissions')
            permissions.column_permissions.global_permission.permission_type = {
                'everything': common_pb2.ColumnPermissions.Permission.PermissionType.EVERYTHING,
                'everything_local': common_pb2.ColumnPermissions.Permission.PermissionType.EVERYTHING_LOCAL}[column_permissions]
        if allow_secure_test and not sql_permissions:
            default_permissions = """
            {
                plaintext_allowed: false,
                join_allowed: true,
                aggregate_allowed: true,
                grouping_allowed: false,
                filtering_allowed: true,
                aggregate_required: true,
                join_required: true,
            }
            """
            sql_permissions = f"default_permission {{ global : {default_permissions}, local : {default_permissions} }}"
        if sql_permissions:
            if type != 'columnwise':
                raise CiphermodeException(
                    'Only columnwise datasets can have SQL permissions')
            permissions.sql_column_permissions.MergeFrom(
                parse_sql_permissions(sql_permissions))
        if type == 'onnx_model':
            data = convert_onnx_model_to_typed_value(data)
        res = self.stub.upload_dataset(data, {'columnwise': common_pb2.DatasetType.COLUMNWISE_TABLE,
                                              'rowwise': common_pb2.DatasetType.ROWWISE_TABLE,
                                              'typed_value': common_pb2.DatasetType.SINGLE_VALUE,
                                              'onnx_model':  common_pb2.DatasetType.SINGLE_VALUE}[type],
                                       name, description, permissions, include_report, endpoint, publish, async_init)
        return self.converter.list_datasets(res).iloc[0] 
[docs]
    def upload_and_publish_dataset(self, *args, **kwargs):
        """
        Upload a dataset and than make it visible for all organizations.
        See `upload_dataset` for arguments.
        Returns:
            DataFrame: A pandas DataFrame containing the uploaded and published dataset.
        """
        kwargs['publish'] = True
        return self.upload_dataset(*args, **kwargs) 
[docs]
    def publish_dataset(self, id):
        """
        Make the dataset visible for all organizations.
        Args:
            id (str): The ID of the dataset.
        Returns:
            DataFrame: A pandas DataFrame containing the published dataset.
        """
        res = self.stub.expose_dataset(id)
        return self.converter.list_datasets(res) 
[docs]
    def list_datasets(self):
        """
        List all datasets.
        Returns:
            DataFrame: A pandas DataFrame containing the list of datasets.
        """
        ids = self.list_datasets_ids()
        res = self.stub.populate_datasets(ids)
        return self.converter.list_datasets(res) 
[docs]
    def list_datasets_ids(self):
        """
        List the IDs of all datasets.
        Returns:
            list[str]: A list of dataset IDs.
        """
        return self.stub.list_datasets() 
[docs]
    def show_dataset(self, dataset_id):
        """
        Display the metadata about the dataset with the specified ID.
        Args:
            dataset_id (str): The ID of the dataset.
        Returns:
            DataFrame: A pandas DataFrame containing the dataset information.
        """
        res = self.stub.populate_datasets([dataset_id])
        (_, resp, *__) = res[0]
        dataset_values = self.stub.get_dataset_values(dataset_id)
        return self.converter.show_dataset(resp.dataset, dataset_values) 
[docs]
    def get_dataset(self, dataset_id):
        """
        Get the dataset with the specified ID.
        Args:
            dataset_id (str): The ID of the dataset.
        Returns:
            Dataset: The dataset with the specified ID.
        """
        (_, resp, *__) = self.stub.populate_datasets([dataset_id])[0]
        return resp.dataset 
[docs]
    def delete_dataset(self, dataset_id):
        """
        Delete dataset with the specified ID.
        Args:
            dataset_id (str): The ID of the dataset.
        Returns:
            bool: True if dataset was successfully deleted.
        """
        return self.stub.delete_dataset(dataset_id) is not None 
[docs]
    def get_report(self, dataset_id):
        """
        Get the report of the specified dataset.
        Args:
            dataset_id (str): The ID of the dataset.
        Returns:
            Report (str): The report of the specified dataset.
        """
        return self.stub.get_report(dataset_id) 
[docs]
    def upload_graph(self, serialized_graph):
        """
        Upload a serialized graph.
        Args:
            serialized_graph (str): The serialized Ciphercore graph to upload.
        Returns:
            DataFrame: A pandas DataFrame containing the uploaded graph information.
        """
        res = self.stub.upload_graph(serialized_graph)
        return self.converter.list_graphs(res) 
[docs]
    def download_graph(self, id):
        """
        Download a graph with the specified ID.
        Args:
            id (str): The ID of the graph.
        Returns:
            str: The serialized Ciphercore graph.
        """
        return self.stub.download_graph(id) 
[docs]
    def list_graphs(self):
        """
        List all graphs.
        Returns:
            DataFrame: A pandas DataFrame containing the list of graphs.
        """
        ids = self.list_graphs_ids()
        res = self.stub.populate_graphs(ids)
        return self.converter.list_graphs(res) 
[docs]
    def list_graphs_ids(self):
        """
        List the IDs of all graphs.
        Returns:
            list[str]: A list of graph IDs.
        """
        return self.stub.list_graphs() 
[docs]
    def create_computation(self, orchestrator, graphs_config, name, description, config=None):
        """
        Create a computation.
        Computation object specifies what computation to execute, regardless of the data. The same computation
        can be used multiple times with different datasets.
        Note that there are easier-to-use functions for specific computations (PSI, SQL, NN training, etc.).
        Args:
            orchestrator (str): The orchestrator type for the computation.
            graphs_config (dict): The "graph name -> graph ID" mapping.
            name (str): The name of the computation.
            description (str): The description of the computation.
            config (dict, optional): Additional orchestrator-specific configuration for the computation.
        Returns:
            DataFrame: A pandas DataFrame containing the created computation information.
        """
        res = self.stub.create_computation(
            orchestrator, graphs_config, name, description, config=config)
        return self.converter.list_computations(res) 
[docs]
    def create_single_graph_computation(self, serialized_graph, name='', description=''):
        """
        Create a single graph computation.
        Args:
            serialized_graph (str): The serialized Ciphercore graph to create a computation for.
            name (str, optional): The name of the computation.
            description (str, optional): The description of the computation.
        Returns:
            DataFrame: A pandas DataFrame containing the created computation information.
        """
        res = self.stub.upload_graph(serialized_graph)
        if len(res) != 1:
            return
        graph_id = res[0][0]
        res = self.stub.create_computation(
            'single_graph', {"graph": graph_id}, name, description)
        return self.converter.list_computations(res) 
[docs]
    def list_computations(self):
        """
        List all computations.
        Returns:
            DataFrame: A pandas DataFrame containing the list of computations.
        """
        ids = self.list_computations_ids()
        res = self.stub.populate_computations(ids)
        return self.converter.list_computations(res) 
[docs]
    def list_computations_ids(self):
        """
        List the IDs of all computations.
        Returns:
            list[str]: A list of computation IDs.
        """
        return self.stub.list_computations() 
[docs]
    def list_cloud_uploads(self):
        """
        List all cloud uploads.
        Returns:
            DataFrame: A pandas DataFrame containing the list of cloud uploads.
        """
        ids = self.stub.list_cloud_uploads()
        res = self.stub.populate_cloud_uploads(ids)
        return self.converter.list_cloud_uploads(res) 
[docs]
    def get_cloud_upload(self, id):
        """
        Get cloud upload with the specified ID.
        Returns:
            DataFrame: A pandas DataFrame containing the cloud upload information.
        """
        res = self.stub.populate_cloud_uploads([id])
        return self.converter.list_cloud_uploads(res) 
[docs]
    def get_psi_computation(self, first_dataset_columns, second_dataset_columns, sharded=True):
        """
        Create a PSI (Private Set Intersection) computation.
        Args:
            first_dataset_columns (list[str]): The list of columns from the first dataset to join.
            second_dataset_columns (list[str]): The list of columns from the second dataset to join.
            sharded (bool, optional): Whether to shard the computation. Default is True.
        Returns:
            str: The ID of the created computation.
        """
        if not isinstance(first_dataset_columns, list) or not isinstance(second_dataset_columns, list):
            raise ValueError(
                'first_dataset_columns and second_dataset_columns should be lists of strings.')
        if len(first_dataset_columns) != len(second_dataset_columns):
            raise ValueError(
                'first_dataset_columns and second_dataset_columns should have the same length.')
        # TODO: for efficiency reasons, maybe reuse existing computation, if columns match?
        join_column_pairs = []
        for first_col, second_col in zip(first_dataset_columns, second_dataset_columns):
            join_column_pairs.append(
                common_pb2.JoinColumnPair(
                    first_dataset_column=first_col,
                    second_dataset_column=second_col
                )
            )
        config = common_pb2.OrchestratorConfig(
            psi_config=common_pb2.PsiConfig(
                join_columns=join_column_pairs,
                sharded=sharded,
            )
        )
        return self.stub.create_computation(
            'psi', {}, 'PSI of {} and {}'.format(first_dataset_columns, second_dataset_columns), "", config=config)[0][0] 
[docs]
    def get_mlp_computation(self, layers, batch_size, optimizer, learning_rate, loss, epochs, precision):
        """
        Create an MLP (Multi-Layer Perceptron) computation.
        Args:
            layers (list): The list with the sizes of hidden layers in the MLP (note that the last one should be 1 in most cases).
            batch_size (int): The batch size for training.
            optimizer (str): The optimizer to use for training (we currently support 'adam', 'adagrad' and 'sgd').
            learning_rate (float): The learning rate for training.
            loss (str): The loss function to use for training (we currently support 'log_loss' and 'mse').
            epochs (int): The number of epochs for training.
            precision (int): The precision for training (it is conducted with fixed precision numbers, with `2**precision` as denominator).
        Returns:
            str: The ID of the created computation.
        """
        config = common_pb2.OrchestratorConfig(
            ml_config=common_pb2.MlpConfig(hidden_layers=layers,
                                           batch_size=batch_size, optimizer=optimizer,
                                           learning_rate=learning_rate, loss=loss,
                                           epochs=epochs, precision=precision))
        return self.stub.create_computation(
            'neural_network_training', {}, f'MLP with {len(layers)} layers', str(config), config=config)[0][0] 
[docs]
    def get_nn_inference_computation(self, batch_size, precision):
        """
        Create a neural network inference computation.
        Args:
            batch_size (int): The batch size for inference, should be the same as for training.
            precision (int): The precision for inference, should be the same as for training.
        Returns:
            str: The ID of the created computation.
        """
        config = common_pb2.OrchestratorConfig(nn_inference_config=common_pb2.NnInferenceConfig(
            batch_size=batch_size, precision=precision))
        return self.stub.create_computation(
            'neural_network_inference', {}, 'NN Inference', str(config), config=config)[0][0] 
[docs]
    def get_llm_inference_computation(self, max_len, num_layers, embedding_dim, num_heads, temperature, top_p):
        """
        Create a LLM inference computation.
        Args:
            max_len (int): The maximum length of the generated text.
            num_layers (int): The number of layers in the transformer.
            embedding_dim (int): The embedding dimension of the transformer.
            num_heads (int): The number of heads in the transformer.
            temperature (float): The temperature for the sampling.
            top_p (float): The top p for the sampling.
        Returns:
            str: The ID of the created computation.
        """
        config = common_pb2.OrchestratorConfig(llm_inference_config=common_pb2.LlmInferenceConfig(
            max_len=max_len, num_layers=num_layers, embedding_dim=embedding_dim, num_heads=num_heads, temperature=temperature, top_p=top_p))
        return self.stub.create_computation(
            'llm_inference', {}, 'LLM Inference', str(config), config=config)[0][0] 
[docs]
    def get_sql_computation(self, query):
        """
        Create an SQL computation.
        Args:
            query (str): The SQL query to execute. It can refer to tables by names, these names need to be specified in the corresponding computation session.
        Returns:
            str: The ID of the created computation.
        """
        # TODO: for efficiency reasons, maybe reuse existing computation, if query is the same?
        config = common_pb2.OrchestratorConfig(
            sql_config=common_pb2.SqlConfig(
                query=query
            )
        )
        return self.stub.create_computation(
            'sql', {}, 'SQL query: {}'.format(query), "", config=config)[0][0] 
[docs]
    def get_knn_computation(self, num_neighbors, has_labels=False):
        """
        Create a KNN (k-nearest-neighbors) computation.
        Args:
            num_neighbors (int): The number of neighbors to consider in the KNN computation.
            has_labels (bool, optional): Whether the input data has labels. Default is False.
        Returns:
            str: The ID of the created computation.
        """
        # TODO: for efficiency reasons, maybe reuse existing computation, if num_neighbors is the same?
        config = common_pb2.OrchestratorConfig(
            knn_config=common_pb2.KnnConfig(
                num_neighbors=num_neighbors,
                label_aggregation=common_pb2.KnnConfig.LabelAggregation.MEAN if has_labels else common_pb2.KnnConfig.LabelAggregation.NONE
            )
        )
        return self.stub.create_computation(
            'nearest_neighbors', {}, 'KNN[neighbors={}]'.format(num_neighbors), "", config=config)[0][0] 
[docs]
    def create_computation_session(self, computation_id, data_config, name='', description=''):
        """
        Create a computation session.
        Args:
            computation_id (str): The ID of the computation.
            data_config (dict): The mapping (name -> value ID). Names are orchestrator-specific (see orchestrator-specific functions for details, e.g. `create_psi`).
            name (str, optional): The name of the session.
            description (str, optional): The description of the session.
        Returns:
            A pandas Series containing the created computation session information.
        """
        self.stub.expose_computation(computation_id)
        res = self.stub.create_computation_session(
            computation_id, data_config, name, description)
        return self.converter.list_computation_sessions(res).iloc[0] 
[docs]
    def create_psi(self, first_dataset_id, second_dataset_id, first_dataset_columns, second_dataset_columns, name='', description='', sharded=True):
        """
        Create a PSI (Private Set Intersection) computation session.
        Args:
            first_dataset_id (str): The ID of the first dataset.
            second_dataset_id (str): The ID of the second dataset.
            first_dataset_columns (list[str]): The column from the first dataset to join.
            second_dataset_columns (list[str]): The column from the second dataset to join.
            name (str, optional): The name of the session.
            description (str, optional): The description of the session.
            sharded (bool, optional): Whether to shard the computation. Default is True.
        Returns:
            A pandas Series containing the created computation session information.
        """
        computation_id = self.get_psi_computation(
            first_dataset_columns, second_dataset_columns, sharded)
        data_config = {'input_0': first_dataset_id,
                       'input_1': second_dataset_id}
        return self.create_computation_session(computation_id, data_config, name, description) 
[docs]
    def create_sql(self, query, data_config, name='', description=''):
        """
        Create an SQL computation session.
        Args:
            query (str): The SQL query to execute.
            data_config (dict): The configuration of data for the computation.
            name (str, optional): The name of the session.
            description (str, optional): The description of the session.
        Returns:
            DataFrame: A pandas DataFrame containing the created computation session information.
        """
        computation_id = self.get_sql_computation(query)
        return self.create_computation_session(computation_id, data_config, name, description) 
[docs]
    def create_mlp(self, train_datasets, validation_datasets, test_datasets, model_dataset,
                   layers=[100, 1], batch_size=64, optimizer='adam',
                   learning_rate=3e-4, loss='log_loss', epochs=3, precision=15,
                   name='', description=''):
        """
        Create an MLP (Multi-Layer Perceptron) training computation session.
        Args:
            train_datasets (list): The list of training dataset IDs.
            validation_datasets (list): The list of validation dataset IDs.
            test_datasets (list): The list of testing dataset IDs.
            layers (list, optional): List of hidden layer sizes in the MLP (in most cases, the last one should be 1). Default is [100, 1].
            batch_size (int, optional): Batch size for training. Default is 64.
            optimizer (str, optional): Optimizer to use for training. Default is 'adam', supported optimizers are 'adam', 'adagrad', 'sgd'.
            learning_rate (float, optional): Learning rate for training. Default is 3e-4.
            loss (str, optional): Loss function to use for training. Default is 'log_loss', supported losses are 'log_loss' and 'mse'.
            epochs (int, optional): Number of epochs for training. Default is 3.
            precision (int, optional): Precision for training. Default is 15. Training is performed in fixed-point arithmetic with denominator `2**precision`.
            name (str, optional): The name of the session.
            description (str, optional): The description of the session.
        Returns:
            DataFrame: A pandas DataFrame containing the created computation session information.
        """
        def config_for_datasets(datasets, name):
            return {f'{name}{i}': id for i, id in enumerate(datasets)}
        if model_dataset:
            data_config = {**config_for_datasets(train_datasets, 'training'), **config_for_datasets(
                validation_datasets, 'validation'), **config_for_datasets(test_datasets, 'testing'), **config_for_datasets([model_dataset], 'model')}
        else:
            data_config = {**config_for_datasets(train_datasets, 'training'), **config_for_datasets(
                validation_datasets, 'validation'), **config_for_datasets(test_datasets, 'testing')}
        loss = {'mse': common_pb2.MlpConfig.MSE,
                'log_loss': common_pb2.MlpConfig.LOG_LOSS}[loss]
        computation_id = self.get_mlp_computation(
            layers, batch_size, optimizer, learning_rate, loss, epochs, precision)
        return self.create_computation_session(computation_id, data_config, name, description) 
[docs]
    def create_nn_inference(self, inference_dataset_id, model_dataset_id, batch_size=64, precision=15, name='', description=''):
        """
        Create a neural network inference computation session.
        Args:
            inference_dataset_id (str): The ID of the inference dataset.
            model_dataset_id (str): The ID of the model dataset.
            batch_size (int, optional): The batch size for inference. Default is 64, should be the same as for training.
            precision (int, optional): The precision for inference. Default is 15, should be the same as for training.
            name (str, optional): The name of the session.
            description (str, optional): The description of the session.
        Returns:
            DataFrame: A pandas DataFrame containing the created computation session information.
        """
        data_config = {'inference0': inference_dataset_id,
                       'model': model_dataset_id}
        computation_id = self.get_nn_inference_computation(
            batch_size, precision)
        return self.create_computation_session(computation_id, data_config, name, description) 
[docs]
    def create_llm_inference(self, inference_dataset_id, model_dataset_id, max_len=128, num_layers=8, embedding_dim=512,
                             num_heads=16, temperature=0.85, top_p=0.85,
                             name='', description=''):
        """
        Create a LLM inference computation session.
        Args:
            inference_dataset_id (str): The ID of the inference dataset.
            model_dataset_id (str): The ID of the model dataset.
            max_len (int, optional): The maximum length of the generated sequence.
            num_layers (int, optional): The number of layers in the model.
            embedding_dim (int, optional): The embedding dimension of the model.
            num_heads (int, optional): The number of attention heads in the model.
            temperature (float, optional): The temperature for sampling.
            top_p (float, optional): The top-p heuristic value for sampling.
            name (str, optional): The name of the session.
            description (str, optional): The description of the session.
        Returns:
            DataFrame: A pandas DataFrame containing the created computation session information.
        """
        data_config = {'prompt': inference_dataset_id,
                       'model': model_dataset_id}
        computation_id = self.get_llm_inference_computation(
            max_len, num_layers, embedding_dim, num_heads, temperature, top_p)
        return self.create_computation_session(computation_id, data_config, name, description) 
[docs]
    def create_knn(self, key_dataset_id, query_dataset_id, num_neighbors, value_dataset_id=None, name='', description=''):
        """
        Create a KNN (k-Nearest-Neighbors) computation session.
        Args:
            key_dataset_id (str): The ID of the rowwise dataset with lookup keys (vectors).
            query_dataset_id (str): The ID of the rowwise dataset with lookup queries (vectors).
            num_neighbors (int): The number of neighbors to consider in the KNN computation.
            value_dataset_id (str, optional): The ID of the dataset with labels. Default is None.
            name (str, optional): The name of the session.
            description (str, optional): The description of the session.
        Returns:
            DataFrame: A pandas DataFrame containing the created computation session information.
        """
        data_config = {'keys': key_dataset_id, 'queries': query_dataset_id}
        if value_dataset_id:
            data_config['values'] = value_dataset_id
        computation_id = self.get_knn_computation(
            num_neighbors, has_labels=(value_dataset_id is not None))
        return self.create_computation_session(computation_id, data_config, name, description) 
[docs]
    def list_computation_sessions(self, filter_computation_session_ids=None, show_tags=False):
        """
        List computation sessions.
        Args:
            filter_computation_session_ids (list[str], optional): List of specific computation session IDs to return.
                If None, all computation sessions are returned. Default is None.
            show_tags (bool, optional): Whether to include the tags column.
        Returns:
            DataFrame: A pandas DataFrame containing the list of computation sessions.
        """
        ids = filter_computation_session_ids
        if not filter_computation_session_ids:
            ids = self.list_computation_sessions_ids()
        res = self.stub.populate_computation_sessions(ids)
        return self.converter.list_computation_sessions(res, show_tags) 
[docs]
    def list_computation_sessions_ids(self):
        """
        List computation session IDs.
        Returns:
            list[str]: A list of computation session IDs.
        """
        return self.stub.list_computation_sessions() 
[docs]
    def tag_computation_session(self, id, key, value=None):
        """
        Tag computation session.
        Args:
            id (str): The ID of the computation session to start.
            key (str): Tag key.
            value (str, optional): Tag value. If None, the tag with a given key is removed instead.
        """
        res = self.stub.tag_computation_session(id, key, value) 
[docs]
    def start_computation_session(self, id):
        """
        Start a specific computation session.
        Args:
            id (str): The ID of the computation session to start.
        Returns:
            DataFrame: A pandas DataFrame containing the started computation session information.
        """
        res = self.stub.start_computation_session(id)
        return self.converter.list_computation_sessions(res) 
[docs]
    def cancel_computation_session(self, id):
        """
        Cancel a specific computation session.
        Args:
            id (str): The ID of the computation session to cancel.
        Returns:
            DataFrame: A pandas DataFrame containing the cancelled computation session information.
        """
        res = self.stub.cancel_computation_session(id)
        return self.converter.list_computation_sessions(res) 
[docs]
    def download_computation_session_result(self, id, onnx=False):
        """
        Download the result of a specific computation session.
        Args:
            id (str): The ID of the computation session to download.
            onnx (bool, optional): Whether to convert the result to ONNX protobuf. Default is False.
        Returns:
            DataFrame: A pandas DataFrame containing the downloaded computation session result.
        Raises:
            CiphermodeException: if more than one of csv, onnx and float_array is set.
        """
        session = self.stub.populate_computation_sessions([id])[0][1].data
        computation = self.stub.populate_computations(
            [session.computation_id])[0][1].computation
        orchestrator = computation.orchestrator_name
        if onnx:
            if orchestrator not in ['neural_network_training']:
                raise CiphermodeException(
                    'Cannot convert to onnx for orchestrator {}'.format(orchestrator))
        results = session.metadata.results
        if len(results) == 0:
            raise CiphermodeException('Session contains no results')
        if len(results) > 1:
            raise CiphermodeException(
                'Not implemented: session contains multiple results')
        result = results[0]
        output_keys = list(result.outputs.keys())
        if len(output_keys) > 1:
            raise CiphermodeException(
                'Not implemented: session result contains multiple outputs')
        output_key = output_keys[0]
        payload = self.stub.download_computation_session_output(
            id, 0, output_key)
        output_format = result.outputs[output_key].output_format
        return self.converter.view_typed_value(payload, output_format, onnx) 
[docs]
    def upload_computation_session_result(self, id, endpoint):
        """
        Uploads the result of a computation session to a specified endpoint.
        Args:
            id (str): The ID of the computation session.
            endpoint (str): The endpoint to which the computation session result will be uploaded.
        Returns:
            DataFrame: A pandas DataFrame containing the new dataset.
        """
        session = self.stub.populate_computation_sessions([id])[0][1].data
        results = session.metadata.results
        if len(results) > 1:
            raise CiphermodeException(
                'Not implemented: session contains multiple results')
        result = results[0]
        output_keys = list(result.outputs.keys())
        if len(output_keys) > 1:
            raise CiphermodeException(
                'Not implemented: session result contains multiple outputs')
        output_key = output_keys[0]
        return self.stub.upload_computation_session_output(id, 0, output_key, endpoint) 
[docs]
    def save_computation_session_result(self, id, name='', description='', as_csv=False, include_summary=False, sql_permissions=None, publish=False):
        """
        Saves the result of a computation session to a new dataset.
        Args:
            id (str): The ID of the computation session.
            name (str, optional): The name to assign to the dataset.
            description (str, optional): The description to assign to the dataset.
            as_csv (bool, optional): Whether to treat the computation result as a CSV-like table (results in a columnwise dataset).
            include_summary (bool, optional): Whether to include a dataset summary for the newly created dataset.
            sql_permissions (str, optional): The SQL permissions to assign to the dataset.
            publish (bool, optional): Whether to make dataset visible for all organizations.
        Returns:
            DataFrame: A pandas DataFrame containing the new dataset.
        """
        permissions = common_pb2.PermissionConfig()
        if sql_permissions:
            permissions.sql_column_permissions.MergeFrom(
                parse_sql_permissions(sql_permissions))
        res = self.stub.save_computation_session_result(
            id, name, description, permissions, as_csv=as_csv, include_report=include_summary, publish=publish)
        return self.converter.list_datasets(res) 
[docs]
    def list_data_requests(self, filter_computation_session_id=None):
        """
        Lists data requests.
        Args:
            filter_computation_session_id (str, optional): If provided, only data requests for this computation session ID will be returned.
        Returns:
            DataFrame: A pandas DataFrame containing the list of data requests.
        """
        ids = self.list_data_requests_ids()
        res = self.stub.populate_data_approvals(ids)
        return self.converter.list_data_approvals(res, filter_computation_session_id) 
[docs]
    def list_data_requests_ids(self):
        """
        Lists the IDs of data requests.
        Returns:
            list[str]: A list of data request IDs.
        """
        return self.stub.list_data_approvals() 
    def _update_data_approval(self, id, status=None, comment=''):
        res = self.stub.update_data_approval(id, status, comment)
        return self.converter.list_data_approvals(res)
[docs]
    def approve_data_request(self, id, comment=''):
        """
        Approves a data request.
        Args:
            id (str): The ID of the data request to approve.
            comment (str, optional): A comment to attach to the data request.
        Returns:
            DataFrame: A pandas DataFrame containing the approved data request.
        """
        return self._update_data_approval(id, common_pb2.DataApproval.APPROVED, comment) 
[docs]
    def reject_data_request(self, id, comment=''):
        """
        Rejects a data request.
        Args:
            id (str): The ID of the data request to reject.
            comment (str, optional): A comment to attach to the data request.
        Returns:
            DataFrame: A pandas DataFrame containing the rejected data request.
        """
        return self._update_data_approval(id, common_pb2.DataApproval.REJECTED, comment) 
[docs]
    def create_explore_dataset_intersection(self, dataset_id1, dataset_id2, column_names1, column_names2, use_approx_match_rate=False):
        """
        Creates an exploration of the intersection between two datasets.
        Args:
            dataset_id1 (str): The ID of the first dataset.
            dataset_id2 (str): The ID of the second dataset.
            column_names1 (list(str)): Names of the columns in the first dataset to compare.
            column_names2 (list(str)): Names of the columns in the second dataset to compare.
            use_approx_match_rate (bool, optional): Whether to use approximate match rate. Default is False.
        Returns:
            String: computation_session_id
        """
        return self.stub.create_explore_dataset_intersection(
            dataset_id1, dataset_id2, column_names1, column_names2, use_approx_match_rate=use_approx_match_rate) 
[docs]
    def poll_explore_dataset_intersection(self, session_id):
        """
        Polls the exploration of a dataset intersection.
        Args:
            session_id (str): The session id associated with the dataset intersection exploration.
        Returns:
            ExploreDatasetIntersectionResponse: Object containing explore computation details.
        """
        return self.stub.poll_explore_dataset_intersection(session_id) 
[docs]
    def list_user_events(self, timestamp_ms, num_events, user=''):
        """
        Lists user audit events up to a given timestamp. Admin only.
        Args:
            timestamp_ms (int): Timestamp, in milliseconds.
            num_events (int): Number of events to fetch.
            user (str, optional): Email address to filter events on.
        Returns:
            DataFrame: A pandas DataFrame containing user audit events.
        """
        res = self.stub.list_user_events(timestamp_ms, num_events, user)
        return self.converter.list_user_events(res) 
[docs]
    def list_node_events(self, timestamp_ms, num_events):
        """
        Lists node audit events up to a given timestamp. Admin only.
        Args:
            timestamp_ms (int): Timestamp, in milliseconds.
            num_events (int): Number of events to fetch.
        Returns:
            DataFrame: A pandas DataFrame containing node audit events.
        """
        res = self.stub.list_node_events(timestamp_ms, num_events)
        return self.converter.list_node_events(res) 
[docs]
    def hash_dataset_columns(self, dataset_id, hash_column_names, new_dataset_name, async_init=False):
        """
        Hashes entries of dataset with given column names to create a succinct representation of the input dataset.
        Succinct representations output by this method can be matched with `create_psi` to get hash values
        they have in common.
        Args:
            dataset_id (str): The dataset ID.
            hash_column_names (list[str]): Columns from the dataset to hash.
            new_dataset_name (str): New dataset name.
            async_init (bool, optional): Whether to download the dataset from the `endpoint` asynchronously.
    
        Returns:
            A pandas Series containing the dataset ID for the succinct representation.
            This dataset contains a single column of (de-duplicated) hash values, each value corresponding to 
            some set of rows in the input dataset where entries indexed by columns in `hash_column_names` 
            had the same hash.
        """
        res = self.stub.hash_dataset_columns(dataset_id, hash_column_names, new_dataset_name, async_init)
        return self.converter.list_datasets(res).iloc[0] 
[docs]
    def waterfall_gather(self, original_dataset_id, stage_session_ids, endpoint):
        """
        Post-processes the results of multiple PSI computations on hashed datasets output by `hash_dataset_columns`
        to obtain the indices of rows in the original dataset that matched, along with the index of the first computation they 
        matched in.
        Can be used to implement a multi-stage "waterfall" join by providing ordered session IDs for each stage,
        or to convert a dataset of hashes into a dataset of indices in the original dataset corresponding to these hashes.
        Args:
            original_dataset_id (str): The original dataset ID.
            stage_session_ids (list[str]): Waterfall session IDs. Each should correspond to a PSI computation (made by `create_psi`) on hashed datasets (made with `hash_dataset_columns`).
            endpoint (str): The endpoint to which the computation session result will be uploaded.
        Returns:
            A pandas Series containing the result of running a multi-stage waterfall match on stage_session_ids.
            
            Note that the result will be empty if called with non-empty endpoint - the result will be written directly to cloud storage.
        """
        return self.stub.waterfall_gather(original_dataset_id, stage_session_ids, endpoint) 
 
[docs]
def create_client(frontend_address,
                  auth_config='~/.ciphercore/auth_config',
                  token_path='~/.ciphercore/token',
                  custom_root_ca=None,
                  tls_domain='localhost',
                  private_key=None,
                  certificate_chain=None,
                  *args,
                  **kwargs):
    """  
    Create a CiphermodeApi instance and intialize it.
    Args:
        frontend_address (str): The address of the server.
        auth_config (str, optional): Path to auth config.
        token_path (str, optional): Path to the token file.
        custom_root_ca (str, optional): Path to a TLS certificate file.
        tls_domain (str, optional): The domain protected by the TLS certificate.
        private_key (str, optional): Path to the client's private key.
        certificate_chain (str, optional): Path to the client's certificate chain.
        *args: Arguments for the PandasConverter.
        **kwargs: Kwargs for the PandasConverter.
    Returns:
        CiphermodeApi: An instance of the CiphermodeApi.
    """
    frontend_address = normalize_address(frontend_address)
    cert = None if custom_root_ca is None else open(
        custom_root_ca, 'rb').read()
    private_key = None if private_key is None else open(
        private_key, 'rb').read()
    certificate_chain = None if certificate_chain is None else open(
        certificate_chain, 'rb').read()
    auth_handler = AuthHandler(
        frontend_address, auth_config, token_path, cert, tls_domain)
    return CiphermodeApi(frontend_address, auth_handler, cert=cert, tls_domain=tls_domain, private_key=private_key, certificate_chain=certificate_chain, *args, **kwargs)