torch.jit.freeze
-
torch.jit.freeze(mod, preserved_attrs=None, optimize_numerics=True)
[source] -
Freezing a
ScriptModule
will clone it and attempt to inline the cloned module’s submodules, parameters, and attributes as constants in the TorchScript IR Graph. By default,forward
will be preserved, as well as attributes & methods specified inpreserved_attrs
. Additionally, any attribute that is modified within a preserved method will be preserved.Freezing currently only accepts ScriptModules that are in eval mode.
- Parameters
-
-
mod (
ScriptModule
) – a module to be frozen - preserved_attrs (Optional[List[str]]) – a list of attributes to preserve in addition to the forward method.
- modified in preserved methods will also be preserved. (Attributes) –
-
optimize_numerics (bool) – If
True
, a set of optimization passes will be run that does not strictly - numerics. Full details of optimization can be found at torch.jit.optimize_frozen_module. (preserve) –
-
mod (
- Returns
-
Frozen
ScriptModule
.
Example (Freezing a simple module with a Parameter):
def forward(self, input): output = self.weight.mm(input) output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3).eval()) frozen_module = torch.jit.freeze(scripted_module) # parameters have been removed and inlined into the Graph as constants assert len(list(frozen_module.named_parameters())) == 0 # See the compiled graph as Python code print(frozen_module.code)
Example (Freezing a module with preserved attributes)
def forward(self, input): self.modified_tensor += 1 return input + self.modified_tensor scripted_module = torch.jit.script(MyModule2().eval()) frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"]) # we've manually preserved `version`, so it still exists on the frozen module and can be modified assert frozen_module.version == 1 frozen_module.version = 2 # `modified_tensor` is detected as being mutated in the forward, so freezing preserves # it to retain model semantics assert frozen_module(torch.tensor(1)) == torch.tensor(12) # now that we've run it once, the next result will be incremented by one assert frozen_module(torch.tensor(1)) == torch.tensor(13)
Note
If you’re not sure why an attribute is not being inlined as a constant, you can run
dump_alias_db
on frozen_module.forward.graph to see if freezing has detected the attribute is being modified.
© 2019 Torch Contributors
Licensed under the 3-clause BSD License.
https://pytorch.org/docs/1.8.0/generated/torch.jit.freeze.html