Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions cmd/sst/add_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package main

import (
"testing"

"github.com/sst/sst/v3/pkg/project"
)

func TestGetAliasName(t *testing.T) {
tests := []struct {
name string
entry *project.ProviderLockEntry
want string
}{
{
name: "simple provider",
entry: &project.ProviderLockEntry{
Name: "stripe",
Package: "pulumi-stripe",
Alias: "stripe",
},
want: "stripe",
},
{
name: "strip official suffix",
entry: &project.ProviderLockEntry{
Name: "stripe-official",
Package: "@sst-provider/stripe-official",
Alias: "stripe",
},
want: "stripe",
},
{
name: "strip community suffix from alias",
entry: &project.ProviderLockEntry{
Name: "@scope/pulumi-foo-community",
Package: "@scope/pulumi-foo-community",
Alias: "foocommunity",
},
want: "foo",
},
{
name: "package input still uses alias",
entry: &project.ProviderLockEntry{
Name: "@paynearme/pulumi-jetstream",
Package: "@paynearme/pulumi-jetstream",
Alias: "jetstream",
},
want: "jetstream",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getAliasName(tt.entry); got != tt.want {
t.Fatalf("got %q, want %q", got, tt.want)
}
})
}
}
32 changes: 22 additions & 10 deletions cmd/sst/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@ import (

var version = "dev"

func getAliasName(entry *project.ProviderLockEntry) string {
name := entry.Name
if entry.Name == entry.Package {
name = entry.Alias
}
for _, suffix := range []string{"official", "community"} {
if !strings.HasSuffix(entry.Name, "-"+suffix) && !strings.HasSuffix(entry.Package, "-"+suffix) {
continue
}
name = strings.TrimSuffix(name, "-"+suffix)
name = strings.TrimSuffix(name, suffix)
}
return name
}

func main() {
// check if node_modules/.bin/sst exists
nodeModulesBinPath := filepath.Join("node_modules", ".bin", "sst")
Expand Down Expand Up @@ -520,7 +535,10 @@ var root = &cli.Command{
"```ts title=\"sst.config.ts\"",
"{",
" providers: {",
" aws: \"6.27.0\"",
" aws: {",
" package: \"@pulumi/aws\",",
" version: \"6.27.0\"",
" }",
" }",
"}",
"```",
Expand All @@ -537,6 +555,7 @@ var root = &cli.Command{
"{",
" providers: {",
" aws: {",
" package: \"@pulumi/aws\",",
" version: \"6.26.0\"",
" }",
" }",
Expand Down Expand Up @@ -597,15 +616,8 @@ var root = &cli.Command{
if err != nil {
return util.NewReadableError(err, "Could not find provider "+pkg)
}
// When the user passed a full package name (e.g. @paynearme/pulumi-jetstream),
// use the alias as the config key and set the package override
providerName := entry.Name
pkgOverride := ""
if entry.Name == entry.Package {
providerName = entry.Alias
pkgOverride = entry.Package
}
err = p.Add(providerName, entry.Version, pkgOverride)
providerName := getAliasName(entry)
err = p.Add(providerName, entry.Version, entry.Package)
if err != nil {
return util.NewReadableError(err, err.Error())
}
Expand Down
118 changes: 94 additions & 24 deletions platform/src/ast/add.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import ts from "typescript";
import prettier from "prettier";

const config = process.argv[2];
const pkg = process.argv[3];
const provider = process.argv[3];
const version = process.argv[4];
const pkgName = process.argv[5] || "";

Expand Down Expand Up @@ -72,40 +72,110 @@ if (!providersProperty) {

if (!ts.isObjectLiteralExpression(providersProperty.initializer)) {
console.error(
'The "providers" property must be a plain object, not a dynamic expression like a ternary or variable.'
'The "providers" property must be a plain object, not a dynamic expression like a ternary or variable.',
);
process.exit(1);
}

if (
providersProperty.initializer.properties.find(
(property) => property.name.getText().replaceAll('"', "") === pkg,
)
) {
process.exit(0);
function getPropertyName(property) {
return property.name.getText().replace(/^['"]|['"]$/g, "");
}
// Create a new property node
let newValue;
if (pkgName) {
newValue = ts.factory.createObjectLiteralExpression([
ts.factory.createPropertyAssignment(
"package",
ts.factory.createStringLiteral(pkgName),
),

function createStringProperty(name, value) {
return ts.factory.createPropertyAssignment(
name,
ts.factory.createStringLiteral(value),
);
}

function createProviderValue(versionValue) {
if (!pkgName) {
return ts.factory.createStringLiteral(versionValue);
}

return ts.factory.createObjectLiteralExpression(
[
createStringProperty("package", pkgName),
createStringProperty("version", versionValue),
],
false,
);
}

function upsertObjectProperty(properties, name, initializer, overwrite) {
const index = properties.findIndex(
(property) =>
ts.isPropertyAssignment(property) && getPropertyName(property) === name,
);

if (index === -1) {
properties.push(ts.factory.createPropertyAssignment(name, initializer));
return;
}

if (!overwrite) {
return;
}

properties.splice(
index,
1,
ts.factory.createPropertyAssignment(
"version",
ts.factory.createStringLiteral(version),
properties[index].name,
initializer,
),
], false);
} else {
newValue = ts.factory.createStringLiteral(version);
);
}

function updateProviderValue(initializer) {
if (!pkgName) {
return ts.factory.createStringLiteral(version);
}

if (ts.isStringLiteralLike(initializer)) {
return createProviderValue(initializer.text);
}

if (!ts.isObjectLiteralExpression(initializer)) {
return createProviderValue(version);
}

const properties = [...initializer.properties];
upsertObjectProperty(
properties,
"package",
ts.factory.createStringLiteral(pkgName),
true,
);
upsertObjectProperty(
properties,
"version",
ts.factory.createStringLiteral(version),
false,
);
return ts.factory.createObjectLiteralExpression(properties, false);
}

const existingIndex = providersProperty.initializer.properties.findIndex(
(property) =>
ts.isPropertyAssignment(property) &&
getPropertyName(property) === provider,
);

const newProperty = ts.factory.createPropertyAssignment(
ts.factory.createStringLiteral(pkg),
newValue,
ts.factory.createStringLiteral(provider),
existingIndex === -1
? createProviderValue(version)
: updateProviderValue(
providersProperty.initializer.properties[existingIndex].initializer,
),
);

providersProperty.initializer.properties.push(newProperty);
if (existingIndex === -1) {
providersProperty.initializer.properties.push(newProperty);
} else {
providersProperty.initializer.properties.splice(existingIndex, 1, newProperty);
}

const printer = ts.createPrinter();
const modifiedCode = printer.printNode(
Expand Down
55 changes: 44 additions & 11 deletions platform/test/ast/add.test.ts
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import { describe, it, expect, beforeEach } from "vitest";
import { describe, it, expect } from "vitest";
import { execFileSync } from "child_process";
import fs from "fs";
import path from "path";
import os from "os";

const ADD_SCRIPT = path.resolve(__dirname, "../../src/ast/add.mjs");
const PROVIDER = "grafana";
const PKG = "@pulumiverse/grafana";
const VERSION = "0.0.1";

function run(config: string) {
const tmp = path.join(os.tmpdir(), `sst-add-test-${Date.now()}.ts`);
fs.writeFileSync(tmp, config);
execFileSync("node", [ADD_SCRIPT, tmp, PKG, VERSION]);
execFileSync("node", [ADD_SCRIPT, tmp, PROVIDER, VERSION, PKG]);
const result = fs.readFileSync(tmp, "utf-8");
fs.unlinkSync(tmp);
return result;
Expand All @@ -27,7 +28,9 @@ describe("add provider", () => {
};
},
});`);
expect(result).toContain(`"${PKG}": "${VERSION}"`);
expect(result).toContain(`${PROVIDER}: {`);
expect(result).toContain(`package: "${PKG}"`);
expect(result).toContain(`version: "${VERSION}"`);
});

it("arrow function with block body", () => {
Expand All @@ -39,7 +42,9 @@ describe("add provider", () => {
};
},
});`);
expect(result).toContain(`"${PKG}": "${VERSION}"`);
expect(result).toContain(`${PROVIDER}: {`);
expect(result).toContain(`package: "${PKG}"`);
expect(result).toContain(`version: "${VERSION}"`);
});

it("arrow function with concise body", () => {
Expand All @@ -49,7 +54,9 @@ describe("add provider", () => {
providers: {},
}),
});`);
expect(result).toContain(`"${PKG}": "${VERSION}"`);
expect(result).toContain(`${PROVIDER}: {`);
expect(result).toContain(`package: "${PKG}"`);
expect(result).toContain(`version: "${VERSION}"`);
});

it("function expression", () => {
Expand All @@ -61,7 +68,9 @@ describe("add provider", () => {
};
},
});`);
expect(result).toContain(`"${PKG}": "${VERSION}"`);
expect(result).toContain(`${PROVIDER}: {`);
expect(result).toContain(`package: "${PKG}"`);
expect(result).toContain(`version: "${VERSION}"`);
});

it("adds providers key when missing", () => {
Expand All @@ -73,20 +82,44 @@ describe("add provider", () => {
},
});`);
expect(result).toContain("providers");
expect(result).toContain(`"${PKG}": "${VERSION}"`);
expect(result).toContain(`${PROVIDER}: {`);
expect(result).toContain(`package: "${PKG}"`);
expect(result).toContain(`version: "${VERSION}"`);
});

it("skips if provider already exists", () => {
it("adds package to existing string provider", () => {
const config = `export default $config({
app(input) {
return {
name: "my-app",
providers: { "${PKG}": "0.0.0" },
providers: { "${PROVIDER}": "0.0.0" },
};
},
});`;
const result = run(config);
expect(result).toContain(`"${PKG}": "0.0.0"`);
expect(result).not.toContain(VERSION);
expect(result).toContain(`${PROVIDER}: {`);
expect(result).toContain(`package: "${PKG}"`);
expect(result).toContain(`version: "0.0.0"`);
});

it("adds package to existing object provider", () => {
const config = `export default $config({
app(input) {
return {
name: "my-app",
providers: {
"${PROVIDER}": {
version: "0.0.0",
region: "us-east-1",
},
},
};
},
});`;
const result = run(config);
expect(result).toContain(`${PROVIDER}: {`);
expect(result).toContain(`package: "${PKG}"`);
expect(result).toContain(`version: "0.0.0"`);
expect(result).toContain(`region: "us-east-1"`);
});
});
Loading