import os
import torch
import torch.nn as nn
import torch._inductor.pattern_matcher as pm
from torch._higher_order_ops.auto_functionalize import auto_functionalized
from torch._inductor.compile_fx import compile_fxcache_dir = "/home/xytpai/workspace/work/temp"
envs = {"TORCHINDUCTOR_CACHE_DIR": os.path.join(cache_dir, "inductor"),
}
for k,v in envs.items():os.environ[k] = v@torch.library.custom_op("myops::add", mutates_args=["result"])
def myops_add(result: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> None:torch.add(x, y, out=result)@torch.library.custom_op("myops::relu", mutates_args=["result"])
def myops_relu(result: torch.Tensor, x: torch.Tensor) -> None:result.copy_(x)torch.relu_(result)@torch.library.custom_op("myops::add_relu", mutates_args=["result"])
def myops_add_relu(result: torch.Tensor, x: torch.Tensor, y: torch.Tensor) -> None:z = x + yresult.copy_(z)torch.relu_(result)def pattern(result: torch.Tensor, result_add: torch.Tensor, x: torch.Tensor, y: torch.Tensor):at1 = auto_functionalized(torch.ops.myops.add.default,result=result_add,x=x, y=y)at2 = auto_functionalized(torch.ops.myops.relu.default,result=result,x=at1[1])return at2[1]def replacement(result: torch.Tensor, result_add: torch.Tensor, x: torch.Tensor, y: torch.Tensor):at = auto_functionalized(torch.ops.myops.add_relu.default,result=result,x=x, y=y)return at[1]inputs = [torch.empty(5, 4, dtype=torch.float), # resulttorch.empty(5, 4, dtype=torch.float), # result_addtorch.empty(5, 4, dtype=torch.float), # xtorch.empty(5, 4, dtype=torch.float), # y
]pm_pass = pm.PatternMatcherPass(pass_name="fusion_pass")
pm.register_replacement(pattern, replacement, inputs, pm.fwd_only, pm_pass)def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:print(graph)_count = pm_pass.apply(graph)print(_count)print(graph)graph.eliminate_dead_code()return graphdef custom_backend(graph: torch.fx.GraphModule, example_inputs):from torch._inductor import configcurrent_config = config.get_config_copy()current_config["post_grad_custom_post_pass"] = custom_passreturn compile_fx(graph, example_inputs, config_patches=current_config)# def fw_add(x, y):
# out = torch.empty_like(x)
# torch.ops.myops.add(out, x, y)
# return out# def fw_relu(x):
# out = torch.empty_like(x)
# torch.ops.myops.relu(out, x)
# return out@torch.compile(backend=custom_backend)
class SimpleModel(nn.Module):@torch._inductor.config.patch(enable_auto_functionalized_v2=False)# def forward(self, x, y):# x = fw_add(x, y)# x = fw_relu(x)# return xdef forward(self, x, y):out = torch.empty_like(x)out2 = torch.empty_like(x)torch.ops.myops.add(out, x, y)torch.ops.myops.relu(out2, out)return out2model = SimpleModel()
x = torch.rand(10, 10)
y = torch.rand(10, 10)
z = model(x, y)