Source code for mango.mpi

__doc__ = \
"""
======================================================
Message Passing Interface Utilities (:mod:`mango.mpi`)
======================================================
.. _mpi4py: http://mpi4py.scipy.org
.. currentmodule:: mango.mpi

Convenience functions for MPI (`mpi4py`_). This module imports
everything from the `mpi4py`_ namespace (if available).


Variables
=========

.. autodata:: haveMpi4py

.. autodata:: world
   
.. autodata:: size
   
.. autodata:: rank


Functions
=========

.. autosummary::
   :toctree: generated/

   getLoggers - Returns :samp:`(logger,rootLogger)` pair for MPI logging.
   initialiseLoggers - Initialises loggers to generate output at a specified logging-level.
   getCartShape - Calculates an MPI cartesian layout for a given dimension and MPI communicator.
   getCartShapeForSize - Calculates an MPI cartesian layout for a given dimension and number of MPI processes.
   
Classes
=======
.. autosummary::
   :toctree: generated/

   SplitStreamHandler - A :obj:`logging.Handler` class for splitting messages between a pair of files/streams.

"""

#    
#    :annotation: The `COMM_WORLD` MPI communicator.
#    :annotation: Number of MPI processes in world (world.Get_size()).
#    :annotation: Rank of this MPI process (world.Get_rank()).

#global rank, size, world, haveMpi4py

haveMpi4py = False #: Boolean indicating the availability of :samp:`mpi4py`.

import sys
if sys.platform.startswith('linux'):
    import DLFCN as dl
    _flags = sys.getdlopenflags()
    sys.setdlopenflags(dl.RTLD_NOW|dl.RTLD_GLOBAL)
    from . import _mango_open_mpi
    sys.setdlopenflags(_flags)
else:
    from . import _mango_open_mpi

from ._mango_open_mpi import usingMpi as _mangoUsingMpi

try:
    if (_mangoUsingMpi()):
        from mpi4py.MPI import *
        haveMpi4py = True
except:
    haveMpi4py = False

import atexit
_mango_open_mpi._initialiseMpeTiming()
atexit.register(_mango_open_mpi._finaliseMpeTiming)

rank = 0 #: Rank of this MPI process (:samp:`world.Get_rank()`).
size = 1 #: Number of MPI processes in world (:samp:`world.Get_size()`).
world = None #: The :samp:`COMM_WORLD` MPI communicator.

if (haveMpi4py):
    world = COMM_WORLD
    rank  = COMM_WORLD.Get_rank()
    size  = COMM_WORLD.Get_size()

import logging, math, sys

def getFormattedRankString():
    """
    Returns the logging prefix containing MPI-rank info.
    """
    global rank, size
    if (haveMpi4py):
        rnkFormatStr = ("%0" + str(len(str(size))) + "d") % rank
    else:
        rnkFormatStr = "0"
    return rnkFormatStr

def getFormatter(mpiRnkLabel="MpiRnk"):
    """
    :rtype: logging.Formatter
    :return: Regular formatter for non-root-rank logging.
    """
    if haveMpi4py:
        rnkStr = mpiRnkLabel + getFormattedRankString()
        paddingLen = max([0, len("MpiRoot" + getFormattedRankString()) - len(rnkStr)])
        formatter = \
            logging.Formatter(
                "%(asctime)s:" + rnkStr + ": " + (" "*paddingLen) + "%(message)s",
                "%H:%M:%S"
            )
    else:
        rnkStr = mpiRnkLabel
        paddingLen = max([0, len("Root") - len(rnkStr)])
        formatter = \
            logging.Formatter(
                "%(asctime)s:" + mpiRnkLabel + ": " + (" "*paddingLen) + "%(message)s",
                "%H:%M:%S"
            )
    return formatter

def getRootFormatter():
    """
    :rtype: logging.Formatter
    :return: Formatter for root-rank logging.
    """
    if (not haveMpi4py):
        formatter = getFormatter("Root")
    else:
        formatter = getFormatter("MpiRoot")
    return formatter

def getRootLogger(logrNameSuffix, handlerList=[], formatter=None, rootRank=0):
    """
    Returns logger for logging events from only the MPI process with
    rank rootRank.

    :type logrNameSuffix: str
    :param logrNameSuffix: Name of returned logger is "root." + logrNameSuffix.
    :type handlerList: list
    :param handlerList: List of logging.Handler objects, these get added to the logger on the rootRank process.
    :type formatter: logging.Formatter
    :param formatter: Gets set as the formatter on the rootRank process.
    :type rootRank: int
    :param rootRank: MPI rank of the root process on which logging events are to be recorded.
    :rtype: logging.Logger
    :return: A logger object for logging from the root-rank MPI process only.
    """
    global rank, size
    if (not haveMpi4py):
        rank = rootRank
        logr = logging.getLogger("root." + logrNameSuffix)
        if (formatter == None):
            formatter = getFormatter()
    else:
        if (formatter == None):
            formatter = getFormatter()
        logr = logging.getLogger("root." + logrNameSuffix)
    if (rank == rootRank):
        for handler in handlerList:
            handler.setFormatter(formatter)
            logr.addHandler(handler)

    return logr

def getNonRootLogger(logrNameSuffix, handlerList=[], formatter=None, rootRank=0):
    """
    Returns logger for logging events from all the MPI processes.

    :type logrNameSuffix: str
    :param logrNameSuffix: Name of returned logger is "rank." + logrNameSuffix.
    :type handlerList: list
    :param handlerList: List of logging.Handler objects, these get added to the logger on this MPI process.
    :type formatter: logging.Formatter
    :param formatter: Gets set as the formatter on this MPI process.
    :type rootRank: int
    :param rootRank:MPI rank of the root process.
    :rtype: logging.Logger
    :return: A logger object for logging from the any-rank MPI process.
    """
    global rank, size
    if (not haveMpi4py):
        rank = rootRank
        logr = logging.getLogger("rank." + logrNameSuffix)
        if (formatter == None):
            formatter = getFormatter()
    else:
        if (formatter == None):
            formatter = getFormatter()
        logr = logging.getLogger("rank." + logrNameSuffix)
    for handler in handlerList:
        handler.setFormatter(formatter)
        logr.addHandler(handler)

    return logr

[docs]def getLoggers(name, rootRank=0): """ Returns :samp:`(logger,rootLogger)` pair of :obj:`logging.Logger` objects. The objects will prefix messages with time and rank strings. :type name: str :param name: The name suffix for the loggers. :type rootRank: int :param rootRank: Rank of the root-logger MPI process. :rtype: :obj:`logging.Logger` pair :return: :samp:`(logger,rootLogger)` logging object pair. For example:: >>> import mango.mpi >>> import logging >>> logger, rootLogger = mangp.mpi.getLoggers("my.module.name") >>> mango.mpi.initialise() >>> mpi.initialiseLoggers(["my.module.name",], logLevel=logging.INFO) >>> logger.info("This message appears for all MPI process ranks (including root rank).") >>> rootLogger.info("This message only appears for the root rank process.") """ logr = getNonRootLogger(name) for handler in logr.handlers: handler.setFormatter(getFormatter()) rootLogr = getRootLogger(name, rootRank=rootRank) return logr, rootLogr
class _Python2SplitStreamHandler(logging.Handler): """ A python :obj:`logging.Handler` for splitting logging messages to different streams depending on the logging-level. """ def __init__(self, outstr=sys.stdout, errstr=sys.stderr, splitlevel=logging.WARNING): """ Initialise with a pair of streams and a threshold level which determines the stream where the messages are writting. :type outstr: file-like :param outstr: Logging messages are written to this stream if the message level is less than :samp:`self.splitLevel`. :type errstr: stream :param errstr: Logging messages are written to this stream if the message level is greater-than-or-equal-to :samp:`self.splitLevel`. :type splitlevel: int :param splitlevel: Logging level threshold determining split streams for log messages. """ self.outStream = outstr self.errStream = errstr self.splitLevel = splitlevel logging.Handler.__init__(self) def emit(self, record): # mostly copy-paste from logging.StreamHandler try: msg = self.format(record) if record.levelno < self.splitLevel: stream = self.outStream else: stream = self.errStream fs = "%s\n" try: if (isinstance(msg, unicode) and getattr(stream, 'encoding', None)): ufs = fs.decode(stream.encoding) try: stream.write(ufs % msg) except UnicodeEncodeError: stream.write((ufs % msg).encode(stream.encoding)) else: stream.write(fs % msg) except UnicodeError: stream.write(fs % msg.encode("UTF-8")) stream.flush() except (KeyboardInterrupt, SystemExit): raise except: self.handleError(record) class _Python3SplitStreamHandler(logging.Handler): """ A python :obj:`logging.Handler` for splitting logging messages to different streams depending on the logging-level. """ terminator = '\n' def __init__(self, outstr=sys.stdout, errstr=sys.stderr, splitlevel=logging.WARNING): """ Initialise with a pair of streams and a threshold level which determines the stream where the messages are writting. :type outstr: file-like :param outstr: Logging messages are written to this stream if the message level is less than :samp:`self.splitLevel`. :type errstr: stream :param errstr: Logging messages are written to this stream if the message level is greater-than-or-equal-to :samp:`self.splitLevel`. :type splitlevel: int :param splitlevel: Logging level threshold determining split streams for log messages. """ self.outStream = outstr self.errStream = errstr self.splitLevel = splitlevel logging.Handler.__init__(self) def flush(self): """ Flushes the stream. """ self.acquire() try: if self.outStream and hasattr(self.outStream, "flush"): self.outStream.flush() if self.errStream and hasattr(self.errStream, "flush"): self.errStream.flush() finally: self.release() def emit(self, record): """ Emit a record. If a formatter is specified, it is used to format the record. The record is then written to the stream with a trailing newline. If exception information is present, it is formatted using traceback.print_exception and appended to the stream. If the stream has an 'encoding' attribute, it is used to determine how to do the output to the stream. """ try: msg = self.format(record) if record.levelno < self.splitLevel: stream = self.outStream else: stream = self.errStream stream.write(msg) stream.write(self.terminator) self.flush() except (KeyboardInterrupt, SystemExit): #pragma: no cover raise except: self.handleError(record) if (sys.version_info[0] <= 2): class SplitStreamHandler(_Python2SplitStreamHandler): pass else:
[docs] class SplitStreamHandler(_Python3SplitStreamHandler): pass
[docs]def initialiseLoggers(logrNameList, handlerClass=SplitStreamHandler, logLevel=logging.WARN, rootRank=0): """ Initialises specified (by name suffix) loggers to generate output at the specified logging level. If the specified named loggers do not exist, they are created. :type logrNameList: :obj:`list` of :obj:`str` :param logrNameList: List of logger names. :type handlerClass: :obj:`logging.Handler` type. :param handlerClass: The handler class for output of log messages, for example :obj:`SplitStreamHandler` or :obj:`logging.StreamHandler`. :type logLevel: int :param logLevel: Log level for messages, typically one of :obj:`logging.DEBUG`, :obj:`logging.INFO`, :obj:`logging.WARN`, :obj:`logging.ERROR` or :obj:`logging.CRITICAL`. See :ref:`levels`. :type rootRank: int :param rootRank: MPI rank of the root-logger process. See :func:`getLoggers` for example usage. """ frmttr = getFormatter() rootFrmttr = getRootFormatter() for name in logrNameList: logr,rootLogr = getLoggers(name, rootRank=rootRank) handler = handlerClass() handler.setFormatter(frmttr) logr.addHandler(handler) logr.setLevel(logLevel) if (rank == rootRank): handler = handlerClass() handler.setFormatter(rootFrmttr) rootLogr.addHandler(handler) rootLogr.setLevel(logLevel) elif hasattr(logging, "NullHandler"): rootLogr.addHandler(logging.NullHandler())
[docs]def getCartShapeForSize(dimension, size): """ Returns an MPI cartesian layout for the given cartesian-dimension (:samp:`dimension`) and number of MPI processes (:samp:`size`). :type dimension: int :param dimension: Spatial dimension for returned cartesian layout. :type size: int :param size: Number of MPI processes to use in layout. :rtype: list :return: list of dimension elements, such that the product (multiplication) of the elements equals :samp:`size`. """ shape = [int(1) for i in range(dimension)] numRemaining = size d = 0 while (d < dimension) and (numRemaining > 0): f = int(round(math.pow(float(numRemaining), 1.0/float(dimension-d)))) while ((f > 1) and ((numRemaining % f) != 0)): f -= 1 shape[d] = f numRemaining = numRemaining/f d += 1 shape.sort() return shape
[docs]def getCartShape(dimension, communicator=None): """ Returns :samp:`getCartShapeForSize(dimension, communicator.Get_size())`. :type dimension: int :param dimension: Spatial dimension for returned cartesian layout. :type communicator: :obj:`mpi4py.MPI.Comm` :param communicator: If :samp:`None`, uses :obj:`world`. :rtype: list :return: :samp:`getCartShapeForSize(dimension, communicator.Get_size())` """ if communicator == None: communicator = world if communicator != None: commSz = communicator.Get_size() else: commSz = 1 return getCartShapeForSize(dimension, commSz)
__all__ = [s for s in dir() if not s.startswith('_')] logger,rootLogger = getLoggers(__name__)