Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Datasets/CIFAR10/CIFAR10.swift
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func loadCIFARFile(named name: String, in directory: String = ".") -> CIFARExamp
let images = Tensor<UInt8>(shape: [imageCount, 3, 32, 32], scalars: bytes)

// Transpose from the CIFAR-provided N(CHW) to TF's default NHWC.
let imageTensor = Tensor<Float>(images.transposed(withPermutations: [0, 2, 3, 1]))
let imageTensor = Tensor<Float>(images.transposed(permutation: [0, 2, 3, 1]))

let mean = Tensor<Float>([0.485, 0.456, 0.406])
let std = Tensor<Float>([0.229, 0.224, 0.225])
Expand All @@ -125,8 +125,8 @@ func loadCIFARFile(named name: String, in directory: String = ".") -> CIFARExamp
func loadCIFARTrainingFiles() -> CIFARExample {
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0).bin") }
return CIFARExample(
label: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.label }),
data: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.data })
label: _Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.label }),
data: _Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.data })
)
}

Expand Down
2 changes: 1 addition & 1 deletion Datasets/MNIST/MNIST.swift
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ fileprivate func fetchDataset(
return (
images:
Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
.transposed(withPermutations: [0, 2, 3, 1]) / 255, // NHWC
.transposed(permutation: [0, 2, 3, 1]) / 255, // NHWC
labels: Tensor(labels)
)
}
Expand Down
2 changes: 1 addition & 1 deletion GAN/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ func saveImageGrid(_ testImage: Tensor<Float>, name: String) throws {
// Add padding.
gridImage = gridImage.padded(forSizes: [(0, 0), (0, 0), (1, 1), (1, 1)], with: 1)
// Transpose to create single image.
gridImage = gridImage.transposed(withPermutations: [0, 2, 1, 3])
gridImage = gridImage.transposed(permutation: [0, 2, 1, 3])
gridImage = gridImage.reshaped(
to: [
(imageHeight + 2) * testImageGridSize,
Expand Down
2 changes: 1 addition & 1 deletion MiniGo/Models/PythonCheckpointReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class PythonCheckpointReader {
let countSuffix = layerCounts[layerName] == nil ? "" : "_\(layerCounts[layerName]!)"
let tensorName = layerName + countSuffix + "/" + weightName
// TODO(jekbradbury): support variadic dtype attrs in RawOpsGenerated
return Raw.restoreV2(prefix: StringTensor(path),
return _Raw.restoreV2(prefix: StringTensor(path),
tensorNames: StringTensor([tensorName]),
shapeAndSlices: StringTensor([""]))
}
Expand Down
14 changes: 7 additions & 7 deletions Support/Image.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ public struct Image {
}

public init(jpeg url: URL, byteOrdering: ByteOrdering = .rgb) {
let loadedFile = Raw.readFile(filename: StringTensor(url.absoluteString))
let loadedJpeg = Raw.decodeJpeg(contents: loadedFile, channels: 3, dctMethod: "")
let loadedFile = _Raw.readFile(filename: StringTensor(url.absoluteString))
let loadedJpeg = _Raw.decodeJpeg(contents: loadedFile, channels: 3, dctMethod: "")
if byteOrdering == .bgr {
self.imageData = .uint8(
data: Raw.reverse(loadedJpeg, dims: Tensor<Bool>([false, false, false, true])))
data: _Raw.reverse(loadedJpeg, dims: Tensor<Bool>([false, false, false, true])))
} else {
self.imageData = .uint8(data: loadedJpeg)
}
Expand All @@ -59,21 +59,21 @@ public struct Image {
outputImageData = Tensor<UInt8>(adjustedData)
}

let encodedJpeg = Raw.encodeJpeg(
let encodedJpeg = _Raw.encodeJpeg(
image: outputImageData, format: .grayscale, quality: quality, xmpMetadata: "")
Raw.writeFile(filename: StringTensor(url.absoluteString), contents: encodedJpeg)
_Raw.writeFile(filename: StringTensor(url.absoluteString), contents: encodedJpeg)
}

public func resized(to size: (Int, Int)) -> Image {
switch self.imageData {
case let .uint8(data):
return Image(
tensor: Raw.resizeBilinear(
tensor: _Raw.resizeBilinear(
images: Tensor<UInt8>([data]),
size: Tensor<Int32>([Int32(size.0), Int32(size.1)])))
case let .float(data):
return Image(
tensor: Raw.resizeBilinear(
tensor: _Raw.resizeBilinear(
images: Tensor<Float>([data]),
size: Tensor<Int32>([Int32(size.0), Int32(size.1)])))
}
Expand Down
8 changes: 4 additions & 4 deletions Transformer/Model.swift
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func causallyMasked(_ dotProducts: Tensor<Float>, enable: Bool = false) -> Tenso
}
let (queryTimeSteps, keyTimeSteps) = (dotProducts.shape[1], dotProducts.shape[2])
let ones = Tensor<Float>(ones: [1, queryTimeSteps, keyTimeSteps])
let mask = Raw.matrixBandPart(
let mask = _Raw.matrixBandPart(
ones,
numLower: Tensor(Int32(-1)),
numUpper: Tensor(Int32(queryTimeSteps - keyTimeSteps)))
Expand Down Expand Up @@ -138,7 +138,7 @@ func splitHeads(_ input: Tensor<Float>, headCount: Int) -> Tensor<Float> {
let (batchSize, timeSteps, features) = (input.shape[0], input.shape[1], input.shape[2])
let featuresPerHead = features / headCount
let splitLastDim = input.reshaped(to: [batchSize, timeSteps, headCount, featuresPerHead])
let movedToFront = splitLastDim.transposed(withPermutations: 0, 2, 1, 3)
let movedToFront = splitLastDim.transposed(permutation: 0, 2, 1, 3)
return movedToFront.reshaped(to: [batchSize * headCount, timeSteps, featuresPerHead])
}

Expand All @@ -149,7 +149,7 @@ func joinHeads(_ input: Tensor<Float>, headCount: Int) -> Tensor<Float> {
let batchSize = generalizedBatch / headCount
let features = featuresPerHead * headCount
let splitFirstDim = input.reshaped(to: [batchSize, headCount, timeSteps, featuresPerHead])
let movedToBack = splitFirstDim.transposed(withPermutations: 0, 2, 1, 3)
let movedToBack = splitFirstDim.transposed(permutation: 0, 2, 1, 3)
return movedToBack.reshaped(to: [batchSize, timeSteps, features])
}

Expand All @@ -173,7 +173,7 @@ func _vjpSplitQKV(_ input: Tensor<Float>)
-> (AttentionInput, (AttentionInput.TangentVector) -> Tensor<Float>) {
let value = splitQKV(input)
return (value, { seed in
return Raw.concatV2([seed.query, seed.key, seed.value], axis: Tensor<Int32>(2))
return _Raw.concatV2([seed.query, seed.key, seed.value], axis: Tensor<Int32>(2))
})
}

Expand Down
2 changes: 1 addition & 1 deletion Transformer/Operators.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func batchedMatmul<Scalar : Numeric>(
adjointLeft: Bool = false,
adjointRight: Bool = false
) -> Tensor<Scalar> {
return Raw.batchMatMul(left, right, adjX: adjointLeft, adjY: adjointRight)
return _Raw.batchMatMul(left, right, adjX: adjointLeft, adjY: adjointRight)
}

@usableFromInline
Expand Down
2 changes: 1 addition & 1 deletion Transformer/PythonCheckpointReader.swift
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func readTensor<Scalar: TensorFlowScalar>(
scalarType: Scalar.Type
) -> Tensor<Scalar> {
// TODO(jekbradbury): support variadic dtype attrs in RawOpsGenerated
return Raw.restoreV2(prefix: StringTensor(path),
return _Raw.restoreV2(prefix: StringTensor(path),
tensorNames: StringTensor([name]),
shapeAndSlices: StringTensor([""]))
}
Expand Down
2 changes: 1 addition & 1 deletion Transformer/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ for _ in 0..<100 {
let lastLogit = logits.slice(
lowerBounds: [0, timeSteps - 1, 0],
upperBounds: [batchSize, timeSteps, vocabSize]) / temperature
tokens = Raw.multinomial(logits: lastLogit.squeezingShape(at: 1), numSamples: Tensor<Int32>(1))
tokens = _Raw.multinomial(logits: lastLogit.squeezingShape(at: 1), numSamples: Tensor<Int32>(1))
print(encoder.decode(tokens[0].makeNumpyArray()), terminator: "")
}
print()