Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion python/paddle/v2/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import numpy as np
from . import core
import proto.framework_pb2 as framework_pb2
import contextlib

__all__ = [
'Block', 'Variable', 'Program', 'Operator', 'default_startup_program',
'default_main_program'
'default_main_program', 'program_guard', 'switch_startup_program',
'switch_main_program'
]


Expand Down Expand Up @@ -659,8 +661,83 @@ def __init__(self, block, shape, dtype, **kwargs):


def default_startup_program():
"""
Get default startup program. In startup program, Paddle will initialize
parameters, initialize nccl handle, etc.

Returns:
Program: startup program
"""
return _startup_program_


def default_main_program():
"""
Get default main program. The main program is used for training or testing.

Returns:
Program: main program
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need this default_main_program() function, not use _main_program_ directly?

"""
return _main_program_


def switch_main_program(program):
"""
Switch the main program to a new program.

Args:
program(Program): The new main program

Returns:
Program: The previous main program
"""
global _main_program_
prev_program = _main_program_
_main_program_ = program
return prev_program
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, why return prev_program, not _main_program_? We are switch to a new program, not previous one



def switch_startup_program(program):
"""
Switch the startup program to a new program
Args:
program(Program): The new startup program

Returns:
Program: The previous startup program
"""
global _startup_program_
prev_program = _startup_program_
_startup_program_ = program
return prev_program


@contextlib.contextmanager
def program_guard(main_program, startup_program=None):
"""
Switch program with `with` statement

Examples:
>>> with program_guard(Program()):
>>> data = fluid.layers.data(...)
>>> hidden = fluid.layers.fc(...)

Args:
main_program(Program): New main program inside `with` statement
startup_program(Program): New startup program inside `with` statement.
None means do not change startup program.

Returns:
None
"""
if not isinstance(main_program, Program):
raise TypeError("main_program should be Program")
main_program = switch_main_program(main_program)
if startup_program is not None:
if not isinstance(startup_program, Program):
raise TypeError("startup_program should be Program")
startup_program = switch_startup_program(startup_program)
yield
switch_main_program(main_program)
if startup_program is not None:
switch_startup_program(startup_program)
Loading