from contextlib import contextmanager
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from db import get_conn, put_conn
app = FastAPI(title="Lakebase CRUD Demo")
# ── Pydantic 모델 ──────────────────────────
class ProductCreate(BaseModel):
name: str
price: float
category: str | None = None
class ProductUpdate(BaseModel):
name: str | None = None
price: float | None = None
category: str | None = None
# ── 커넥션 컨텍스트 매니저 ──────────────────
@contextmanager
def db():
conn = get_conn()
try:
yield conn
conn.commit() # 정상 종료 시 커밋
except Exception:
conn.rollback() # 예외 발생 시 롤백
raise
finally:
put_conn(conn) # 커넥션을 풀에 반납
# ── CREATE ──────────────────────────────────
@app.post("/products", status_code=201)
def create_product(body: ProductCreate):
with db() as conn:
cur = conn.cursor()
cur.execute(
"""INSERT INTO products (name, price, category)
VALUES (%s, %s, %s) RETURNING id, created_at""",
(body.name, body.price, body.category),
)
row = cur.fetchone()
return {"id": row[0], "created_at": str(row[1])}
# ── READ (목록) ─────────────────────────────
@app.get("/products")
def list_products(category: str | None = None, limit: int = 50):
with db() as conn:
cur = conn.cursor()
if category:
cur.execute(
"SELECT id, name, price, category FROM products WHERE category = %s LIMIT %s",
(category, limit),
)
else:
cur.execute(
"SELECT id, name, price, category FROM products LIMIT %s",
(limit,),
)
rows = cur.fetchall()
return [
{"id": r[0], "name": r[1], "price": float(r[2]), "category": r[3]}
for r in rows
]
# ── READ (단건) ─────────────────────────────
@app.get("/products/{product_id}")
def get_product(product_id: int):
with db() as conn:
cur = conn.cursor()
cur.execute("SELECT id, name, price, category FROM products WHERE id = %s", (product_id,))
row = cur.fetchone()
if not row:
raise HTTPException(404, "Product not found")
return {"id": row[0], "name": row[1], "price": float(row[2]), "category": row[3]}
# ── UPDATE ──────────────────────────────────
@app.put("/products/{product_id}")
def update_product(product_id: int, body: ProductUpdate):
fields, values = [], []
for col in ("name", "price", "category"):
v = getattr(body, col)
if v is not None:
fields.append(f"{col} = %s")
values.append(v)
if not fields:
raise HTTPException(400, "No fields to update")
fields.append("updated_at = CURRENT_TIMESTAMP")
values.append(product_id)
with db() as conn:
cur = conn.cursor()
cur.execute(
f"UPDATE products SET {', '.join(fields)} WHERE id = %s RETURNING id",
values,
)
row = cur.fetchone()
if not row:
raise HTTPException(404, "Product not found")
return {"updated": row[0]}
# ── DELETE ──────────────────────────────────
@app.delete("/products/{product_id}")
def delete_product(product_id: int):
with db() as conn:
cur = conn.cursor()
cur.execute("DELETE FROM products WHERE id = %s RETURNING id", (product_id,))
row = cur.fetchone()
if not row:
raise HTTPException(404, "Product not found")
return {"deleted": row[0]}