随机生成pytorch算子测试序列且保证算子参数合法

随机生成pytorch算子测试序列且保证算子参数合法

  • 代码
  • 输出

背景:

1.一些对维度进行操作的算子的单算子测试,结果正常,但多个算子组合在一起,结果就不对。是否能给一个算子列表,随机生成它们的组合呢

功能描述:

1.此程序用于在 CUDA 环境中生成随机张量并对其施加一系列随机选择的操作

2.程序首先随机生成张量的形状和内容,然后随机选择一个操作(如 reshapetransposematmul 等),并生成适当的参数以执行该操作

3.最终输出变换后的张量并打印相关操作信息

4.整个过程在100次不同的种子下每次进行10次操作,以保证操作的多样性和结果的随机性

代码

import torch
import random
from functools import reduce
from operator import mul
import numpy as npmax_size = 4096  # 每个维度的最大大小
max_tensor_elements = 1*4096*4096  # 张量中元素的总数限制min_dim_size = 1  # 最小维度大小
max_dim_size = max_size  # 扩大这个范围可以更快生成符合要求的大小def generate_random_shape(dim, max_attempts=10):for _ in range(max_attempts):shape = [random.randint(min_dim_size, max_dim_size) for _ in range(dim)]if reduce(mul, shape, 1) <= max_tensor_elements:return tuple(shape)# 兜底策略,防止尝试次数用尽:再遍历生成的随机形状,逐个将维度缩小直到符合限制shape = [random.randint(1, max_size) for _ in range(dim)]current_elements = reduce(mul, shape, 1)while current_elements > max_tensor_elements:for i in range(len(shape)):if shape[i] > 1:shape[i] //= 2current_elements = reduce(mul, shape, 1)if current_elements <= max_tensor_elements:breakreturn tuple(shape)def generate_random_input(shape):return torch.randn(shape).to("cuda").half()def generate_random_operator(input_shape):operators = ['unsqueeze', 'repeat', 'permute', 'transpose', 'reshape', 'expand', 'contiguous', 'matmul', 'mul', 'concat',"view"]return random.choice(operators)def generate_random_reshape(input_shape):# 计算输入张量的总元素数total_elements = np.prod(input_shape)divisors = []# 找到 total_elements 的所有约数for i in range(1, int(np.sqrt(total_elements)) + 1):if total_elements % i == 0:divisors.append(i)if i != total_elements // i:divisors.append(total_elements // i)dimensions = []remaining_elements = total_elements# 随机选择新的维度并且保证元素数量不变while remaining_elements > 1 and len(dimensions) < len(input_shape):divisor = np.random.choice(divisors)dimensions.append(divisor)remaining_elements //= divisordivisors = [d for d in divisors if remaining_elements % d == 0]if remaining_elements > 1:dimensions.append(remaining_elements)    np.random.shuffle(dimensions)    return tuple(dimensions)def generate_reshape_params(tensor):return generate_random_reshape(tensor.shape)def random_transpose_params(tensor):return random.sample(range(tensor.dim()), 2)def generate_repeat_params(input_shape):while True:repeats = [random.randint(1, 4) for _ in input_shape]if reduce(mul, [dim * repeat for dim, repeat in zip(input_shape, repeats)], 1) <= max_tensor_elements:return tuple(repeats)def generate_expand_params(input_shape):expanded_shape = []while True:expanded_shape = [random.randint(min(2,dim), dim*2) if dim == 1 else dim for dim in input_shape]if reduce(mul, expanded_shape, 1) <= max_tensor_elements:breakreturn expanded_shapedef generate_random_operator_parameters(input_shape, operator, input_tensor):if operator == 'unsqueeze':return (random.randint(0, len(input_shape) - 1),)if operator == 'repeat':return generate_repeat_params(input_shape)if operator == 'permute':return random.sample(range(len(input_shape)), len(input_shape))if operator == 'transpose':return random_transpose_params(input_tensor)if operator in ['reshape',"view"]:return generate_reshape_params(input_tensor)if operator == 'expand':return generate_expand_params(input_shape)if operator == 'matmul':if input_tensor.dim() == 1:return ()return (input_tensor.size(-1), random.randint(1, max_size))if operator in ['contiguous','mul']:return ()if operator == 'concat':return (random.randint(0, len(input_shape) - 1),)def execute_operator(input_tensor, operator, operator_parameters):if operator == 'unsqueeze':return input_tensor.unsqueeze(*operator_parameters)if operator == 'repeat':return input_tensor.repeat(operator_parameters)if operator == 'permute':return input_tensor.permute(operator_parameters)if operator == 'transpose':return input_tensor.transpose(*operator_parameters)if operator == 'reshape':return input_tensor.reshape(operator_parameters)if operator == 'view':return input_tensor.view(operator_parameters)    if operator == 'expand':return input_tensor.expand(operator_parameters)if operator == 'contiguous':return input_tensor.contiguous()if operator == 'matmul':if input_tensor.dim() ==1:return input_tensorother = torch.randn(*operator_parameters).to(input_tensor.device).type_as(input_tensor)return torch.matmul(input_tensor, other)if operator == 'mul':return input_tensor * input_tensorif operator == 'concat':return torch.cat((input_tensor, input_tensor), dim=operator_parameters[0])def main():for seed in range(2):random.seed(seed)np.random.seed(seed)torch.random.manual_seed(seed)for i in range(10):input_shape = generate_random_shape(random.randint(2, 5))input_tensor = generate_random_input(input_shape)operator = generate_random_operator(input_shape)operator_parameters = generate_random_operator_parameters(input_shape, operator, input_tensor)output_tensor = execute_operator(input_tensor, operator, operator_parameters)print(f"seed:{seed:03d} seq:{i:02d} {operator:<10} input:{str(input_shape):<32} param:{str(operator_parameters):<32} output:{str(output_tensor.shape):<32}")print(output_tensor.cpu().numpy().reshape(-1)[:8])torch.cuda.empty_cache()
if __name__ == '__main__':main()

输出

seed:000 seq:00 repeat     input:(7, 42, 26, 36, 56)              param:(1, 1, 1, 1, 1)                  output:torch.Size([7, 42, 26, 36, 56])
seed:000 seq:01 view       input:(248, 227, 276)                  param:(92, 908, 186)                   output:torch.Size([92, 908, 186])
seed:000 seq:02 view       input:(18, 21, 51, 32, 17)             param:(17, 4536, 136)                  output:torch.Size([17, 4536, 136])
seed:000 seq:03 reshape    input:(2548, 3565)                     param:(644, 65, 217)                   output:torch.Size([644, 65, 217])
seed:000 seq:04 reshape    input:(46, 42, 14, 57, 7)              param:(28, 266, 3, 483)                output:torch.Size([28, 266, 3, 483])
seed:000 seq:05 contiguous input:(222, 100, 597)                  param:()                               output:torch.Size([222, 100, 597])
seed:000 seq:06 view       input:(15, 27, 56, 8, 59)              param:(3, 3, 20160, 1, 59)             output:torch.Size([3, 3, 20160, 1, 59])
seed:000 seq:07 view       input:(1461, 1161)                     param:(188469, 9)                      output:torch.Size([188469, 9])
seed:000 seq:08 reshape    input:(19, 29, 19, 17, 54)             param:(31407, 1, 3, 17, 6, 1)          output:torch.Size([31407, 1, 3, 17, 6, 1])
seed:000 seq:09 transpose  input:(12, 126, 46, 157)               param:[2, 3]                           output:torch.Size([12, 126, 157, 46])
[-0.581   0.568   1.187   2.46   -0.1392 -0.3362  0.2076 -0.662 ]
seed:001 seq:00 view       input:(119, 354, 236)                  param:(4, 1, 17, 146202)               output:torch.Size([4, 1, 17, 146202])
seed:001 seq:01 reshape    input:(60, 961, 178)                   param:(3, 3421160)                     output:torch.Size([3, 3421160])
seed:001 seq:02 expand     input:(16, 10, 34, 37, 58)             param:[16, 10, 34, 37, 58]             output:torch.Size([16, 10, 34, 37, 58])
seed:001 seq:03 concat     input:(12, 44, 12, 26, 55)             param:(1,)                             output:torch.Size([12, 88, 12, 26, 55])
seed:001 seq:04 expand     input:(48, 9, 28, 20, 68)              param:[48, 9, 28, 20, 68]              output:torch.Size([48, 9, 28, 20, 68])
seed:001 seq:05 repeat     input:(16, 16, 162, 233)               param:(1, 1, 1, 1)                     output:torch.Size([16, 16, 162, 233])
seed:001 seq:06 expand     input:(25, 426, 19, 63)                param:[25, 426, 19, 63]                output:torch.Size([25, 426, 19, 63])
seed:001 seq:07 permute    input:(153, 153, 380)                  param:[2, 1, 0]                        output:torch.Size([380, 153, 153])
seed:001 seq:08 permute    input:(3091, 1445)                     param:[1, 0]                           output:torch.Size([1445, 3091])
seed:001 seq:09 mul        input:(142, 254, 388)                  param:()                               output:torch.Size([142, 254, 388])
[3.31   0.3372 0.2354 0.1373 0.594  2.326  0.7344 2.16  ]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mzph.cn/web/18771.shtml

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PHP 汉字转拼音

使用 overtrue/pinyin 库将汉字转换为拼音 在这篇文章中&#xff0c;我将向大家介绍如何使用 overtrue/pinyin 库来将汉字转换为拼音。这是一个非常方便的PHP库&#xff0c;能够帮助我们轻松地进行汉字到拼音的转换。 安装 overtrue/pinyin 库 首先&#xff0c;我们需要通过 C…

redis--集群节点维护

添加节点 因公司业务发展迅猛&#xff0c;现有的三主三从redis cluster架构可能无法满足现有业务的并发写入需求&#xff0c;因此公司紧急采购一台服务器192.168.7.107&#xff0c;需要将其动态添加到集群当中其不能影响业务使用和数据丢失&#xff0c;则添加过程如下: 同步之…

Pandas-中axis的用法

在Pandas中&#xff0c;min(axis)方法是计算DataFrame或Series中每行或每列的最小值的函数。该函数可以接受一个参数axis&#xff0c;用于指定计算最小值的方向。当axis0时&#xff0c;表示沿着行的方向计算最小值&#xff1b;当axis1时&#xff0c;表示沿着列的方向计算最小值…

买入看跌期权怎么理解?

今天带你了解买入看跌期权怎么理解&#xff1f;看跌期权买入者往往预期市场价格将下跌。 买入看跌期权怎么理解&#xff1f; 买入看跌期权是指购买者支付权利金&#xff0c;获得以特定价格向期权出售者卖出一定数量的某种特定商品的权利。看跌期权买入者往往预期市场价格将下跌…

【YOLOv5/v7改进系列】引入AKConv——即插即用的卷积块

一、导言 介绍了一种名为AKConv&#xff08;Alterable Kernel Convolution&#xff09;的新型卷积操作&#xff0c;旨在解决标准卷积操作存在的两个根本性问题。首先&#xff0c;标准卷积操作受限于局部窗口&#xff0c;无法捕获来自其他位置的信息&#xff0c;且其采样形状固…

软件系统测试的类型和方法介绍

测试是软件开发过程中至关重要的一环&#xff0c;负责验证和确认软件系统是否符合预期的需求&#xff0c;并帮助开发团队消除潜在的缺陷。系统测试作为软件测试中不可缺少的过程&#xff0c;是根据预先制定的测试计划和测试用例&#xff0c;以检查软件系统功能、性能、安全性和…

JavaScript tab选项卡切换

下面是一个简单的JavaScript代码示例&#xff0c;展示如何使用tab选项卡来切换内容。 HTML代码&#xff1a; <div class"tab"><button class"tablinks" onclick"openTab(event, tab1)">选项卡1</button><button class&qu…

仿真51单片机程序(下载安装+Proteus)

我是看的这个大佬的:http://t.csdnimg.cn/Z07SZ 大佬写的很详细了,我就不献丑了. 贴上俩个运行成功的截图,有碰到问题的欢迎交流.

初识BootLoader

一、 BootLoader的概念 引导加载程序是系统加电后运行的第一段软件代码。回忆一下PC的体系结构我们可以知道&#xff0c;PC机中的引导加载程序由BIOS&#xff08;本质是一段固件程序&#xff09;和位于硬盘MBR中的BootLoader&#xff08;如LILO、GRUB等&#xff09;组成。BIOS…

【学习Day2】计算机基础

✍&#x1f3fb;记录学习过程中的输出&#xff0c;坚持每天学习一点点~ ❤️希望能给大家提供帮助~欢迎点赞&#x1f44d;&#x1f3fb;收藏⭐评论✍&#x1f3fb;指点&#x1f64f; 1.4 校验码 奇偶校验 ● 奇偶校验码的编码方法是&#xff1a; 由若干位有效信息的头部或者…

探寻数据处理的高效之道:从Python内置方法到NumPy的飞跃

新书上架~&#x1f447;全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我&#x1f446;&#xff0c;收藏下次不迷路┗|&#xff40;O′|┛ 嗷~~ 目录 一、引言&#xff1a;为什么要学习NumPy&#xff1f; 二、案例展示&#xff1a;创建整数序列…

idm软件是做什么的 IDM是啥软件 idm软件怎么下载 idm软件怎么下载

一、IDM是啥软件 IDM 是由美国 Tonec 公司开发的 Windows 软件&#xff0c;该软件最初于 2005 年发布。IDM全称Internet Download Manager&#xff0c;是一款Windows平台老牌而功能强大的下载加速器&#xff0c;专注于互联网数据下载。这款软件是一款不错的轻量级下载工具&…

关于如何在 smartforms 中 debug

发现一旦smartforms 复杂起来&#xff0c;Debug的时候就一下子找不到指定位置&#xff0c;所以如何才能最简单的找到指定位置呢 以这个为案例 然后打上断点即可debug

Java实现AES,DES,RSA加密

Java的Cipher类 Cipher类提供了加密和解密的功能。 Cipher类可以完成aes&#xff0c;des&#xff0c;des3和rsa等加密方式 AES加密算法 介绍 这个标准用来替代原先的DES,AES加密过程涉及到4种操作&#xff0c;分别是字节替代、行移位、列混淆和轮密钥加。解密过程分别为对…

相对论表明速度越快时间越慢,为什么速度会影响时间?

在物理学的宏伟殿堂中&#xff0c;相对论以其深邃的洞察力&#xff0c;挑战了我们对时间和空间的传统认识。1905年&#xff0c;阿尔伯特爱因斯坦提出了狭义相对论&#xff0c;揭示了在所有惯性参照系中&#xff0c;光速是常数的惊人事实。 随后在1915年&#xff0c;他进一步发展…

YOLOv5数据集的文件结构和文件格式以及标注工具LabelImg的说明文档

文章目录 一 概述二 文件结构与数据格式2.1 数据集的文件结构2.2 数据格式2.2 文件结构2.3 标注文件的注意事项 三 手动标注YOLOv5数据集3.1 标注工具的选择3.2 标注流程 四 总结与注意事项4.1 labelImg的使用技巧与说明4.2 注意事项 一 概述 YOLOv5 是一个采用深度学习技…

Linux 编译屏障之 ACCESS_ONCE()

文章目录 1. 前言2. 背景3. 为什么要有 ACCESS_ONCE() &#xff1f;4. ACCESS_ONCE() 代码实现5. ACCESS_ONCE() 实例分析6. ACCESS() 的演进7. 参考资料 1. 前言 限于作者能力水平&#xff0c;本文可能存在谬误&#xff0c;因此而给读者带来的损失&#xff0c;作者不做任何承…

基于匹配追踪和最大重叠离散小波变换的ECG心电信号R波检测(MATLAB 2018a)

准确识别心电信号的R波是进行HRV分析的前提。因此&#xff0c;开发出准确的心电信号R波检测方法十分重要。近几十年来&#xff0c;提出的R峰检测方法主要分为两个阶段。第1阶段是预处理阶段&#xff0c;目的是对受不同噪声影响的原始心电信号进行降噪处理&#xff0c;从而实现增…

基于SpringBoot+Html+Mysql的餐厅点餐管理系统外卖点餐系统

博主介绍&#xff1a; 大家好&#xff0c;本人精通Java、Python、C#、C、C编程语言&#xff0c;同时也熟练掌握微信小程序、Php和Android等技术&#xff0c;能够为大家提供全方位的技术支持和交流。 我有丰富的成品Java、Python、C#毕设项目经验&#xff0c;能够为学生提供各类…

算法与数据结构高手养成:朴素的贪心法(上)最优化策略

✨✨ 欢迎大家来访Srlua的博文&#xff08;づ&#xffe3;3&#xffe3;&#xff09;づ╭❤&#xff5e;✨✨ &#x1f31f;&#x1f31f; 欢迎各位亲爱的读者&#xff0c;感谢你们抽出宝贵的时间来阅读我的文章。 我是Srlua小谢&#xff0c;在这里我会分享我的知识和经验。&am…