diff options
| -rw-r--r-- | .dockerignore | 3 | ||||
| -rw-r--r-- | .github/workflows/build-docker.yaml | 46 | ||||
| -rw-r--r-- | .github/workflows/ci.yml | 21 | ||||
| -rw-r--r-- | .gitignore | 2 | ||||
| -rw-r--r-- | Dockerfile | 22 | ||||
| -rw-r--r-- | compose.yaml | 11 | ||||
| -rw-r--r-- | db.go | 87 | ||||
| -rw-r--r-- | db_test.go | 75 | ||||
| -rw-r--r-- | go.mod | 5 | ||||
| -rw-r--r-- | go.sum | 2 | ||||
| -rw-r--r-- | main.go | 115 | ||||
| -rw-r--r-- | server.go | 61 | ||||
| -rw-r--r-- | server_test.go | 106 | ||||
| -rw-r--r-- | shortid.go | 38 | ||||
| -rw-r--r-- | shortid_test.go | 86 |
15 files changed, 680 insertions, 0 deletions
diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..bf1a0ae --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +.git +*.db +/nilink diff --git a/.github/workflows/build-docker.yaml b/.github/workflows/build-docker.yaml new file mode 100644 index 0000000..e18bdf3 --- /dev/null +++ b/.github/workflows/build-docker.yaml @@ -0,0 +1,46 @@ +name: Build and Push Docker Image + +on: + push: + tags: + - 'v*' + +jobs: + build-and-push: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata + id: meta + uses: docker/metadata-action@v5 + with: + images: ghcr.io/${{ github.repository }} + tags: | + type=ref,event=branch + type=ref,event=pr + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + + - name: Build and push + uses: docker/build-push-action@v5 + with: + context: . + push: true + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..4b26086 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,21 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-go@v5 + with: + go-version: "1.24" + cache-dependency-path: go.sum + + - run: go vet ./... + - run: go test ./... diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6d575ce --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +*.db +/nilink diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..4bc51bc --- /dev/null +++ b/Dockerfile @@ -0,0 +1,22 @@ +FROM golang:1.24 AS builder + +WORKDIR /app + +RUN apt-get update && \ + apt-get install -y libsqlite3-dev + +COPY go.mod go.sum ./ +RUN go mod download + +COPY . . +RUN CGO_ENABLED=1 GOOS=linux go build -o nilink . + +########################################## +FROM gcr.io/distroless/cc-debian12 + +WORKDIR /app +COPY --from=builder /app/nilink /app + +EXPOSE 8080 +ENTRYPOINT ["/app/nilink"] +CMD ["serve", "-addr", ":8080", "-db", "/data/nilink.db"] diff --git a/compose.yaml b/compose.yaml new file mode 100644 index 0000000..77eb8c1 --- /dev/null +++ b/compose.yaml @@ -0,0 +1,11 @@ +services: + nilink: + build: + context: . + ports: + - '127.0.0.1:8080:8080' + volumes: + - ./data:/app/data + environment: + TZ: Asia/Tokyo + restart: always @@ -0,0 +1,87 @@ +package main + +import ( + "database/sql" + "fmt" + + _ "github.com/mattn/go-sqlite3" +) + +type link struct { + ID int64 + URL string + CreatedAt string +} + +func openDB(path string) (*sql.DB, error) { + db, err := sql.Open("sqlite3", path) + if err != nil { + return nil, err + } + if _, err := db.Exec("PRAGMA journal_mode=WAL"); err != nil { + db.Close() + return nil, err + } + if _, err := db.Exec("PRAGMA foreign_keys=ON"); err != nil { + db.Close() + return nil, err + } + if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS links ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + url TEXT NOT NULL, + created_at TEXT NOT NULL DEFAULT (datetime('now')) + )`); err != nil { + db.Close() + return nil, err + } + return db, nil +} + +func insertLink(db *sql.DB, url string) (int64, error) { + res, err := db.Exec("INSERT INTO links (url) VALUES (?)", url) + if err != nil { + return 0, err + } + return res.LastInsertId() +} + +func deleteLink(db *sql.DB, id int64) error { + res, err := db.Exec("DELETE FROM links WHERE id = ?", id) + if err != nil { + return err + } + n, err := res.RowsAffected() + if err != nil { + return err + } + if n == 0 { + return fmt.Errorf("link not found") + } + return nil +} + +func getURL(db *sql.DB, id int64) (string, error) { + var url string + err := db.QueryRow("SELECT url FROM links WHERE id = ?", id).Scan(&url) + if err != nil { + return "", fmt.Errorf("link not found") + } + return url, nil +} + +func listLinks(db *sql.DB) ([]link, error) { + rows, err := db.Query("SELECT id, url, created_at FROM links ORDER BY id") + if err != nil { + return nil, err + } + defer rows.Close() + var links []link + for rows.Next() { + var l link + if err := rows.Scan(&l.ID, &l.URL, &l.CreatedAt); err != nil { + return nil, err + } + links = append(links, l) + } + return links, rows.Err() +} diff --git a/db_test.go b/db_test.go new file mode 100644 index 0000000..5a9377d --- /dev/null +++ b/db_test.go @@ -0,0 +1,75 @@ +package main + +import ( + "database/sql" + "testing" +) + +func testDB(t *testing.T) *sql.DB { + t.Helper() + db, err := openDB(":memory:") + if err != nil { + t.Fatalf("openDB: %v", err) + } + t.Cleanup(func() { db.Close() }) + return db +} + +func TestInsertAndGet(t *testing.T) { + db := testDB(t) + id, err := insertLink(db, "https://example.com") + if err != nil { + t.Fatalf("insertLink: %v", err) + } + url, err := getURL(db, id) + if err != nil { + t.Fatalf("getURL: %v", err) + } + if url != "https://example.com" { + t.Errorf("getURL = %q, want %q", url, "https://example.com") + } +} + +func TestDeleteThenGet(t *testing.T) { + db := testDB(t) + id, err := insertLink(db, "https://example.com") + if err != nil { + t.Fatalf("insertLink: %v", err) + } + if err := deleteLink(db, id); err != nil { + t.Fatalf("deleteLink: %v", err) + } + if _, err := getURL(db, id); err == nil { + t.Error("getURL after delete should fail") + } +} + +func TestListLinks(t *testing.T) { + db := testDB(t) + insertLink(db, "https://a.com") + insertLink(db, "https://b.com") + links, err := listLinks(db) + if err != nil { + t.Fatalf("listLinks: %v", err) + } + if len(links) != 2 { + t.Fatalf("listLinks returned %d rows, want 2", len(links)) + } + if links[0].URL != "https://a.com" || links[1].URL != "https://b.com" { + t.Errorf("unexpected URLs: %v", links) + } +} + +func TestGetNotFound(t *testing.T) { + db := testDB(t) + if _, err := getURL(db, 999); err == nil { + t.Error("getURL(999) should fail") + } +} + +func TestDeleteNotFound(t *testing.T) { + db := testDB(t) + if err := deleteLink(db, 999); err == nil { + t.Error("deleteLink(999) should fail") + } +} @@ -0,0 +1,5 @@ +module nilink + +go 1.25.5 + +require github.com/mattn/go-sqlite3 v1.14.33 // indirect @@ -0,0 +1,2 @@ +github.com/mattn/go-sqlite3 v1.14.33 h1:A5blZ5ulQo2AtayQ9/limgHEkFreKj1Dv226a1K73s0= +github.com/mattn/go-sqlite3 v1.14.33/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= @@ -0,0 +1,115 @@ +package main + +import ( + "flag" + "fmt" + "os" +) + +func main() { + if len(os.Args) < 2 { + fmt.Fprintln(os.Stderr, "usage: nilink <serve|add|remove|list>") + os.Exit(1) + } + switch os.Args[1] { + case "serve": + cmdServe(os.Args[2:]) + case "add": + cmdAdd(os.Args[2:]) + case "remove": + cmdRemove(os.Args[2:]) + case "list": + cmdList(os.Args[2:]) + default: + fmt.Fprintf(os.Stderr, "unknown command: %s\n", os.Args[1]) + os.Exit(1) + } +} + +func cmdAdd(args []string) { + flags := flag.NewFlagSet("add", flag.ExitOnError) + dbPath := flags.String("db", "data/nilink.db", "database path") + flags.Parse(args) + + if flags.NArg() != 1 { + fmt.Fprintln(os.Stderr, "usage: nilink add [-db path] <url>") + os.Exit(1) + } + url := flags.Arg(0) + + db, err := openDB(*dbPath) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + defer db.Close() + + id, err := insertLink(db, url) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + short, err := encodeID(id) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + fmt.Printf("%s %s\n", short, url) +} + +func cmdRemove(args []string) { + flags := flag.NewFlagSet("remove", flag.ExitOnError) + dbPath := flags.String("db", "data/nilink.db", "database path") + flags.Parse(args) + + if flags.NArg() != 1 { + fmt.Fprintln(os.Stderr, "usage: nilink remove [-db path] <short-id>") + os.Exit(1) + } + shortID := flags.Arg(0) + + id, err := decodeID(shortID) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + + db, err := openDB(*dbPath) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + defer db.Close() + + if err := deleteLink(db, id); err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } +} + +func cmdList(args []string) { + flags := flag.NewFlagSet("list", flag.ExitOnError) + dbPath := flags.String("db", "data/nilink.db", "database path") + flags.Parse(args) + + db, err := openDB(*dbPath) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + defer db.Close() + + links, err := listLinks(db) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + for _, l := range links { + short, err := encodeID(l.ID) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + fmt.Printf("%s %s %s\n", short, l.URL, l.CreatedAt) + } +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..df98893 --- /dev/null +++ b/server.go @@ -0,0 +1,61 @@ +package main + +import ( + "database/sql" + "flag" + "fmt" + "net/http" + "os" + "strings" +) + +func newMux(db *sql.DB) http.Handler { + mux := http.NewServeMux() + mux.HandleFunc("/robots.txt", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain") + fmt.Fprint(w, "User-agent: *\nDisallow: /\n") + }) + mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/") + if path == "" { + http.Error(w, "not found", http.StatusNotFound) + return + } + if strings.Contains(path, "/") { + http.Error(w, "not found", http.StatusNotFound) + return + } + id, err := decodeID(path) + if err != nil { + http.Error(w, "not found", http.StatusNotFound) + return + } + url, err := getURL(db, id) + if err != nil { + http.Error(w, "not found", http.StatusNotFound) + return + } + http.Redirect(w, r, url, http.StatusMovedPermanently) + }) + return mux +} + +func cmdServe(args []string) { + fs := flag.NewFlagSet("serve", flag.ExitOnError) + addr := fs.String("addr", ":8080", "listen address") + dbPath := fs.String("db", "nilink.db", "database path") + fs.Parse(args) + + db, err := openDB(*dbPath) + if err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } + defer db.Close() + + fmt.Fprintf(os.Stderr, "listening on %s\n", *addr) + if err := http.ListenAndServe(*addr, newMux(db)); err != nil { + fmt.Fprintf(os.Stderr, "error: %v\n", err) + os.Exit(1) + } +} diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..22abd6a --- /dev/null +++ b/server_test.go @@ -0,0 +1,106 @@ +package main + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func testServer(t *testing.T) *httptest.Server { + t.Helper() + db, err := openDB(":memory:") + if err != nil { + t.Fatalf("openDB: %v", err) + } + t.Cleanup(func() { db.Close() }) + insertLink(db, "https://example.com") + return httptest.NewServer(newMux(db)) +} + +func TestRobotsTxt(t *testing.T) { + srv := testServer(t) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/robots.txt") + if err != nil { + t.Fatalf("GET /robots.txt: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + t.Errorf("status = %d, want 200", resp.StatusCode) + } + ct := resp.Header.Get("Content-Type") + if ct != "text/plain" { + t.Errorf("Content-Type = %q, want text/plain", ct) + } + body, _ := io.ReadAll(resp.Body) + want := "User-agent: *\nDisallow: /\n" + if string(body) != want { + t.Errorf("body = %q, want %q", body, want) + } +} + +func TestRedirect(t *testing.T) { + srv := testServer(t) + defer srv.Close() + + short, _ := encodeID(1) + client := &http.Client{CheckRedirect: func(*http.Request, []*http.Request) error { + return http.ErrUseLastResponse + }} + resp, err := client.Get(srv.URL + "/" + short) + if err != nil { + t.Fatalf("GET /%s: %v", short, err) + } + defer resp.Body.Close() + if resp.StatusCode != 301 { + t.Errorf("status = %d, want 301", resp.StatusCode) + } + loc := resp.Header.Get("Location") + if loc != "https://example.com" { + t.Errorf("Location = %q, want %q", loc, "https://example.com") + } +} + +func TestNotFoundInvalidID(t *testing.T) { + srv := testServer(t) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/ZZZZ") + if err != nil { + t.Fatalf("GET /ZZZZ: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != 404 { + t.Errorf("status = %d, want 404", resp.StatusCode) + } +} + +func TestNotFoundRoot(t *testing.T) { + srv := testServer(t) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/") + if err != nil { + t.Fatalf("GET /: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != 404 { + t.Errorf("status = %d, want 404", resp.StatusCode) + } +} + +func TestNotFoundNested(t *testing.T) { + srv := testServer(t) + defer srv.Close() + + resp, err := http.Get(srv.URL + "/nested/path") + if err != nil { + t.Fatalf("GET /nested/path: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != 404 { + t.Errorf("status = %d, want 404", resp.StatusCode) + } +} diff --git a/shortid.go b/shortid.go new file mode 100644 index 0000000..9cfd076 --- /dev/null +++ b/shortid.go @@ -0,0 +1,38 @@ +package main + +import ( + "encoding/base32" + "encoding/binary" + "fmt" + "strings" +) + +var b32 = base32.StdEncoding.WithPadding(base32.NoPadding) + +func encodeID(id int64) (string, error) { + if id < 0 { + return "", fmt.Errorf("id out of range: %d", id) + } + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(id)) + // Trim leading zero bytes, keep at least 2 bytes (4 chars). + i := 0 + for i < 6 && buf[i] == 0 { + i++ + } + return b32.EncodeToString(buf[i:]), nil +} + +func decodeID(s string) (int64, error) { + s = strings.ToUpper(s) + buf, err := b32.DecodeString(s) + if err != nil { + return 0, fmt.Errorf("invalid short id: %w", err) + } + if len(buf) == 0 || len(buf) > 8 { + return 0, fmt.Errorf("invalid short id") + } + padded := make([]byte, 8) + copy(padded[8-len(buf):], buf) + return int64(binary.BigEndian.Uint64(padded)), nil +} diff --git a/shortid_test.go b/shortid_test.go new file mode 100644 index 0000000..54bc763 --- /dev/null +++ b/shortid_test.go @@ -0,0 +1,86 @@ +package main + +import "testing" + +func TestEncodeKnownValues(t *testing.T) { + tests := []struct { + id int64 + want string + }{ + {0, "AAAA"}, + {1, "AAAQ"}, + {256, "AEAA"}, + {65535, "777Q"}, + } + for _, tt := range tests { + got, err := encodeID(tt.id) + if err != nil { + t.Fatalf("encodeID(%d): %v", tt.id, err) + } + if got != tt.want { + t.Errorf("encodeID(%d) = %q, want %q", tt.id, got, tt.want) + } + } +} + +func TestDecodeKnownValues(t *testing.T) { + tests := []struct { + s string + want int64 + }{ + {"AAAA", 0}, + {"AAAQ", 1}, + {"AEAA", 256}, + {"777Q", 65535}, + } + for _, tt := range tests { + got, err := decodeID(tt.s) + if err != nil { + t.Fatalf("decodeID(%q): %v", tt.s, err) + } + if got != tt.want { + t.Errorf("decodeID(%q) = %d, want %d", tt.s, got, tt.want) + } + } +} + +func TestDecodeCaseInsensitive(t *testing.T) { + got, err := decodeID("aaaq") + if err != nil { + t.Fatalf("decodeID(aaaq): %v", err) + } + if got != 1 { + t.Errorf("decodeID(aaaq) = %d, want 1", got) + } +} + +func TestRoundTrip(t *testing.T) { + for _, id := range []int64{0, 1, 100, 256, 1000, 65535, 100000, 1000000} { + s, err := encodeID(id) + if err != nil { + t.Fatalf("encodeID(%d): %v", id, err) + } + got, err := decodeID(s) + if err != nil { + t.Fatalf("decodeID(%q): %v", s, err) + } + if got != id { + t.Errorf("roundtrip(%d): got %d", id, got) + } + } +} + +func TestEncodeOutOfRange(t *testing.T) { + if _, err := encodeID(-1); err == nil { + t.Error("encodeID(-1) should fail") + } +} + +func TestDecodeInvalid(t *testing.T) { + invalids := []string{"", "A", "!!!!", "AAAAAA"} + for _, s := range invalids { + if _, err := decodeID(s); err == nil { + t.Errorf("decodeID(%q) should fail", s) + } + } +} |
