Nothing Special   »   [go: up one dir, main page]

Skip to content

Commit

Permalink
Make nnapi cat converter accept flex inputs
Browse files Browse the repository at this point in the history
Summary: As title

Test Plan: pytest test/test_nnapi.py::TestNNAPI::test_cat

Reviewed By: anshuljain1

Differential Revision: D29480747

fbshipit-source-id: 161803054ff1a4c2c750fc30a5f0fc6d8a24b2c9
  • Loading branch information
Akshit Khurana authored and facebook-github-bot committed Jul 9, 2021
1 parent 9e81d3d commit 76c0f22
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
11 changes: 11 additions & 0 deletions test/test_nnapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,17 @@ def forward(self, t1, t2):
nhwc(torch.randn(1, 4, 3, 3)),
])

self.check(
CatModule(1),
[
torch.randn(1, 2, 3, 3),
torch.randn(1, 4, 3, 3),
],
convert_args=[
torch.zeros(0, 0, 0, 0),
torch.zeros(0, 0, 0, 0)
])

def test_pointwise_unary(self):
for op in ["relu", "sigmoid"]:
with self.subTest(op):
Expand Down
13 changes: 11 additions & 2 deletions torch/backends/_nnapi/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ def add_cat(self, node):
out_oper = None
out_dim_size = 0
for inp in tensors:
in_id, in_oper = self.get_tensor_operand_by_jitval_fixed_size(inp)
in_id, in_oper = self.get_tensor_operand_by_jitval(inp)
if out_oper is None:
out_shape = change_element(in_oper.shape, dim, -1)
out_oper = in_oper._replace(shape=out_shape)
Expand All @@ -1085,10 +1085,19 @@ def add_cat(self, node):
else:
nnapi_dim = dim

out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
for idx, d in enumerate(out_oper.shape):
if d == 0:
if idx == dim:
shape = " + ".join(flex_name(ip_id, dim) for ip_id in in_ids)
self.compute_operand_shape(out_id, idx, shape)
else:
self.forward_operand_shape(out_id, idx, in_ids[0], idx)

inputs = in_ids + [self.add_immediate_int_scalar(nnapi_dim)]

outputs = [None] * 1
outputs[0] = self.add_tensor_operand(node.outputsAt(0), out_oper)
outputs[0] = out_id

self.add_operation(NNAPI_OperationCode.CONCATENATION, inputs, outputs)

Expand Down

0 comments on commit 76c0f22

Please sign in to comment.